Skip to content

Instantly share code, notes, and snippets.

@ChunML
Last active June 12, 2019 03:44
Show Gist options
  • Save ChunML/fa1e23850a4d845030a38af2754cadf6 to your computer and use it in GitHub Desktop.
Save ChunML/fa1e23850a4d845030a38af2754cadf6 to your computer and use it in GitHub Desktop.
for e in range(50):
batches = get_batches(in_text, out_text, flags.batch_size, flags.seq_size)
state_h, state_c = net.zero_state(flags.batch_size)
# Transfer data to GPU
state_h = state_h.to(device)
state_c = state_c.to(device)
for x, y in batches:
iteration += 1
# Tell it we are in training mode
net.train()
# Reset all gradients
optimizer.zero_grad()
# Transfer data to GPU
x = torch.tensor(x).to(device)
y = torch.tensor(y).to(device)
logits, (state_h, state_c) = net(x, (state_h, state_c))
loss = criterion(logits.transpose(1, 2), y)
state_h = state_h.detach()
state_c = state_c.detach()
loss_value = loss.item()
# Perform back-propagation
loss.backward()
# Update the network's parameters
optimizer.step()
@PyExtreme
Copy link

state_h, state_c = net.zero_state(flags.batch_size) gives
AttributeError: 'RNNModule' object has no attribute 'zero_state'

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment