Skip to content

Instantly share code, notes, and snippets.

@prrao87
Created January 13, 2019 00:04
Show Gist options
  • Save prrao87/cf5a771cc94671185fd67a337b50e689 to your computer and use it in GitHub Desktop.
Save prrao87/cf5a771cc94671185fd67a337b50e689 to your computer and use it in GitHub Desktop.
split tweet data into training, validation and test sets for the transformer
def stance(data_dir, topic=None):
path = Path(data_dir)
trainfile = 'semeval2016-task6-trainingdata.txt'
testfile = 'SemEval2016-Task6-subtaskA-testdata.txt'
X, Y = _stance(path/trainfile, topic=topic)
teX, _ = _stance(path/testfile, topic=topic)
tr_text, va_text, tr_sent, va_sent = train_test_split(X, Y, test_size=0.2, random_state=seed)
trX = []
trY = []
for t, s in zip(tr_text, tr_sent):
trX.append(t)
trY.append(s)
vaX = []
vaY = []
for t, s in zip(va_text, va_sent):
vaX.append(t)
vaY.append(s)
trY = np.asarray(trY, dtype=np.int32)
vaY = np.asarray(vaY, dtype=np.int32)
return (trX, trY), (vaX, vaY), (teX, )
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment