Skip to content

Instantly share code, notes, and snippets.

@lowjoel
Created March 25, 2015 04:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lowjoel/adc6a0a554feb81df017 to your computer and use it in GitHub Desktop.
Save lowjoel/adc6a0a554feb81df017 to your computer and use it in GitHub Desktop.
Matrix Multiplication in x10
import x10.io.Console;
public class MatrixMultiply {
private interface Matrix[T] {
public def initial(): T;
public def width(): Long;
public def widthRange(): LongRange;
public def height(): Long;
public def heightRange(): LongRange;
/**
* Accesses a matrix element.
*/
public operator this(y: Long, x: Long): T;
/**
* Sets a matrix element.
*/
public operator this(y: Long, x: Long) = (newVal: T): void;
/**
* Obtain a submatrix of this matrix.
*/
public operator this(yRange: LongRange, xRange: LongRange): Matrix[T];
}
private static class RealMatrix[T] implements Matrix[T] {
private class MatrixView implements Matrix[T] {
private val heightRange: LongRange;
private val widthRange: LongRange;
public def initial() = RealMatrix.this.initial();
public def height() = heightRange.max - heightRange.min + 1;
public def heightRange() = 0..(height() - 1);
public def width() = widthRange.max - widthRange.min + 1;
public def widthRange() = 0..(width() - 1);
public def this(heightRange: LongRange, widthRange: LongRange) {
this.heightRange = heightRange;
this.widthRange = widthRange;
}
public operator this(y: Long, x: Long) {
return RealMatrix.this(y + heightRange.min, x + widthRange.min);
}
public operator this(y: Long, x: Long) = (newVal: T) {
RealMatrix.this(y + heightRange.min, x + widthRange.min) = newVal;
}
public operator this(yRange: LongRange, xRange: LongRange) {
return new RealMatrix.MatrixView((heightRange.min + yRange.min)..(yRange.max - yRange.min),
(widthRange.min + xRange.min)..(xRange.max - xRange.min));
}
}
private val initial: T;
private val _height: Long;
private val _width: Long;
private val elements: Rail[T];
public def initial() = initial;
public def height() = _height;
public def heightRange() = 0..(_height - 1);
public def width() = _width;
public def widthRange() = 0..(_width - 1);
public def this(height: Long, width: Long, initial: T) {
_width = width;
_height = height;
this.initial = initial;
elements = new Rail[T](height * width, initial);
}
public operator this(y: Long, x: Long) {
return elements(offsetOf(y, x));
}
public operator this(y: Long, x: Long) = (newVal: T) {
elements(offsetOf(y, x)) = newVal;
}
public operator this(yRange: LongRange, xRange: LongRange) {
return new MatrixView(yRange, xRange);
}
/**
* Computes the offset of the given element in the Rail.
*/
private def offsetOf(y: Long, x: Long) = y * _width + x;
public def toString(): String {
var result: String = "";
for (y in heightRange()) {
for (x in widthRange()) {
result += this(y, x).toString() + " ";
}
result += "\n";
}
return result;
}
}
/**
* Multiplies 2x2 matrices.
*/
private static def multPrimitive[T](result: Matrix[T], a: Matrix[T], b: Matrix[T]) {T <: Arithmetic[T]} {
result(0, 0) = a(0, 0) * b(0, 0) + a(0, 1) * b(1, 0);
result(0, 1) = a(0, 0) * b(0, 1) + a(0, 1) * b(1, 1);
result(1, 0) = a(1, 0) * b(0, 0) + a(1, 1) * b(1, 0);
result(1, 1) = a(1, 0) * b(0, 1) + a(1, 1) * b(1, 1);
}
private static def mult[T](result: Matrix[T], a: Matrix[T], b: Matrix[T]) {T <: Arithmetic[T]} {
if (result.width() == 2) {
multPrimitive(result, a, b);
} else {
val a_1_1: Matrix[T] = a(0..(a.height() / 2 - 1), 0..(a.width() / 2 - 1));
val a_1_2: Matrix[T] = a(0..(a.height() / 2 - 1), (a.width() / 2)..(a.width() - 1));
val a_2_1: Matrix[T] = a((a.height() / 2)..(a.height() - 1), 0..(a.width() / 2 - 1));
val a_2_2: Matrix[T] = a((a.height() / 2)..(a.height() - 1), (a.width() / 2)..(a.width() - 1));
val b_1_1: Matrix[T] = b(0..(b.height() / 2 - 1), 0..(b.width() / 2 - 1));
val b_1_2: Matrix[T] = b(0..(b.height() / 2 - 1), (b.width() / 2)..(b.width() - 1));
val b_2_1: Matrix[T] = b((b.height() / 2)..(b.height() - 1), 0..(b.width() / 2 - 1));
val b_2_2: Matrix[T] = b((b.height() / 2)..(b.height() - 1), (b.width() / 2)..(b.width() - 1));
val c_1_1: Matrix[T] = result(0..(result.height() / 2 - 1), 0..(result.width() / 2 - 1));
val c_1_2: Matrix[T] = result(0..(result.height() / 2 - 1), (result.width() / 2)..(result.width() - 1));
val c_2_1: Matrix[T] = result((result.height() / 2)..(result.height() - 1), 0..(result.width() / 2 - 1));
val c_2_2: Matrix[T] = result((result.height() / 2)..(result.height() - 1), (result.width() / 2)..(result.width() - 1));
finish {
val a_1_1_b_1_1 = new RealMatrix[T](c_1_1.height(), c_1_1.width(), c_1_1.initial());
val a_1_2_b_2_1 = new RealMatrix[T](c_1_1.height(), c_1_1.width(), c_1_1.initial());
val a_1_1_b_1_2 = new RealMatrix[T](c_1_2.height(), c_1_2.width(), c_1_2.initial());
val a_1_2_b_2_2 = new RealMatrix[T](c_1_2.height(), c_1_2.width(), c_1_2.initial());
val a_2_1_b_1_1 = new RealMatrix[T](c_2_1.height(), c_2_1.width(), c_2_1.initial());
val a_2_2_b_2_1 = new RealMatrix[T](c_2_1.height(), c_2_1.width(), c_2_1.initial());
val a_2_1_b_1_2 = new RealMatrix[T](c_2_2.height(), c_2_2.width(), c_2_2.initial());
val a_2_2_b_2_2 = new RealMatrix[T](c_2_2.height(), c_2_2.width(), c_2_2.initial());
finish {
async mult[T](a_1_1_b_1_1, a_1_1, b_1_1);
async mult[T](a_1_2_b_2_1, a_1_2, b_2_1);
async mult[T](a_1_1_b_1_2, a_1_1, b_1_2);
async mult[T](a_1_2_b_2_2, a_1_2, b_2_2);
async mult[T](a_2_1_b_1_1, a_2_1, b_1_1);
async mult[T](a_2_2_b_2_1, a_2_2, b_2_1);
async mult[T](a_2_1_b_1_2, a_2_1, b_1_2);
mult[T](a_2_2_b_2_2, a_2_2, b_2_2);
}
async add[T](c_1_1, a_1_1_b_1_1, a_1_2_b_2_1);
async add[T](c_1_2, a_1_1_b_1_2, a_1_2_b_2_2);
async add[T](c_2_1, a_2_1_b_1_1, a_2_2_b_2_1);
add[T](c_2_2, a_2_1_b_1_2, a_2_2_b_2_2);
}
}
}
private static def add[T](result: Matrix[T], a: Matrix[T], b: Matrix[T]) {T <: Arithmetic[T]} {
for (y in result.widthRange()) {
for (x in result.heightRange()) {
result(y, x) = a(y, x) + b(y, x);
}
}
}
public static def main(args: Rail[String]) {
val identity = new RealMatrix[Float](4, 4, 0);
identity(0, 0) = 1;
identity(1, 1) = 1;
identity(2, 2) = 1;
identity(3, 3) = 1;
val zero = new RealMatrix[Float](4, 4, 0);
val result = new RealMatrix[Float](4, 4, 0);
mult(result, identity, zero);
Console.OUT.print(result);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment