Skip to content

Instantly share code, notes, and snippets.

@mdraw
Created February 5, 2019 03:07
Show Gist options
  • Save mdraw/f5d9b24d8c43731756be9245ac4eeffb to your computer and use it in GitHub Desktop.
Save mdraw/f5d9b24d8c43731756be9245ac4eeffb to your computer and use it in GitHub Desktop.
(WIP, not working) LinearScheduler JIT module for DropBlock. This fails in line 28 ("unexpected expression on left-hand side of assignment.")
import torch
if torch.__version__ >= '1.0.0':
Module = torch.jit.ScriptModule
script_method = torch.jit.script_method
else:
Module = torch.nn.Module
def script_method(x):
return x
class LinearScheduler(Module):
def __init__(self, dropblock, start_value, stop_value, nr_steps):
super(LinearScheduler, self).__init__()
self.dropblock = dropblock
self.register_buffer('i', torch.tensor(0, dtype=torch.int64))
self.register_buffer('drop_values', torch.linspace(start_value, stop_value, nr_steps))
@script_method
def forward(self, x):
return self.dropblock(x)
@script_method
def step(self):
if bool(self.i < len(self.drop_values)):
self.dropblock.drop_prob = self.drop_values[self.i]
self.i += 1
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment