Skip to content

Instantly share code, notes, and snippets.

@qooba
Last active August 22, 2018 16:03
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save qooba/1765d58da9b99801c92e63ecbcec385d to your computer and use it in GitHub Desktop.
Save qooba/1765d58da9b99801c92e63ecbcec385d to your computer and use it in GitHub Desktop.
Another brick in the … recommendation system – Databricks in action. Read more: https://qooba.net/2018/08/22/another-brick-in-the-recommendation-system-databricks-in-action/
Display the source blob
Display the rendered blob
Raw
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 1. Read dataset "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from pyspark.mllib.recommendation import ALS\n",
"from sklearn.model_selection import train_test_split\n",
"import numpy as np\n",
"import pandas as pd\n",
"from itertools import islice\n",
"\n",
"def parseRating(line):\n",
" fields = line.split(',')\n",
" return (int(fields[0]), int(fields[1]), float(fields[2]))\n",
"\n",
"dataset = sc.textFile('/FileStore/tables/ratings.csv')\n",
"\n",
"dataset = dataset.mapPartitionsWithIndex(\n",
" lambda idx, it: islice(it, 1, None) if idx == 0 else it \n",
")\n",
"\n",
"dataset = sc.textFile('/FileStore/tables/ratings.csv').map(parseRating).cache().toDF()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Split dataset into training and testing"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"msk = np.random.rand(dataset.count()) < 0.80\n",
"\n",
"dfp = dataset.toPandas()\n",
"\n",
"train = dfp[msk]\n",
"test = dfp[~msk]\n",
"\n",
"training = sqlContext.createDataFrame(train)\n",
"testing = sqlContext.createDataFrame(test)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Training the model "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"model = ALS.train(training.rdd, rank = 10, iterations = 5)\n",
"predictions = model.predictAll(testing.rdd.map(lambda x: (x[0], x[1])))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Model evaluation "
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"RMSE = 1.11032771802\n",
"R-squared = -0.108086892358\n"
]
}
],
"source": [
"from pyspark.mllib.evaluation import RegressionMetrics, RankingMetrics\n",
"\n",
"pred = predictions.map(lambda r: ((r.user, r.product), r.rating))\n",
"ratingsTuple = testing.rdd.map(lambda r: ((r[0], r[1]), r[2]))\n",
"scoreAndLabels = pred.join(ratingsTuple).map(lambda tup: tup[1])\n",
"metrics = RegressionMetrics(scoreAndLabels)\n",
"\n",
"# Root mean squared error\n",
"print(\"RMSE = %s\" % metrics.rootMeanSquaredError)\n",
"\n",
"# R-squared\n",
"print(\"R-squared = %s\" % metrics.r2)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 5. Predictions "
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" user product rating\n",
"0 648 3456 2.952487\n",
"1 529 3456 2.499555\n",
"2 452 3272 2.745604\n",
"3 665 4352 1.110651\n",
"4 30 4352 1.166530\n",
"5 48 52328 3.248486\n",
"6 574 52328 3.616721\n",
"7 615 52328 3.009266\n",
"8 516 464 3.756158\n",
"9 376 912 4.178127\n",
"10 520 912 5.047627\n",
"11 232 912 4.209946\n",
"12 153 912 4.830047\n",
"13 17 912 4.497674\n",
"14 569 912 4.253401\n",
"15 457 912 3.596744\n",
"16 466 912 4.339720\n",
"17 411 912 3.443644\n",
"18 236 912 4.236495\n",
"19 420 912 3.979551\n",
"20 220 912 4.167368\n",
"21 572 912 3.908388\n",
"22 620 912 4.297537\n",
"23 372 912 2.412261\n",
"24 580 912 3.645110\n",
"25 125 912 5.063882\n",
"26 37 912 5.244649\n",
"27 605 912 3.695372\n",
"28 222 912 3.780547\n",
"29 30 912 4.292419\n",
"... ... ... ...\n",
"19466 287 112175 4.505859\n",
"19467 663 120799 3.008335\n",
"19468 262 6535 0.669488\n",
"19469 294 6535 1.770202\n",
"19470 166 6535 2.479188\n",
"19471 439 5279 2.453580\n",
"19472 656 3751 4.332597\n",
"19473 624 3751 3.451944\n",
"19474 56 3751 4.078903\n",
"19475 48 3751 3.327549\n",
"19476 562 3751 3.968959\n",
"19477 195 3751 3.364107\n",
"19478 83 3751 3.161479\n",
"19479 115 3751 3.375012\n",
"19480 59 3751 1.274712\n",
"19481 547 3751 3.075101\n",
"19482 316 3751 3.024337\n",
"19483 268 3751 3.551662\n",
"19484 164 3751 3.336100\n",
"19485 468 3751 3.107545\n",
"19486 189 3751 3.604715\n",
"19487 381 3751 4.103798\n",
"19488 61 3751 1.806850\n",
"19489 582 3751 3.177963\n",
"19490 598 3751 3.730484\n",
"19491 174 3751 5.159532\n",
"19492 615 3751 3.757181\n",
"19493 463 3751 3.907354\n",
"19494 431 3751 3.411025\n",
"19495 239 3751 4.315567\n",
"\n",
"[19496 rows x 3 columns]\n",
"\n"
]
}
],
"source": [
"pred_pd = predictions.toDF().toPandas()\n",
"pred_pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 6. User features "
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"\n",
" _1 _2\n",
"0 8 [-0.181914702058, -0.115650877357, -0.29060563...\n",
"1 16 [-0.539044797421, -0.199193432927, -0.18596178...\n",
"2 24 [-0.153028860688, 0.286775797606, -0.315734446...\n",
"3 32 [0.0522118471563, -0.880892157555, -0.11604796...\n",
"4 40 [0.197199463844, 0.246979385614, -0.0937066823...\n",
"5 48 [-0.187192514539, 0.122330382466, -0.447892934...\n",
"6 56 [0.266946703196, 0.609441876411, -0.5589547157...\n",
"7 64 [-0.967185676098, 0.900995850563, 0.3997882604...\n",
"8 72 [-0.107276752591, 0.159011214972, -0.362072616...\n",
"9 80 [-0.00726160453632, -0.791743159294, -0.400485...\n",
"10 88 [-0.287688970566, -0.455240428448, -0.35379630...\n",
"11 96 [-0.123236767948, -0.0736145153642, 0.22729793...\n",
"12 104 [-0.103782884777, 0.0293445084244, -0.48008337...\n",
"13 112 [-0.299286931753, 0.795587062836, -0.212774381...\n",
"14 120 [-0.392990142107, 0.170754164457, -0.592899322...\n",
"15 128 [-0.276408433914, 0.85479670763, -0.4386402964...\n",
"16 136 [-0.27713561058, 0.27391731739, 0.041793849319...\n",
"17 144 [-0.345366835594, 0.317992568016, 0.1295921951...\n",
"18 152 [-0.0377134233713, -0.057498164475, -0.0384886...\n",
"19 160 [-0.76272213459, -0.480974048376, -0.031201677...\n",
"20 168 [-0.28441041708, -0.2553345263, -0.35673257708...\n",
"21 176 [-0.176772579551, 0.337848335505, -0.521425485...\n",
"22 184 [-0.259108990431, 0.0231811460108, 0.038888446...\n",
"23 192 [-1.16283667088, 0.593379020691, -0.4730607867...\n",
"24 200 [0.440446376801, -0.224955514073, 0.0334046110...\n",
"25 208 [0.350734114647, -0.59681981802, -0.0968377292...\n",
"26 216 [0.106207422912, 0.474756985903, -0.0745117291...\n",
"27 224 [-0.0743150413036, 0.144918620586, -0.39639681...\n",
"28 232 [-0.380437016487, -0.286985069513, -0.36831775...\n",
"29 240 [-0.330195009708, 0.0716245099902, -0.37601268...\n",
".. ... ...\n",
"641 439 [-0.451001465321, -0.207361519337, 0.204217776...\n",
"642 447 [-0.382181704044, 0.259449094534, -0.210204839...\n",
"643 455 [-0.717461585999, 0.108905680478, 0.2277534306...\n",
"644 463 [-0.369250893593, 0.143394082785, -0.354474574...\n",
"645 471 [-0.0929622277617, 0.349240124226, -0.61510080...\n",
"646 479 [0.138411849737, -0.72600376606, 0.41511940956...\n",
"647 487 [-0.429432213306, 0.0936513096094, -0.05948500...\n",
"648 495 [-0.213055074215, 0.943114280701, 0.5440177917...\n",
"649 503 [0.362305760384, 0.21999822557, -0.27724897861...\n",
"650 511 [-1.07568562031, -0.377701282501, -0.509751379...\n",
"651 519 [0.445893526077, 0.419599086046, -0.2704966068...\n",
"652 527 [0.337932169437, -0.270141869783, -0.247990548...\n",
"653 535 [0.468528628349, -0.208868101239, 0.1047676354...\n",
"654 543 [-0.775096774101, 0.569828510284, -0.002466787...\n",
"655 551 [0.0927337184548, -0.00640812888741, -0.265042...\n",
"656 559 [-0.580280542374, -0.247369945049, -0.33804556...\n",
"657 567 [1.12319242954, -0.416586726904, -0.5191169381...\n",
"658 575 [-0.112690746784, 0.488440066576, -0.577874839...\n",
"659 583 [-0.417977333069, 0.0888268128037, -1.14079797...\n",
"660 591 [-0.0201117657125, 1.24624288082, -0.426221072...\n",
"661 599 [0.621458232403, 0.149647131562, -0.4135710895...\n",
"662 607 [-0.138034597039, 0.213772326708, -0.205611854...\n",
"663 615 [-0.314275950193, 0.0982936099172, -0.45499169...\n",
"664 623 [-0.274739950895, 0.232173085213, -0.317190140...\n",
"665 631 [0.138550639153, 0.338225275278, -0.0649613291...\n",
"666 639 [-0.257849574089, -0.284000188112, -0.25136265...\n",
"667 647 [-0.424806773663, -0.279336392879, -0.09299824...\n",
"668 655 [0.25625321269, 0.064479842782, -0.13081842660...\n",
"669 663 [-0.775534152985, 0.224992081523, -0.307271957...\n",
"670 671 [-0.302962034941, 0.358576714993, -0.411455988...\n",
"\n",
"[671 rows x 2 columns]\n",
"\n"
]
}
],
"source": [
"user_features_pd = model.userFeatures().toDF().toPandas()\n",
"user_features_pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 7. Product features "
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"\n",
" _1 _2\n",
"0 8 [-2.56164264679, -0.135379731655, -0.267503380...\n",
"1 16 [0.280905783176, -0.0995275080204, -2.31120014...\n",
"2 24 [1.06595146656, 0.942460775375, -0.99149888753...\n",
"3 32 [0.0849908143282, -0.380194991827, -1.77220821...\n",
"4 40 [-0.641390919685, 1.06841874123, -2.0100176334...\n",
"5 48 [-0.630687594414, 0.611408829689, -1.172370195...\n",
"6 64 [-1.32728481293, 0.121022567153, -1.3647956848...\n",
"7 72 [1.01586747169, -0.781272828579, -1.1023261547...\n",
"8 80 [0.294988274574, -0.444449067116, -0.427629321...\n",
"9 88 [1.16558218002, 0.297061294317, -1.49852311611...\n",
"10 96 [0.0426225103438, 0.12978219986, -0.6232036352...\n",
"11 104 [0.0799032375216, 0.795982241631, -1.248673319...\n",
"12 112 [0.404516816139, 0.142588049173, -0.9666603803...\n",
"13 144 [0.0962482914329, -2.78304743767, -2.050385475...\n",
"14 152 [0.786001861095, -0.110578395426, -1.313220977...\n",
"15 160 [0.0527417510748, 0.651777744293, -0.327918887...\n",
"16 168 [-0.26722189784, -0.142951309681, -0.594646930...\n",
"17 176 [1.92572700977, 1.66081643105, -3.53873419762,...\n",
"18 200 [-0.7074046731, 0.283516049385, -1.30996775627...\n",
"19 208 [-0.249078258872, -0.269577711821, -0.85641849...\n",
"20 216 [-0.737537741661, -0.132424861193, -0.93438577...\n",
"21 224 [0.754742562771, 0.00107711227611, -2.26278805...\n",
"22 232 [0.757782459259, -0.389210581779, -1.381112933...\n",
"23 240 [-0.129541561007, -2.14392805099, -0.146269515...\n",
"24 248 [-0.683298707008, -0.281481802464, -0.68466645...\n",
"25 256 [-0.707969367504, 0.112456910312, -1.143353581...\n",
"26 264 [0.0845773145556, -0.0443584471941, -0.6866024...\n",
"27 272 [-1.10929083824, -0.663882791996, -1.035177469...\n",
"28 280 [-0.675240159035, -0.105041205883, -1.63121020...\n",
"29 288 [0.200414344668, -0.372297793627, -1.867403984...\n",
"... ... ...\n",
"8396 116503 [-0.0877871140838, 0.0784340277314, -0.4758511...\n",
"8397 116799 [0.218765199184, -0.288648873568, -0.278112500...\n",
"8398 116823 [-2.49026703835, 0.038705997169, -1.6413687467...\n",
"8399 116855 [-0.688780725002, 1.48978054523, -0.8829264044...\n",
"8400 117871 [-0.0253734719008, 0.980544626713, -1.68411159...\n",
"8401 117895 [0.115044265985, 1.16033554077, -1.36425220966...\n",
"8402 119655 [0.156621277332, 0.378675609827, -0.0999258607...\n",
"8403 120799 [-0.0284684654325, -1.36316335201, -2.31898665...\n",
"8404 121231 [0.249330818653, -2.58391785622, -2.3615202903...\n",
"8405 130351 [0.463936567307, 1.00956237316, -1.06439805031...\n",
"8406 134783 [-0.532292425632, 0.126817718148, -0.199420616...\n",
"8407 135567 [0.681578099728, -1.07554030418, -3.0175538063...\n",
"8408 135887 [0.123207136989, 0.053562708199, -0.2619900405...\n",
"8409 136447 [1.48365867138, 1.81369459629, -0.853097915649...\n",
"8410 139415 [0.469863861799, 1.13602674007, -0.29977759718...\n",
"8411 139855 [0.875060796738, -1.15459549427, -1.1124500036...\n",
"8412 140247 [0.463936567307, 1.00956237316, -1.06439805031...\n",
"8413 140711 [0.594856381416, -0.108498781919, -2.311982154...\n",
"8414 140743 [-0.860975861549, 1.86222565174, -1.1036579608...\n",
"8415 140751 [-0.860975861549, 1.86222565174, -1.1036579608...\n",
"8416 140759 [-0.860975861549, 1.86222565174, -1.1036579608...\n",
"8417 143255 [0.2349319309, 0.568013370037, -0.149888798594...\n",
"8418 145775 [0.0349026806653, -0.375773221254, -0.38171833...\n",
"8419 145839 [0.437530398369, -0.577297747135, -0.556225001...\n",
"8420 145935 [1.31259119511, -1.73189330101, -1.66867494583...\n",
"8421 152079 [0.656295597553, -0.865946650505, -0.834337472...\n",
"8422 156607 [-0.426517218351, 0.51908493042, -0.8460250496...\n",
"8423 157407 [0.2349319309, 0.568013370037, -0.149888798594...\n",
"8424 160271 [1.09382605553, -1.44324433804, -1.39056241512...\n",
"8425 160567 [1.75012159348, -2.30919098854, -2.22490000725...\n",
"\n",
"[8426 rows x 2 columns]\n",
"\n"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"product_features_pd = model.productFeatures().toDF().toPandas()\n",
"product_features_pd"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 8. Export features to json"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import requests\n",
"import json\n",
"import datetime\n",
"\n",
"pf_json=product_features_pd.to_json()\n",
"sas='{blob_shared_key}'\n",
"url='https://{blob_name}.blob.core.windows.net/{path}/pf.json'\n",
"url_ok=url+sas\n",
"\n",
"headers={\n",
" 'x-ms-blob-type': 'BlockBlob',\n",
" 'x-ms-date': str(datetime.datetime.now()),\n",
" 'x-ms-version': '2014-02-14',\n",
" 'Content-Type': 'text/plain; charset=UTF-8',\n",
" 'Content-Length': str(len(pf_json))\n",
"}\n",
"r =requests.put(url=url_ok,data=pf_json,headers=headers)\n"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.5.4"
},
"name": "Recommendation",
"notebookId": 1
},
"nbformat": 4,
"nbformat_minor": 1
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment