Skip to content

Instantly share code, notes, and snippets.

@ita9naiwa
Created October 27, 2020 14:11
Show Gist options
  • Save ita9naiwa/b328c43508193611a83c07ae0553a9f3 to your computer and use it in GitHub Desktop.
Save ita9naiwa/b328c43508193611a83c07ae0553a9f3 to your computer and use it in GitHub Desktop.
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import implicit\n",
"import pickle"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from implicit.cml import CollaborativeMetricLearning\n",
"from implicit.als import AlternatingLeastSquares\n",
"from implicit.lmf import LogisticMatrixFactorization\n",
"from implicit.evaluation import *\n",
"from implicit.datasets.sketchfab import get_sketchfab\n",
"from implicit.datasets.movielens import get_movielens"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"#_, _, mat = get_sketchfab()\n",
"_, mat = get_movielens(variant='1m')"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
"seed=1541"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
"mat.data[:] = 1"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"tr, te = train_test_split(mat, 0.8)"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"scrolled": true
},
"outputs": [],
"source": [
"m2 = CollaborativeMetricLearning(factors=64, \n",
" threshold=1.0,\n",
" learning_rate=0.1, \n",
" iterations=15, \n",
" num_threads=8, \n",
" regularization=0.00,\n",
" neg_sampling=100,\n",
" random_state=seed)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 15/15 [00:40<00:00, 2.79s/it]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "0c6385a9bd6f426f8012661292f6b607",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"{'precision': 0.2958226060958392, 'map': 0.20293504045410232, 'ndcg': 0.28618070893351455, 'auc': 0.5181144038465575}\n"
]
}
],
"source": [
"m2.fit(tr.T, True)\n",
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {},
"outputs": [],
"source": [
"m2 = CollaborativeMetricLearning(factors=64, \n",
" threshold=1.0,\n",
" learning_rate=0.1, \n",
" iterations=15, \n",
" num_threads=8, \n",
" regularization=0.01,\n",
" neg_sampling=100,\n",
" random_state=seed)"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 15/15 [00:43<00:00, 3.08s/it]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "4b94fa55352f4e7bba4a0ec9c9bff749",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"{'precision': 0.33095257868947764, 'map': 0.23899023594696514, 'ndcg': 0.32266404435662915, 'auc': 0.5204181639238991}\n"
]
}
],
"source": [
"m2.fit(tr.T, True)\n",
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {},
"outputs": [],
"source": [
"m2 = CollaborativeMetricLearning(factors=64, \n",
" threshold=1.0,\n",
" learning_rate=0.3, \n",
" iterations=15, \n",
" num_threads=8, \n",
" regularization=0.03,\n",
" neg_sampling=100,\n",
" random_state=seed)"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 15/15 [01:03<00:00, 4.25s/it]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "74a1d29d18034cdfbcd2d1292e5de7a7",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"{'precision': 0.37646374885806827, 'map': 0.28215497521921507, 'ndcg': 0.3662162875887344, 'auc': 0.5236010609909335}\n"
]
}
],
"source": [
"m2.fit(tr.T, True)\n",
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"m2 = CollaborativeMetricLearning(factors=64, \n",
" threshold=1.0,\n",
" learning_rate=0.1, \n",
" iterations=15, \n",
" num_threads=8, \n",
" regularization=0.05,\n",
" neg_sampling=100,\n",
" random_state=seed)"
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 15/15 [00:54<00:00, 3.70s/it]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "641d0792eeb0442b9dc36aeb7f2b2601",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"{'precision': 0.37380616227888047, 'map': 0.2818614182234088, 'ndcg': 0.36362742743417237, 'auc': 0.5222198795750684}\n"
]
}
],
"source": [
"m2.fit(tr.T, True)\n",
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"m2 = CollaborativeMetricLearning(factors=64, \n",
" threshold=1.0,\n",
" learning_rate=0.1, \n",
" iterations=15, \n",
" num_threads=8, \n",
" regularization=0.1,\n",
" neg_sampling=100,\n",
" random_state=seed)"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"100%|██████████| 15/15 [00:55<00:00, 3.85s/it]\n"
]
},
{
"data": {
"application/vnd.jupyter.widget-view+json": {
"model_id": "acc87622b9974a4493f5257591f6ec11",
"version_major": 2,
"version_minor": 0
},
"text/plain": [
"HBox(children=(IntProgress(value=0, max=3000), HTML(value='')))"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
"{'precision': 0.3812806245328461, 'map': 0.28596698013301086, 'ndcg': 0.36791157480005765, 'auc': 0.5231909268136861}\n"
]
}
],
"source": [
"m2.fit(tr.T, True)\n",
"print(ranking_metrics_at_k(m2, tr[:3000], te[:3000], 5))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.3"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment