Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Basic Negative Sampling Implementations"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pandas as pd\n",
"import numpy as np\n",
"from scipy import stats\n",
"from functools import partial\n",
"np.random.seed(322)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>n_items</th>\n",
" <th>pos_inds</th>\n",
" <th>n_samp</th>\n",
" <th>frac_pos</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>case0</th>\n",
" <td>25</td>\n",
" <td>[3, 9, 22]</td>\n",
" <td>1</td>\n",
" <td>0.12000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>case1</th>\n",
" <td>25</td>\n",
" <td>[3, 9, 22]</td>\n",
" <td>100</td>\n",
" <td>0.12000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>case2</th>\n",
" <td>25000</td>\n",
" <td>[3, 9, 22]</td>\n",
" <td>100</td>\n",
" <td>0.00012</td>\n",
" </tr>\n",
" <tr>\n",
" <th>case3</th>\n",
" <td>25</td>\n",
" <td>[3, 9, 22]</td>\n",
" <td>10000</td>\n",
" <td>0.12000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>case4</th>\n",
" <td>25</td>\n",
" <td>[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,...</td>\n",
" <td>100</td>\n",
" <td>0.88000</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" n_items pos_inds n_samp \\\n",
"case0 25 [3, 9, 22] 1 \n",
"case1 25 [3, 9, 22] 100 \n",
"case2 25000 [3, 9, 22] 100 \n",
"case3 25 [3, 9, 22] 10000 \n",
"case4 25 [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13,... 100 \n",
"\n",
" frac_pos \n",
"case0 0.12000 \n",
"case1 0.12000 \n",
"case2 0.00012 \n",
"case3 0.12000 \n",
"case4 0.88000 "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cases_df = pd.DataFrame.from_dict({\n",
" 'case0': {\n",
" 'n_items': 25,\n",
" 'pos_inds': np.array([3, 9, 22]),\n",
" 'n_samp': 1,\n",
" },\n",
" 'case1': {\n",
" 'n_items': 25,\n",
" 'pos_inds': np.array([3, 9, 22]),\n",
" 'n_samp': 100,\n",
" },\n",
" 'case2': {\n",
" 'n_items': 25_000,\n",
" 'pos_inds': np.array([3, 9, 22]),\n",
" 'n_samp': 100,\n",
" },\n",
" 'case3': {\n",
" 'n_items': 25,\n",
" 'pos_inds': np.array([3, 9, 22]),\n",
" 'n_samp': 10_000,\n",
" },\n",
" 'case4': {\n",
" 'n_items': 25,\n",
" 'pos_inds': np.arange(25-3),\n",
" 'n_samp': 100,\n",
" },\n",
"}, orient='index')\n",
"cases_df['frac_pos'] = cases_df['pos_inds'].map(len) / cases_df['n_items']\n",
"cases_df"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Incremental Guess and Check"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"def negsamp_incr(pos_check, pos_inds, n_items, n_samp=32):\n",
" \"\"\" Guess and check with arbitrary positivity check\n",
" \"\"\"\n",
" neg_inds = []\n",
" while len(neg_inds) < n_samp:\n",
" raw_samp = np.random.randint(0, n_items)\n",
" if not pos_check(raw_samp, pos_inds):\n",
" neg_inds.append(raw_samp)\n",
" return neg_inds"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"def negsamp_incr_naive(pos_inds, n_items, n_samp=32):\n",
" \"\"\" Guess and check with list membership\n",
" \"\"\"\n",
" pos_check = lambda raw_samp, pos_inds: raw_samp in pos_inds\n",
" return negsamp_incr(pos_check, pos_inds, n_items, n_samp)"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"def negsamp_incr_set(pos_inds, n_items, n_samp=32):\n",
" \"\"\" Guess and check with hashtable membership\n",
" \"\"\"\n",
" pos_inds = set(pos_inds)\n",
" pos_check = lambda raw_samp, pos_inds: raw_samp in pos_inds\n",
" return negsamp_incr(pos_check, pos_inds, n_items, n_samp)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"from bisect import bisect_left\n",
"\n",
"def bsearch_in(search_val, val_arr):\n",
" i = bisect_left(val_arr, search_val)\n",
" return i != len(val_arr) and val_arr[i] == search_val\n",
" \n",
"def negsamp_incr_bsearch(pos_inds, n_items, n_samp=32):\n",
" \"\"\" Guess and check with binary search\n",
" `pos_inds` is assumed to be ordered\n",
" \"\"\"\n",
" pos_check = bsearch_in\n",
" return negsamp_incr(pos_check, pos_inds, n_items, n_samp)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vectorized Guess and Check"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"def negsamp_vectorized_bsearch(pos_inds, n_items, n_samp=32):\n",
" \"\"\" Guess and check vectorized\n",
" Assumes that we are allowed to potentially \n",
" return less than n_samp samples\n",
" \"\"\"\n",
" raw_samps = np.random.randint(0, n_items, size=n_samp)\n",
" ss = np.searchsorted(pos_inds, raw_samps)\n",
" pos_mask = raw_samps == np.take(pos_inds, ss, mode='clip')\n",
" neg_inds = raw_samps[~pos_mask]\n",
" return neg_inds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Vectorized Pre-verified"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def negsamp_vectorized_bsearch_preverif(pos_inds, n_items, n_samp=32):\n",
" \"\"\" Pre-verified with binary search\n",
" `pos_inds` is assumed to be ordered\n",
" \"\"\"\n",
" raw_samp = np.random.randint(0, n_items - len(pos_inds), size=n_samp)\n",
" pos_inds_adj = pos_inds - np.arange(len(pos_inds))\n",
" neg_inds = raw_samp + np.searchsorted(pos_inds_adj, raw_samp, side='right')\n",
" return neg_inds"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Sanity Checking"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"strategies = [\n",
" negsamp_incr_naive,\n",
" negsamp_incr_set,\n",
" negsamp_incr_bsearch,\n",
" negsamp_vectorized_bsearch,\n",
" negsamp_vectorized_bsearch_preverif,\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [],
"source": [
"# Just a quick sanity check\n",
"def is_valid(samps, n_items, pos_inds):\n",
" densities, bin_edges = np.histogram(samps, bins=range(n_items+1), density=True)\n",
" neg_inds = np.array(list(set(np.arange(n_items)) - set(pos_inds)))\n",
" # Should not have any positives sampled as negatives\n",
" is_non_pos = not densities[pos_inds].any()\n",
" # Distribution should be ~uniform\n",
" is_uniform = (stats.chisquare(densities[neg_inds]).pvalue > 0.95 or \n",
" len(samps)<=1 # let's be forgiving if n_samp=1\n",
" )\n",
" return is_non_pos and is_uniform\n",
"\n",
"for case_name, row in cases_df.iterrows():\n",
" for strat in strategies:\n",
" samps = strat(row['pos_inds'], row['n_items'], row['n_samp'])\n",
" assert is_valid(samps, row['n_items'], row['pos_inds'])"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Timed Runs"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"# Run our strategies against our cases\n",
"times_d = {}\n",
"for case_name, row in cases_df.iterrows():\n",
" d = {}\n",
" for strat in strategies:\n",
" r = %timeit -o -q strat(row['pos_inds'], row['n_items'], row['n_samp'])\n",
" d[strat.__name__] = r.average\n",
" times_d[case_name] = d"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>case0</th>\n",
" <th>case1</th>\n",
" <th>case2</th>\n",
" <th>case3</th>\n",
" <th>case4</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>negsamp_incr_naive</th>\n",
" <td>0.000025</td>\n",
" <td>0.000557</td>\n",
" <td>0.000495</td>\n",
" <td>0.051792</td>\n",
" <td>0.003816</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_incr_set</th>\n",
" <td>0.000021</td>\n",
" <td>0.000185</td>\n",
" <td>0.000171</td>\n",
" <td>0.016344</td>\n",
" <td>0.001264</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_incr_bsearch</th>\n",
" <td>0.000020</td>\n",
" <td>0.000255</td>\n",
" <td>0.000207</td>\n",
" <td>0.023141</td>\n",
" <td>0.001898</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_vectorized_bsearch</th>\n",
" <td>0.000025</td>\n",
" <td>0.000027</td>\n",
" <td>0.000027</td>\n",
" <td>0.000214</td>\n",
" <td>0.000029</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_vectorized_bsearch_preverif</th>\n",
" <td>0.000024</td>\n",
" <td>0.000025</td>\n",
" <td>0.000025</td>\n",
" <td>0.000203</td>\n",
" <td>0.000025</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" case0 case1 case2 case3 \\\n",
"negsamp_incr_naive 0.000025 0.000557 0.000495 0.051792 \n",
"negsamp_incr_set 0.000021 0.000185 0.000171 0.016344 \n",
"negsamp_incr_bsearch 0.000020 0.000255 0.000207 0.023141 \n",
"negsamp_vectorized_bsearch 0.000025 0.000027 0.000027 0.000214 \n",
"negsamp_vectorized_bsearch_preverif 0.000024 0.000025 0.000025 0.000203 \n",
"\n",
" case4 \n",
"negsamp_incr_naive 0.003816 \n",
"negsamp_incr_set 0.001264 \n",
"negsamp_incr_bsearch 0.001898 \n",
"negsamp_vectorized_bsearch 0.000029 \n",
"negsamp_vectorized_bsearch_preverif 0.000025 "
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Average times\n",
"times_df = pd.DataFrame.from_dict(times_d, orient='columns')\\\n",
" .reindex([s.__name__ for s in strategies])\n",
"times_df"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>case0</th>\n",
" <th>case1</th>\n",
" <th>case2</th>\n",
" <th>case3</th>\n",
" <th>case4</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>negsamp_incr_naive</th>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" <td>1.000000</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_incr_set</th>\n",
" <td>1.233551</td>\n",
" <td>3.007306</td>\n",
" <td>2.891718</td>\n",
" <td>3.168921</td>\n",
" <td>3.018263</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_incr_bsearch</th>\n",
" <td>1.274717</td>\n",
" <td>2.186638</td>\n",
" <td>2.392386</td>\n",
" <td>2.238093</td>\n",
" <td>2.010704</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_vectorized_bsearch</th>\n",
" <td>1.022648</td>\n",
" <td>20.735678</td>\n",
" <td>18.502182</td>\n",
" <td>242.499565</td>\n",
" <td>131.991317</td>\n",
" </tr>\n",
" <tr>\n",
" <th>negsamp_vectorized_bsearch_preverif</th>\n",
" <td>1.079992</td>\n",
" <td>21.908844</td>\n",
" <td>19.691580</td>\n",
" <td>255.305570</td>\n",
" <td>150.949921</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" case0 case1 case2 \\\n",
"negsamp_incr_naive 1.000000 1.000000 1.000000 \n",
"negsamp_incr_set 1.233551 3.007306 2.891718 \n",
"negsamp_incr_bsearch 1.274717 2.186638 2.392386 \n",
"negsamp_vectorized_bsearch 1.022648 20.735678 18.502182 \n",
"negsamp_vectorized_bsearch_preverif 1.079992 21.908844 19.691580 \n",
"\n",
" case3 case4 \n",
"negsamp_incr_naive 1.000000 1.000000 \n",
"negsamp_incr_set 3.168921 3.018263 \n",
"negsamp_incr_bsearch 2.238093 2.010704 \n",
"negsamp_vectorized_bsearch 242.499565 131.991317 \n",
"negsamp_vectorized_bsearch_preverif 255.305570 150.949921 "
]
},
"execution_count": 13,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"# Speedup factors compared to our most naive strategy\n",
"speedup_df = times_df.loc['negsamp_incr_naive'] / times_df\n",
"speedup_df"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.4"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
You can’t perform that action at this time.