Skip to content

Instantly share code, notes, and snippets.

@astojilj
Created February 3, 2019 14:31
Show Gist options
  • Save astojilj/b6cc855e708bb2d77c7f892ef8489137 to your computer and use it in GitHub Desktop.
Save astojilj/b6cc855e708bb2d77c7f892ef8489137 to your computer and use it in GitHub Desktop.
diff --git a/integration_tests/benchmarks/matmul_benchmarks.ts b/integration_tests/benchmarks/matmul_benchmarks.ts
index 8f67ed9..5e242fb 100644
--- a/integration_tests/benchmarks/matmul_benchmarks.ts
+++ b/integration_tests/benchmarks/matmul_benchmarks.ts
@@ -43,8 +43,8 @@ export class MatmulGPUBenchmark implements BenchmarkTest {
async run(size: number): Promise<number> {
tf.setBackend('webgl');
- const a: tf.Tensor2D = tf.randomNormal([size, size]);
- const b: tf.Tensor2D = tf.randomNormal([size, size]);
+ const a: tf.Tensor3D = tf.randomNormal([3, size, size]);
+ const b: tf.Tensor3D = tf.randomNormal([3, size, size]);
const benchmark = () => tf.matMul(a, b);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment