"How you might create a Scala matrix library in a functional programming style"
Sure you can program in Scala just like you would in Java and get all the advantages of the cleaner syntax. But if you really want to explore the power of Scala you should try some functional programming.
One way to tell if you are programming in a largely functional manner is if you prefer val declarations to var declarations and if you prefer immutable to mutable classes.
To implement matrices, let's leverage the immutable lists generated by Scala's default List(…) call.
Arbitrarily choosing a row-major representation, we will represent a row of a matrix as as a list of doubles. For convenience let's create a type alias for that:
type Row = List[Double]
A matrix can then be a list of rows:
type Matrix = List[Row]
One fundamental operation we will be doing is a dot product. Let's implement this in map-reduce style:
def dotProd(v1:Row,v2:Row) =
v1.zip( v2 ).
map{ t:(Double,Double) => t._1 * t._2 }.
reduceLeft(_ + _)
The zip method combines together two parallel lists of numbers into a single list of pairs of numbers. The map method multiplies the elements of each of these pairs together. The reduceLeft method sums the numbers of the list to create a single number.
OK, now let's see how we can create some of the standard matrix functions.
First here is how we can transpose a matrix:
def transpose(m:Matrix):Matrix =
if(m.head.isEmpty) Nil
else m.map(_.head) :: transpose(m.map(_.tail))
Notice this is a recursive function. It uses the construct firstRow :: remainderOfRows to form the list of rows that is the output matrix. It calculates the first row of the output matrix by using map to get the first element of each row of the input matrix. It gets the remainder of the rows by recursively getting the transpose of the matrix formed by the tail of each input row (i.e. every element except the first one).
How about matrix multiplication? That turns out to be pretty straightforward using standard Scala "for comprehensions":
def mXm( m1:Matrix, m2:Matrix ) =
for( m1row <- m1 ) yield
for( m2col <- transpose(m2) ) yield
dotProd( m1row, m2col )
Here we directly implement the matrix multiplication formula, iterating through the rows of the first matrix and the columns of the second matrix, and calculating the dot-product of each. Note that to iterate over the columns of a matrix we actually iterate over the rows of the transpose of the matrix.
However it would be nice to be able to use standard mathematical operators like A*B and AT when coding with matrices. Well, we can using the power of Scala's operator identifiers, infix syntax, and implicit conversion:
case class RichMatrix(m:Matrix){
def T = transpose(m)
def *(that:RichMatrix) = mXm( this.m, that.m )
}
implicit def pimp(m:Matrix) = new RichMatrix(m)
The user of this library need never explicitly instantiate any RichMatrix objects -- they are automatically created from List[Double[Double]] object whenever the T or * method is called on it. This allows you to do things like:
val M = List(
List( 1.0, 2.0, 3.0 ),
List( 4.0, 5.0, 6.0 )
)
val MT = List(
List( 1.0, 4.0 ),
List( 2.0, 5.0 ),
List( 3.0, 6.0 )
)
M.T must_== MT
val A = List(List( 2.0, 0.0 ),
List( 3.0,-1.0 ),
List( 0.0, 1.0 ),
List( 1.0, 1.0 ))
val B = List(List( 1.0, 0.0, 2.0 ),
List( 4.0, -1.0, 0.0 ))
val C = List(List( 2.0, 0.0, 4.0 ),
List( -1.0, 1.0, 6.0 ),
List( 4.0, -1.0, 0.0 ),
List( 5.0, -1.0, 2.0 ))
A * B must_== C
While we are at it we might as well add a few more convenience methods to the RichMatrix:
case class RichMatrix(m:Matrix){
...
def apply(i:Int,j:Int) = m(i)(j)
def rowCount = m.length
def colCount = m.head.length
def toStr = "\n"+m.map{
_.map{"\t" + _}.reduceLeft(_ + _)+"\n"
}.reduceLeft(_ + _)
}
The apply method is a special method invoked when the parenthesis operator is applied to an object. It allows you to, for example, do c(1,3) to get the number on the second row, fourth column of the matrix. Note that this method should not be used in performance-sensitive inner loops because it is not efficient. A matrix uses linked-lists as storage so the element access is O(N) not O(1) as would be the case if the matrix used arrays.
The rowCount and colCount methods are I hope obvious.
The toStr returns a string representation of the matrix in a nice tabular format that is much easier to than the default toString method if List[List[Double]], Note how this is done with two nested map/reduce pairs.
Now, creating a matrix using the List( List(…), List(…), … ) syntax is fine for small matrices, but what if we want to generate large matrices? One way to do this is:
object Matrix{
def apply(
rowCount:Int, colCount:Int
)(
f:(Int,Int) => Double
) =
(
for(i <- 1 to rowCount) yield
(for(j <- 1 to colCount) yield f(i,j)).toList
).toList
}
This allows you do do things like:
// 100x200 matrix containing random elements
val A = Matrix(100,200) { (i:Int,j:Int) =>
random
}
// 5x5 identity matrix
val I = Matrix(5,5) { (i:Int,j:Int) =>
if(i==j) 1.0 else 0.0
}
Note that everything above is pure-functional: all the objects are immutable and each function is implemented as a single expression.
That's it for now. There is obviously a lot more we would have to add to make this a useful, production-quality library,
(If you want to try this yourself, you can download the complete source code and the BDD-style regression test that demonstrates its use.)
[Originally published on my old blog on 10 Jan 2010.]