Skip to content

Instantly share code, notes, and snippets.

@you74674
Created May 9, 2020 13:47
Show Gist options
  • Save you74674/48dd48ba2b21cc5af40aa512a43b4f32 to your computer and use it in GitHub Desktop.
Save you74674/48dd48ba2b21cc5af40aa512a43b4f32 to your computer and use it in GitHub Desktop.
pytorch jit overload with inheritance
import torch
from torch.nn.utils.rnn import PackedSequence
from typing import overload, Optional
class Base(torch.nn.Module):
def __init__(self):
super().__init__()
@overload
@torch._jit_internal._overload_method
def forward(self, inputs, hx=None):
# type: (PackedSequence, Optional[Tensor]) -> PackedSequence
pass
@overload
@torch._jit_internal._overload_method
def forward(self, inputs, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
pass
def forward(self, inputs, hx=None):
return inputs
class Derive(Base):
pass
class Derive2(Base):
@overload
@torch._jit_internal._overload_method
def forward(self, inputs, hx=None):
# type: (PackedSequence, Optional[Tensor]) -> PackedSequence
pass
@overload
@torch._jit_internal._overload_method
def forward(self, inputs, hx=None):
# type: (Tensor, Optional[Tensor]) -> Tensor
pass
def forward(self, inputs, hx=None):
return Base.forward(self, inputs, hx)
torch.jit.script(Base())
try:
torch.jit.script(Derive())#this doesn't work
except Exception as e:
print(e)
try:
torch.jit.script(Derive2())#doesn't work either
except Exception as e:
print(e)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment