8 people like it.

Monadic state lifting combinators

The generic model for stateful computation (S -> S x R) provides a convenient mechanism of threading stateful computation results since the functor λR . S -> S x R is monadic. But what happens if we want to thread state itself? Well, the mapping λS. S -> S x R is not even functorial! But it turns out it can become so with a bit of trickery. This snippet demonstrates a way to lift, project or inject stateful computations into ambient state monads.

  1: 
  2: 
  3: 
  4: 
  5: 
  6: 
  7: 
  8: 
  9: 
 10: 
 11: 
 12: 
 13: 
 14: 
 15: 
 16: 
 17: 
 18: 
 19: 
 20: 
 21: 
 22: 
 23: 
 24: 
 25: 
 26: 
 27: 
 28: 
 29: 
 30: 
 31: 
 32: 
 33: 
 34: 
 35: 
 36: 
 37: 
 38: 
 39: 
 40: 
 41: 
 42: 
 43: 
 44: 
 45: 
 46: 
 47: 
 48: 
 49: 
 50: 
 51: 
 52: 
 53: 
 54: 
 55: 
 56: 
 57: 
 58: 
 59: 
 60: 
 61: 
 62: 
 63: 
 64: 
 65: 
 66: 
 67: 
 68: 
 69: 
 70: 
 71: 
 72: 
 73: 
 74: 
 75: 
 76: 
 77: 
 78: 
 79: 
 80: 
 81: 
 82: 
 83: 
 84: 
 85: 
 86: 
 87: 
 88: 
 89: 
 90: 
 91: 
 92: 
 93: 
 94: 
 95: 
 96: 
 97: 
 98: 
 99: 
100: 
101: 
102: 
103: 
104: 
105: 
106: 
107: 
108: 
109: 
110: 
111: 
112: 
113: 
114: 
115: 
116: 
117: 
118: 
119: 
type State<'S,'A> = Stateful of ('S -> 'S * 'A)

and StateBuilder() =
    let (!) (Stateful f) = f
    let unit x = Stateful(fun s -> s,x)
    let (>>=) f g = Stateful(fun s -> let s', a = !f s in !(g a) s')

    member __.Return x = unit x
    member __.Bind (f,g) = f >>= g
    member __.ReturnFrom f = f
    member __.Zero () = Stateful(fun s -> s,())
    member __.Combine (f, g) = f >>= (fun () -> g)
    member __.Delay f = f ()

let state = new StateBuilder()

module State =

    let run x (Stateful f) = f x |> snd

    let get () = Stateful(fun t -> t,t)
    let set t = Stateful(fun _ -> t,())
    let swap f = Stateful(fun t -> f t,())
    let extract f = Stateful(fun t -> t,f t)

    /// this is the basic state lifting combinator which given a pair
    /// of opposing arrows induces a natural lifting on states. 
    /// not very practical in real life, but added for the sake of completeness.
    let lift (f : 'T -> 'S) (g : 'S -> 'T) (Stateful h) =
        Stateful (fun s -> let t, r = h (g s) in f t, r) : State<'S,'A>

    // the following two combinators are the ones we will be actually using 

    /// for any decomposition 'T ~= 'S * 'S0, returns the natural embedding
    /// State<'S,'A> -> State<'T,'A>
    let inject (split : 'T -> 'S * 'S0 ) (assemble : 'S -> 'S0 -> 'T) (Stateful f) =
        Stateful(
            fun t ->
                let s, s0 = split t
                let s', a = f s
                (assemble s' s0), a
        ) : State<'T,'A>

    /// for any decomposition 'T ~= 'S * 'S0, returns the natural projection
    /// 'S0 -> State<'T,'A> -> State<'S,'A>
    let project (split : 'T -> 'S * 'S0) (assemble : 'S -> 'S0 -> 'T) s0 (Stateful f) =
        Stateful(
            fun s ->
                let t, r = f (assemble s s0)
                let s',_ = split t
                s', r
        ) : State<'S,'A>

    let init s0 f = project (fun s -> (),s) (fun _ s -> s) s0 f

//
// example : generic level-order tree traversal
//
type Tree<'U> = Leaf of 'U | Node of Tree<'U> * 'U * Tree<'U>

// need an immutable queue implementation!
type Queue<'T> = private { back : 'T list ; front : 'T list }
with
    static member ofList xs = { back = [] ; front = xs }
    member self.Enqueue ts = 
        match ts with [] -> self | t::ts' -> { self with back = t :: self.back }.Enqueue ts'
    member self.Dequeue () =
        match self with
        | {back = [] ; front = []} -> failwith "queue underflow!"
        | {back = ys ; front = []} -> { back = [] ; front = List.rev ys }.Dequeue()
        | {front = x::xs} -> { self with front = xs }, x
    member self.IsEmpty = match self with {back = [] ; front = []} -> true | _ -> false


// we want to define a higher-order breadth-first tree traversal using our state monad.
// the higher-order function threads its own internal state, namely a queue of all nodes
// waiting to be traversed. naturally, we do not want to expose this internal state to the
// input function, as this may mess up the traversal pattern. enter state lifting.
let levelorder (foldF : 'U -> State<'S,unit>) (t : Tree<'U>) =
    // define state lifting rules
    // external state is 'S, internal state is 'S * Queue<'U>
    let injectLeft f = State.inject id (fun x y -> x,y) f
    let projectLeft q0 f = State.project id (fun x y -> x,y) q0 f

    let updateQueue (q : Queue<_>) = State.swap (fun (s,_) -> s,q)

    let rec traverse () =
        state {
            let! _,(q : Queue<_>) = State.get()

            if q.IsEmpty then
                return ()
            else
                let q, t = q.Dequeue()

                match t with
                | Leaf u ->
                    do! foldF u |> injectLeft
                    do! updateQueue q
                    do! traverse ()
                | Node (l,u,r) ->
                    do! foldF u |> injectLeft
                    do! updateQueue <| q.Enqueue [l;r]
                    do! traverse ()
        }

    traverse () |> projectLeft (Queue<_>.ofList [t])


let test t =
    state {
        do! levelorder (fun v -> State.swap(fun vs -> vs @ [v])) t

        return! State.get()
    } |> State.init []

let tree = Node(Node(Node(Leaf 8,4,Leaf 9),2,Leaf 5),1,Node(Leaf 6,3,Node(Leaf 10,7,Leaf 11)))

test tree |> State.run ()
union case State.Stateful: ('S -> 'S * 'A) -> State<'S,'A>
Multiple items
type StateBuilder =
  new : unit -> StateBuilder
  member Bind : f:State<'f,'g> * g:('g -> State<'f,'h>) -> State<'f,'h>
  member Combine : f:State<'b,unit> * g:State<'b,'c> -> State<'b,'c>
  member Delay : f:(unit -> 'a) -> 'a
  member Return : x:'i -> State<'j,'i>
  member ReturnFrom : f:'e -> 'e
  member Zero : unit -> State<'d,unit>

Full name: Script.StateBuilder

--------------------
new : unit -> StateBuilder
val f : ('a -> 'a * 'b)
Multiple items
val unit : ('a -> State<'b,'a>)

--------------------
type unit = Unit

Full name: Microsoft.FSharp.Core.unit
val x : 'a
val s : 'b
val f : State<'a,'b>
val g : ('b -> State<'a,'c>)
val s : 'a
val s' : 'a
val a : 'b
member StateBuilder.Return : x:'i -> State<'j,'i>

Full name: Script.StateBuilder.Return
val x : 'i
val __ : StateBuilder
member StateBuilder.Bind : f:State<'f,'g> * g:('g -> State<'f,'h>) -> State<'f,'h>

Full name: Script.StateBuilder.Bind
val f : State<'f,'g>
val g : ('g -> State<'f,'h>)
member StateBuilder.ReturnFrom : f:'e -> 'e

Full name: Script.StateBuilder.ReturnFrom
val f : 'e
member StateBuilder.Zero : unit -> State<'d,unit>

Full name: Script.StateBuilder.Zero
val s : 'd
member StateBuilder.Combine : f:State<'b,unit> * g:State<'b,'c> -> State<'b,'c>

Full name: Script.StateBuilder.Combine
val f : State<'b,unit>
val g : State<'b,'c>
member StateBuilder.Delay : f:(unit -> 'a) -> 'a

Full name: Script.StateBuilder.Delay
val f : (unit -> 'a)
val state : StateBuilder

Full name: Script.state
type State<'S,'A> = | Stateful of ('S -> 'S * 'A)

Full name: Script.State<_,_>
val run : x:'a -> State<'a,'b> -> 'b

Full name: Script.State.run
val snd : tuple:('T1 * 'T2) -> 'T2

Full name: Microsoft.FSharp.Core.Operators.snd
val get : unit -> State<'a,'a>

Full name: Script.State.get
val t : 'a
val set : t:'a -> State<'a,unit>

Full name: Script.State.set
val swap : f:('a -> 'a) -> State<'a,unit>

Full name: Script.State.swap
val f : ('a -> 'a)
val extract : f:('a -> 'b) -> State<'a,'b>

Full name: Script.State.extract
val f : ('a -> 'b)
val lift : f:('T -> 'S) -> g:('S -> 'T) -> State<'T,'A> -> State<'S,'A>

Full name: Script.State.lift


 this is the basic state lifting combinator which given a pair
 of opposing arrows induces a natural lifting on states.
 not very practical in real life, but added for the sake of completeness.
val f : ('T -> 'S)
val g : ('S -> 'T)
val h : ('T -> 'T * 'A)
val s : 'S
val t : 'T
val r : 'A
val inject : split:('T -> 'S * 'S0) -> assemble:('S -> 'S0 -> 'T) -> State<'S,'A> -> State<'T,'A>

Full name: Script.State.inject


 for any decomposition 'T ~= 'S * 'S0, returns the natural embedding
 State<'S,'A> -> State<'T,'A>
val split : ('T -> 'S * 'S0)
val assemble : ('S -> 'S0 -> 'T)
val f : ('S -> 'S * 'A)
val s0 : 'S0
val s' : 'S
val a : 'A
val project : split:('T -> 'S * 'S0) -> assemble:('S -> 'S0 -> 'T) -> s0:'S0 -> State<'T,'A> -> State<'S,'A>

Full name: Script.State.project


 for any decomposition 'T ~= 'S * 'S0, returns the natural projection
 'S0 -> State<'T,'A> -> State<'S,'A>
val f : ('T -> 'T * 'A)
val init : s0:'a -> f:State<'a,'b> -> State<unit,'b>

Full name: Script.State.init
val s0 : 'a
type Tree<'U> =
  | Leaf of 'U
  | Node of Tree<'U> * 'U * Tree<'U>

Full name: Script.Tree<_>
union case Tree.Leaf: 'U -> Tree<'U>
union case Tree.Node: Tree<'U> * 'U * Tree<'U> -> Tree<'U>
type Queue<'T> =
  private {back: 'T list;
           front: 'T list;}
  member Dequeue : unit -> Queue<'T> * 'T
  member Enqueue : ts:'T list -> Queue<'T>
  member IsEmpty : bool
  static member ofList : xs:'a list -> Queue<'a>

Full name: Script.Queue<_>
Queue.back: 'T list
type 'T list = List<'T>

Full name: Microsoft.FSharp.Collections.list<_>
Queue.front: 'T list
static member Queue.ofList : xs:'a list -> Queue<'a>

Full name: Script.Queue`1.ofList
val xs : 'a list
val self : Queue<'T>
member Queue.Enqueue : ts:'T list -> Queue<'T>

Full name: Script.Queue`1.Enqueue
val ts : 'T list
val ts' : 'T list
member Queue.Dequeue : unit -> Queue<'T> * 'T

Full name: Script.Queue`1.Dequeue
val failwith : message:string -> 'T

Full name: Microsoft.FSharp.Core.Operators.failwith
val ys : 'T list
Multiple items
module List

from Microsoft.FSharp.Collections

--------------------
type List<'T> =
  | ( [] )
  | ( :: ) of Head: 'T * Tail: 'T list
  interface IEnumerable
  interface IEnumerable<'T>
  member Head : 'T
  member IsEmpty : bool
  member Item : index:int -> 'T with get
  member Length : int
  member Tail : 'T list
  static member Cons : head:'T * tail:'T list -> 'T list
  static member Empty : 'T list

Full name: Microsoft.FSharp.Collections.List<_>
val rev : list:'T list -> 'T list

Full name: Microsoft.FSharp.Collections.List.rev
val x : 'T
val xs : 'T list
member Queue.IsEmpty : bool

Full name: Script.Queue`1.IsEmpty
val levelorder : foldF:('U -> State<'S,unit>) -> t:Tree<'U> -> State<'S,unit>

Full name: Script.levelorder
val foldF : ('U -> State<'S,unit>)
Multiple items
module State

from Script

--------------------
type State<'S,'A> = | Stateful of ('S -> 'S * 'A)

Full name: Script.State<_,_>
type unit = Unit

Full name: Microsoft.FSharp.Core.unit
val t : Tree<'U>
val injectLeft : (State<'a,'b> -> State<('a * 'c),'b>)
val id : x:'T -> 'T

Full name: Microsoft.FSharp.Core.Operators.id
val y : 'c
val projectLeft : ('a -> State<('b * 'a),'c> -> State<'b,'c>)
val q0 : 'a
val f : State<('b * 'a),'c>
val x : 'b
val y : 'a
val updateQueue : (Queue<'a> -> State<('b * Queue<'a>),unit>)
val q : Queue<'a>
val traverse : (unit -> State<('S * Queue<Tree<'U>>),unit>)
val q : Queue<Tree<'U>>
property Queue.IsEmpty: bool
member Queue.Dequeue : unit -> Queue<'T> * 'T
val u : 'U
val l : Tree<'U>
val r : Tree<'U>
member Queue.Enqueue : ts:'T list -> Queue<'T>
val test : t:Tree<'a> -> State<unit,'a list>

Full name: Script.test
val t : Tree<'a>
val v : 'a
val vs : 'a list
val tree : Tree<int>

Full name: Script.tree

More information

Link:http://fssnip.net/cL
Posted:12 years ago
Author:Eirik Tsarpalis
Tags: state monad , lifting combinator , monad , split