Skip to content

Instantly share code, notes, and snippets.

@astariul
Created February 22, 2019 01:35
Show Gist options
  • Save astariul/5c53515281272268711dc48e6a516417 to your computer and use it in GitHub Desktop.
Save astariul/5c53515281272268711dc48e6a516417 to your computer and use it in GitHub Desktop.
def _process_tensors(self, data):
# Truncate it to padding len
article_seq = [d.article_seq[:self.enc_max_len] for d in data]
abstract_seq = [d.abstract_seq[:self.dec_max_len - 2] for d in data]
# -2 is for [START] and [STOP]
# Add [START] and [STOP] to the target abstract
for s in abstract_seq:
s.insert(0, START_TOKEN_ID)
s.append(STOP_TOKEN_ID)
# Pad
article_seq = [s + [PAD_TOKEN_ID] * (self.enc_max_len - len(s))
for s in article_seq]
abstract_seq = [s + [PAD_TOKEN_ID] * (self.dec_max_len - len(s))
for s in abstract_seq]
return torch.tensor(article_seq, device=DEVICE).view(-1, len(data)), \
torch.tensor(abstract_seq, device=DEVICE).view(-1, len(data))
@seanie12
Copy link

using
enc_max_length = max([len(seq) for seq in article_seq])
dec_max_length = max([len(seq) for seq in abstract_seq])
instead of self.enc_max_len and self.dec_max_len might be a better option :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment