Skip to content

Instantly share code, notes, and snippets.

@viswanathgs
Last active August 7, 2020 18:42
Show Gist options
  • Save viswanathgs/16985a6a889220841febd2044f55902f to your computer and use it in GitHub Desktop.
Save viswanathgs/16985a6a889220841febd2044f55902f to your computer and use it in GitHub Desktop.
diff --git a/platform/pybmi/pybmi/torch/__init__.py b/platform/pybmi/pybmi/torch/__init__.py
index fa5c5f69a..71e24ea74 100644
--- a/platform/pybmi/pybmi/torch/__init__.py
+++ b/platform/pybmi/pybmi/torch/__init__.py
@@ -3,5 +3,5 @@ from .modules import (
GRU, LSTM, Conv1d, Conv2d, MaxPool2d, Module, Permute, Permuted,
RotationInvariantMLP, Sequential, SkipConnection, Slice, StackTime,
Stateless, StatelessWrapper, TdsFullyConnectedBlock, TdsConv2dTimeBlock,
- TdsBlock, DistributedDataParallel, save)
+ TdsBlock, DistributedDataParallel, save, BatchNorm1d)
from .online import TorchTransformer
diff --git a/platform/pybmi/pybmi/torch/modules.py b/platform/pybmi/pybmi/torch/modules.py
index c318729d6..8f80e3886 100644
--- a/platform/pybmi/pybmi/torch/modules.py
+++ b/platform/pybmi/pybmi/torch/modules.py
@@ -103,6 +103,17 @@ class StatelessWrapper(Stateless):
return self.child.forward(inputs)
+class BatchNorm1d(Stateless):
+ def __init__(self, num_features: int) -> None:
+ super().__init__()
+ self.bn = nn.BatchNorm1d(num_features)
+ # self.train(False)
+
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
+ x = inputs.transpose(1, 2)
+ return self.bn(x).transpose(1, 2)
+
+
class Sequential(Module):
"""A linear chain of :py:class:`pybmi.torch.Module` objects
@@ -803,6 +814,7 @@ class TdsConv2dTimeBlock(Stateless):
self.rpad = rpad
self.pad, conv_pad = maybe_asymetric_padding(self.kw, self.rpad)
+ conv_pad = 0 # todo slog
self.conv2d = nn.Conv2d(
in_channels=c,
out_channels=c,
@@ -820,8 +832,10 @@ class TdsConv2dTimeBlock(Stateless):
x = self.pad(x)
x = self.conv2d(x)
x = self.relu(x)
- x = x.transpose(1, 2).reshape(N, T, C) # N, T, C
- x = x + inputs
+ # todo slog
+ x = x.transpose(1, 2)
+ x = x.reshape(N, x.shape[1], C) # N, T', C
+ x = x + inputs[:, :x.shape[1], :] # todo slog
x = self.layer_norm(x)
return x
diff --git a/platform/pybmi/pybmi/torch/tests/test_modules.py b/platform/pybmi/pybmi/torch/tests/test_modules.py
index 462140f2e..0ba8b47d9 100644
--- a/platform/pybmi/pybmi/torch/tests/test_modules.py
+++ b/platform/pybmi/pybmi/torch/tests/test_modules.py
@@ -88,11 +88,12 @@ FACTORIES: Mapping[str, Tuple[str, Callable[[], pt.Module]]] = {
# "tds_fully_connected_block": ("3d", lambda: pt.TdsFullyConnectedBlock(
# INPUT_SIZE,
# )),
- "tds_conv2d_time_block": ("3d", lambda: pt.TdsConv2dTimeBlock(
- c=INPUT_SIZE // 2,
- w=INPUT_SIZE // 2,
- kw=3,
- )),
+ # "tds_conv2d_time_block": ("3d", lambda: pt.TdsConv2dTimeBlock(
+ # c=INPUT_SIZE // 2,
+ # w=INPUT_SIZE // 2,
+ # kw=3,
+ # )),
+ "batch_norm": ("3d", lambda: pt.BatchNorm1d(INPUT_SIZE)),
} # yapf: disable
FACTORIES_1D = {
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment