//[snippet:HKT Encoding]
type HKT = interface end

type App<'F, 't when 'F :> HKT> = private App of payload : obj

type App<'F, 't1, 't2 when 'F :> HKT> = App<'F, TCons<'t1, 't2>>
and  App<'F, 't1, 't2, 't3 when 'F :> HKT> = App<'F, TCons<'t1, 't2, 't3>>
and  App<'F, 't1, 't2, 't3, 't4 when 'F :> HKT> = App<'F, TCons<'t1, 't2, 't3, 't4>>

and  TCons<'T1, 'T2> = class end
and  TCons<'T1, 'T2, 'T3> = TCons<TCons<'T1, 'T2>, 'T3>
and  TCons<'T1, 'T2, 'T3, 'T4> = TCons<TCons<'T1, 'T2, 'T3>, 'T4>

[<CompilationRepresentation(CompilationRepresentationFlags.ModuleSuffix)>]
module HKT =
    let inline pack (value : 'Fa) : App<'F, 'a>
        when 'F : (static member Assign : App<'F, 'a> * 'Fa -> unit) =
        App value
        
    let inline unpack (App value : App<'F, 'a>) : 'Fa
        when 'F : (static member Assign : App<'F, 'a> * 'Fa -> unit) =
        value :?> _
        
    let inline (|Unpack|) app = unpack app
//[/snippet]

//[snippet:Defining our Language]
type Symantics<'F, 'T when 'F :> HKT> = S of (Algebra<'F> -> App<'F, 'T>)

and Algebra<'F when 'F :> HKT> =
    abstract Lit : 'T -> App<'F, 'T>
    abstract Add  : App<'F, int> -> App<'F, int> -> App<'F, int>
    abstract Leq : App<'F, 'a> -> App<'F, 'a> -> App<'F, bool> when 'a : comparison
    abstract IfThenElse : App<'F, bool> -> (unit -> App<'F, 'a>) -> (unit -> App<'F, 'a>) -> App<'F, 'a>

    abstract Lam : (App<'F, 'a> -> App<'F,'b>) -> App<'F, 'a -> 'b>
    abstract App : App<'F, 'a -> 'b> -> App<'F, 'a> -> App<'F, 'b>
    abstract Fix : App<'F, ('a -> 'b) -> 'a -> 'b> -> App<'F, 'a -> 'b> 

// Defining constructors for the language
let lit t = S(fun alg -> alg.Lit t)
let add (S t) (S t') = S(fun alg -> alg.Add (t alg) (t' alg))
let leq (S t) (S t') = S(fun alg -> alg.Leq (t alg) (t' alg))
let ifThenElse (S c) (S a) (S b) = S (fun alg -> alg.IfThenElse (c alg) (fun () -> a alg) (fun () -> b alg))

let lam f = S(fun alg -> alg.Lam (fun x -> let (S g) = f (S (fun _ -> x)) in g alg))
let app (S a) (S b) = S(fun alg -> alg.App (a alg) (b alg))
let fix (S F) = S(fun alg -> alg.Fix (F alg))

// evaluator
let inline eval alg (S f) = f alg |> HKT.unpack

//  Defining example terms
let g () = app (lam (fun x -> add x (lit 1))) (lit 1)

let fib () =
    fix(
        lam(fun f ->
            lam(fun x ->
                ifThenElse
                    (leq x (lit 1))
                    x
                    (add
                        (app f (add x (lit -2)))
                        (app f (add x (lit -1))))
            )
        )
    )
//[/snippet]

//[snippet:Evaluator Intepretation]
type Id =
    interface HKT
    static member Assign(_ : App<Id, 'a>, _ : 'a) = ()

type Evaluator() =
    interface Algebra<Id> with
        member __.Lit t = HKT.pack t
        member __.Add (HKT.Unpack a) (HKT.Unpack b) = HKT.pack (a + b)
        member __.Leq (HKT.Unpack a) (HKT.Unpack b) = HKT.pack (a <= b)
        member __.IfThenElse (HKT.Unpack c) a b = if c then a() else b()

        member __.Lam f = HKT.pack (HKT.unpack << f << HKT.pack)
        member __.App (HKT.Unpack a) (HKT.Unpack b) = HKT.pack(a b)
        member __.Fix (HKT.Unpack F) = let rec aux x = F aux x in HKT.pack aux

eval (Evaluator()) (g ()) // 2
eval (Evaluator()) (fib ()) 10 // 55
//[/snippet]

//[snippet:Staged Intepretation]
open FSharp.Quotations

type Expr =
    interface HKT
    static member Assign(_ : App<Expr, 'a>, _ : Expr<'a>) = ()

type StagedEvaluator() =
    let i = ref 0
    let run (f : Expr<'T> -> Expr<'S>) : Expr<'T -> 'S> =
        incr i
        let v = new Var(sprintf "_%d_" !i, typeof<'T>)
        let ev = Expr.Var v |> Expr.Cast
        let body = f ev
        Expr.Lambda(v, body) |> Expr.Cast

    interface Algebra<Expr> with
        member __.Lit t = HKT.pack <@ t @>
        member __.Add (HKT.Unpack a) (HKT.Unpack b) = HKT.pack <@ %a + %b @>
        member __.Leq (HKT.Unpack a) (HKT.Unpack b) = HKT.pack <@ %a <= %b @>
        member __.IfThenElse (HKT.Unpack c) a b = HKT.pack <@ if %c then %(HKT.unpack (a())) else %(HKT.unpack (b())) @>

        member __.Lam f = run (HKT.pack >> f >> HKT.unpack) |> HKT.pack
        member __.App (HKT.Unpack a) (HKT.Unpack b) = HKT.pack <@ (%a) %b @>
        member __.Fix (HKT.Unpack F) = HKT.pack <@ let rec aux x = (%F) aux x in aux @>
        
eval (StagedEvaluator()) (g ()) |> Swensen.Unquote.Operators.decompile
eval (StagedEvaluator()) (fib ()) |> Swensen.Unquote.Operators.decompile
//[/snippet]

//[snippet:Partial Evaluator]
type Partial<'a> = Value of App<Id, 'a> | Expr of App<Expr, 'a>

and PartialExpr =
    interface HKT
    static member Assign(_ : App<PartialExpr, 'a>, _ : Partial<'a>) = ()

let embed (p:Partial<'a>) = match p with Expr (HKT.Unpack e) -> e | Value (HKT.Unpack v) -> <@ v @>

type PartialEvaluator() =
    let e = Evaluator() :> Algebra<_>
    let s = StagedEvaluator() :> Algebra<_>
    // expression embedding API
    let (|EExpr|) (p : Partial<'a>) : App<Expr,'a> = embed p |> HKT.pack

    interface Algebra<PartialExpr> with
        member __.Lit t = HKT.pack(Value (HKT.pack t))
        member __.Add (HKT.Unpack a) (HKT.Unpack b) =
            match a, b with
            | Value a, Value b -> Value(e.Add a b)
            | EExpr a, EExpr b -> Expr(s.Add a b)
            |> HKT.pack

        member __.Leq (HKT.Unpack a) (HKT.Unpack b) =
            match a, b with
            | Value a, Value b -> Value(e.Leq a b)
            | EExpr a, EExpr b -> Expr(s.Leq a b)
            |> HKT.pack
            
        member __.IfThenElse (HKT.Unpack c) a b = 
            match c with
            | Value (HKT.Unpack true) -> a ()
            | Value (HKT.Unpack false) -> b ()
            | EExpr c' ->
                let a' () = let (HKT.Unpack (EExpr a)) = a () in a
                let b' () = let (HKT.Unpack (EExpr b)) = b () in b
                Expr(s.IfThenElse c' a' b') |> HKT.pack

        // TODO expand
        member __.Lam f = s.Lam (fun a -> let (HKT.Unpack (EExpr b)) = f (HKT.pack (Expr a)) in b) |> Expr |> HKT.pack
        member __.App (HKT.Unpack (EExpr a)) (HKT.Unpack (EExpr b)) = s.App a b |> Expr |> HKT.pack
        member __.Fix (HKT.Unpack (EExpr F)) = s.Fix F |> Expr |> HKT.pack

eval (PartialEvaluator()) (lam (fun x -> (ifThenElse (leq (lit 1) (lit 2)) (add (lit 1) (lit 2)) x))) // fun x -> 3
//[/snippet]