Created
February 5, 2023 16:44
-
-
Save achad4/f88a6d52f2f26afd34a94d3aa1796e74 to your computer and use it in GitHub Desktop.
Example code for batch ML inference from DynamoDB using GSI
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import boto3 | |
from datetime import datetime | |
import numpy as np | |
from fastapi import FastAPI | |
app = FastAPI() | |
dynamodb = boto3.resource("dynamodb") | |
s3 = boto3.client('s3') | |
@app.get("/predict/{category_name}") | |
async def get_customer_ids_predicted_to_purchase(category_name: str): | |
customer_purchase_ts = dynamodb.Table("customer_purchase_ts") | |
# Narrow down our predictions to customers who have made a purchase in the relevant category | |
response = customer_purchase_ts.query( | |
IndexName='purchase_category', | |
KeyConditionExpression="purchase_category = :category", | |
ExpressionAttributeValues={ | |
":category": category_name, | |
} | |
) | |
customer_ids = list(set([r.get("post_id") for r in response.get("Items")])) | |
# Batch get all purchases in last 30d for these customers | |
items = [] | |
thirty_days_back = (datetime.datetime.now() - datetime.timedelta(days=30)).timestamp() | |
unprocessed_keys = customer_ids | |
while unprocessed_keys: | |
for customer_id in unprocessed_keys: | |
batch_keys = [] | |
batch_keys.append({ | |
'PrimaryKey': {'customer_id': {'S': str(customer_id)}}, | |
'ExpressionAttributeValues': { | |
':target_timestamp': {'N': thirty_days_back} | |
} | |
}) | |
batch_keys.append({ | |
'PrimaryKey': {'partitionKey': {'S': customer_id}}, | |
}) | |
response = dynamodb.batch_get_item( | |
RequestItems={ | |
"customer_purchase_ts": { | |
'Keys': batch_keys, | |
'ConsistentRead': True, | |
'IndexName': 'feature_time_series', | |
'KeyConditionExpression': 'customer_id = :customer_id AND transaction_date_epoch > :target_timestamp', | |
'ExpressionAttributeValues': { | |
':customer_id': {'S': customer_id}, | |
':target_timestamp': {'N': thirty_days_back} | |
} | |
} | |
} | |
) | |
items.extend(response['Responses']["customer_purchase_ts"]) | |
#Paginate through results until there are no unprocessed keys left | |
unprocessed_keys = response.get('UnprocessedKeys', {}).get("customer_purchase_ts", []) | |
log_regression_model = load_pretrained_model("purchase_prediction") | |
prediction_map = [] | |
for customer_id in customer_ids: | |
# Filter for items in the response pertaining to current customer | |
pruchase_volumes = [ | |
r.get("purchase_volume") | |
for r in response.get("Items") if r.get("customer_id") == customer_id | |
] | |
# Load a pretrained model and serve the prediciton | |
avg_purchase_volume_30d = np.sum((pruchase_volumes)) | |
X = np.array([avg_purchase_volume_30d]) | |
y_pred = log_regression_model.predict(X) | |
prediction_map.append({ | |
"customer_id": customer_id, | |
"will_purchase_in_next_30d": y_pred | |
}) | |
return prediction_map |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment