Skip to content

Instantly share code, notes, and snippets.

@harpone
Created September 8, 2022 19:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save harpone/dceebd22135c2149ab78557f80f6e313 to your computer and use it in GitHub Desktop.
Save harpone/dceebd22135c2149ab78557f80f6e313 to your computer and use it in GitHub Desktop.
Parallel NN layer
# PARALLEL:
from contractpool import ContractPool # imaginary 'contractpool' library, similar to python's `multiprocessing`
W = parameter(M, N)
def forward(x_: float32[N]) -> float32:
# matrix-vector multiplication:
zs = float32[M] # let's imagine we have a float32 dtype in Vyper
with ContractPool(dot, M) as p:
p.map(W, x_, out=zs) # launches M subcontracts asynchronously, each subcontract writes values to zs
p.sync() # stop main contract execution here until all subcontracts have finished writing to `zs`, then continue
# summation:
y = sum(zs) # scalar
return y
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment