Skip to content

Instantly share code, notes, and snippets.

@AtomicVar
Created October 30, 2023 07:00
Show Gist options
  • Save AtomicVar/1e73dceb010a03cbe4832918f276a77a to your computer and use it in GitHub Desktop.
Save AtomicVar/1e73dceb010a03cbe4832918f276a77a to your computer and use it in GitHub Desktop.
批次矩阵乘法 等价于一种特殊的 Conv1D
"""
表明:批次矩阵乘法 等价于一种特殊的 Conv1D
"""
import torch
N = 2
Cin = 3
Lin = 4
Cout = 5
# Use conv1d
x = torch.rand(size=(N, Cin, Lin))
w = torch.rand(size=(Cout, Cin, 1))
y0 = torch.nn.functional.conv1d(x, w)
print(y0)
# Use batched matrix multiplication
w_batch = w.reshape(1, Cout, Cin)
y1 = torch.matmul(w_batch, x)
print(y1)
# Print mean error
print(torch.mean(torch.abs(y0 - y1)))
@AtomicVar
Copy link
Author

一个 kernel size1 的 Conv1D,实际上就是一个批次矩阵乘法。反之亦然。

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment