Skip to content

Instantly share code, notes, and snippets.

@vanatteveldt
Last active June 17, 2018 14:50
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vanatteveldt/1a2aa9c470ca64f8bf5969a83c28d16a to your computer and use it in GitHub Desktop.
Save vanatteveldt/1a2aa9c470ca64f8bf5969a83c28d16a to your computer and use it in GitHub Desktop.
import csv
import spacy
import random
def get_cats(label, labelset):
return {'cats': {l: l == label for l in labelset}}
def evaluate(textcat, tokenizer, texts, gold_labels):
ncorrect = 0
for j, doc in enumerate(textcat.pipe(tokenizer(text) for text in texts)):
top_score = sorted(doc.cats, key=doc.cats.get)[-1]
if top_score == gold_labels[j]:
ncorrect += 1
return ncorrect / len(gold_labels)
# Get data, split into 1000 train and ~250 test
data = [(d['text'], d['label']) for d in csv.DictReader(open('issues2.csv'))]
random.shuffle(data)
train_data = data[:1000]
test_data = data[1000:]
train_texts = [t for (t, i) in train_data]
train_labels = [i for (t, i) in train_data]
test_texts = [t for (t, i) in test_data]
test_labels = [i for (t, i) in test_data]
labelset = set(test_labels) | set(train_labels)
# Set up spacy pipeline
nlp = spacy.load('nl_core_news_sm')
pretrained = nlp.pipe_names
textcat = nlp.create_pipe('textcat')
nlp.add_pipe(textcat, last=True)
for l in labelset:
textcat.add_label(l)
# Train and test
with nlp.disable_pipes(*pretrained):
optimizer = nlp.begin_training()
for i in range(200):
losses = {}
annotations = [get_cats(label, labelset) for label in train_labels]
nlp.update(train_texts, annotations, sgd=optimizer, drop=0.2,
losses=losses)
with textcat.model.use_params(optimizer.averages):
acc_train = evaluate(textcat, nlp.tokenizer, train_texts, train_labels)
acc_test = evaluate(textcat, nlp.tokenizer, test_texts, test_labels)
print("Iter: {i}, Acc(train): {acc_train:1.3f}, Acc(test): {acc_test:1.3f}, Losses: {loss}"
.format(loss=losses['textcat'], **locals()))
$ env/bin/python initial_classifier.py
Warning: Unnamed vectors -- this won't allow multiple vectors models to be loaded. (Shape: (0, 0))
Iter: 0, Acc(train): 0.058, Acc(test): 0.058, Losses: 5.487778186798096
Iter: 1, Acc(train): 0.059, Acc(test): 0.041, Losses: 3.338900327682495
Iter: 2, Acc(train): 0.066, Acc(test): 0.041, Losses: 2.46976375579834
Iter: 3, Acc(train): 0.060, Acc(test): 0.025, Losses: 1.892720103263855
Iter: 4, Acc(train): 0.067, Acc(test): 0.017, Losses: 1.457034945487976
Iter: 5, Acc(train): 0.096, Acc(test): 0.025, Losses: 1.340357780456543
Iter: 6, Acc(train): 0.116, Acc(test): 0.037, Losses: 1.246807336807251
Iter: 7, Acc(train): 0.158, Acc(test): 0.079, Losses: 1.1653318405151367
Iter: 8, Acc(train): 0.151, Acc(test): 0.100, Losses: 1.0578923225402832
Iter: 9, Acc(train): 0.157, Acc(test): 0.100, Losses: 1.0061776638031006
Iter: 10, Acc(train): 0.160, Acc(test): 0.108, Losses: 0.9997687935829163
Iter: 11, Acc(train): 0.167, Acc(test): 0.141, Losses: 0.9994023442268372
Iter: 12, Acc(train): 0.176, Acc(test): 0.141, Losses: 0.971940815448761
Iter: 13, Acc(train): 0.181, Acc(test): 0.149, Losses: 0.9356969594955444
Iter: 14, Acc(train): 0.188, Acc(test): 0.158, Losses: 0.9266849160194397
Iter: 15, Acc(train): 0.170, Acc(test): 0.133, Losses: 0.9230220317840576
Iter: 16, Acc(train): 0.184, Acc(test): 0.137, Losses: 0.9157535433769226
Iter: 17, Acc(train): 0.204, Acc(test): 0.145, Losses: 0.8926990032196045
Iter: 18, Acc(train): 0.222, Acc(test): 0.154, Losses: 0.8765964508056641
Iter: 19, Acc(train): 0.235, Acc(test): 0.183, Losses: 0.8826152086257935
Iter: 20, Acc(train): 0.245, Acc(test): 0.187, Losses: 0.8612433075904846
Iter: 21, Acc(train): 0.249, Acc(test): 0.178, Losses: 0.8500571250915527
Iter: 22, Acc(train): 0.260, Acc(test): 0.178, Losses: 0.8523365259170532
Iter: 23, Acc(train): 0.264, Acc(test): 0.178, Losses: 0.8429845571517944
Iter: 24, Acc(train): 0.271, Acc(test): 0.183, Losses: 0.830218493938446
Iter: 25, Acc(train): 0.281, Acc(test): 0.178, Losses: 0.8112003803253174
Iter: 26, Acc(train): 0.301, Acc(test): 0.187, Losses: 0.8008423447608948
Iter: 27, Acc(train): 0.310, Acc(test): 0.203, Losses: 0.797116756439209
Iter: 28, Acc(train): 0.316, Acc(test): 0.216, Losses: 0.8070734739303589
Iter: 29, Acc(train): 0.321, Acc(test): 0.228, Losses: 0.7884078621864319
Iter: 30, Acc(train): 0.323, Acc(test): 0.232, Losses: 0.7829582691192627
Iter: 31, Acc(train): 0.326, Acc(test): 0.237, Losses: 0.764626681804657
Iter: 32, Acc(train): 0.331, Acc(test): 0.241, Losses: 0.7590939402580261
Iter: 33, Acc(train): 0.331, Acc(test): 0.241, Losses: 0.7535281181335449
Iter: 34, Acc(train): 0.330, Acc(test): 0.241, Losses: 0.7417029142379761
Iter: 35, Acc(train): 0.332, Acc(test): 0.237, Losses: 0.7463305592536926
Iter: 36, Acc(train): 0.335, Acc(test): 0.232, Losses: 0.7417045831680298
Iter: 37, Acc(train): 0.349, Acc(test): 0.261, Losses: 0.7345207333564758
Iter: 38, Acc(train): 0.354, Acc(test): 0.257, Losses: 0.7268350720405579
Iter: 39, Acc(train): 0.362, Acc(test): 0.257, Losses: 0.7184979319572449
Iter: 40, Acc(train): 0.366, Acc(test): 0.257, Losses: 0.7169594764709473
Iter: 41, Acc(train): 0.369, Acc(test): 0.270, Losses: 0.7110732793807983
Iter: 42, Acc(train): 0.377, Acc(test): 0.274, Losses: 0.698552668094635
Iter: 43, Acc(train): 0.383, Acc(test): 0.274, Losses: 0.7015926241874695
Iter: 44, Acc(train): 0.386, Acc(test): 0.282, Losses: 0.6926952600479126
Iter: 45, Acc(train): 0.389, Acc(test): 0.286, Losses: 0.6951528191566467
Iter: 46, Acc(train): 0.395, Acc(test): 0.290, Losses: 0.6891776323318481
Iter: 47, Acc(train): 0.398, Acc(test): 0.290, Losses: 0.6850417256355286
Iter: 48, Acc(train): 0.407, Acc(test): 0.290, Losses: 0.6794711947441101
Iter: 49, Acc(train): 0.408, Acc(test): 0.295, Losses: 0.6727749705314636
Iter: 50, Acc(train): 0.411, Acc(test): 0.295, Losses: 0.6712129712104797
Iter: 51, Acc(train): 0.415, Acc(test): 0.290, Losses: 0.6682124733924866
Iter: 52, Acc(train): 0.418, Acc(test): 0.290, Losses: 0.6585105657577515
Iter: 53, Acc(train): 0.423, Acc(test): 0.286, Losses: 0.6576476097106934
Iter: 54, Acc(train): 0.426, Acc(test): 0.286, Losses: 0.6593191623687744
Iter: 55, Acc(train): 0.431, Acc(test): 0.290, Losses: 0.6528293490409851
Iter: 56, Acc(train): 0.434, Acc(test): 0.290, Losses: 0.6522123217582703
Iter: 57, Acc(train): 0.437, Acc(test): 0.295, Losses: 0.64777672290802
Iter: 58, Acc(train): 0.438, Acc(test): 0.299, Losses: 0.638460099697113
Iter: 59, Acc(train): 0.441, Acc(test): 0.303, Losses: 0.6383422613143921
Iter: 60, Acc(train): 0.441, Acc(test): 0.303, Losses: 0.6370445489883423
Iter: 61, Acc(train): 0.441, Acc(test): 0.303, Losses: 0.6358642578125
Iter: 62, Acc(train): 0.442, Acc(test): 0.303, Losses: 0.629823625087738
Iter: 63, Acc(train): 0.444, Acc(test): 0.303, Losses: 0.6301221251487732
Iter: 64, Acc(train): 0.447, Acc(test): 0.303, Losses: 0.6261898279190063
Iter: 65, Acc(train): 0.447, Acc(test): 0.303, Losses: 0.620913028717041
Iter: 66, Acc(train): 0.447, Acc(test): 0.303, Losses: 0.6252844333648682
Iter: 67, Acc(train): 0.448, Acc(test): 0.311, Losses: 0.6245530247688293
Iter: 68, Acc(train): 0.448, Acc(test): 0.320, Losses: 0.6190335750579834
Iter: 69, Acc(train): 0.460, Acc(test): 0.336, Losses: 0.6150144934654236
Iter: 70, Acc(train): 0.460, Acc(test): 0.336, Losses: 0.6181594133377075
Iter: 71, Acc(train): 0.460, Acc(test): 0.336, Losses: 0.6155303716659546
Iter: 72, Acc(train): 0.460, Acc(test): 0.336, Losses: 0.6143075227737427
Iter: 73, Acc(train): 0.461, Acc(test): 0.340, Losses: 0.6117486357688904
Iter: 74, Acc(train): 0.463, Acc(test): 0.340, Losses: 0.6136998534202576
Iter: 75, Acc(train): 0.464, Acc(test): 0.340, Losses: 0.6056715250015259
Iter: 76, Acc(train): 0.464, Acc(test): 0.340, Losses: 0.6040197610855103
Iter: 77, Acc(train): 0.464, Acc(test): 0.340, Losses: 0.6047747731208801
Iter: 78, Acc(train): 0.465, Acc(test): 0.340, Losses: 0.6025795936584473
Iter: 79, Acc(train): 0.467, Acc(test): 0.340, Losses: 0.6032374501228333
Iter: 80, Acc(train): 0.467, Acc(test): 0.340, Losses: 0.6023434996604919
Iter: 81, Acc(train): 0.468, Acc(test): 0.340, Losses: 0.5972998142242432
Iter: 82, Acc(train): 0.469, Acc(test): 0.340, Losses: 0.598336398601532
Iter: 83, Acc(train): 0.469, Acc(test): 0.340, Losses: 0.599162757396698
Iter: 84, Acc(train): 0.470, Acc(test): 0.340, Losses: 0.5977776646614075
Iter: 85, Acc(train): 0.470, Acc(test): 0.344, Losses: 0.5968936085700989
Iter: 86, Acc(train): 0.472, Acc(test): 0.344, Losses: 0.5973110795021057
Iter: 87, Acc(train): 0.474, Acc(test): 0.349, Losses: 0.5929310321807861
Iter: 88, Acc(train): 0.475, Acc(test): 0.349, Losses: 0.5919730067253113
Iter: 89, Acc(train): 0.475, Acc(test): 0.349, Losses: 0.5929781198501587
Iter: 90, Acc(train): 0.475, Acc(test): 0.349, Losses: 0.5918078422546387
Iter: 91, Acc(train): 0.476, Acc(test): 0.349, Losses: 0.5889279246330261
Iter: 92, Acc(train): 0.477, Acc(test): 0.353, Losses: 0.5871965289115906
Iter: 93, Acc(train): 0.479, Acc(test): 0.349, Losses: 0.5851335525512695
Iter: 94, Acc(train): 0.480, Acc(test): 0.353, Losses: 0.5843682289123535
Iter: 95, Acc(train): 0.480, Acc(test): 0.353, Losses: 0.5843235850334167
Iter: 96, Acc(train): 0.481, Acc(test): 0.353, Losses: 0.5814635157585144
Iter: 97, Acc(train): 0.481, Acc(test): 0.353, Losses: 0.5815912485122681
Iter: 98, Acc(train): 0.485, Acc(test): 0.357, Losses: 0.5797308087348938
Iter: 99, Acc(train): 0.486, Acc(test): 0.365, Losses: 0.5755237340927124
Iter: 100, Acc(train): 0.488, Acc(test): 0.361, Losses: 0.5744960904121399
Iter: 101, Acc(train): 0.488, Acc(test): 0.361, Losses: 0.5685713887214661
Iter: 102, Acc(train): 0.491, Acc(test): 0.361, Losses: 0.5711052417755127
Iter: 103, Acc(train): 0.493, Acc(test): 0.365, Losses: 0.5703095197677612
Iter: 104, Acc(train): 0.498, Acc(test): 0.365, Losses: 0.5621322989463806
Iter: 105, Acc(train): 0.501, Acc(test): 0.369, Losses: 0.5626657009124756
Iter: 106, Acc(train): 0.502, Acc(test): 0.369, Losses: 0.5563108921051025
Iter: 107, Acc(train): 0.504, Acc(test): 0.369, Losses: 0.5532985925674438
Iter: 108, Acc(train): 0.506, Acc(test): 0.369, Losses: 0.5492045879364014
Iter: 109, Acc(train): 0.511, Acc(test): 0.369, Losses: 0.5541149377822876
Iter: 110, Acc(train): 0.513, Acc(test): 0.369, Losses: 0.545274019241333
Iter: 111, Acc(train): 0.517, Acc(test): 0.369, Losses: 0.5426530838012695
Iter: 112, Acc(train): 0.521, Acc(test): 0.365, Losses: 0.5389185547828674
Iter: 113, Acc(train): 0.525, Acc(test): 0.369, Losses: 0.5396145582199097
Iter: 114, Acc(train): 0.528, Acc(test): 0.369, Losses: 0.5320292711257935
Iter: 115, Acc(train): 0.532, Acc(test): 0.369, Losses: 0.532733142375946
Iter: 116, Acc(train): 0.536, Acc(test): 0.369, Losses: 0.5283084511756897
Iter: 117, Acc(train): 0.546, Acc(test): 0.369, Losses: 0.5251138210296631
Iter: 118, Acc(train): 0.550, Acc(test): 0.373, Losses: 0.5185801982879639
Iter: 119, Acc(train): 0.556, Acc(test): 0.378, Losses: 0.5126572251319885
Iter: 120, Acc(train): 0.563, Acc(test): 0.390, Losses: 0.505986213684082
Iter: 121, Acc(train): 0.566, Acc(test): 0.390, Losses: 0.4964945614337921
Iter: 122, Acc(train): 0.567, Acc(test): 0.394, Losses: 0.4896056056022644
Iter: 123, Acc(train): 0.576, Acc(test): 0.394, Losses: 0.48184967041015625
Iter: 124, Acc(train): 0.581, Acc(test): 0.398, Losses: 0.4802703261375427
Iter: 125, Acc(train): 0.583, Acc(test): 0.386, Losses: 0.4719477593898773
Iter: 126, Acc(train): 0.591, Acc(test): 0.394, Losses: 0.46586158871650696
Iter: 127, Acc(train): 0.596, Acc(test): 0.394, Losses: 0.4637313187122345
Iter: 128, Acc(train): 0.600, Acc(test): 0.390, Losses: 0.4600047171115875
Iter: 129, Acc(train): 0.602, Acc(test): 0.390, Losses: 0.4488006830215454
Iter: 130, Acc(train): 0.605, Acc(test): 0.407, Losses: 0.4483371675014496
Iter: 131, Acc(train): 0.611, Acc(test): 0.407, Losses: 0.4436570107936859
Iter: 132, Acc(train): 0.618, Acc(test): 0.402, Losses: 0.4431220293045044
Iter: 133, Acc(train): 0.620, Acc(test): 0.402, Losses: 0.4422287344932556
Iter: 134, Acc(train): 0.622, Acc(test): 0.411, Losses: 0.4363781809806824
Iter: 135, Acc(train): 0.626, Acc(test): 0.411, Losses: 0.43279650807380676
Iter: 136, Acc(train): 0.627, Acc(test): 0.411, Losses: 0.4308914840221405
Iter: 137, Acc(train): 0.629, Acc(test): 0.415, Losses: 0.4287142753601074
Iter: 138, Acc(train): 0.632, Acc(test): 0.415, Losses: 0.4285946488380432
Iter: 139, Acc(train): 0.634, Acc(test): 0.415, Losses: 0.426947683095932
Iter: 140, Acc(train): 0.634, Acc(test): 0.415, Losses: 0.42426609992980957
Iter: 141, Acc(train): 0.635, Acc(test): 0.419, Losses: 0.4218948483467102
Iter: 142, Acc(train): 0.638, Acc(test): 0.423, Losses: 0.41673773527145386
Iter: 143, Acc(train): 0.641, Acc(test): 0.427, Losses: 0.41734179854393005
Iter: 144, Acc(train): 0.644, Acc(test): 0.423, Losses: 0.41206833720207214
Iter: 145, Acc(train): 0.647, Acc(test): 0.427, Losses: 0.4142393171787262
Iter: 146, Acc(train): 0.647, Acc(test): 0.423, Losses: 0.41005298495292664
Iter: 147, Acc(train): 0.647, Acc(test): 0.423, Losses: 0.4082860052585602
Iter: 148, Acc(train): 0.648, Acc(test): 0.419, Losses: 0.41000762581825256
Iter: 149, Acc(train): 0.648, Acc(test): 0.423, Losses: 0.4108700454235077
Iter: 150, Acc(train): 0.648, Acc(test): 0.423, Losses: 0.40720683336257935
Iter: 151, Acc(train): 0.648, Acc(test): 0.423, Losses: 0.4079752266407013
Iter: 152, Acc(train): 0.648, Acc(test): 0.423, Losses: 0.4078514873981476
Iter: 153, Acc(train): 0.648, Acc(test): 0.423, Losses: 0.4077819585800171
Iter: 154, Acc(train): 0.648, Acc(test): 0.427, Losses: 0.4062652289867401
Iter: 155, Acc(train): 0.648, Acc(test): 0.427, Losses: 0.4057643413543701
Iter: 156, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.4049816429615021
Iter: 157, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.4030606746673584
Iter: 158, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.40520617365837097
Iter: 159, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.40413394570350647
Iter: 160, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.4066201448440552
Iter: 161, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.40534573793411255
Iter: 162, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.4028932452201843
Iter: 163, Acc(train): 0.648, Acc(test): 0.432, Losses: 0.40276166796684265
Iter: 164, Acc(train): 0.648, Acc(test): 0.436, Losses: 0.4033645689487457
Iter: 165, Acc(train): 0.648, Acc(test): 0.440, Losses: 0.40249937772750854
Iter: 166, Acc(train): 0.648, Acc(test): 0.444, Losses: 0.40261489152908325
Iter: 167, Acc(train): 0.648, Acc(test): 0.448, Losses: 0.40067049860954285
Iter: 168, Acc(train): 0.648, Acc(test): 0.444, Losses: 0.40328094363212585
Iter: 169, Acc(train): 0.648, Acc(test): 0.444, Losses: 0.4011622965335846
Iter: 170, Acc(train): 0.648, Acc(test): 0.444, Losses: 0.4006740152835846
Iter: 171, Acc(train): 0.648, Acc(test): 0.448, Losses: 0.40162408351898193
Iter: 172, Acc(train): 0.648, Acc(test): 0.448, Losses: 0.4008123278617859
Iter: 173, Acc(train): 0.649, Acc(test): 0.448, Losses: 0.40172067284584045
Iter: 174, Acc(train): 0.649, Acc(test): 0.448, Losses: 0.39933106303215027
Iter: 175, Acc(train): 0.649, Acc(test): 0.448, Losses: 0.40000712871551514
Iter: 176, Acc(train): 0.649, Acc(test): 0.448, Losses: 0.4002104699611664
Iter: 177, Acc(train): 0.649, Acc(test): 0.448, Losses: 0.39877083897590637
Iter: 178, Acc(train): 0.650, Acc(test): 0.448, Losses: 0.3976871967315674
Iter: 179, Acc(train): 0.650, Acc(test): 0.448, Losses: 0.39784955978393555
Iter: 180, Acc(train): 0.650, Acc(test): 0.452, Losses: 0.3991497755050659
Iter: 181, Acc(train): 0.650, Acc(test): 0.452, Losses: 0.39902374148368835
Iter: 182, Acc(train): 0.650, Acc(test): 0.448, Losses: 0.3975830078125
Iter: 183, Acc(train): 0.651, Acc(test): 0.448, Losses: 0.39719539880752563
Iter: 184, Acc(train): 0.651, Acc(test): 0.452, Losses: 0.39893510937690735
Iter: 185, Acc(train): 0.652, Acc(test): 0.452, Losses: 0.39921483397483826
Iter: 186, Acc(train): 0.652, Acc(test): 0.452, Losses: 0.3992139399051666
Iter: 187, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.3986184597015381
Iter: 188, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.3977349102497101
Iter: 189, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.3991568386554718
Iter: 190, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.39813148975372314
Iter: 191, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.3962554335594177
Iter: 192, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.39831027388572693
Iter: 193, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.3954155743122101
Iter: 194, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.39497339725494385
Iter: 195, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.3949982225894928
Iter: 196, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.39554205536842346
Iter: 197, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.394592821598053
Iter: 198, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.39566856622695923
Iter: 199, Acc(train): 0.653, Acc(test): 0.452, Losses: 0.39701879024505615
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment