Skip to content

Instantly share code, notes, and snippets.

@t-abe
Last active December 8, 2016 09:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save t-abe/0c50afdb72166f6bab86cc611e6c2a03 to your computer and use it in GitHub Desktop.
Save t-abe/0c50afdb72166f6bab86cc611e6c2a03 to your computer and use it in GitHub Desktop.
class ForgetVariable(chainer.Variable):
data_cache = None
last_accessed_id = None
def __init__(self, x):
"""
:type x: chainer.Variable
"""
self.rank = x.rank
self._volatile = x._volatile
self._grad = x._grad
self.creator = x.creator
self.name = x.name
self._data = None
if self.creator is not None:
self.output_index = None
for i, y in enumerate(self.creator.outputs):
if y() is x:
self.output_index = i
assert self.output_index is not None
self.creator.outputs = tuple([weakref.ref(self) if id(y()) == id(x) else y
for y in self.creator.outputs])
ForgetVariable.data_cache = x.data
ForgetVariable.last_accessed_id = id(self)
else:
self._data = x.data
@property
def data(self):
if self._data is not None:
return self._data
if ForgetVariable.last_accessed_id == id(self):
return ForgetVariable.data_cache
in_data = tuple([x.data for x in self.creator.inputs])
out_data = self.creator.forward(in_data)[self.output_index]
ForgetVariable.data_cache = out_data
ForgetVariable.last_accessed_id = id(self)
return out_data
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment