Skip to content

Instantly share code, notes, and snippets.

@astojilj
Created January 28, 2019 20:03
Show Gist options
  • Save astojilj/971be3037c115a9fa5fbb5d204d4da9e to your computer and use it in GitHub Desktop.
Save astojilj/971be3037c115a9fa5fbb5d204d4da9e to your computer and use it in GitHub Desktop.
diff --git a/integration_tests/benchmarks/benchmark_test.ts b/integration_tests/benchmarks/benchmark_test.ts
index 8148ea22..d90a4ca3 100644
--- a/integration_tests/benchmarks/benchmark_test.ts
+++ b/integration_tests/benchmarks/benchmark_test.ts
@@ -18,6 +18,7 @@
import {ConvGPUBenchmark, RegularConvParams} from './conv_benchmarks';
import {MatmulGPUBenchmark} from './matmul_benchmarks';
import {MobileNetV1GPUBenchmark} from './mobilenet_benchmarks';
+import {TransposeGPUBenchmark} from './transpose_benchmarks';
import * as test_util from './test_util';
const BENCHMARK_RUNS = 100;
@@ -39,6 +40,18 @@ describe('benchmarks', () => {
done();
});
+ it('transpose', async done => {
+ const cs = [1, 96, 960];
+
+ const benchmark = new TransposeGPUBenchmark();
+
+ await test_util.benchmarkAndLog(
+ 'transpose', cols => benchmark.run(cols), cs, cols => `Cols=${cols}`,
+ BENCHMARK_RUNS);
+
+ done();
+ });
+
it('conv2d', async done => {
const sizes = [10, 100, 227];
const convParams: RegularConvParams =
diff --git a/integration_tests/benchmarks/transpose_benchmarks.ts b/integration_tests/benchmarks/transpose_benchmarks.ts
new file mode 100644
index 00000000..3d9aef51
--- /dev/null
+++ b/integration_tests/benchmarks/transpose_benchmarks.ts
@@ -0,0 +1,38 @@
+/**
+ * @license
+ * Copyright 2017 Google Inc. All Rights Reserved.
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ * =============================================================================
+ */
+import * as tf from '@tensorflow/tfjs-core';
+import {BenchmarkTest} from './types';
+import * as util from './util';
+
+
+export class TransposeGPUBenchmark implements BenchmarkTest {
+ async run(cols: number): Promise<number> {
+ tf.setBackend('webgl');
+ tf.ENV.set('WEBGL_PACK', false);
+
+ const t: tf.Tensor<tf.Rank.R6> = tf.randomNormal([1, 19, 4, 19, 4, cols]);
+ const a = t.add(1.0);
+
+ const benchmark = () => tf.transpose(a, [2, 4, 0, 1, 3, 5]) as tf.Tensor;
+
+ const time = await util.warmupAndBenchmarkGPU(benchmark);
+
+ a.dispose();
+
+ return time;
+ }
+}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment