//type Expr<'T> =
//    Const : 'T -> Expr<'T>
//    Add : Expr<int> -> Expr<int> -> Expr<int>
//    IfThenElse : Expr<bool> -> Expr<'T> -> Expr<'T> -> Expr<'T>
//    App : Expr<'T -> 'S> -> Expr<'T> -> Expr<'S>
//    Lam : (Expr<'T> -> Expr<'S>) -> Expr<'T -> 'S>
//    Fix : Expr<('T -> 'S) -> 'T -> 'S> -> Expr<'T -> 'S>

[<AbstractClass>]
type Expr<'T> internal () =
    abstract Match : IPatternMatch<'T, 'R> -> 'R

// instaces of IPatternMatch encode a match expression

and IPatternMatch<'T, 'R> =
    abstract Const : 'T -> 'R
    abstract Add : Expr<int> -> Expr<int> -> 'R
    abstract IfThenElse : Expr<bool> -> Expr<'T> -> Expr<'T> -> 'R
    abstract App<'S> : Expr<'S -> 'T> -> Expr<'S> -> 'R
    abstract Lam<'T1, 'T2> : (Expr<'T1> -> Expr<'T2>) -> 'R
    abstract Fix<'T1, 'T2> : Expr<('T1 -> 'T2) -> 'T1 -> 'T2> -> 'R

// concrete case implementations

type internal Const<'T>(value : 'T) =
    inherit Expr<'T> ()
    override __.Match (m : IPatternMatch<'T, 'R>) = m.Const value

type internal Add(left : Expr<int>, right : Expr<int>) =
    inherit Expr<int> ()
    override __.Match (m : IPatternMatch<int, 'R>) = m.Add left right

type internal IfThenElse<'T>(b : Expr<bool>, l : Expr<'T>, r : Expr<'T>) =
    inherit Expr<'T> ()
    override __.Match (m : IPatternMatch<'T, 'R>) = m.IfThenElse b l r

type internal App<'T,'S> (f : Expr<'S -> 'T>, x : Expr<'S>) =
    inherit Expr<'T> ()
    override __.Match (m : IPatternMatch<'T, 'R>) = m.App f x

type internal Lam<'T1,'T2> (f : Expr<'T1> -> Expr<'T2>) =
    inherit Expr<'T1 -> 'T2> ()
    override __.Match (m : IPatternMatch<'T1 -> 'T2, 'R>) = m.Lam f

type internal Fix<'T, 'S> (f : Expr<('T -> 'S) -> 'T -> 'S>) =
    inherit Expr<'T -> 'S> ()
    override __.Match (m : IPatternMatch<'T -> 'S, 'R>) = m.Fix f

// constructor api

let constant x = Const<_>(x) :> Expr<_>
let add x y = Add(x,y) :> Expr<_>
let ifThenElse b l r = IfThenElse<_>(b,l,r) :> Expr<_>
let app f x = App<_,_>(f,x) :> Expr<_>
let lam f = Lam<_,_>(f) :> Expr<_>
let fix f = Fix<_,_>(f) :> Expr<_>

let pmatch (x : Expr<'T>) (pattern : IPatternMatch<'T, 'R>) = x.Match pattern


// example 1 : pretty print

let rec pretty<'T> (expr : Expr<'T>) : string =
    pmatch expr {
        new IPatternMatch<'T, string> with
            member __.Const x = sprintf "(%A)" x
            member __.Add x y = sprintf "(%s + %s)" (pretty x) (pretty y)
            member __.IfThenElse b x y = sprintf "if %s then %s else %s" (pretty b) (pretty x) (pretty y)
            member __.App f x = sprintf "(%s %s)" (pretty f) (pretty x)
            // unhandled cases
            member __.Lam _ = invalidOp ""
            member __.Fix _ = invalidOp ""
    }

pretty (ifThenElse (constant true) (constant (Some 12)) (constant None))

// example 2 : eval

// use cast to reconcile type variables of equal type
let cast<'S,'T> (x : 'T) = unbox<'S> x

let rec eval<'T> (expr : Expr<'T>) : 'T =
    pmatch expr {
        new IPatternMatch<'T, 'T> with
            member __.Const x = x
            member __.Add x y = cast<_,int>(eval x + eval y)
            member __.IfThenElse b x y = if eval b then eval x else eval y
            member __.App f x = (eval f) (eval x)
            member __.Lam<'S1,'S2> f = cast<_, 'S1 -> 'S2>(eval << f << constant)
            member __.Fix<'S1,'S2> f = cast<_, 'S1 -> 'S2>(eval f (fun x -> eval (fix f) x))
    }

eval (app (lam (fun b -> ifThenElse b (constant 12) (constant 42))) (constant false))

let multiply = 
    fix (lam(fun f -> 
        lam(fun n -> 
            ifThenElse (constant (eval n = 0)) 
                (lam (fun _ -> constant 0)) 
                (lam (fun m -> add m (app (app f (add n (constant -1))) m))))))

eval (app (app multiply (constant 6)) (constant 7))