Skip to content

Instantly share code, notes, and snippets.

@pmineiro

pmineiro/CbDroDemo.ipynb

Last active Dec 7, 2020
Embed
What would you like to do?
--cb_dro demo for vowpal wabbit using covertype. To see the lift, note the "since last acc" column with and without --cb-dro.
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"code_folding": []
},
"outputs": [],
"source": [
"class EasyAcc:\n",
" def __init__(self):\n",
" self.n = 0\n",
" self.sum = 0\n",
"\n",
" def __iadd__(self, other):\n",
" self.n += 1\n",
" self.sum += other\n",
" return self\n",
" \n",
" def __isub__(self, other):\n",
" self.n += 1\n",
" self.sum -= other\n",
" return self\n",
"\n",
" def mean(self):\n",
" return self.sum / max(self.n, 1)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"code_folding": [
10,
30,
72
]
},
"outputs": [],
"source": [
"def cb_explore_adf_covertype_demo(usedro=False):\n",
" from collections import Counter\n",
" from sklearn.datasets import fetch_covtype\n",
" from sklearn.decomposition import PCA\n",
" from vowpalwabbit import pyvw\n",
" from math import ceil\n",
" import numpy as np\n",
" \n",
" np.random.seed(31337)\n",
"\n",
" if True:\n",
" Object = lambda **kwargs: type(\"Object\", (), kwargs)()\n",
"\n",
" cov = fetch_covtype()\n",
" cov.data = PCA(whiten=True).fit_transform(cov.data)\n",
" cov.target -= 1\n",
" assert 7 == len(Counter(cov.target))\n",
" npretrain = ceil(0.1 * cov.data.shape[0])\n",
" order = np.random.permutation(cov.data.shape[0])\n",
" pretrain = Object(data = cov.data[order[:npretrain]], target = cov.target[order[:npretrain]])\n",
" offpolicylearn = Object(data = cov.data[order[npretrain:]], target = cov.target[order[npretrain:]])\n",
" \n",
" print('****** pretraining phase (online learning) ******')\n",
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n",
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n",
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n",
" ),\n",
" flush=True)\n",
" \n",
" vw = pyvw.vw('--cb_explore_adf --cubic axx -q ax --ignore_linear x --noconstant')\n",
" for exno, (ex, label) in enumerate(zip(pretrain.data, pretrain.target)):\n",
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n",
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n",
" pred = vw.predict(exstr)\n",
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n",
" probs /= np.sum(probs)\n",
" action = np.random.choice(7, p=probs)\n",
" loggingacc += 1 if action == label else 0\n",
" plog += probs[action]\n",
" sincelastplog += probs[action]\n",
" \n",
" argmaxaction = np.argmax(probs)\n",
" piacc += 1 if argmaxaction == label else 0\n",
" sincelastpiacc += 1 if argmaxaction == label else 0\n",
" \n",
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n",
" for k in range(7)\n",
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n",
" ])\n",
" \n",
" vw.learn(labelexstr)\n",
"\n",
" if (exno & (exno - 1) == 0):\n",
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n",
" loggingacc.n,\n",
" loggingacc.mean(),\n",
" plog.mean(),\n",
" sincelastplog.mean(),\n",
" piacc.mean(),\n",
" sincelastpiacc.mean()\n",
" ),\n",
" flush=True)\n",
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]\n",
" \n",
" print('****** off-policy learning phase ******')\n",
" loggingacc, plog, piacc, sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(5) ]\n",
" print('{:<5s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}\\t{:<9s}'.format(\n",
" 'n', 'log acc', 'plog', 'since last plog', 'pi acc', 'since last acc'\n",
" ),\n",
" flush=True)\n",
" \n",
" offpolicyvw = pyvw.vw(f'--cb_adf --cubic axx -q ax --ignore_linear x --noconstant {\"--cb_dro\" if usedro else \"\"}') \n",
" for exno, (ex, label) in enumerate(zip(offpolicylearn.data, offpolicylearn.target)):\n",
" sharedfeat = ' '.join([ 'shared |x'] + [ f'{k}:{v}' for k, v in enumerate(ex) if v != 0 ])\n",
" exstr = '\\n'.join([ sharedfeat ] + [ f' |a {k+1}:1' for k in range(7) ])\n",
" pred = vw.predict(exstr)\n",
" probs = np.clip(np.array(pred), a_min=0, a_max=None)\n",
" probs /= np.sum(probs)\n",
" action = np.random.choice(7, p=probs)\n",
" loggingacc += 1 if action == label else 0\n",
" plog += probs[action]\n",
" sincelastplog += probs[action]\n",
" \n",
" offpred = offpolicyvw.predict(exstr)\n",
" argmaxaction = np.argmin(offpred)\n",
" piacc += 1 if argmaxaction == label else 0\n",
" sincelastpiacc += 1 if argmaxaction == label else 0\n",
" \n",
" labelexstr = '\\n'.join([ sharedfeat ] + [ f' {l} |a {k+1}:1' \n",
" for k in range(7)\n",
" for l in (f'0:{-1 if action == label else 0}:{probs[k]}' if action == k else '',)\n",
" ])\n",
" \n",
" offpolicyvw.learn(labelexstr)\n",
"\n",
" if (exno & (exno - 1) == 0):\n",
" print('{:<5d}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}\\t{:<9.5f}'.format(\n",
" loggingacc.n,\n",
" loggingacc.mean(),\n",
" plog.mean(),\n",
" sincelastplog.mean(),\n",
" piacc.mean(),\n",
" sincelastpiacc.mean()\n",
" ),\n",
" flush=True)\n",
" sincelastplog, sincelastpiacc = [ EasyAcc() for _ in range(2) ]"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"code_folding": [
10,
31
]
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"****** pretraining phase (online learning) ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n",
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n",
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n",
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n",
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n",
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n",
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n",
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n",
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n",
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n",
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n",
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n",
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n",
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n",
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n",
"****** off-policy learning phase ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n",
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n",
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n",
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n",
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n",
"17 \t0.76471 \t0.95714 \t0.95714 \t0.17647 \t0.25000 \n",
"33 \t0.72727 \t0.89957 \t0.83839 \t0.09091 \t0.00000 \n",
"65 \t0.70769 \t0.91330 \t0.92746 \t0.29231 \t0.50000 \n",
"129 \t0.68217 \t0.91296 \t0.91261 \t0.32558 \t0.35938 \n",
"257 \t0.63424 \t0.88691 \t0.86066 \t0.42023 \t0.51562 \n",
"513 \t0.64522 \t0.89233 \t0.89777 \t0.50487 \t0.58984 \n",
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.56195 \t0.61914 \n",
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.61689 \t0.67188 \n",
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63730 \t0.65771 \n",
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66337 \t0.68945 \n",
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68459 \t0.70581 \n",
"32769\t0.69676 \t0.91618 \t0.91713 \t0.69398 \t0.70337 \n",
"65537\t0.69745 \t0.91590 \t0.91563 \t0.69762 \t0.70126 \n",
"131073\t0.69714 \t0.91613 \t0.91637 \t0.70680 \t0.71599 \n",
"262145\t0.69574 \t0.91587 \t0.91561 \t0.71269 \t0.71858 \n"
]
}
],
"source": [
"cb_explore_adf_covertype_demo()"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"****** pretraining phase (online learning) ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"2 \t0.50000 \t0.14286 \t0.14286 \t0.00000 \t0.00000 \n",
"3 \t0.33333 \t0.41429 \t0.95714 \t0.00000 \t0.00000 \n",
"5 \t0.40000 \t0.63143 \t0.95714 \t0.20000 \t0.50000 \n",
"9 \t0.44444 \t0.68823 \t0.75923 \t0.44444 \t0.75000 \n",
"17 \t0.35294 \t0.66576 \t0.64048 \t0.47059 \t0.50000 \n",
"33 \t0.27273 \t0.68709 \t0.70975 \t0.33333 \t0.18750 \n",
"65 \t0.35385 \t0.77619 \t0.86808 \t0.36923 \t0.40625 \n",
"129 \t0.44961 \t0.84633 \t0.91756 \t0.45736 \t0.54688 \n",
"257 \t0.53307 \t0.89043 \t0.93488 \t0.53307 \t0.60938 \n",
"513 \t0.51462 \t0.90150 \t0.91261 \t0.52437 \t0.51562 \n",
"1025 \t0.52976 \t0.90983 \t0.91818 \t0.54634 \t0.56836 \n",
"2049 \t0.56418 \t0.91493 \t0.92003 \t0.58077 \t0.61523 \n",
"4097 \t0.59629 \t0.91702 \t0.91911 \t0.61801 \t0.65527 \n",
"8193 \t0.63066 \t0.91737 \t0.91771 \t0.65507 \t0.69214 \n",
"16385\t0.65621 \t0.91864 \t0.91992 \t0.68093 \t0.70679 \n",
"32769\t0.66535 \t0.91777 \t0.91690 \t0.69166 \t0.70239 \n",
"****** off-policy learning phase ******\n",
"n \tlog acc \tplog \tsince last plog\tpi acc \tsince last acc\n",
"1 \t0.00000 \t0.95714 \t0.95714 \t1.00000 \t1.00000 \n",
"2 \t0.50000 \t0.95714 \t0.95714 \t0.50000 \t0.00000 \n",
"3 \t0.33333 \t0.95714 \t0.95714 \t0.33333 \t0.00000 \n",
"5 \t0.60000 \t0.95714 \t0.95714 \t0.20000 \t0.00000 \n",
"9 \t0.77778 \t0.95714 \t0.95714 \t0.11111 \t0.00000 \n",
"17 \t0.76471 \t0.95714 \t0.95714 \t0.11765 \t0.12500 \n",
"33 \t0.72727 \t0.89957 \t0.83839 \t0.06061 \t0.00000 \n",
"65 \t0.70769 \t0.91330 \t0.92746 \t0.12308 \t0.18750 \n",
"129 \t0.68217 \t0.91296 \t0.91261 \t0.17054 \t0.21875 \n",
"257 \t0.63424 \t0.88691 \t0.86066 \t0.30350 \t0.43750 \n",
"513 \t0.64522 \t0.89233 \t0.89777 \t0.45419 \t0.60547 \n",
"1025 \t0.66146 \t0.90987 \t0.92746 \t0.53854 \t0.62305 \n",
"2049 \t0.67155 \t0.90846 \t0.90705 \t0.59834 \t0.65820 \n",
"4097 \t0.67708 \t0.91216 \t0.91586 \t0.63241 \t0.66650 \n",
"8193 \t0.68742 \t0.91331 \t0.91447 \t0.66301 \t0.69360 \n",
"16385\t0.69509 \t0.91522 \t0.91713 \t0.68727 \t0.71155 \n",
"32769\t0.69676 \t0.91618 \t0.91713 \t0.70094 \t0.71460 \n",
"65537\t0.69745 \t0.91590 \t0.91563 \t0.70963 \t0.71832 \n",
"131073\t0.69714 \t0.91613 \t0.91637 \t0.71910 \t0.72856 \n",
"262145\t0.69574 \t0.91587 \t0.91561 \t0.72653 \t0.73396 \n"
]
}
],
"source": [
"cb_explore_adf_covertype_demo(usedro=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
@pmineiro

This comment has been minimized.

Copy link
Owner Author

@pmineiro pmineiro commented Dec 5, 2020

In this gist I:

  • pre-train a logging policy using 10% of covertype, and then fix the logging policy thereafter
  • off-policy train another policy using data from the logging policy, either with or without the --cb_dro flag
  • --cb_dro improves the trained policy from 71.8% to 73.3% accuracy.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment