Skip to content

Instantly share code, notes, and snippets.

@LeeSinLiang
Last active May 24, 2023 13:49
Show Gist options
  • Save LeeSinLiang/b4a51e2e930d803c1450d0aafff73835 to your computer and use it in GitHub Desktop.
Save LeeSinLiang/b4a51e2e930d803c1450d0aafff73835 to your computer and use it in GitHub Desktop.
Enables nn.Sequential to accept multiple inputs, enhancing the flexibility of sequential neural network models.
import torch.nn as nn
class MultiInputSequential(nn.Sequential):
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment