Skip to content

Instantly share code, notes, and snippets.

@mariokostelac
Created February 4, 2021 09:04
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save mariokostelac/35b23fe68507b19214aa346df1bb6df4 to your computer and use it in GitHub Desktop.
Save mariokostelac/35b23fe68507b19214aa346df1bb6df4 to your computer and use it in GitHub Desktop.
Debug layer for trax, prints out the shape of the data it works with
from trax.layers import base
class Debug(base.Layer):
def __init__(self, msg=""):
super().__init__(name=f'Debug')
self.msg = msg
self.debug = False
def forward(self, x):
if self.debug:
print(f"{self.msg} {x.shape}")
return x
def init_weights_and_state(self, input_signature):
pass
@mariokostelac
Copy link
Author

After training, turn on the debug mode with

model.sublayers[0].debug = True

(or whatever is appropriate for your model).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment