3 people like it.

Parallel Strassen Matrix Multiplication

Strassen's Multiplication Algorithm works much better than the standard approach when the matrix is large. The program implements a parallel version of it.

 1: 
 2: 
 3: 
 4: 
 5: 
 6: 
 7: 
 8: 
 9: 
10: 
11: 
12: 
13: 
14: 
15: 
16: 
17: 
18: 
19: 
20: 
21: 
22: 
23: 
24: 
25: 
26: 
27: 
28: 
29: 
30: 
31: 
32: 
33: 
34: 
35: 
36: 
37: 
38: 
39: 
40: 
41: 
42: 
43: 
44: 
45: 
46: 
47: 
48: 
49: 
50: 
open Microsoft.FSharp.Math
open System.Threading.Tasks
open System.IO

let splitMatrix (original : Matrix<float>) =
    let (oldRows, oldCols) = original.Dimensions
    let newRows = oldRows >>> 1
    let newCols = oldCols >>> 1
    let c11 = original.[0..(newRows - 1), 0..(newCols - 1)]
    let c12 = original.[0..(newRows - 1), newCols..(oldCols - 1)]
    let c21 = original.[newRows..(oldRows - 1), 0..(newCols - 1)]
    let c22 = original.[newRows..(oldRows - 1), newCols..(oldCols - 1)]
    (c11, c12, c21, c22)
    

let integrateMatrix(a11 : Matrix<float>, a12 : Matrix<float>, a21 : Matrix<float>, a22 : Matrix<float>) =
     let (oldRows, oldCols) = a11.Dimensions
     let (rows, cols) = (oldRows <<< 1, oldCols <<< 1)
     let helper i j = 
        if (i < oldRows) then
           if (j < oldCols) then
                a11.[i, j]
           else
               a12.[i, (j - oldCols)] 
        else
            if (j < oldCols) then
                a21.[(i - oldRows), j]
            else
                a22.[(i - oldRows), (j - oldCols)]
     Matrix.init rows cols helper

let rec Strassen (A : Matrix<float>, B : Matrix<float>) =
    let (rows, cols) = A.Dimensions
    if (rows <= 64) then
        A * B
    else
        let (a11, a12, a21, a22) = splitMatrix(A)
        let (b11, b12, b21, b22) = splitMatrix(B)
        let M1 = Task.Factory.StartNew(fun () -> Strassen ((a11 + a22), (b11 + b22)))
        let M2 = Task.Factory.StartNew(fun () -> Strassen ((a21 + a22), b11))
        let M3 = Task.Factory.StartNew(fun () -> Strassen (a11, (b12 - b22)))
        let M4 = Task.Factory.StartNew(fun () -> Strassen (a22, (b21 - b11)))
        let M5 = Task.Factory.StartNew(fun () -> Strassen ((a11 + a12), b22))
        let M6 = Task.Factory.StartNew(fun () -> Strassen ((a21 - a11), (b11 + b12)))
        let M7 = Strassen ((a12 - a22), (b21 + b22))
        let c11 = M1.Result + M4.Result - M5.Result + M7
        let c12 = M3.Result + M5.Result
        let c21 = M2.Result + M4.Result
        let c22 = M1.Result - M2.Result + M3.Result + M6.Result
        integrateMatrix(c11, c12, c21, c22)
namespace Microsoft
namespace Microsoft.FSharp
namespace System
namespace System.Threading
namespace System.Threading.Tasks
namespace System.IO
val splitMatrix : original:'a -> 'b * 'c * 'd * 'e

Full name: Script.splitMatrix
val original : 'a
Multiple items
val float : value:'T -> float (requires member op_Explicit)

Full name: Microsoft.FSharp.Core.Operators.float

--------------------
type float = System.Double

Full name: Microsoft.FSharp.Core.float

--------------------
type float<'Measure> = float

Full name: Microsoft.FSharp.Core.float<_>
val oldRows : int
val oldCols : int
val newRows : int
val newCols : int
val c11 : 'b
val c12 : 'c
val c21 : 'd
val c22 : 'e
val integrateMatrix : a11:'a * a12:'b * a21:'c * a22:'d -> 'e

Full name: Script.integrateMatrix
val a11 : 'a
val a12 : 'b
val a21 : 'c
val a22 : 'd
val rows : int
val cols : int
val helper : (int -> int -> 'f)
val i : int
val j : int
val Strassen : A:int * B:int -> int

Full name: Script.Strassen
val A : int
val B : int
val cols : obj
val a11 : int
val a12 : int
val a21 : int
val a22 : int
val b11 : int
val b12 : int
val b21 : int
val b22 : int
val M1 : Task<int>
Multiple items
type Task =
  new : action:Action -> Task + 7 overloads
  member AsyncState : obj
  member ContinueWith : continuationAction:Action<Task> -> Task + 9 overloads
  member CreationOptions : TaskCreationOptions
  member Dispose : unit -> unit
  member Exception : AggregateException
  member Id : int
  member IsCanceled : bool
  member IsCompleted : bool
  member IsFaulted : bool
  ...

Full name: System.Threading.Tasks.Task

--------------------
type Task<'TResult> =
  inherit Task
  new : function:Func<'TResult> -> Task<'TResult> + 7 overloads
  member ContinueWith : continuationAction:Action<Task<'TResult>> -> Task + 9 overloads
  member Result : 'TResult with get, set
  static member Factory : TaskFactory<'TResult>

Full name: System.Threading.Tasks.Task<_>

--------------------
Task(action: System.Action) : unit
Task(action: System.Action, cancellationToken: System.Threading.CancellationToken) : unit
Task(action: System.Action, creationOptions: TaskCreationOptions) : unit
Task(action: System.Action<obj>, state: obj) : unit
Task(action: System.Action, cancellationToken: System.Threading.CancellationToken, creationOptions: TaskCreationOptions) : unit
Task(action: System.Action<obj>, state: obj, cancellationToken: System.Threading.CancellationToken) : unit
Task(action: System.Action<obj>, state: obj, creationOptions: TaskCreationOptions) : unit
Task(action: System.Action<obj>, state: obj, cancellationToken: System.Threading.CancellationToken, creationOptions: TaskCreationOptions) : unit

--------------------
Task(function: System.Func<'TResult>) : unit
Task(function: System.Func<'TResult>, cancellationToken: System.Threading.CancellationToken) : unit
Task(function: System.Func<'TResult>, creationOptions: TaskCreationOptions) : unit
Task(function: System.Func<obj,'TResult>, state: obj) : unit
Task(function: System.Func<'TResult>, cancellationToken: System.Threading.CancellationToken, creationOptions: TaskCreationOptions) : unit
Task(function: System.Func<obj,'TResult>, state: obj, cancellationToken: System.Threading.CancellationToken) : unit
Task(function: System.Func<obj,'TResult>, state: obj, creationOptions: TaskCreationOptions) : unit
Task(function: System.Func<obj,'TResult>, state: obj, cancellationToken: System.Threading.CancellationToken, creationOptions: TaskCreationOptions) : unit
Multiple items
property Task.Factory: TaskFactory

--------------------
property Task.Factory: TaskFactory<'TResult>
Multiple items
TaskFactory.StartNew<'TResult>(function: System.Func<'TResult>) : Task<'TResult>
   (+0 other overloads)
TaskFactory.StartNew(action: System.Action) : Task
   (+0 other overloads)
TaskFactory.StartNew<'TResult>(function: System.Func<obj,'TResult>, state: obj) : Task<'TResult>
   (+0 other overloads)
TaskFactory.StartNew<'TResult>(function: System.Func<'TResult>, creationOptions: TaskCreationOptions) : Task<'TResult>
   (+0 other overloads)
TaskFactory.StartNew<'TResult>(function: System.Func<'TResult>, cancellationToken: System.Threading.CancellationToken) : Task<'TResult>
   (+0 other overloads)
TaskFactory.StartNew(action: System.Action<obj>, state: obj) : Task
   (+0 other overloads)
TaskFactory.StartNew(action: System.Action, creationOptions: TaskCreationOptions) : Task
   (+0 other overloads)
TaskFactory.StartNew(action: System.Action, cancellationToken: System.Threading.CancellationToken) : Task
   (+0 other overloads)
TaskFactory.StartNew<'TResult>(function: System.Func<obj,'TResult>, state: obj, creationOptions: TaskCreationOptions) : Task<'TResult>
   (+0 other overloads)
TaskFactory.StartNew<'TResult>(function: System.Func<obj,'TResult>, state: obj, cancellationToken: System.Threading.CancellationToken) : Task<'TResult>
   (+0 other overloads)

--------------------
TaskFactory.StartNew(function: System.Func<'TResult>) : Task<'TResult>
TaskFactory.StartNew(function: System.Func<obj,'TResult>, state: obj) : Task<'TResult>
TaskFactory.StartNew(function: System.Func<'TResult>, creationOptions: TaskCreationOptions) : Task<'TResult>
TaskFactory.StartNew(function: System.Func<'TResult>, cancellationToken: System.Threading.CancellationToken) : Task<'TResult>
TaskFactory.StartNew(function: System.Func<obj,'TResult>, state: obj, creationOptions: TaskCreationOptions) : Task<'TResult>
TaskFactory.StartNew(function: System.Func<obj,'TResult>, state: obj, cancellationToken: System.Threading.CancellationToken) : Task<'TResult>
TaskFactory.StartNew(function: System.Func<'TResult>, cancellationToken: System.Threading.CancellationToken, creationOptions: TaskCreationOptions, scheduler: TaskScheduler) : Task<'TResult>
TaskFactory.StartNew(function: System.Func<obj,'TResult>, state: obj, cancellationToken: System.Threading.CancellationToken, creationOptions: TaskCreationOptions, scheduler: TaskScheduler) : Task<'TResult>
val M2 : Task<int>
val M3 : Task<int>
val M4 : Task<int>
val M5 : Task<int>
val M6 : Task<int>
val M7 : int
val c11 : int
property Task.Result: int
val c12 : int
val c21 : int
val c22 : int
Raw view Test code New version

More information

Link:http://fssnip.net/7N
Posted:13 years ago
Author:Zhuobo Feng
Tags: strassen , matrix multiplication