Skip to content

Instantly share code, notes, and snippets.

Created February 19, 2021 16:02
Show Gist options
  • Save pacarvalho/9774b1766b2ac34aee15b4420c463a0e to your computer and use it in GitHub Desktop.
Save pacarvalho/9774b1766b2ac34aee15b4420c463a0e to your computer and use it in GitHub Desktop.
Simple Lambda Collaborative Filter for Recommendation Engine
# Library used for training of the model
# NOTE: We require the environment variable SURPRISE_DATA_FOLDER to designate the
# the directory used by Surprise. Else it will err due to the default Lambda home
# directory being read only.
from surprise import SVD
from surprise import Dataset
from surprise import Reader
import os # Get environment variables
import psycopg2 # PSQL connector library
import csv # Format we will export ratings as
import pickle # Will be used to export model to S3
import boto3 # Required for S3 access
from botocore.vendored import requests # Make requests to server
from collections import defaultdict # Allows creating a dict with default values
# Collect environment variables
RDS_DB_NAME = os.environ['RDS_DB_NAME']
BUCKET = os.environ['BUCKET']
# Connect to the database
# Since this code is outside the handler it will only execute during cold starts
# and keep persistent connections to DB
print('Creating connection to DB...')
conn_string = "host=%s user=%s password=%s dbname=%s" % \
conn = psycopg2.connect(conn_string)
# Configure boto for use with S3
print('Setting up S3 connection...')
s3 = boto3.resource('s3')
bucket = s3.Bucket(BUCKET)
def lambda_handler(event, context):
ratings_file_path = '/tmp/user-item-ratings.csv'
query = "select user_id, item_id, rating from user_ratings"
with conn.cursor() as cur:
# STEP 1: Get Data
# Get data from SQL
print('Query the DB for analytics_events data...')
cur.execute(query) # Make the query
# Results are stored in cur as an iterable of (user_id, item_id, rating)
# NOTE: If your data does not already consist of ratings you may need to reduce it.
# For instance, if instead you have viewing time, impressions, clicks, etc
# you would determine what types of interactions are related to more positive
# versus more negative ratings and sum them up to a single rating score for each
# pair of user and item.
# Save ratings as a CSV so we can pass it to Surprise
# TODO: We are currently saving to disk since the Suprise interface expects a file path.
# Ideally we would keep it in memory since the Lambda has a higher RAM vs Disk limit.
print('Save ratings to CSV...')
with open(ratings_file_path, mode='w') as ratings_file:
ratings_writer = csv.writer(ratings_file, delimiter=',')
for user_id, item_id, rating in cur:
ratings_writer.writerow([user_id, item_id, rating])
# STEP 2: Train
# Train the model
print('Configure model for training...')
reader = Reader(line_format='user item rating', sep=',')
ratings_file_path = os.path.expanduser(ratings_file_path)
data = Dataset.load_from_file(ratings_file_path, reader=reader)
algo = SVD()
print('Train model...')
trainset = data.build_full_trainset()
# STEP 3: Predict
# Create list of unseen items for each user (recommendations) based on their predicted rating
print('Creating predictions for each user...')
print('Number of users:', trainset.n_users)
print('Number of items:', trainset.n_items)
percent_complete = 0
for user_index in range(trainset.n_users):
# All items this user HAS rated
user_items = set([j for (j, _) in trainset.ur[user_index]])
# All items this user has NOT rated
items_user_has_not_rated = [trainset.to_raw_iid(
i) for i in trainset.all_items() if i not in user_items]
# Get the actual user_id from index
user_id = trainset.to_raw_uid(user_index)
# Make predictions for all items this user has NOT rated
predictions_for_items_not_rated_by_user = []
for iid in items_user_has_not_rated:
prediction = algo.predict(user_id, iid)
[int(prediction.iid), prediction.est])
# Sort the prediction from best match to worst match
key=lambda x: x[1], reverse=True)
# Save the results to S3
predictions_for_items_not_rated_by_user), Key='predictions/' + str(user_id))
# Print percent complete so its easier to track execution of algorithm in logs
if user_index % int(trainset.n_users / 20) == 0:
print('Percent complete: ', percent_complete, '%')
percent_complete += 5
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment