Last active
June 1, 2022 08:22
-
-
Save tteofili/b4c81a3de6aef40e8dfa27eaf22f116d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import pandas as pd\n", | |
"from certa.explain import CertaExplainer\n", | |
"from certa.utils import merge_sources\n", | |
"from certa.models.ditto.ditto import DittoModel" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"datadir = 'datasets/Beer'\n", | |
"lsource = pd.read_csv(datadir + '/tableA.csv')\n", | |
"rsource = pd.read_csv(datadir + '/tableB.csv')\n", | |
"gt = pd.read_csv(datadir + '/train.csv')\n", | |
"valid = pd.read_csv(datadir + '/valid.csv')\n", | |
"test = pd.read_csv(datadir + '/test.csv')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"import torch\n", | |
"\n", | |
"pt_model_dict = torch.load('ditto/Structured/Beer/model.pt')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight']\n", | |
"- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" | |
] | |
} | |
], | |
"source": [ | |
"ditto_model = DittoModel(lm='distilbert', device='cpu')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 5, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "<All keys matched successfully>" | |
}, | |
"execution_count": 5, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"ditto_model.load_state_dict(pt_model_dict['model'])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 6, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"from certa.models.ditto.summarize import Summarizer\n", | |
"\n", | |
"summarizer = Summarizer(lsource, rsource, 'distilbert')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 7, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"certa_explainer = CertaExplainer(lsource, rsource)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 8, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"Some weights of the model checkpoint at distilbert-base-uncased were not used when initializing DistilBertModel: ['vocab_layer_norm.bias', 'vocab_transform.bias', 'vocab_layer_norm.weight', 'vocab_projector.weight', 'vocab_projector.bias', 'vocab_transform.weight']\n", | |
"- This IS expected if you are initializing DistilBertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n", | |
"- This IS NOT expected if you are initializing DistilBertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n" | |
] | |
} | |
], | |
"source": [ | |
"from certa.models.bert import EMTERModel\n", | |
"\n", | |
"model = EMTERModel(ditto=True, summarizer=summarizer, dk='general')" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 9, | |
"metadata": {}, | |
"outputs": [], | |
"source": [ | |
"model.model = ditto_model" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 10, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"def predict_fn(x):\n", | |
" return model.predict(x, len=64)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 11, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [], | |
"source": [ | |
"test_df = merge_sources(test, 'ltable_', 'rtable_', lsource, rsource, ['label'], [])" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 12, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "label 1\nltable_ABV 7.40 %\nltable_Beer_Name Honey Basil Amber\nltable_Brew_Factory_Name Rude Hippo Brewing Company\nltable_Style American Amber / Red Ale\nName: 2, dtype: object" | |
}, | |
"execution_count": 12, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"idx = 2\n", | |
"rand_row = test_df.iloc[idx]\n", | |
"l_id = int(rand_row['ltable_id'])\n", | |
"l_tuple = lsource.iloc[l_id]\n", | |
"r_id = int(rand_row['rtable_id'])\n", | |
"r_tuple = rsource.iloc[r_id]\n", | |
"rand_row.head()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 13, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "id 3917\nBeer_Name Honey Basil Amber\nBrew_Factory_Name Rude Hippo Brewing Company\nStyle American Amber / Red Ale\nABV 7.40 %\nName: 3917, dtype: object" | |
}, | |
"execution_count": 13, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"l_tuple" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 14, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "id 2224\nBeer_Name Rude Hippo Honey Basil Amber\nBrew_Factory_Name 18th Street Brewery\nStyle Amber Ale\nABV 7.40 %\nName: 2224, dtype: object" | |
}, | |
"execution_count": 14, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"r_tuple" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 15, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
}, | |
"outputs": [ | |
{ | |
"name": "stderr", | |
"output_type": "stream", | |
"text": [ | |
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n", | |
" warnings.warn(Warnings.W108)\n", | |
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n", | |
" warnings.warn(Warnings.W108)\n", | |
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n", | |
" warnings.warn(Warnings.W108)\n", | |
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n", | |
" warnings.warn(Warnings.W108)\n", | |
"/home/tteofili/.local/lib/python3.6/site-packages/spacy/pipeline/lemmatizer.py:211: UserWarning: [W108] The rule-based lemmatizer did not find POS annotation for one or more tokens. Check that your pipeline includes components that assign token.pos, typically 'tagger'+'attribute_ruler' or 'morphologizer'.\n", | |
" warnings.warn(Warnings.W108)\n" | |
] | |
} | |
], | |
"source": [ | |
"saliency_df, cf_summary, counterfactual_examples, triangles, lattices = certa_explainer.explain(l_tuple, r_tuple, predict_fn, num_triangles=100)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 16, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "{'ltable_Beer_Name': {0: 0.5113636363636364},\n 'ltable_Brew_Factory_Name': {0: 0.375},\n 'ltable_Style': {0: 0.3181818181818182},\n 'ltable_ABV': {0: 0.2954545454545454},\n 'rtable_Beer_Name': {0: 0.5113636363636364},\n 'rtable_Brew_Factory_Name': {0: 0.28409090909090906},\n 'rtable_Style': {0: 0.28409090909090906},\n 'rtable_ABV': {0: 0.28409090909090906}}" | |
}, | |
"execution_count": 16, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"saliency_df.to_dict()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 17, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": "ltable_Beer_Name 0.5\nltable_Brew_Factory_Name 0.5\nltable_Style 0.0\nltable_ABV 0.0\nrtable_Beer_Name 0.5\nrtable_Brew_Factory_Name 0.0\nrtable_Style 0.0\nrtable_ABV 0.0\nltable_Beer_Name/ltable_Brew_Factory_Name 0.5\nltable_Beer_Name/ltable_Style 0.5\nltable_Beer_Name/ltable_ABV 0.5\nltable_Brew_Factory_Name/ltable_Style 0.2\nltable_Brew_Factory_Name/ltable_ABV 0.0\nltable_Style/ltable_ABV 0.0\nrtable_Beer_Name/rtable_Brew_Factory_Name 0.5\nrtable_Beer_Name/rtable_Style 0.5\nrtable_Beer_Name/rtable_ABV 0.5\nrtable_Brew_Factory_Name/rtable_Style 0.0\nrtable_Brew_Factory_Name/rtable_ABV 0.0\nrtable_Style/rtable_ABV 0.0\nltable_Beer_Name/ltable_Brew_Factory_Name/ltable_Style 0.5\nltable_Beer_Name/ltable_Brew_Factory_Name/ltable_ABV 0.5\nltable_Beer_Name/ltable_Style/ltable_ABV 0.5\nltable_Brew_Factory_Name/ltable_Style/ltable_ABV 0.1\nrtable_Beer_Name/rtable_Brew_Factory_Name/rtable_Style 0.5\nrtable_Beer_Name/rtable_Brew_Factory_Name/rtable_ABV 0.5\nrtable_Beer_Name/rtable_Style/rtable_ABV 0.5\nrtable_Brew_Factory_Name/rtable_Style/rtable_ABV 0.0\ndtype: float64" | |
}, | |
"execution_count": 17, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"cf_summary" | |
], | |
"metadata": { | |
"collapsed": false, | |
"pycharm": { | |
"name": "#%%\n" | |
} | |
} | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 18, | |
"metadata": { | |
"pycharm": { | |
"name": "#%%\n" | |
}, | |
"scrolled": true | |
}, | |
"outputs": [ | |
{ | |
"data": { | |
"text/plain": " label ltable_Beer_Name \\\n0 T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3... \n8 T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3... \n10 Honey Basil Amber \n12 Honey Basil Amber \n11 Honey Basil Amber \n14 Honey Basil Amber \n13 Honey Basil Amber \n\n ltable_Brew_Factory_Name ltable_Style ltable_ABV \\\n0 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n8 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n10 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n12 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n11 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n14 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n13 Rude Hippo Brewing Company American Amber / Red Ale 7.40 % \n\n rtable_Beer_Name rtable_Brew_Factory_Name \\\n0 Rude Hippo Honey Basil Amber 18th Street Brewery \n8 Rude Hippo Honey Basil Amber 18th Street Brewery \n10 te de Jade Brigantine Ambrà © e 18th Street Brewery \n12 _ i_lavar Barley Wine 18th Street Brewery \n11 Cà ´ te de Jade Brigantine Ambrà © 18th Street Brewery \n14 Cà ´ te de Jade Brigantine Ambrà © e 18th Street Brewery \n13 Cà ´ te de Jade Brigantine Ambrà 18th Street Brewery \n\n rtable_Style rtable_ABV match_score nomatch_score \\\n0 Amber Ale 7.40 % 0.19255195558071136 0.8074480444192886 \n8 Amber Ale 7.40 % 0.17068031430244446 0.8293196856975555 \n10 Amber Ale 7.40 % 0.10978817939758301 0.890211820602417 \n12 Amber Ale 7.40 % 0.09088551998138428 0.9091144800186157 \n11 Amber Ale 7.40 % 0.10560467094182968 0.8943953290581703 \n14 Amber Ale 7.40 % 0.08469007909297943 0.9153099209070206 \n13 Amber Ale 7.40 % 0.10777609050273895 0.892223909497261 \n\n alteredAttributes droppedValues \\\n0 ('ltable_Beer_Name',) ['Honey Basil Amber'] \n8 ('ltable_Beer_Name',) ['Honey Basil Amber'] \n10 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n12 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n11 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n14 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n13 ('rtable_Beer_Name',) ['Rude Hippo Honey Basil Amber'] \n\n copiedValues triangle \\\n0 ['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å... 0@3917 1@2224 0@72990 \n8 ['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å... 0@3917 1@2224 0@72978 \n10 ['te de Jade Brigantine Ambrà © e'] 1@2224 0@3917 1@25064 \n12 ['_ i_lavar Barley Wine'] 1@2224 0@3917 1@42959 \n11 ['Cà ´ te de Jade Brigantine Ambrà ©'] 1@2224 0@3917 1@25077 \n14 ['Cà ´ te de Jade Brigantine Ambrà © e'] 1@2224 0@3917 1@25100 \n13 ['Cà ´ te de Jade Brigantine AmbrÃ'] 1@2224 0@3917 1@25075 \n\n attr_count \n0 2 \n8 2 \n10 2 \n12 2 \n11 2 \n14 2 \n13 2 ", | |
"text/html": "<div>\n<style scoped>\n .dataframe tbody tr th:only-of-type {\n vertical-align: middle;\n }\n\n .dataframe tbody tr th {\n vertical-align: top;\n }\n\n .dataframe thead th {\n text-align: right;\n }\n</style>\n<table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: right;\">\n <th></th>\n <th>label</th>\n <th>ltable_Beer_Name</th>\n <th>ltable_Brew_Factory_Name</th>\n <th>ltable_Style</th>\n <th>ltable_ABV</th>\n <th>rtable_Beer_Name</th>\n <th>rtable_Brew_Factory_Name</th>\n <th>rtable_Style</th>\n <th>rtable_ABV</th>\n <th>match_score</th>\n <th>nomatch_score</th>\n <th>alteredAttributes</th>\n <th>droppedValues</th>\n <th>copiedValues</th>\n <th>triangle</th>\n <th>attr_count</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <th>0</th>\n <td></td>\n <td>T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3...</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Rude Hippo Honey Basil Amber</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.19255195558071136</td>\n <td>0.8074480444192886</td>\n <td>('ltable_Beer_Name',)</td>\n <td>['Honey Basil Amber']</td>\n <td>['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å...</td>\n <td>0@3917 1@2224 0@72990</td>\n <td>2</td>\n </tr>\n <tr>\n <th>8</th>\n <td></td>\n <td>T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å 3...</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Rude Hippo Honey Basil Amber</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.17068031430244446</td>\n <td>0.8293196856975555</td>\n <td>('ltable_Beer_Name',)</td>\n <td>['Honey Basil Amber']</td>\n <td>['T_ebonickà © Geronimo Polotmavà 1/2 Rà 1/2 Å...</td>\n <td>0@3917 1@2224 0@72978</td>\n <td>2</td>\n </tr>\n <tr>\n <th>10</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>te de Jade Brigantine Ambrà © e</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.10978817939758301</td>\n <td>0.890211820602417</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['te de Jade Brigantine Ambrà © e']</td>\n <td>1@2224 0@3917 1@25064</td>\n <td>2</td>\n </tr>\n <tr>\n <th>12</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>_ i_lavar Barley Wine</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.09088551998138428</td>\n <td>0.9091144800186157</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['_ i_lavar Barley Wine']</td>\n <td>1@2224 0@3917 1@42959</td>\n <td>2</td>\n </tr>\n <tr>\n <th>11</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Cà ´ te de Jade Brigantine Ambrà ©</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.10560467094182968</td>\n <td>0.8943953290581703</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['Cà ´ te de Jade Brigantine Ambrà ©']</td>\n <td>1@2224 0@3917 1@25077</td>\n <td>2</td>\n </tr>\n <tr>\n <th>14</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Cà ´ te de Jade Brigantine Ambrà © e</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.08469007909297943</td>\n <td>0.9153099209070206</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['Cà ´ te de Jade Brigantine Ambrà © e']</td>\n <td>1@2224 0@3917 1@25100</td>\n <td>2</td>\n </tr>\n <tr>\n <th>13</th>\n <td></td>\n <td>Honey Basil Amber</td>\n <td>Rude Hippo Brewing Company</td>\n <td>American Amber / Red Ale</td>\n <td>7.40 %</td>\n <td>Cà ´ te de Jade Brigantine AmbrÃ</td>\n <td>18th Street Brewery</td>\n <td>Amber Ale</td>\n <td>7.40 %</td>\n <td>0.10777609050273895</td>\n <td>0.892223909497261</td>\n <td>('rtable_Beer_Name',)</td>\n <td>['Rude Hippo Honey Basil Amber']</td>\n <td>['Cà ´ te de Jade Brigantine AmbrÃ']</td>\n <td>1@2224 0@3917 1@25075</td>\n <td>2</td>\n </tr>\n </tbody>\n</table>\n</div>" | |
}, | |
"execution_count": 18, | |
"metadata": {}, | |
"output_type": "execute_result" | |
} | |
], | |
"source": [ | |
"counterfactual_examples" | |
] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.6.8" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 4 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment