(*
Max Flow - push-relabel algorithm
*)
type VType = S | T | V                                             // vertex type S=source; T=sink; V=internal vertex
type V = {Id:string; T : VType}                                    // vertex in input graph
type E = {Target:string; C:float}                                  // edge in input graph C = capacity
type GInp = (V * E list) list                                      // input graph - adjacency list
type Ef = {Src:string; Tgt:string; C:float; Fwd:float; Bck:float}  // flow in residual graph (forward and backward flows are together)
type Gf = Map<string*string,Ef>                                    // residual graph
type Vst = {H:int; E:float}                                        // node Height and overflow (E)
type HF = Map<string,Vst>                                          // height and overflow state kept by node id
type PQ = Set<int*string>                                          // use Set as poor man's priority queue (-height * node id)
type G = Map<string, (V * E list * string list)>                   // cached internal graph structure with back node pointers (static over the alg)
type S = G * (Gf * HF * PQ)                                        // captures current state of the process
type Dir = Fwd | Bck                                               // direction of flow forward (src --> tgt) or backward

let order (a:string,b:string) = if a < b then (a,b) else (b,a)     // res graph key is lexical order of src tgt

//add a flow to flow graph
let setFlow (ef:Ef) (gf:Gf) =
    gf |> Map.add (order(ef.Src,ef.Tgt)) ef
    
let getFlow id tgt (gf:Gf) = gf |> Map.tryFind (order(id,tgt))
    
//create initial state from input graph 
let initialize (g:GInp) : S =
    let incoming = 
        (Map.empty,g) 
        ||> List.fold (fun acc (v,es) -> 
            (acc,es) 
            ||> List.fold (fun acc e -> 
                let ls =
                    acc 
                    |> Map.tryFind e.Target
                    |> Option.defaultValue []
                acc |> Map.add e.Target (v.Id::ls)))
    let g = (Map.empty,g) ||> List.fold (fun acc (v,es) -> acc |> Map.add v.Id (v,es, incoming |> Map.tryFind v.Id |> Option.defaultValue []))
    let vertices = Map.count g
    let st =
        ((Map.empty,Map.empty,Set.empty),g)
        ||> Map.fold (fun (gf,hf,pq) _ (n,es,_) ->
            match n.T with            
            | V | T ->
                let gf = 
                    (gf,es) 
                    ||> List.fold (fun gf e -> gf |> setFlow {Src=n.Id; Tgt=e.Target; C=e.C; Fwd=0.0; Bck=0.0})
                let hf = hf |> Map.add n.Id {H=0; E=0.0}
                gf,hf,pq
            | S -> 
                let gf =
                    (gf,es) 
                    ||> List.fold (fun gf e -> gf |> setFlow {Src=n.Id; Tgt=e.Target; C=e.C; Fwd=e.C; Bck=0.0})
                let hf = hf |> Map.add n.Id {H=vertices; E=0.0} 
                let hf = (hf,es) ||> List.fold (fun hf e -> hf |> Map.add e.Target {H=0; E=e.C})
                let pq = (pq,es) ||> List.fold (fun pq e -> pq |> Set.add (-0,e.Target))
                gf,hf,pq)
    g,st
    
//determine if flow - in the given direction - has spare capacity
let hasCap = function  Fwd,f -> f.Fwd < f.C | Bck,f -> f.Bck < f.Fwd

//increase height of node to the given height
let increaseHeight id h2 (s:S) =
    let g,(gf,hf,pq) = s
    let (v,_,_) = g.[id]
    match v.T with S  | T -> failwith "cannot change height of source and sink" | _ -> ()
    let vHf = hf.[id]
    let pq = pq |> Set.remove (-vHf.H,id)             //update height in priority q
    let pq = pq |> Set.add (-h2,id)
    let hf = hf |> Map.add id ({hf.[id] with H=h2})
    g,(gf,hf,pq)

//finds unsaturated edges out from the given node (includes forward and backward edges)
let unsatEs id (s:S) = 
    let g,(gf,_,_) = s
    let (v,es,incming) = g.[id]         
    let unsat_Fwd = 
        es 
        |> List.choose (fun e -> 
            getFlow id e.Target gf
            |> Option.map (fun f -> Fwd,f)
            |> Option.filter hasCap)
    let unsat_Bck =
        incming 
            |> List.choose (fun inc -> 
                getFlow id inc gf
                |> Option.map (fun f -> Bck,f)
                |> Option.filter hasCap)
    unsat_Fwd @ unsat_Bck

//push deltaFlow from src to tgt (does not check constraints)
let pushFlow src tgt deltaFlow (s:S) =
    let g,(gf,hf,pq) = s
    let (vTgt,_,_) = g.[tgt]
    let srcHf = hf.[src]
    let tgtHf = hf.[tgt]
    let f = getFlow src tgt gf |> Option.get 
    let f =
        if f.Src = src then 
            {f with Fwd = f.Fwd + deltaFlow}
        else
            {f with Bck = f.Bck + deltaFlow}
    let gf = setFlow f gf
    let srcHf = {srcHf with E = srcHf.E - deltaFlow}
    let tgtHf = {tgtHf with E = tgtHf.E + deltaFlow}
    let hf = hf |> Map.add src srcHf |> Map.add tgt tgtHf
    let pq = if srcHf.E <= 0.0 then pq |> Set.remove (-srcHf.H,src) else pq
    let pq = match vTgt.T with V when tgtHf.E > 0.0 -> pq |> Set.add (-tgtHf.H,tgt) | _ -> pq
    g,(gf,hf,pq)

//push of push-relable algorithim
let push id (s:S) =
    let g,(gf,hf,pq) = s
    let srcHf = hf.[id]
    let e_unsat = unsatEs id s |> List.tryFind (function Fwd,f -> srcHf.H > hf.[f.Tgt].H | Bck,f -> srcHf.H > hf.[f.Src].H)
    e_unsat
    |> Option.map (fun (dir,f) -> 
        let src,tgt,df =
            match dir with
            | Fwd -> f.Src, f.Tgt, min srcHf.E (f.C - f.Fwd)
            | Bck -> f.Tgt, f.Src, min srcHf.E (f.Fwd - f.Bck)
        pushFlow src tgt df s)
                
//relabel (increase height) of the given node
let relabel id (s:S) =
    let g,(gf,hf,pq) = s
    let srcHf = hf.[id]
    let v,es,incmng = g.[id]
    let unsates = unsatEs id s
    let minHNbr = unsates |> List.map (function Fwd,f -> hf.[f.Tgt].H | Bck,f->hf.[f.Src].H) |> List.min
    increaseHeight id (minHNbr+1) s
        
//push or relabel active node with the highest height
let pushRelabel (s:S) =
    let g,(gf,hf,pq) = s
    let _,nActive = Set.minElement pq //highest 
    match push nActive s with
    | Some s' -> s'
    | None    -> relabel nActive s

//push-relabel main loop
let maxFlow (s:S) =
    let rec loop ((_,(_,_,pq)) as s) =
        if Set.isEmpty pq then
            s
        else
            let s' = pushRelabel s
            loop s'
    loop s

let printFlow (s:S) =
    let g,(gf,_,_) = s
    let (v,es,_) = g |> Map.toSeq |> Seq.map snd |> Seq.find (fun (v,es,_) -> v.T = S) //source
    let maxFlow = es |> List.choose (fun e -> getFlow v.Id e.Target gf) |> List.sumBy (fun f->f.Fwd-f.Bck)
    printfn "** Maxflow: %0.02f **" maxFlow
    gf 
    |> Map.toSeq
    |> Seq.map snd
    |> Seq.filter (fun f -> f.Fwd - f.Bck > 0.0)
    |> Seq.iter (fun f -> 
        printfn "%s --> %s : %0.02f" f.Src f.Tgt (f.Fwd - f.Bck)
    )


(**********
Test case
**********)

let ginp1 =
    [
        {Id="Vancouver"; T=S}, [{Target="Edmonton"; C=16.0}; {Target="Calgary"; C=13.0}]
        {Id="Edmonton"; T=V}, [{Target="Saskatoon"; C=12.}]
        {Id="Calgary"; T=V}, [{Target="Edmonton";C=4.0}; {Target="Regina"; C=14.0}]
        {Id="Regina"; T=V}, [{Target="Saskatoon"; C=7.0}; {Target="Winnipeg"; C=4.0;}]
        {Id="Saskatoon";T=V}, [{Target="Winnipeg"; C=20.}; {Target="Calgary";C=9.0}]
        {Id="Winnipeg"; T=T}, []
    ]

let s1 = initialize ginp1

let s1g = maxFlow s1

printFlow s1g

(*
** Maxflow: 23.00 **
Calgary --> Regina : 11.00
Vancouver --> Calgary : 11.00
Edmonton --> Saskatoon : 12.00
Vancouver --> Edmonton : 12.00
Regina --> Saskatoon : 7.00
Regina --> Winnipeg : 4.00
Saskatoon --> Winnipeg : 19.00
*)