Skip to content

Instantly share code, notes, and snippets.

@petermchale
Created July 31, 2021 23:24
Show Gist options
  • Save petermchale/4c962904bb70cd1d2315f554f877e92f to your computer and use it in GitHub Desktop.
Save petermchale/4c962904bb70cd1d2315f554f877e92f to your computer and use it in GitHub Desktop.
hidden markov model to tag sentences with parts of speech
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"id": "blessed-horror",
"metadata": {},
"source": [
"## Introduction \n",
"\n",
"Part of speech tagging is the process of determining the syntactic category of a word from the words in its surrounding context. It is often used to help disambiguate natural language phrases because it can be done quickly with high accuracy. Tagging can be used for many NLP tasks like determining correct pronunciation during speech synthesis (for example, dis-count as a noun vs dis-count as a verb), for information retrieval, and for word sense disambiguation.\n",
"\n",
"Here, I use the Pomegranate library to build a hidden Markov model for part of speech tagging. I train the HMM in a supervised fashion using a tagged set of sentences. \n",
"\n"
]
},
{
"cell_type": "markdown",
"id": "statistical-portsmouth",
"metadata": {},
"source": [
"## Fetch a corpus of tagged sentences"
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "suitable-corrections",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"[nltk_data] Downloading package treebank to\n",
"[nltk_data] /Users/petermchale/nltk_data...\n",
"[nltk_data] Package treebank is already up-to-date!\n",
"[nltk_data] Downloading package universal_tagset to\n",
"[nltk_data] /Users/petermchale/nltk_data...\n",
"[nltk_data] Package universal_tagset is already up-to-date!\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"sample tagged sentence:\n",
"[('Pierre', 'NOUN'), ('Vinken', 'NOUN'), (',', '.'), ('61', 'NUM'), ('years', 'NOUN'), ('old', 'ADJ'), (',', '.'), ('will', 'VERB'), ('join', 'VERB'), ('the', 'DET'), ('board', 'NOUN'), ('as', 'ADP'), ('a', 'DET'), ('nonexecutive', 'ADJ'), ('director', 'NOUN'), ('Nov.', 'NOUN'), ('29', 'NUM'), ('.', '.')]\n",
"\n",
"unique tags in corpus:\n",
"{'NOUN', 'ADP', 'ADJ', 'CONJ', 'DET', 'PRT', 'ADV', 'VERB', 'NUM', 'X', 'PRON', '.'}\n",
"\n",
"number of unique words in training set: 11052\n"
]
}
],
"source": [
"import nltk\n",
" \n",
"# download the treebank corpus from nltk\n",
"nltk.download('treebank')\n",
" \n",
"# download the universal tagset from nltk\n",
"nltk.download('universal_tagset')\n",
" \n",
"# read the treebank tagged sentences\n",
"nltk_data = list(nltk.corpus.treebank.tagged_sents(tagset='universal'))\n",
" \n",
"# print sample sentence, along with tags\n",
"print('sample tagged sentence:')\n",
"print(nltk_data[0])\n",
"\n",
"from sklearn.model_selection import train_test_split\n",
"\n",
"# split data into training and test sets \n",
"train_set, test_set = train_test_split(nltk_data, train_size=0.80, test_size=0.20, random_state=101)\n",
"\n",
"# use \"set\" datatype to check how many unique tags are present \n",
"unique_tags = {tag for word, tag in [tup for sent in nltk_data for tup in sent]}\n",
"print()\n",
"print('unique tags in corpus:')\n",
"print(unique_tags)\n",
"\n",
"unique_words_in_training_set = {word for word, tag in [tup for sent in train_set for tup in sent]}\n",
"print() \n",
"print('number of unique words in training set: {}'.format(len(unique_words_in_training_set)))"
]
},
{
"cell_type": "markdown",
"id": "seasonal-planning",
"metadata": {},
"source": [
"## Estimate emission probabilities associated with hidden states of HMM"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "anticipated-secretary",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Given that hidden state is \"NOUN\", probability of emitting \"Drink\" is 4.3542628233040145e-05\n"
]
}
],
"source": [
"from collections import defaultdict\n",
"\n",
"def estimate_emission_probabilities(): \n",
" train_tagged_words = [tup for sent in train_set for tup in sent]\n",
"\n",
" # compute the number of times each (tag, word) tuple, and each tag, occurred in training part of corpus \n",
" # https://stackoverflow.com/a/5900634\n",
" tag_word_counts = defaultdict(lambda: defaultdict(int))\n",
" tag_counts = defaultdict(int)\n",
" for word, tag in train_tagged_words:\n",
" tag_word_counts[tag][word] += 1 \n",
" tag_counts[tag] +=1 \n",
"\n",
" emission_probabilities = defaultdict(lambda: defaultdict(int))\n",
" for tag, tag_dict in tag_word_counts.items(): \n",
" for word, word_count in tag_dict.items():\n",
" emission_probabilities[tag][word] = word_count/tag_counts[tag]\n",
"\n",
" return emission_probabilities \n",
"\n",
"def illustrate_emission_probabilities(): \n",
" tag = 'NOUN'\n",
" word = 'Drink'\n",
" print('Given that hidden state is \"{}\", probability of emitting \"{}\" is {}'.format(\n",
" tag, \n",
" word,\n",
" estimate_emission_probabilities()[tag][word]\n",
" ))\n",
" \n",
"_ = illustrate_emission_probabilities()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"id": "anticipated-wrapping",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"{\n",
" \"class\" : \"State\",\n",
" \"distribution\" : {\n",
" \"class\" : \"Distribution\",\n",
" \"dtype\" : \"str\",\n",
" \"name\" : \"DiscreteDistribution\",\n",
" \"parameters\" : [\n",
" {\n",
" \"its\" : 0.11981776765375854,\n",
" \"us\" : 0.007744874715261959,\n",
" \"it\" : 0.17084282460136674,\n",
" \"You\" : 0.008200455580865604,\n",
" \"his\" : 0.04464692482915718,\n",
" \"whom\" : 0.002733485193621868,\n",
" \"their\" : 0.06697038724373576,\n",
" \"them\" : 0.027790432801822324,\n",
" \"they\" : 0.07699316628701594,\n",
" \"we\" : 0.01640091116173121,\n",
" \"he\" : 0.08382687927107062,\n",
" \"Your\" : 0.0031890660592255125,\n",
" \"He\" : 0.025968109339407745,\n",
" \"We\" : 0.01685649202733485,\n",
" \"They\" : 0.019134396355353075,\n",
" \"who\" : 0.05740318906605923,\n",
" \"It\" : 0.03963553530751709,\n",
" \"She\" : 0.008200455580865604,\n",
" \"her\" : 0.020045558086560365,\n",
" \"she\" : 0.026879271070615034,\n",
" \"what\" : 0.015945330296127564,\n",
" \"His\" : 0.002277904328018223,\n",
" \"I\" : 0.04510250569476082,\n",
" \"our\" : 0.008200455580865604,\n",
" \"themselves\" : 0.004555808656036446,\n",
" \"What\" : 0.00683371298405467,\n",
" \"you\" : 0.025968109339407745,\n",
" \"my\" : 0.00592255125284738,\n",
" \"Who\" : 0.0018223234624145787,\n",
" \"him\" : 0.006378132118451025,\n",
" \"one\" : 0.00045558086560364467,\n",
" \"yourself\" : 0.001366742596810934,\n",
" \"me\" : 0.0031890660592255125,\n",
" \"your\" : 0.008200455580865604,\n",
" \"itself\" : 0.0031890660592255125,\n",
" \"IT\" : 0.00045558086560364467,\n",
" \"himself\" : 0.002733485193621868,\n",
" \"whose\" : 0.006378132118451025,\n",
" \"My\" : 0.0009111617312072893,\n",
" \"Its\" : 0.004100227790432802,\n",
" \"Their\" : 0.001366742596810934,\n",
" \"Her\" : 0.001366742596810934\n",
" }\n",
" ],\n",
" \"frozen\" : false\n",
" },\n",
" \"name\" : \"PRON\",\n",
" \"weight\" : 1.0\n",
"}\n"
]
}
],
"source": [
"from pomegranate import State, DiscreteDistribution, HiddenMarkovModel\n",
"\n",
"states = {tag: \n",
" State(\n",
" DiscreteDistribution(dict(estimate_emission_probabilities()[tag])), \n",
" name=tag\n",
" ) \n",
" for tag in unique_tags}\n",
"\n",
"model = HiddenMarkovModel()\n",
"\n",
"for tag in unique_tags: \n",
" model.add_state(states[tag])\n",
" \n",
"print(states['PRON']) "
]
},
{
"cell_type": "markdown",
"id": "documentary-zimbabwe",
"metadata": {},
"source": [
"## Estimate transition probabilites of HMM"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "swedish-blank",
"metadata": {
"scrolled": true
},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>NOUN</th>\n",
" <th>VERB</th>\n",
" <th>ADP</th>\n",
" <th>ADJ</th>\n",
" <th>.</th>\n",
" <th>NUM</th>\n",
" <th>X</th>\n",
" <th>PRON</th>\n",
" <th>DET</th>\n",
" <th>CONJ</th>\n",
" <th>PRT</th>\n",
" <th>ADV</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>NOUN</th>\n",
" <td>0.262330</td>\n",
" <td>0.110589</td>\n",
" <td>0.323585</td>\n",
" <td>0.696893</td>\n",
" <td>0.183677</td>\n",
" <td>0.351554</td>\n",
" <td>0.061707</td>\n",
" <td>0.212756</td>\n",
" <td>0.635998</td>\n",
" <td>0.349067</td>\n",
" <td>0.250489</td>\n",
" <td>0.032208</td>\n",
" </tr>\n",
" <tr>\n",
" <th>VERB</th>\n",
" <td>0.149224</td>\n",
" <td>0.167956</td>\n",
" <td>0.008482</td>\n",
" <td>0.011456</td>\n",
" <td>0.129105</td>\n",
" <td>0.020722</td>\n",
" <td>0.206459</td>\n",
" <td>0.484738</td>\n",
" <td>0.040253</td>\n",
" <td>0.150384</td>\n",
" <td>0.401174</td>\n",
" <td>0.339154</td>\n",
" </tr>\n",
" <tr>\n",
" <th>.</th>\n",
" <td>0.240153</td>\n",
" <td>0.034807</td>\n",
" <td>0.038739</td>\n",
" <td>0.066019</td>\n",
" <td>0.099163</td>\n",
" <td>0.118971</td>\n",
" <td>0.160900</td>\n",
" <td>0.041913</td>\n",
" <td>0.017395</td>\n",
" <td>0.035126</td>\n",
" <td>0.045010</td>\n",
" <td>0.139309</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ADP</th>\n",
" <td>0.176847</td>\n",
" <td>0.092357</td>\n",
" <td>0.016964</td>\n",
" <td>0.080583</td>\n",
" <td>0.072601</td>\n",
" <td>0.037513</td>\n",
" <td>0.142253</td>\n",
" <td>0.022323</td>\n",
" <td>0.009919</td>\n",
" <td>0.055982</td>\n",
" <td>0.019569</td>\n",
" <td>0.119519</td>\n",
" </tr>\n",
" <tr>\n",
" <th>CONJ</th>\n",
" <td>0.042436</td>\n",
" <td>0.005433</td>\n",
" <td>0.000886</td>\n",
" <td>0.016893</td>\n",
" <td>0.063748</td>\n",
" <td>0.014291</td>\n",
" <td>0.010381</td>\n",
" <td>0.005011</td>\n",
" <td>0.000431</td>\n",
" <td>0.000549</td>\n",
" <td>0.002348</td>\n",
" <td>0.006985</td>\n",
" </tr>\n",
" <tr>\n",
" <th>NUM</th>\n",
" <td>0.009150</td>\n",
" <td>0.022836</td>\n",
" <td>0.063299</td>\n",
" <td>0.021748</td>\n",
" <td>0.113168</td>\n",
" <td>0.184352</td>\n",
" <td>0.003076</td>\n",
" <td>0.006834</td>\n",
" <td>0.022858</td>\n",
" <td>0.040615</td>\n",
" <td>0.056751</td>\n",
" <td>0.029880</td>\n",
" </tr>\n",
" <tr>\n",
" <th>PRT</th>\n",
" <td>0.043961</td>\n",
" <td>0.030663</td>\n",
" <td>0.001266</td>\n",
" <td>0.011456</td>\n",
" <td>0.003381</td>\n",
" <td>0.026081</td>\n",
" <td>0.185121</td>\n",
" <td>0.014123</td>\n",
" <td>0.000288</td>\n",
" <td>0.004391</td>\n",
" <td>0.001174</td>\n",
" <td>0.014746</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ADV</th>\n",
" <td>0.016905</td>\n",
" <td>0.083886</td>\n",
" <td>0.014559</td>\n",
" <td>0.005243</td>\n",
" <td>0.052318</td>\n",
" <td>0.003573</td>\n",
" <td>0.025759</td>\n",
" <td>0.036902</td>\n",
" <td>0.012076</td>\n",
" <td>0.057080</td>\n",
" <td>0.009393</td>\n",
" <td>0.081490</td>\n",
" </tr>\n",
" <tr>\n",
" <th>ADJ</th>\n",
" <td>0.012548</td>\n",
" <td>0.066390</td>\n",
" <td>0.107102</td>\n",
" <td>0.063301</td>\n",
" <td>0.048133</td>\n",
" <td>0.035370</td>\n",
" <td>0.017686</td>\n",
" <td>0.070615</td>\n",
" <td>0.206440</td>\n",
" <td>0.113611</td>\n",
" <td>0.082975</td>\n",
" <td>0.130772</td>\n",
" </tr>\n",
" <tr>\n",
" <th>X</th>\n",
" <td>0.028843</td>\n",
" <td>0.215930</td>\n",
" <td>0.034561</td>\n",
" <td>0.020971</td>\n",
" <td>0.027849</td>\n",
" <td>0.202572</td>\n",
" <td>0.075740</td>\n",
" <td>0.088383</td>\n",
" <td>0.045141</td>\n",
" <td>0.009330</td>\n",
" <td>0.012133</td>\n",
" <td>0.022895</td>\n",
" </tr>\n",
" <tr>\n",
" <th>DET</th>\n",
" <td>0.012984</td>\n",
" <td>0.133610</td>\n",
" <td>0.321053</td>\n",
" <td>0.005243</td>\n",
" <td>0.142788</td>\n",
" <td>0.003573</td>\n",
" <td>0.056709</td>\n",
" <td>0.009567</td>\n",
" <td>0.005894</td>\n",
" <td>0.123491</td>\n",
" <td>0.101370</td>\n",
" <td>0.071013</td>\n",
" </tr>\n",
" <tr>\n",
" <th>PRON</th>\n",
" <td>0.004618</td>\n",
" <td>0.035543</td>\n",
" <td>0.069502</td>\n",
" <td>0.000194</td>\n",
" <td>0.064070</td>\n",
" <td>0.001429</td>\n",
" <td>0.054210</td>\n",
" <td>0.006834</td>\n",
" <td>0.003306</td>\n",
" <td>0.060373</td>\n",
" <td>0.017613</td>\n",
" <td>0.012029</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" NOUN VERB ADP ADJ . NUM X \\\n",
"NOUN 0.262330 0.110589 0.323585 0.696893 0.183677 0.351554 0.061707 \n",
"VERB 0.149224 0.167956 0.008482 0.011456 0.129105 0.020722 0.206459 \n",
". 0.240153 0.034807 0.038739 0.066019 0.099163 0.118971 0.160900 \n",
"ADP 0.176847 0.092357 0.016964 0.080583 0.072601 0.037513 0.142253 \n",
"CONJ 0.042436 0.005433 0.000886 0.016893 0.063748 0.014291 0.010381 \n",
"NUM 0.009150 0.022836 0.063299 0.021748 0.113168 0.184352 0.003076 \n",
"PRT 0.043961 0.030663 0.001266 0.011456 0.003381 0.026081 0.185121 \n",
"ADV 0.016905 0.083886 0.014559 0.005243 0.052318 0.003573 0.025759 \n",
"ADJ 0.012548 0.066390 0.107102 0.063301 0.048133 0.035370 0.017686 \n",
"X 0.028843 0.215930 0.034561 0.020971 0.027849 0.202572 0.075740 \n",
"DET 0.012984 0.133610 0.321053 0.005243 0.142788 0.003573 0.056709 \n",
"PRON 0.004618 0.035543 0.069502 0.000194 0.064070 0.001429 0.054210 \n",
"\n",
" PRON DET CONJ PRT ADV \n",
"NOUN 0.212756 0.635998 0.349067 0.250489 0.032208 \n",
"VERB 0.484738 0.040253 0.150384 0.401174 0.339154 \n",
". 0.041913 0.017395 0.035126 0.045010 0.139309 \n",
"ADP 0.022323 0.009919 0.055982 0.019569 0.119519 \n",
"CONJ 0.005011 0.000431 0.000549 0.002348 0.006985 \n",
"NUM 0.006834 0.022858 0.040615 0.056751 0.029880 \n",
"PRT 0.014123 0.000288 0.004391 0.001174 0.014746 \n",
"ADV 0.036902 0.012076 0.057080 0.009393 0.081490 \n",
"ADJ 0.070615 0.206440 0.113611 0.082975 0.130772 \n",
"X 0.088383 0.045141 0.009330 0.012133 0.022895 \n",
"DET 0.009567 0.005894 0.123491 0.101370 0.071013 \n",
"PRON 0.006834 0.003306 0.060373 0.017613 0.012029 "
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import numpy as np \n",
"\n",
"def estimate_transition_probabilities():\n",
" tag_transition_counts = defaultdict(lambda: defaultdict(int))\n",
" tag_counts = defaultdict(int)\n",
" for sent in train_set: \n",
" # print(sent)\n",
" for start_tup, end_tup in zip(sent[:-1], sent[1:]):\n",
" # print(start_tup[1], '->', end_tup[1])\n",
" tag_transition_counts[start_tup[1]][end_tup[1]] += 1\n",
" tag_counts[start_tup[1]] += 1\n",
"\n",
" transition_probabilities = defaultdict(lambda: defaultdict(int))\n",
" for start_tag, tag_dict in tag_transition_counts.items(): \n",
" for end_tag, end_tag_count in tag_dict.items():\n",
" transition_probabilities[start_tag][end_tag] = end_tag_count/tag_counts[start_tag]\n",
" \n",
"# print(np.sum(list(transition_probabilities['NOUN'].values())))\n",
" return transition_probabilities\n",
" \n",
"import pandas as pd \n",
"\n",
"def illustrate_transition_probabilities(): \n",
" return pd.DataFrame(estimate_transition_probabilities())\n",
" \n",
"illustrate_transition_probabilities()"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "incredible-radical",
"metadata": {},
"outputs": [],
"source": [
"for start_tag in unique_tags: \n",
" for end_tag in unique_tags: \n",
" model.add_transition(\n",
" states[start_tag], \n",
" states[end_tag], \n",
" estimate_transition_probabilities()[start_tag][end_tag]\n",
" )"
]
},
{
"cell_type": "markdown",
"id": "vocal-render",
"metadata": {},
"source": [
"## Estimate initial probabilities of hidden states of HMM"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "trying-somewhere",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'NOUN': 0.2884062599808368,\n",
" 'ADP': 0.13318428617055253,\n",
" 'DET': 0.2312360268284893,\n",
" 'CONJ': 0.053018205046311086,\n",
" 'PRON': 0.07824976045991695,\n",
" '.': 0.07920792079207921,\n",
" 'VERB': 0.01085915043117215,\n",
" 'ADJ': 0.042159054615138934,\n",
" 'NUM': 0.00830405621207282,\n",
" 'X': 0.021079527307569467,\n",
" 'ADV': 0.05269881826892367,\n",
" 'PRT': 0.0015969338869370809}"
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from collections import Counter\n",
"\n",
"def estimate_initial_probabilities(): \n",
" initial_tags = [sent[0][1] for sent in train_set]\n",
" counts = Counter(initial_tags)\n",
" initial_probabilities = {}\n",
" for tag, count in counts.items(): \n",
" initial_probabilities[tag] = count/sum(counts.values())\n",
" return initial_probabilities\n",
"\n",
"estimate_initial_probabilities()"
]
},
{
"cell_type": "code",
"execution_count": 7,
"id": "occasional-thousand",
"metadata": {},
"outputs": [],
"source": [
"for tag in unique_tags: \n",
" model.add_transition(\n",
" model.start, \n",
" states[tag], \n",
" estimate_initial_probabilities()[tag]\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "educated-recommendation",
"metadata": {},
"outputs": [],
"source": [
"model.bake()"
]
},
{
"cell_type": "markdown",
"id": "chicken-marking",
"metadata": {},
"source": [
"## Use trained HMM to tag test sentences and custom sentences"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "extended-warrant",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"The Michigan Democrat's proposal, which *T*-4 is expected *-2 today, is described *-1 by government sources and lobbyists as significantly weaker than the Bush administration's plan * to cut utility emissions that *T*-3 lead to acid rain.\n"
]
},
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>word</th>\n",
" <th>actual_tag</th>\n",
" <th>predicted_tag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>The</td>\n",
" <td>DET</td>\n",
" <td>DET</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>Michigan</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>Democrat</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>'s</td>\n",
" <td>PRT</td>\n",
" <td>PRT</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>proposal</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>,</td>\n",
" <td>.</td>\n",
" <td>.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>which</td>\n",
" <td>DET</td>\n",
" <td>DET</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>*T*-4</td>\n",
" <td>X</td>\n",
" <td>X</td>\n",
" </tr>\n",
" <tr>\n",
" <th>8</th>\n",
" <td>is</td>\n",
" <td>VERB</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>9</th>\n",
" <td>expected</td>\n",
" <td>VERB</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>10</th>\n",
" <td>*-2</td>\n",
" <td>X</td>\n",
" <td>X</td>\n",
" </tr>\n",
" <tr>\n",
" <th>11</th>\n",
" <td>today</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>12</th>\n",
" <td>,</td>\n",
" <td>.</td>\n",
" <td>.</td>\n",
" </tr>\n",
" <tr>\n",
" <th>13</th>\n",
" <td>is</td>\n",
" <td>VERB</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>14</th>\n",
" <td>described</td>\n",
" <td>VERB</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>15</th>\n",
" <td>*-1</td>\n",
" <td>X</td>\n",
" <td>X</td>\n",
" </tr>\n",
" <tr>\n",
" <th>16</th>\n",
" <td>by</td>\n",
" <td>ADP</td>\n",
" <td>ADP</td>\n",
" </tr>\n",
" <tr>\n",
" <th>17</th>\n",
" <td>government</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>18</th>\n",
" <td>sources</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>19</th>\n",
" <td>and</td>\n",
" <td>CONJ</td>\n",
" <td>CONJ</td>\n",
" </tr>\n",
" <tr>\n",
" <th>20</th>\n",
" <td>lobbyists</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>21</th>\n",
" <td>as</td>\n",
" <td>ADV</td>\n",
" <td>ADP</td>\n",
" </tr>\n",
" <tr>\n",
" <th>22</th>\n",
" <td>significantly</td>\n",
" <td>ADV</td>\n",
" <td>ADV</td>\n",
" </tr>\n",
" <tr>\n",
" <th>23</th>\n",
" <td>weaker</td>\n",
" <td>ADJ</td>\n",
" <td>ADJ</td>\n",
" </tr>\n",
" <tr>\n",
" <th>24</th>\n",
" <td>than</td>\n",
" <td>ADP</td>\n",
" <td>ADP</td>\n",
" </tr>\n",
" <tr>\n",
" <th>25</th>\n",
" <td>the</td>\n",
" <td>DET</td>\n",
" <td>DET</td>\n",
" </tr>\n",
" <tr>\n",
" <th>26</th>\n",
" <td>Bush</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>27</th>\n",
" <td>administration</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>28</th>\n",
" <td>'s</td>\n",
" <td>PRT</td>\n",
" <td>PRT</td>\n",
" </tr>\n",
" <tr>\n",
" <th>29</th>\n",
" <td>plan</td>\n",
" <td>NOUN</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>30</th>\n",
" <td>*</td>\n",
" <td>X</td>\n",
" <td>X</td>\n",
" </tr>\n",
" <tr>\n",
" <th>31</th>\n",
" <td>to</td>\n",
" <td>PRT</td>\n",
" <td>PRT</td>\n",
" </tr>\n",
" <tr>\n",
" <th>32</th>\n",
" <td>cut</td>\n",
" <td>VERB</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>33</th>\n",
" <td>utility</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>34</th>\n",
" <td>emissions</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>35</th>\n",
" <td>that</td>\n",
" <td>DET</td>\n",
" <td>ADP</td>\n",
" </tr>\n",
" <tr>\n",
" <th>36</th>\n",
" <td>*T*-3</td>\n",
" <td>X</td>\n",
" <td>X</td>\n",
" </tr>\n",
" <tr>\n",
" <th>37</th>\n",
" <td>lead</td>\n",
" <td>VERB</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>38</th>\n",
" <td>to</td>\n",
" <td>PRT</td>\n",
" <td>PRT</td>\n",
" </tr>\n",
" <tr>\n",
" <th>39</th>\n",
" <td>acid</td>\n",
" <td>ADJ</td>\n",
" <td>ADJ</td>\n",
" </tr>\n",
" <tr>\n",
" <th>40</th>\n",
" <td>rain</td>\n",
" <td>NOUN</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>41</th>\n",
" <td>.</td>\n",
" <td>.</td>\n",
" <td>.</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" word actual_tag predicted_tag\n",
"0 The DET DET\n",
"1 Michigan NOUN NOUN\n",
"2 Democrat NOUN NOUN\n",
"3 's PRT PRT\n",
"4 proposal NOUN NOUN\n",
"5 , . .\n",
"6 which DET DET\n",
"7 *T*-4 X X\n",
"8 is VERB VERB\n",
"9 expected VERB VERB\n",
"10 *-2 X X\n",
"11 today NOUN NOUN\n",
"12 , . .\n",
"13 is VERB VERB\n",
"14 described VERB VERB\n",
"15 *-1 X X\n",
"16 by ADP ADP\n",
"17 government NOUN NOUN\n",
"18 sources NOUN NOUN\n",
"19 and CONJ CONJ\n",
"20 lobbyists NOUN NOUN\n",
"21 as ADV ADP\n",
"22 significantly ADV ADV\n",
"23 weaker ADJ ADJ\n",
"24 than ADP ADP\n",
"25 the DET DET\n",
"26 Bush NOUN NOUN\n",
"27 administration NOUN NOUN\n",
"28 's PRT PRT\n",
"29 plan NOUN VERB\n",
"30 * X X\n",
"31 to PRT PRT\n",
"32 cut VERB VERB\n",
"33 utility NOUN NOUN\n",
"34 emissions NOUN NOUN\n",
"35 that DET ADP\n",
"36 *T*-3 X X\n",
"37 lead VERB VERB\n",
"38 to PRT PRT\n",
"39 acid ADJ ADJ\n",
"40 rain NOUN NOUN\n",
"41 . . ."
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from nltk.tokenize.treebank import TreebankWordDetokenizer\n",
"\n",
"def replace_unknown(words):\n",
" \"\"\"Return a copy of the input list of words where each unknown word is replaced\n",
" by the literal string value 'nan'. Pomegranate will ignore these values\n",
" during computation.\n",
" \"\"\"\n",
" return [word if word in unique_words_in_training_set else 'nan' for word in words]\n",
"\n",
"def predict_on_test_set(sent_idx):\n",
" words = [word for word, tag in test_set[sent_idx]]\n",
" words = replace_unknown(words) \n",
" tags = [tag for word, tag in test_set[sent_idx]]\n",
" logp, path = model.viterbi(words)\n",
" predicted_tags = [state.name for _, state in path[1:]]\n",
" print(TreebankWordDetokenizer().detokenize(words))\n",
" return pd.DataFrame(zip(words, tags, predicted_tags), columns=['word','actual_tag','predicted_tag'])\n",
"\n",
"predict_on_test_set(12)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"id": "aerial-network",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>word</th>\n",
" <th>predicted_tag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>Jack</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>and</td>\n",
" <td>CONJ</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>I</td>\n",
" <td>PRON</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>went</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>up</td>\n",
" <td>ADP</td>\n",
" </tr>\n",
" <tr>\n",
" <th>5</th>\n",
" <td>the</td>\n",
" <td>DET</td>\n",
" </tr>\n",
" <tr>\n",
" <th>6</th>\n",
" <td>hill</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>7</th>\n",
" <td>.</td>\n",
" <td>.</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" word predicted_tag\n",
"0 Jack NOUN\n",
"1 and CONJ\n",
"2 I PRON\n",
"3 went VERB\n",
"4 up ADP\n",
"5 the DET\n",
"6 hill NOUN\n",
"7 . ."
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from nltk.tokenize import wordpunct_tokenize\n",
"\n",
"def predict_on_custom_sentence(sent):\n",
" words = wordpunct_tokenize(sent)\n",
" words = replace_unknown(words)\n",
" logp, path = model.viterbi(words)\n",
" predicted_tags = [state.name for _, state in path[1:]]\n",
" return pd.DataFrame(zip(words, predicted_tags), columns=['word', 'predicted_tag'])\n",
"\n",
"# The tagset consists of the following 12 coarse tags:\n",
"\n",
"# VERB - verbs (all tenses and modes)\n",
"# NOUN - nouns (common and proper)\n",
"# PRON - pronouns\n",
"# ADJ - adjectives\n",
"# ADV - adverbs\n",
"# ADP - adpositions (prepositions and postpositions)\n",
"# CONJ - conjunctions\n",
"# DET - determiners\n",
"# NUM - cardinal numbers\n",
"# PRT - particles or other function words\n",
"# X - other: foreign words, typos, abbreviations\n",
"# . - punctuation\n",
"\n",
"predict_on_custom_sentence('Jack and I went up the hill.')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"id": "developed-drama",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>word</th>\n",
" <th>predicted_tag</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>nan</td>\n",
" <td>NOUN</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>is</td>\n",
" <td>VERB</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>nan</td>\n",
" <td>X</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>.</td>\n",
" <td>.</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" word predicted_tag\n",
"0 nan NOUN\n",
"1 is VERB\n",
"2 nan X\n",
"3 . ."
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# https://github.com/jmschrei/pomegranate/issues/347\n",
"predict_on_custom_sentence('NLP is fun.')"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.8.8"
}
},
"nbformat": 4,
"nbformat_minor": 5
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment