-
-
Save mdraw/f5d9b24d8c43731756be9245ac4eeffb to your computer and use it in GitHub Desktop.
(WIP, not working) LinearScheduler JIT module for DropBlock. This fails in line 28 ("unexpected expression on left-hand side of assignment.")
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
if torch.__version__ >= '1.0.0': | |
Module = torch.jit.ScriptModule | |
script_method = torch.jit.script_method | |
else: | |
Module = torch.nn.Module | |
def script_method(x): | |
return x | |
class LinearScheduler(Module): | |
def __init__(self, dropblock, start_value, stop_value, nr_steps): | |
super(LinearScheduler, self).__init__() | |
self.dropblock = dropblock | |
self.register_buffer('i', torch.tensor(0, dtype=torch.int64)) | |
self.register_buffer('drop_values', torch.linspace(start_value, stop_value, nr_steps)) | |
@script_method | |
def forward(self, x): | |
return self.dropblock(x) | |
@script_method | |
def step(self): | |
if bool(self.i < len(self.drop_values)): | |
self.dropblock.drop_prob = self.drop_values[self.i] | |
self.i += 1 |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment