Skip to content

Instantly share code, notes, and snippets.

@AtomicVar
Created October 30, 2023 07:00
Show Gist options
  • Save AtomicVar/1e73dceb010a03cbe4832918f276a77a to your computer and use it in GitHub Desktop.
Save AtomicVar/1e73dceb010a03cbe4832918f276a77a to your computer and use it in GitHub Desktop.
批次矩阵乘法 等价于一种特殊的 Conv1D
"""
表明:批次矩阵乘法 等价于一种特殊的 Conv1D
"""
import torch
N = 2
Cin = 3
Lin = 4
Cout = 5
# Use conv1d
x = torch.rand(size=(N, Cin, Lin))
w = torch.rand(size=(Cout, Cin, 1))
y0 = torch.nn.functional.conv1d(x, w)
print(y0)
# Use batched matrix multiplication
w_batch = w.reshape(1, Cout, Cin)
y1 = torch.matmul(w_batch, x)
print(y1)
# Print mean error
print(torch.mean(torch.abs(y0 - y1)))
@AtomicVar
Copy link
Author

输出:

tensor([[[0.3988, 0.3916, 0.3875, 0.4721],
         [0.1478, 0.1654, 0.1706, 0.1835],
         [0.4223, 0.2860, 0.2408, 0.4557],
         [0.8344, 1.0262, 1.0423, 0.8535],
         [0.3489, 0.4997, 0.5332, 0.4030]],

        [[0.2723, 0.4004, 0.4566, 0.5023],
         [0.1185, 0.1746, 0.1601, 0.1963],
         [0.1833, 0.2602, 0.5420, 0.4750],
         [0.6062, 1.0481, 0.8460, 0.9907],
         [0.3251, 0.5337, 0.3234, 0.4617]]])
tensor([[[0.3988, 0.3916, 0.3875, 0.4721],
         [0.1478, 0.1654, 0.1706, 0.1835],
         [0.4223, 0.2860, 0.2408, 0.4557],
         [0.8344, 1.0262, 1.0423, 0.8535],
         [0.3489, 0.4997, 0.5332, 0.4030]],

        [[0.2723, 0.4004, 0.4566, 0.5023],
         [0.1185, 0.1746, 0.1601, 0.1963],
         [0.1833, 0.2602, 0.5420, 0.4750],
         [0.6062, 1.0481, 0.8460, 0.9907],
         [0.3251, 0.5337, 0.3234, 0.4617]]])
tensor(7.2643e-09)

@AtomicVar
Copy link
Author

一个 kernel size1 的 Conv1D,实际上就是一个批次矩阵乘法。反之亦然。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment