Skip to content

Instantly share code, notes, and snippets.

@ngimel
Last active December 26, 2019 22:27
Show Gist options
  • Save ngimel/780f09d431906a522c4df2da3cd16fbd to your computer and use it in GitHub Desktop.
Save ngimel/780f09d431906a522c4df2da3cd16fbd to your computer and use it in GitHub Desktop.
int cublasSgemm_v2(addr, int, int, int, int, int, float*, addr, int, addr, int, float*, addr, int);
int cublasGemmEx(addr, int, int, int, int, int, float*, addr, int, int, addr, int, int, float*, addr, int, int, int, int);
int cublasGemmBatchedEx(addr, int, int, int, int, int, float*, addr, int, int, addr, int, int, float*, addr, int, int, int, int, int);
int cublasSgemmStridedBatched(addr, int, int, int, int, int, float*, addr, int, int, addr, int, int, float*, addr, int, int, int);
int cublasGemmStridedBatchedEx(addr, int, int, int, int, int, float*, addr, int, int, int, addr, int, int, int, float*, addr, int, int, int, int, int, int);
import torch
a=torch.randn(10,1,0,device="cuda")
b=torch.randn(10,0,43200, device="cuda")
out = torch.ones(10,1,43200).cuda()
torch.bmm(a,b,out=out)
print(out)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment