--cb_dro demo for vowpal wabbit using covertype. To see the lift, note the "since last acc" column with and without --cb-dro.
{ | |
"cells": [ | |
{ | |
"cell_type": "code", | |
"execution_count": 1, | |
"metadata": { | |
"code_folding": [] | |
}, | |
"outputs": [], | |
"source": [ | |
"class EasyAcc:\n", | |
" def __init__(self):\n", | |
" self.n = 0\n", | |
" self.sum = 0\n", | |
"\n", | |
" def __iadd__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum += other\n", | |
" return self\n", | |
" \n", | |
" def __isub__(self, other):\n", | |
" self.n += 1\n", | |
" self.sum -= other\n", | |
" return self\n", | |
"\n", | |
" def mean(self):\n", | |
" return self.sum / max(self.n, 1)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 2, | |
"metadata": { | |
"code_folding": [ | |
10, | |
30, | |
72 | |
] | |
}, | |
"outputs": [], | |
"source": [ | |
"def cb_explore_adf_covertype_demo(usedro=False):\n", | |
" from collections import Counter\n", | |
" from sklearn.datasets import fetch_covtype\n", | |
" from sklearn.decomposition import PCA\n", | |
" from vowpalwabbit import pyvw\n", | |
" from math import ceil\n", | |
" import numpy as np\n", | |
" \n", | |
" np.random.seed(31337)\n", | |
"\n", | |
" if True:\n", | |
" Object = lambda **kwargs: type(\"Object\", (), kwargs)()\n", | |
"\n", | |
" cov = fetch_covtype()\n", | |
" cov.data = PCA(whiten=True).fit_transform(cov.data)\n", | |
" cov.target -= 1\n", | |
" assert 7 == len(Counter(cov.target))\n", | |
" npretrain = ceil(0.1 * cov.data.shape[0])\n", | |
" order = np.random.permutation(cov.data.shape[0])\n", | |
" pretrain = Object(data = cov.data[order[:npretrain]], target = cov.target[order[:npretrain]])\n", | |
" offpolicylearn = Object(data = cov.data[order[npretrain:]], target = cov.target[order[npretrain:]])\n", | |
" \n", | |
" print('****** pretraining phase (online learning) ******')\n", | |
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n", | |
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n", | |
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" vw = pyvw.vw('--cb_explore_adf --cubic axx -q ax --ignore_linear x --noconstant')\n", | |
" for exno, (ex, label) in enumerate(zip(pretrain.data, pretrain.target)):\n", | |
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n", | |
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n", | |
" pred = vw.predict(exstr)\n", | |
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n", | |
" probs /= np.sum(probs)\n", | |
" action = np.random.choice(7, p=probs)\n", | |
" loggingacc += 1 if action == label else 0\n", | |
" plog += probs[action]\n", | |
" sincelastplog += probs[action]\n", | |
" \n", | |
" argmaxaction = np.argmax(probs)\n", | |
" piacc += 1 if argmaxaction == label else 0\n", | |
" sincelastpiacc += 1 if argmaxaction == label else 0\n", | |
" \n", | |
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n", | |
" for k in range(7)\n", | |
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n", | |
" ])\n", | |
" \n", | |
" vw.learn(labelexstr)\n", | |
"\n", | |
" if (exno & (exno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n", | |
" loggingacc.n,\n", | |
" loggingacc.mean(),\n", | |
" plog.mean(),\n", | |
" sincelastplog.mean(),\n", | |
" piacc.mean(),\n", | |
" sincelastpiacc.mean()\n", | |
" ),\n", | |
" flush=True)\n", | |
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]\n", | |
" \n", | |
" print('****** off-policy learning phase ******')\n", | |
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n", | |
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n", | |
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n", | |
" ),\n", | |
" flush=True)\n", | |
" \n", | |
" offpolicyvw = pyvw.vw(f'--cb_adf --cubic axx -q ax --ignore_linear x --noconstant {\"--cb_dro\" if usedro else \"\"}') \n", | |
" for exno, (ex, label) in enumerate(zip(offpolicylearn.data, offpolicylearn.target)):\n", | |
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n", | |
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n", | |
" pred = vw.predict(exstr)\n", | |
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n", | |
" probs /= np.sum(probs)\n", | |
" action = np.random.choice(7, p=probs)\n", | |
" loggingacc += 1 if action == label else 0\n", | |
" plog += probs[action]\n", | |
" sincelastplog += probs[action]\n", | |
" \n", | |
" offpred = offpolicyvw.predict(exstr)\n", | |
" argmaxaction = np.argmin(offpred)\n", | |
" piacc += 1 if argmaxaction == label else 0\n", | |
" sincelastpiacc += 1 if argmaxaction == label else 0\n", | |
" \n", | |
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n", | |
" for k in range(7)\n", | |
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n", | |
" ])\n", | |
" \n", | |
" offpolicyvw.learn(labelexstr)\n", | |
"\n", | |
" if (exno & (exno - 1) == 0):\n", | |
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n", | |
" loggingacc.n,\n", | |
" loggingacc.mean(),\n", | |
" plog.mean(),\n", | |
" sincelastplog.mean(),\n", | |
" piacc.mean(),\n", | |
" sincelastpiacc.mean()\n", | |
" ),\n", | |
" flush=True)\n", | |
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 3, | |
"metadata": { | |
"code_folding": [ | |
10, | |
31 | |
] | |
}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"****** pretraining phase (online learning) ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n", | |
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n", | |
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n", | |
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n", | |
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n", | |
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n", | |
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n", | |
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n", | |
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n", | |
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n", | |
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n", | |
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n", | |
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n", | |
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n", | |
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n", | |
"****** off-policy learning phase ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n", | |
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n", | |
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n", | |
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n", | |
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n", | |
"17 \t0.76471 \t0.95714 \t0.95714 \t0.17647 \t0.25000 \n", | |
"33 \t0.72727 \t0.89957 \t0.83839 \t0.09091 \t0.00000 \n", | |
"65 \t0.70769 \t0.91330 \t0.92746 \t0.29231 \t0.50000 \n", | |
"129 \t0.68217 \t0.91296 \t0.91261 \t0.32558 \t0.35938 \n", | |
"257 \t0.63424 \t0.88691 \t0.86066 \t0.42023 \t0.51562 \n", | |
"513 \t0.64522 \t0.89233 \t0.89777 \t0.50487 \t0.58984 \n", | |
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.56195 \t0.61914 \n", | |
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.61689 \t0.67188 \n", | |
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63730 \t0.65771 \n", | |
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66337 \t0.68945 \n", | |
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68459 \t0.70581 \n", | |
"32769\t0.69676 \t0.91618 \t0.91713 \t0.69398 \t0.70337 \n", | |
"65537\t0.69745 \t0.91590 \t0.91563 \t0.69762 \t0.70126 \n", | |
"131073\t0.69714 \t0.91613 \t0.91637 \t0.70680 \t0.71599 \n", | |
"262145\t0.69574 \t0.91587 \t0.91561 \t0.71269 \t0.71858 \n" | |
] | |
} | |
], | |
"source": [ | |
"cb_explore_adf_covertype_demo()" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": 4, | |
"metadata": {}, | |
"outputs": [ | |
{ | |
"name": "stdout", | |
"output_type": "stream", | |
"text": [ | |
"****** pretraining phase (online learning) ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n", | |
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n", | |
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n", | |
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n", | |
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n", | |
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n", | |
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n", | |
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n", | |
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n", | |
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n", | |
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n", | |
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n", | |
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n", | |
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n", | |
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n", | |
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n", | |
"****** off-policy learning phase ******\n", | |
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n", | |
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n", | |
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n", | |
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n", | |
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n", | |
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n", | |
"17 \t0.76471 \t0.95714 \t0.95714 \t0.11765 \t0.12500 \n", | |
"33 \t0.72727 \t0.89957 \t0.83839 \t0.06061 \t0.00000 \n", | |
"65 \t0.70769 \t0.91330 \t0.92746 \t0.12308 \t0.18750 \n", | |
"129 \t0.68217 \t0.91296 \t0.91261 \t0.17054 \t0.21875 \n", | |
"257 \t0.63424 \t0.88691 \t0.86066 \t0.30350 \t0.43750 \n", | |
"513 \t0.64522 \t0.89233 \t0.89777 \t0.45419 \t0.60547 \n", | |
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.53854 \t0.62305 \n", | |
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.59834 \t0.65820 \n", | |
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63241 \t0.66650 \n", | |
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66301 \t0.69360 \n", | |
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68727 \t0.71155 \n", | |
"32769\t0.69676 \t0.91618 \t0.91713 \t0.70094 \t0.71460 \n", | |
"65537\t0.69745 \t0.91590 \t0.91563 \t0.70963 \t0.71832 \n", | |
"131073\t0.69714 \t0.91613 \t0.91637 \t0.71910 \t0.72856 \n", | |
"262145\t0.69574 \t0.91587 \t0.91561 \t0.72653 \t0.73396 \n" | |
] | |
} | |
], | |
"source": [ | |
"cb_explore_adf_covertype_demo(usedro=True)" | |
] | |
}, | |
{ | |
"cell_type": "code", | |
"execution_count": null, | |
"metadata": {}, | |
"outputs": [], | |
"source": [] | |
} | |
], | |
"metadata": { | |
"kernelspec": { | |
"display_name": "Python 3", | |
"language": "python", | |
"name": "python3" | |
}, | |
"language_info": { | |
"codemirror_mode": { | |
"name": "ipython", | |
"version": 3 | |
}, | |
"file_extension": ".py", | |
"mimetype": "text/x-python", | |
"name": "python", | |
"nbconvert_exporter": "python", | |
"pygments_lexer": "ipython3", | |
"version": "3.7.4" | |
} | |
}, | |
"nbformat": 4, | |
"nbformat_minor": 2 | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
This comment has been minimized.
In this gist I:
--cb_dro
flag--cb_dro
improves the trained policy from 71.8% to 73.3% accuracy.