Created
June 23, 2017 05:37
-
-
Save iwiwi/9e6d609c69df412fa38fa33654e5c61b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from typing import * | |
class Variable(object): | |
def __init__(self, data): | |
# type: (float) -> None | |
self.data = data | |
self.creator = None # type: Function | |
def __add__(self, y): | |
# type: (Variable) -> Variable | |
return Variable(self.data + y.data) | |
def __mul__(self, y): | |
# type: (Variable) -> Variable | |
return Variable(self.data + y.data) | |
class Function(object): | |
def __init__(self): | |
# type: () -> None | |
self.inputs = [] # type: List[Variable] | |
def __call__(self, inputs): | |
# type: (List[Variable]) -> Variable | |
self.inputs = inputs | |
output = Variable(self.forward([i.data for i in inputs])) | |
output.creator = self | |
return output | |
def forward(self, xs): | |
# type: (List[float]) -> float | |
return 0.0 | |
class Add(Function): | |
def forward(self, xs): | |
# type: (List[float]) -> float | |
return xs[0] + xs[1] | |
class Mul(Function): | |
def forward(self, xs): | |
# type: (List[float]) -> float | |
return xs[0] * xs[1] | |
class RNN(object): | |
def __init__(self): | |
# type: () -> None | |
self.w = Variable(2.0) | |
self.v = Variable(1.0) | |
self.b = Variable(0.5) | |
self.u = Variable(2.0) | |
def add(self, a, b): # move to global | |
# type: (Variable, Variable) -> Variable | |
return Add().__call__([a, b]) | |
def mul(self, a, b): # move to global | |
# type: (Variable, Variable) -> Variable | |
return Mul().__call__([a, b]) | |
def __call__(self, xs): | |
# type: (List[Variable]) -> Variable | |
y = Variable(0.0) | |
h = Variable(0.0) # TODO: None | |
for x in xs: | |
a = self.add(self.mul(self.w, x), self.b) | |
a = self.add(a, self.mul(self.v, h)) | |
h = a | |
y = self.add(y, self.mul(self.u, a)) | |
return y | |
def run(): | |
# type: () -> int | |
rnn = RNN() | |
for _ in range(100000): | |
xs = [Variable(float(i)) for i in range(100)] | |
y = rnn.__call__(xs) | |
return 0 # TODO: void |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment