Skip to content

Instantly share code, notes, and snippets.

@tteofili
Last active June 1, 2022 08:22
Show Gist options
  • Save tteofili/b4c81a3de6aef40e8dfa27eaf22f116d to your computer and use it in GitHub Desktop.
Save tteofili/b4c81a3de6aef40e8dfa27eaf22f116d to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"from certa.explain import CertaExplainer\n",
"from certa.utils import merge_sources\n",
"from certa.models.ditto.ditto import DittoModel"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"datadir = 'datasets/Beer'\n",
"lsource = pd.read_csv(datadir + '/tableA.csv')\n",
"rsource = pd.read_csv(datadir + '/tableB.csv')\n",
"gt = pd.read_csv(datadir + '/train.csv')\n",
"valid = pd.read_csv(datadir + '/valid.csv')\n",
"test = pd.read_csv(datadir + '/test.csv')"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"\n",
"pt_model_dict = torch.load('ditto/Structured/Beer/model.pt')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight']\n",
"- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"ditto_model = DittoModel(lm='distilbert', device='cpu')"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "<All keys matched successfully>"
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"ditto_model.load_state_dict(pt_model_dict['model'])"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from certa.models.ditto.summarize import Summarizer\n",
"\n",
"summarizer = Summarizer(lsource, rsource, 'distilbert')"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"certa_explainer = CertaExplainer(lsource, rsource)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight']\n",
"- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
"- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n"
]
}
],
"source": [
"from certa.models.bert import EMTERModel\n",
"\n",
"model = EMTERModel(ditto=True, summarizer=summarizer, dk='general')"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"model.model = ditto_model"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"def predict_fn(x):\n",
" return model.predict(x, len=64)"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [],
"source": [
"test_df = merge_sources(test, 'ltable_', 'rtable_', lsource, rsource, ['label'], [])"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": "label 1\nltable_ABV 7.40 %\nltable_Beer_Name Honey Basil Amber\nltable_Brew_Factory_Name Rude Hippo Brewing Company\nltable_Style American Amber / Red Ale\nName: 2, dtype: object"
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"idx = 2\n",
"rand_row = test_df.iloc[idx]\n",
"l_id = int(rand_row['ltable_id'])\n",
"l_tuple = lsource.iloc[l_id]\n",
"r_id = int(rand_row['rtable_id'])\n",
"r_tuple = rsource.iloc[r_id]\n",
"rand_row.head()"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": "id 3917\nBeer_Name Honey Basil Amber\nBrew_Factory_Name Rude Hippo Brewing Company\nStyle American Amber / Red Ale\nABV 7.40 %\nName: 3917, dtype: object"
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"l_tuple"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"data": {
"text/plain": "id 2224\nBeer_Name Rude Hippo Honey Basil Amber\nBrew_Factory_Name 18th Street Brewery\nStyle Amber Ale\nABV 7.40 %\nName: 2224, dtype: object"
},
"execution_count": 14,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"r_tuple"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {
"pycharm": {
"name": "#%%\n"
}
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n",
" warnings.warn(Warnings.W108)\n",
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n",
" warnings.warn(Warnings.W108)\n",
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n",
" warnings.warn(Warnings.W108)\n",
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n",
" warnings.warn(Warnings.W108)\n",
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n",
" warnings.warn(Warnings.W108)\n"
]
}
],
"source": [
"saliency_df, cf_summary, counterfactual_examples, triangles, lattices = certa_explainer.explain(l_tuple, r_tuple, predict_fn, num_triangles=100)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": "{'ltable_Beer_Name': {0: 0.5113636363636364},\n 'ltable_Brew_Factory_Name': {0: 0.375},\n 'ltable_Style': {0: 0.3181818181818182},\n 'ltable_ABV': {0: 0.2954545454545454},\n 'rtable_Beer_Name': {0: 0.5113636363636364},\n 'rtable_Brew_Factory_Name': {0: 0.28409090909090906},\n 'rtable_Style': {0: 0.28409090909090906},\n 'rtable_ABV': {0: 0.28409090909090906}}"
},
"execution_count": 16,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"saliency_df.to_dict()"
]
},
{
"cell_type": "code",
"execution_count": 17,
"outputs": [
{
"data": {
"text/plain": "ltable_Beer_Name 0.5\nltable_Brew_Factory_Name 0.5\nltable_Style 0.0\nltable_ABV 0.0\nrtable_Beer_Name 0.5\nrtable_Brew_Factory_Name 0.0\nrtable_Style 0.0\nrtable_ABV 0.0\nltable_Beer_Name/ltable_Brew_Factory_Name 0.5\nltable_Beer_Name/ltable_Style 0.5\nltable_Beer_Name/ltable_ABV 0.5\nltable_Brew_Factory_Name/ltable_Style 0.2\nltable_Brew_Factory_Name/ltable_ABV 0.0\nltable_Style/ltable_ABV 0.0\nrtable_Beer_Name/rtable_Brew_Factory_Name 0.5\nrtable_Beer_Name/rtable_Style 0.5\nrtable_Beer_Name/rtable_ABV 0.5\nrtable_Brew_Factory_Name/rtable_Style 0.0\nrtable_Brew_Factory_Name/rtable_ABV 0.0\nrtable_Style/rtable_ABV 0.0\nltable_Beer_Name/ltable_Brew_Factory_Name/ltable_Style 0.5\nltable_Beer_Name/ltable_Brew_Factory_Name/ltable_ABV 0.5\nltable_Beer_Name/ltable_Style/ltable_ABV 0.5\nltable_Brew_Factory_Name/ltable_Style/ltable_ABV 0.1\nrtable_Beer_Name/rtable_Brew_Factory_Name/rtable_Style 0.5\nrtable_Beer_Name/rtable_Brew_Factory_Name/rtable_ABV 0.5\nrtable_Beer_Name/rtable_Style/rtable_ABV 0.5\nrtable_Brew_Factory_Name/rtable_Style/rtable_ABV 0.0\ndtype: float64"
},
"execution_count": 17,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cf_summary"
],
"metadata": {
"collapsed": false,
"pycharm": {
"name": "#%%\n"
}
}
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {
"pycharm": {
"name": "#%%\n"
},
"scrolled": true
},
"outputs": [
{
"data": {
"text/plain": " label ltable_Beer_Name \\\n0 T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3... \n8 T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3... \n10 Honey Basil Amber \n12 Honey Basil Amber \n11 Honey Basil Amber \n14 Honey Basil Amber \n13 Honey Basil Amber \n\n ltable_Brew_Factory_Name ltable_Style ltable_ABV \\\n0 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n8 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n10 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n12 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n11 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n14 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n13 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n\n rtable_Beer_Name rtable_Brew_Factory_Name \\\n0 Rude Hippo Honey Basil Amber 18th Street Brewery \n8 Rude Hippo Honey Basil Amber 18th Street Brewery \n10 te de Jade Brigantine Ambrà © e 18th Street Brewery \n12 _ i_lavar Barley Wine 18th Street Brewery \n11 Cà ´ te de Jade Brigantine Ambrà © 18th Street Brewery \n14 Cà ´ te de Jade Brigantine Ambrà © e 18th Street Brewery \n13 Cà ´ te de Jade Brigantine Ambrà 18th Street Brewery \n\n rtable_Style rtable_ABV match_score nomatch_score \\\n0 Amber Ale 7.40 % 0.19255195558071136 0.8074480444192886 \n8 Amber Ale 7.40 % 0.17068031430244446 0.8293196856975555 \n10 Amber Ale 7.40 % 0.10978817939758301 0.890211820602417 \n12 Amber Ale 7.40 % 0.09088551998138428 0.9091144800186157 \n11 Amber Ale 7.40 % 0.10560467094182968 0.8943953290581703 \n14 Amber Ale 7.40 % 0.08469007909297943 0.9153099209070206 \n13 Amber Ale 7.40 % 0.10777609050273895 0.892223909497261 \n\n alteredAttributes droppedValues \\\n0 ('ltable_Beer_Name',) ['Honey Basil Amber'] \n8 ('ltable_Beer_Name',) ['Honey Basil Amber'] \n10 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n12 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n11 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n14 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n13 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n\n copiedValues triangle \\\n0 ['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å... 0@3917 1@2224 0@72990 \n8 ['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å... 0@3917 1@2224 0@72978 \n10 ['te de Jade Brigantine Ambrà © e'] 1@2224 0@3917 1@25064 \n12 ['_ i_lavar Barley Wine'] 1@2224 0@3917 1@42959 \n11 ['Cà ´ te de Jade Brigantine Ambrà ©'] 1@2224 0@3917 1@25077 \n14 ['Cà ´ te de Jade Brigantine Ambrà © e'] 1@2224 0@3917 1@25100 \n13 ['Cà ´ te de Jade Brigantine AmbrÃ'] 1@2224 0@3917 1@25075 \n\n attr_count \n0 2 \n8 2 \n10 2 \n12 2 \n11 2 \n14 2 \n13 2 ",
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>label</th>\n <th>ltable_Beer_Name</th>\n <th>ltable_Brew_Factory_Name</th>\n <th>ltable_Style</th>\n <th>ltable_ABV</th>\n <th>rtable_Beer_Name</th>\n <th>rtable_Brew_Factory_Name</th>\n <th>rtable_Style</th>\n <th>rtable_ABV</th>\n <th>match_score</th>\n <th>nomatch_score</th>\n <th>alteredAttributes</th>\n <th>droppedValues</th>\n <th>copiedValues</th>\n <th>triangle</th>\n <th>attr_count</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td></td>\n <td>T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3...</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Rude Hippo Honey Basil Amber</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.19255195558071136</td>\n <td>0.8074480444192886</td>\n <td>('ltable_Beer_Name',)</td>\n <td>['Honey Basil Amber']</td>\n <td>['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å...</td>\n <td>0@3917 1@2224 0@72990</td>\n <td>2</td>\n </tr>\n <tr>\n <th>8</th>\n <td></td>\n <td>T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3...</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Rude Hippo Honey Basil Amber</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.17068031430244446</td>\n <td>0.8293196856975555</td>\n <td>('ltable_Beer_Name',)</td>\n <td>['Honey Basil Amber']</td>\n <td>['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å...</td>\n <td>0@3917 1@2224 0@72978</td>\n <td>2</td>\n </tr>\n <tr>\n <th>10</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>te de Jade Brigantine Ambrà © e</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.10978817939758301</td>\n <td>0.890211820602417</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['te de Jade Brigantine Ambrà © e']</td>\n <td>1@2224 0@3917 1@25064</td>\n <td>2</td>\n </tr>\n <tr>\n <th>12</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>_ i_lavar Barley Wine</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.09088551998138428</td>\n <td>0.9091144800186157</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['_ i_lavar Barley Wine']</td>\n <td>1@2224 0@3917 1@42959</td>\n <td>2</td>\n </tr>\n <tr>\n <th>11</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Cà ´ te de Jade Brigantine Ambrà ©</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.10560467094182968</td>\n <td>0.8943953290581703</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['Cà ´ te de Jade Brigantine Ambrà ©']</td>\n <td>1@2224 0@3917 1@25077</td>\n <td>2</td>\n </tr>\n <tr>\n <th>14</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Cà ´ te de Jade Brigantine Ambrà © e</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.08469007909297943</td>\n <td>0.9153099209070206</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['Cà ´ te de Jade Brigantine Ambrà © e']</td>\n <td>1@2224 0@3917 1@25100</td>\n <td>2</td>\n </tr>\n <tr>\n <th>13</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Cà ´ te de Jade Brigantine AmbrÃ</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.10777609050273895</td>\n <td>0.892223909497261</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['Cà ´ te de Jade Brigantine AmbrÃ']</td>\n <td>1@2224 0@3917 1@25075</td>\n <td>2</td>\n </tr>\n </tbody>\n</table>\n</div>"
},
"execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"counterfactual_examples"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.8"
}
},
"nbformat": 4,
"nbformat_minor": 4
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment