This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Residual(t.nn.Module): | |
def __init__(self, *args: t.nn.Module): | |
super().__init__() | |
self.delegate = t.nn.Sequential(*args) | |
def forward(self, inputs): | |
return self.delegate(inputs) + inputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import tensorflow as tf | |
def message_passing(nodes, edges, edge_features, message_fn, edge_keep_prob=1.0): | |
""" | |
Pass messages between nodes and sum the incoming messages at each node. | |
Implements equation 1 and 2 in the paper, i.e. m_{.j}^t &= \sum_{i \in N(j)} f(h_i^{t-1}, h_j^{t-1}) | |
:param nodes: (n_nodes, n_features) tensor of node hidden states. | |
:param edges: (n_edges, 2) tensor of indices (i, j) indicating an edge from nodes[i] to nodes[j]. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"source": { | |
"source": "show flights saturday evening from st. louis to burbank", | |
"date": "1994-09-16" | |
}, | |
"target": { | |
"toloc.city_name": "bur", | |
"depart_time.time": "18:00", | |
"fromloc.city_name": "stl", | |
"depart_date": "1994-09-17" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
import string | |
import matplotlib | |
import numpy as np | |
from scipy.spatial.distance import cityblock | |
matplotlib.use('Agg') | |
import matplotlib.pyplot as plt |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Encoder | |
source = Input(shape=(None,), dtype='int32', name='source') | |
embedded = Embedding(output_dim=128, input_dim=train.source_vocab_size(), mask_zero=True)(source) | |
last_hid = LSTM(output_dim=128)(embedded) | |
# Decoder | |
repeated = RepeatVector(train.target.padded.shape[1])(last_hid) | |
decoder = LSTM(output_dim=128, return_sequences=True)(repeated) | |
output = TimeDistributed(Dense(output_dim=train.target_vocab_size(), activation='softmax'))(decoder) | |
model = Model([source], output=[output]) |
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from lasagne import * | |
from lasagne.layers import * | |
from lasagne.random import get_rng | |
from lasagne.utils import * | |
import numpy as np | |
import theano.tensor as T | |
from theano.tensor.shared_randomstreams import RandomStreams | |
class DropoutLSTMLayer(MergeLayer): |
NewerOlder