Skip to content

Instantly share code, notes, and snippets.

@cbalint13
Created July 24, 2019 20:23
Show Gist options
  • Save cbalint13/d2506c572d76a0bd7dc7c378fcb0c881 to your computer and use it in GitHub Desktop.
Save cbalint13/d2506c572d76a0bd7dc7c378fcb0c881 to your computer and use it in GitHub Desktop.
TVM_Winograd_OLD_vs_NEW_2_4_6
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
""" Utility functions for implementing Winograd convolutions
[*] Fast Algorithms for Convolutional Neural Networks
Andrew Lavin, Scott Gray
https://arxiv.org/abs/1509.09308
https://github.com/andravin/wincnn
"""
from operator import mul
from functools import reduce
import numpy as np
from ..util import const_matrix
# pylint: disable=invalid-name
def _cook_toom_convolution(a, n, r):
"""Compute Cook-Toom convolution A,B,G matrices"""
def _F_m(a, n):
f = lambda j, i: reduce(mul, ((a[i]-a[k] if k != i else 1) for k in range(0, n-1)), 1)
F = np.fromfunction(np.vectorize(f), (1, n-1), dtype=int)
F = np.diagflat(F)
F = np.append(F, np.zeros((n-1, 1), dtype=int), axis=1)
f = lambda i, j: (1 if j == (n-1) else 0)
z = np.fromfunction(np.vectorize(f), (1, n), dtype=int)
return np.append(F, z, axis=0)
def _A_m(a, m, n):
f = lambda i, j: a[i]**j
A = np.fromfunction(np.vectorize(f), (m-1, n), dtype=int)
f = lambda i, j: (1 if j == (n-1) else 0)
z = np.fromfunction(np.vectorize(f), (1, n), dtype=int)
return np.append(A, z, axis=0)
def _B_m(a, n):
f = lambda j, i: reduce(mul, ((a[i]-a[k] if k != i else 1) for k in range(0, n-1)), 1)
Ff = np.fromfunction(np.vectorize(f), (1, n-1), dtype=int)
f = lambda i, nth: (reduce(mul, [(np.poly1d([1, -a[k]]) if k != i else 1) \
for k in range(0, n-1)], 1)).coef[n-1-nth-1]/Ff[0, i]
F = np.fromfunction(np.vectorize(f), (n-1, n-1), dtype=int)
f = lambda i, j: -a[i]**(n-1)
t = np.fromfunction(np.vectorize(f), (n-1, 1), dtype=int)
T = np.append(np.eye(n-1), t, axis=1)
return np.append(F.T.dot(T), np.array([np.eye(n)[n-1]]), axis=0)
alpha = n + r - 1
f = _F_m(a, alpha)
if f[0, 0] < 0:
f[0, :] *= -1
A = _A_m(a, alpha, n)
G = _A_m(a, alpha, r).T
G = G.dot(np.linalg.inv(f)).T
B = _B_m(a, alpha)
B = B.dot(f.T)
return (A, B, G)
def _interpolation_points(degree):
"""Propose filter points"""
assert 2 < degree < 18
# Default interpolation lookup table
#
# [1] Error Analysis and Improving the Accuracy of Winograd Convolution for Deep Neural Networks
# Barbara Barabasz, Andrew Anderson, Kirk M. Soodhalter, David Gregg
# https://arxiv.org/abs/1803.10986
#
# pylint: disable=bad-whitespace,line-too-long
in_pts = [
# {invalid}
[],
#01 {E=4.63E-08 on conv2d [1]}
[],
#02 {E=7.65E-08 on F( 2,3) [1]}
[0, -1, 1],
#03 {E=2.35E-07 on F( 3,3) [1]}
[0, -1, 1, 1/2],
#04 {E=3.29E-07 on F( 4,3) [1]}
[0, -1, 1, 1/2, -2],
#05 {E=6.81E-07 on F( 5,3) [1]}
[0, -1, 1, 1/2, -2, -1/2],
#06 {E=8.79E-07 on F( 6,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2],
#07 {E=3.71E-06 on F( 7,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4],
#08 {E=7.35E-06 on F( 8,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4],
#09 {E=2.20E-05 on F( 9,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 3/4, -4/3],
#10 {E=3.22E-05 on F(10,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 3/4, -4/3],
#11 {E=1.09E-04 on F(11,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 3/4, -4/3, 1/4],
#12 {E=1.99E-04 on F(12,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4],
#13 {E=5.54E-04 on F(13,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, 3/4, -4/3],
#14 {E=8.80E-04 on F(14,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4, 3/4, -4/3],
#15 {E=1.07E-02 on F(15,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4, 2/3, -3/2, 3/2],
#16 {E=1.93E-02 on F(16,3) [1]}
[0, -1, 1, 1/2, -1/2, 2, -2, -1/4, 4, 1/4, -3/4, 4/3, -4, 2/3, -3/2, -3/2, 3/2]
] # pylint: enable=bad-whitespace,line-too-long
return np.array(in_pts[degree-1], dtype=np.float64)
def new_winograd_transform_matrices(tile_size, kernel_size, out_dtype):
"""Compute the A, B, and G transform matrices for `tile_size` as a `tvm.Expr`.
"""
if not 1 < tile_size < 9:
raise ValueError("Unsupported tile size for Winograd: {}".format(tile_size))
if not 2 < kernel_size < 8:
raise ValueError("Unsupported kernel size for Winograd: {}".format(kernel_size))
degree = tile_size + kernel_size - 2
intp_pts = _interpolation_points(degree)
A_data, B_data, G_data = _cook_toom_convolution(intp_pts, tile_size, kernel_size)
return (
const_matrix(A_data.astype(out_dtype), "A"),
const_matrix(B_data.astype(out_dtype), "B"),
const_matrix(G_data.astype(out_dtype), "G"),
)
def old_winograd_transform_matrices(tile_size, kernel_size, out_dtype):
if tile_size == 2:
g_data = np.array(
[
[1, 0, 0],
[1.0 / 2, 1.0 / 2, 1.0 / 2],
[1.0 / 2, -1.0 / 2, 1.0 / 2],
[0, 0, 1],
],
dtype=out_dtype,
)
b_data = np.array(
[
[1, 0, 0, 0],
[0, 1, -1, 1],
[-1, 1, 1, 0],
[0, 0, 0, -1],
],
dtype=out_dtype,
)
a_data = np.array(
[[1, 0], [1, 1], [1, -1], [0, -1]],
dtype=out_dtype
)
elif tile_size == 4:
g_data = np.array(
[
[1 / 4.0, 0, 0],
[-1 / 6.0, -1 / 6.0, -1 / 6.0],
[-1 / 6.0, 1 / 6.0, -1 / 6.0],
[1 / 24.0, 1 / 12.0, 1 / 6.0],
[1 / 24.0, -1 / 12.0, 1 / 6.0],
[0, 0, 1],
],
dtype=out_dtype,
)
b_data = np.array(
[
[4, 0, 0, 0, 0, 0],
[0, -4, 4, -2, 2, 4],
[-5, -4, -4, -1, -1, 0],
[0, 1, -1, 2, -2, -5],
[1, 1, 1, 1, 1, 0],
[0, 0, 0, 0, 0, 1],
],
dtype=out_dtype,
)
a_data = np.array(
[
[1, 0, 0, 0],
[1, 1, 1, 1],
[1, -1, 1, -1],
[1, 2, 4, 8],
[1, -2, 4, -8],
[0, 0, 0, 1],
],
dtype=out_dtype,
)
elif tile_size == 6:
g_data = np.array(
[
[1, 0, 0],
[-2 / 9, -2 / 9, -2 / 9],
[-2 / 9, 2 / 9, -2 / 9],
[1 / 90, 1 / 45, 2 / 45],
[1 / 90, -1 / 45, 2 / 45],
[1 / 45, 1 / 90, 1 / 180],
[1 / 45, -1 / 90, 1 / 180],
[0, 0, 1],
],
dtype=out_dtype,
)
b_data = np.array(
[
[1, 0, -21 / 4, 0, 21 / 4, 0, -1, 0],
[0, 1, 1, -17 / 4, -17 / 4, 1, 1, 0],
[0, -1, 1, 17 / 4, -17 / 4, -1, 1, 0],
[0, 1 / 2, 1 / 4, -5 / 2, -5 / 4, 2, 1, 0],
[0, -1 / 2, 1 / 4, 5 / 2, -5 / 4, -2, 1, 0],
[0, 2, 4, -5 / 2, -5, 1 / 2, 1, 0],
[0, -2, 4, 5 / 2, -5, -1 / 2, 1, 0],
[0, -1, 0, 21 / 4, 0, -21 / 4, 0, 1],
],
dtype=out_dtype,
).T
a_data = np.array(
[
[1, 1, 1, 1, 1, 32, 32, 0],
[0, 1, -1, 2, -2, 16, -16, 0],
[0, 1, 1, 4, 4, 8, 8, 0],
[0, 1, -1, 8, -8, 4, -4, 0],
[0, 1, 1, 16, 16, 2, 2, 0],
[0, 1, -1, 32, -32, 1, -1, 1],
],
dtype=out_dtype,
).T
return (
const_matrix(a_data.astype(out_dtype), "A"),
const_matrix(b_data.astype(out_dtype), "B"),
const_matrix(g_data.astype(out_dtype), "G"),
)
def get_state():
state = False
# create OFF state file
import os
if not os.path.exists("/tmp/winograd.state"):
return None
else:
f = open("/tmp/winograd.state", "r")
state = int(f.read(1))
f.close()
return state
def winograd_transform_matrices(tile_size, kernel_size, out_dtype):
if (get_state()):
print("NEW winograd")
return new_winograd_transform_matrices(tile_size, kernel_size, out_dtype)
else:
print("OLD winograd")
return old_winograd_transform_matrices(tile_size, kernel_size, out_dtype)
test_topi_winograd_accuracy.test_conv2d_nchw ...
Workload: (1, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.4511020708111478e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 5.9650443671943314e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.0420007757612526e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.216157317856342e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.452783259657915e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 5.9633445700087504e-05
Workload: (1, 128, 28, 128, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 4.1373971768319884e-05
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 4.339034444266431e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 4.942436637602618e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 5.21192471347139e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 5.564884791569795e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 0.00012102896301259292
Workload: (1, 256, 14, 256, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011305222658656708
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011576440550874003
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00013714976314316624
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00014060975788450814
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011305222658656708
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011576440550874003
Workload: (1, 512, 7, 512, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.0002809650421867533
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.00028360036244900116
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.0002860461807265869
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.0003887513282542221
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.0002809650421867533
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.00028360036244900116
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.447055001551466e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 6.0288877128664076e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.0493271110943525e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.2388423065866966e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.4475851214172145e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 6.027167807264623e-05
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.4624417662568877e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.974867470739931e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.0865389932417506e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.2606775374583014e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.4628841312083634e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.9756536660941226e-05
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.475992100214023e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.9628189104542476e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.0617673387602555e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.2405094435067317e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.475242074032856e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.96491143918555e-05
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.453766274316886e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 6.031242275714051e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.074526701181002e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.2646847359168817e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.454845169151851e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 6.031650512950276e-05
Workload: (1, 1, 1, 1, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 1.8215774666430207e-08
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 3.3146134725825505e-09
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 9.420877686849849e-07
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 5.6290031302808075e-08
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 1.8215774666430207e-08
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 3.3146134725825505e-09
Workload: (3, 3, 3, 3, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 2.990957390921382e-07
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 3.4686324717008247e-07
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 5.934091036796377e-07
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 4.886136621315327e-07
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 2.990957390921382e-07
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 3.4686324717008247e-07
Workload: (2, 13, 71, 59, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.588651238681398e-06
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.869170484743837e-06
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
NEW winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 3.287651425819812e-06
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 4
OLD winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 3.78126471758971e-06
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.588651238681398e-06
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.869170484743837e-06
ok
test_topi_winograd_accuracy.test_conv2d_nchw ...
Workload: (1, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.4282570039177755e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 5.9233052460490975e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.874101823673433e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.9824263463133686e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 2.4278512789239528e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (1, 64, 56, 56)
ABSdiff: 5.923121114905956e-05
Workload: (1, 128, 28, 128, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 4.086823490468969e-05
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 4.295053575566694e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 6.377403168790755e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 6.615732833516908e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 5.610743231937542e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (1, 128, 28, 28)
ABSdiff: 0.00011906119297902599
Workload: (1, 256, 14, 256, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011452753281651969
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011711756063344491
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00018067234542971906
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00018482461536374872
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011452753281651969
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (1, 256, 14, 14)
ABSdiff: 0.00011711756063344491
Workload: (1, 512, 7, 512, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.00028562422717012646
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.00028898457758124026
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.00035035983085944435
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.0003539133945559881
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.00028562422717012646
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (1, 512, 7, 7)
ABSdiff: 0.00028898457758124026
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.4295486720469287e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 6.016031071156796e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.8758394443551472e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.9855934781453023e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.429293860432221e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 6.017109940054672e-05
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.4376490881431137e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.987216617158271e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.89711684989349e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 3.0101267847319094e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.437644818154733e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.983677381832141e-05
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.467622402628273e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.976416940094564e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.879268482485703e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.9990959888087458e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.467363882866482e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.976702062699013e-05
Workload: (2, 64, 56, 64, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.4978470027669167e-05
Running on target: cuda
[DBG] [cuda] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.9295613778137555e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.870490887443283e-05
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.985468871842946e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
NEW winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 2.499327669831756e-05
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 4
OLD winograd
Output shape: (2, 64, 56, 56)
ABSdiff: 5.929320151113948e-05
Workload: (1, 1, 1, 1, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 1.0126328220394498e-08
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 6.973097299578512e-08
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 1.2933561777117575e-07
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 1.2933561777117575e-07
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 1.0126328220394498e-08
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (1, 1, 1, 1)
ABSdiff: 6.973097299578512e-08
Workload: (3, 3, 3, 3, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 2.538112472183633e-07
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 3.021806142704887e-07
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 1.1635106824505584e-06
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 1.3849481189530543e-06
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 2.538112472183633e-07
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (3, 3, 3, 3)
ABSdiff: 3.021806142704887e-07
Workload: (2, 13, 71, 59, 3, 1, 1, 1)
Running on target: cuda
[DBG] [cuda] Tile size = 2
NEW winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.5785498161180955e-06
Running on target: cuda
[DBG] [cuda] Tile size = 2
OLD winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.8478274250846262e-06
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
NEW winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 6.105624850066014e-06
Running on target: llvm -device=arm_cpu
[DBG] [arm_cpu] Tile size = 6
OLD winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 6.367171834081284e-06
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
NEW winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.5785498161180955e-06
Running on target: opencl -device=mali
[DBG] [mali] Tile size = 2
OLD winograd
Output shape: (2, 59, 71, 71)
ABSdiff: 1.8478274250846262e-06
ok
----------------------------------------------------------------------
Ran 1 test in 241.348s
OK
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example code to do convolution."""
import numpy as np
import tvm
from tvm import autotvm
from tvm.autotvm.task.space import FallbackConfigEntity
import topi
import topi.testing
from tvm.contrib.pickle_memoize import memoize
from topi.util import get_const_tuple
def verify_conv2d_nchw(batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation=1, add_bias=False, add_relu=False,
devices=['cuda', 'llvm -device=arm_cpu', 'opencl -device=mali']):
print("\nWorkload: (%d, %d, %d, %d, %d, %d, %d, %d)" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
in_height = in_width = in_size
A = tvm.placeholder((batch, in_channel, in_height, in_width), name='A')
W = tvm.placeholder((num_filter, in_channel, kernel, kernel), name='W')
bias = tvm.placeholder((num_filter, 1, 1), name='bias')
a_shape = get_const_tuple(A.shape)
w_shape = get_const_tuple(W.shape)
bias_shape = get_const_tuple(bias.shape)
dtype = A.dtype
def set_state(state):
f = open("/tmp/winograd.state", "w")
state = f.write("%i" % state)
f.close()
#@memoize("topi.tests.test_topi_conv2d_nchw.verify_conv2d_nchw")
def get_ref_data():
a_np = np.random.uniform(size=a_shape).astype(dtype)
w_np = np.random.uniform(size=w_shape).astype(dtype)
b_np = np.random.uniform(size=bias_shape).astype(dtype)
dw_np = topi.testing.dilate_python(w_np, (1, 1, dilation, dilation))
c_np = topi.testing.conv2d_nchw_python(a_np, dw_np, stride, padding)
if add_bias:
b_np = np.random.uniform(size=bias_shape).astype(dtype)
c_np += b_np
if add_relu:
c_np = np.maximum(c_np, 0)
return a_np, w_np, b_np, c_np
a_np, w_np, b_np, c_np = get_ref_data()
def check_device(device):
ctx = tvm.context(device, 0)
if not ctx.exist:
print("Skip because %s is not enabled" % device)
return
print("Running on target: %s" % device)
with tvm.target.create(device):
C = topi.nn.conv2d(A, W, stride, padding, dilation, layout='NCHW', out_dtype=dtype)
if add_bias:
C = topi.add(C, bias)
if add_relu:
C = topi.nn.relu(C)
s = topi.generic.schedule_conv2d_nchw([C])
a = tvm.nd.array(a_np, ctx)
w = tvm.nd.array(w_np, ctx)
b = tvm.nd.array(b_np, ctx)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), ctx)
if add_bias:
func = tvm.build(s, [A, W, bias, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, b, c)
else:
func = tvm.build(s, [A, W, C], device, name="relu_%d_%d_%d_%d_%d_%d_%d_%d" % (batch, in_channel, in_size, num_filter, kernel, stride, padding, dilation))
func(a, w, c)
C = c.asnumpy();
Adiff = np.sum( np.abs(np.abs(C) - np.abs(c_np)) )
print("Output shape: ", C.shape)
print("ABSdiff: ", Adiff / np.size(C))
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-5)
for device in devices:
set_state(True)
check_device(device)
set_state(False)
check_device(device)
class WinogradFallback(autotvm.FallbackContext):
def _query_inside(self, target, workload):
key = (target, workload)
if key in self.memory:
return self.memory[key]
cfg = FallbackConfigEntity()
cfg.template_key = 'winograd'
self.memory[key] = cfg
return cfg
def test_conv2d_nchw():
autotvm.DispatchContext.current.silent = False
with WinogradFallback():
# resnet 18 workloads
verify_conv2d_nchw(1, 64, 56, 64, 3, 1, 1)
verify_conv2d_nchw(1, 128, 28, 128, 3, 1, 1)
verify_conv2d_nchw(1, 256, 14, 256, 3, 1, 1)
verify_conv2d_nchw(1, 512, 7, 512, 3, 1, 1)
# batch size = 2
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1)
# relu, bias
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_bias=True)
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True)
verify_conv2d_nchw(2, 64, 56, 64, 3, 1, 1, add_relu=True, add_bias=True)
# werid workloads
verify_conv2d_nchw(1, 1, 1, 1, 3, 1, 1)
verify_conv2d_nchw(3, 3, 3, 3, 3, 1, 1)
verify_conv2d_nchw(2, 13, 71, 59, 3, 1, 1)
if __name__ == "__main__":
test_conv2d_nchw()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment