open System
open System.Threading.Tasks

type Async with
    static member AwaitTaskCorrect(task : Task) : Async<unit> =
        Async.FromContinuations(fun (sc,ec,cc) ->
            task.ContinueWith(fun (task:Task) ->
                if task.IsFaulted then
                    let e = task.Exception
                    if e.InnerExceptions.Count = 1 then ec e.InnerExceptions.[0]
                    else ec e
                elif task.IsCanceled then
                    ec(TaskCanceledException())
                else
                    sc ())
            |> ignore)

    static member AwaitTaskCorrect(task : Task<'T>) : Async<'T> =
        Async.FromContinuations(fun (sc,ec,cc) ->
            task.ContinueWith(fun (task:Task<'T>) ->
                if task.IsFaulted then
                    let e = task.Exception
                    if e.InnerExceptions.Count = 1 then ec e.InnerExceptions.[0]
                    else ec e
                elif task.IsCanceled then
                    ec(TaskCanceledException())
                else
                    sc task.Result)
            |> ignore)


// examples

let mkFailingTask exn = Task.Factory.StartNew<_>(fun () -> raise exn)

let test1 taskAwaiter =
    async {
        try
            return! taskAwaiter (mkFailingTask (ArgumentException()))
        with
        | :? ArgumentException -> return true
        | :? AggregateException -> return false
    } |> Async.RunSynchronously


test1 Async.AwaitTask // false
test1 Async.AwaitTaskCorrect // true


let test2 taskAwaiter =
    async {
        try
            return! taskAwaiter (mkFailingTask (AggregateException("kaboom!")))
        with
        | :? AggregateException as e when e.Message = "kaboom!" -> return true
        | :? AggregateException -> return false
    } |> Async.RunSynchronously

test2 Async.AwaitTask // false
test2 Async.AwaitTaskCorrect // true