Skip to content

Instantly share code, notes, and snippets.

@jstemerdink
Created May 28, 2019 07:52
Show Gist options
  • Save jstemerdink/81b1f9245b7d5adf2e1e004968643a07 to your computer and use it in GitHub Desktop.
Save jstemerdink/81b1f9245b7d5adf2e1e004968643a07 to your computer and use it in GitHub Desktop.

Use ML.net recommender for better upsell (POC)

Read my blog here

Powered by ReSharper image

namespace Epi.Libraries.Commerce.Predictions
{
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Threading;
using System.Web;
using EPiServer.Commerce.Catalog.ContentTypes;
using EPiServer.Commerce.Catalog.Linking;
using EPiServer.Commerce.Order;
using EPiServer.Core;
using EPiServer.ServiceLocation;
using Microsoft.ML;
using Microsoft.ML.Data;
using Microsoft.ML.Trainers;
[ServiceConfiguration(typeof(PredictionEngineService), Lifecycle = ServiceInstanceScope.Singleton)]
public class PredictionEngineService
{
private readonly IAssociationRepository associationRepository;
private readonly ReaderWriterLockSlim cacheLock = new ReaderWriterLockSlim();
private readonly ServiceAccessor<HttpContextBase> httpContextAccessor;
private readonly IOrderSearchService orderSearchService;
/// <summary>
/// Create MLContext to be shared across the model creation workflow objects
/// </summary>
private MLContext mlContext;
private string modelPath;
private ITransformer modelTransformer;
private PredictionEngine<ProductEntry, CoPurchasePrediction> predictionEngine;
public PredictionEngineService(
ServiceAccessor<HttpContextBase> httpContextAccessor,
IOrderSearchService orderSearchService,
IAssociationRepository associationRepository)
{
this.httpContextAccessor = httpContextAccessor;
this.orderSearchService = orderSearchService;
this.associationRepository = associationRepository;
}
public List<ContentReference> GetUpSellItems(ICart cart)
{
List<ProductCoPurchasePrediction> predictions = new List<ProductCoPurchasePrediction>();
IEnumerable<ILineItem> lineItems = cart.GetFirstForm().GetAllLineItems();
foreach (ILineItem lineItem in lineItems)
{
EntryContentBase entry = lineItem.GetEntryContent();
predictions.AddRange(this.GetUpSellRecommendations(referenceToEntry: entry.ContentLink));
}
return predictions.OrderByDescending(i => i.Score).Take(6)
.Select(p => new ContentReference(contentID: p.CoPurchaseProductId)).ToList();
}
/// <summary>
/// Initializes this instance.
/// </summary>
/// <exception cref="T:System.Exception">A delegate callback throws an exception.</exception>
/// <exception cref="T:System.Threading.LockRecursionException">The current thread cannot acquire the write lock when it holds the read lock.-or-The <see cref="P:System.Threading.ReaderWriterLockSlim.RecursionPolicy" /> property is <see cref="F:System.Threading.LockRecursionPolicy.NoRecursion" />, and the current thread has attempted to acquire the read lock when it already holds the read lock. -or-The <see cref="P:System.Threading.ReaderWriterLockSlim.RecursionPolicy" /> property is <see cref="F:System.Threading.LockRecursionPolicy.NoRecursion" />, and the current thread has attempted to acquire the read lock when it already holds the write lock. -or-The recursion number would exceed the capacity of the counter. This limit is so large that applications should never encounter this exception.</exception>
/// <exception cref="T:System.Threading.SynchronizationLockException">The current thread has not entered the lock in read mode.</exception>
public void Init()
{
this.mlContext = new MLContext();
HttpContextBase httpContext = this.httpContextAccessor();
this.modelPath = httpContext.Server.MapPath("~/App_Data/recommendations_model.zip");
if (File.Exists(path: this.modelPath))
{
this.cacheLock.EnterReadLock();
try
{
this.modelTransformer = this.mlContext.Model.Load(
filePath: this.modelPath,
inputSchema: out DataViewSchema _);
}
finally
{
this.cacheLock.ExitReadLock();
}
}
else
{
List<ProductEntry> products = this.GetProductEntries();
this.modelTransformer = this.LoadDataAndTrain(products: products);
}
this.predictionEngine =
this.mlContext.Model.CreatePredictionEngine<ProductEntry, CoPurchasePrediction>(
transformer: this.modelTransformer);
}
/// <summary>
/// Updates the model.
/// </summary>
public void UpdateModel()
{
List<ProductEntry> products = this.GetProductEntries();
this.modelTransformer = this.LoadDataAndTrain(products: products);
this.predictionEngine =
this.mlContext.Model.CreatePredictionEngine<ProductEntry, CoPurchasePrediction>(
transformer: this.modelTransformer);
}
private ProductCoPurchasePrediction GetPrediction(int productId, int coPurchaseProductId)
{
CoPurchasePrediction prediction = this.predictionEngine.Predict(
new ProductEntry { ProductId = (uint)productId, CoPurchaseProductId = (uint)coPurchaseProductId });
return new ProductCoPurchasePrediction
{
ProductId = productId, CoPurchaseProductId = coPurchaseProductId, Score = prediction.Score
};
}
private List<ProductEntry> GetProductEntries()
{
List<ProductEntry> products = new List<ProductEntry>();
IEnumerable<IPurchaseOrder> orders = this.orderSearchService.FindPurchaseOrders(new OrderSearchFilter())
.Orders;
foreach (IPurchaseOrder orderSearchResult in orders)
{
IEnumerable<ILineItem> lineItems = orderSearchResult.Forms.FirstOrDefault().GetAllLineItems().ToList();
if (lineItems.Count() <= 1)
{
continue;
}
EntryContentBase firstItem = lineItems.First().GetEntryContent();
products.AddRange(
lineItems.Skip(1).Select(lineItem => lineItem.GetEntryContent()).Select(
entry => new ProductEntry
{
ProductId = (uint)firstItem.ContentLink.ID,
CoPurchaseProductId = (uint)entry.ContentLink.ID
}));
}
return products;
}
private IEnumerable<ProductCoPurchasePrediction> GetUpSellRecommendations(ContentReference referenceToEntry)
{
IEnumerable<Association> associations = this.ListAssociations(referenceToEntry: referenceToEntry);
List<ProductCoPurchasePrediction> predictions = associations.Select(
association => this.GetPrediction(
productId: association.Source.ID,
coPurchaseProductId: association.Target.ID)).ToList();
return predictions;
}
private IEnumerable<Association> ListAssociations(ContentReference referenceToEntry)
{
IEnumerable<Association> associations =
this.associationRepository.GetAssociations(contentLink: referenceToEntry);
return associations;
}
private ITransformer LoadDataAndTrain(List<ProductEntry> products)
{
// Read the trained data using TextLoader by defining the schema for reading the product co-purchase dataset
IDataView traindata = this.mlContext.Data.LoadFromEnumerable(data: products);
// Your data is already encoded so all you need to do is specify options for MatrixFactorizationTrainer with a few extra hyper parameters
// LossFunction, Alpha, Lambda and a few others like K and C as shown below and call the trainer.
MatrixFactorizationTrainer.Options options = new MatrixFactorizationTrainer.Options();
options.MatrixColumnIndexColumnName = nameof(ProductEntry.ProductId);
options.MatrixRowIndexColumnName = nameof(ProductEntry.CoPurchaseProductId);
options.LabelColumnName = nameof(ProductEntry.Label);
options.LossFunction = MatrixFactorizationTrainer.LossFunctionType.SquareLossOneClass;
options.Alpha = 0.01;
options.Lambda = 0.025;
// For better results use the following parameters
options.ApproximationRank = 100;
options.C = 0.00001;
// Step 4: Call the MatrixFactorization trainer by passing options.
MatrixFactorizationTrainer est = this.mlContext.Recommendation().Trainers
.MatrixFactorization(options: options);
// STEP 5: Train the model fitting to the DataSet
ITransformer model = est.Fit(input: traindata);
this.cacheLock.EnterWriteLock();
try
{
this.mlContext.Model.Save(model: model, inputSchema: traindata.Schema, filePath: this.modelPath);
}
finally
{
this.cacheLock.ExitWriteLock();
}
return model;
}
}
/// <summary>
/// Class CoPurchasePrediction.
/// </summary>
public class CoPurchasePrediction
{
/// <summary>
/// Gets or sets the score.
/// </summary>
/// <value>The score.</value>
public float Score { get; set; }
}
/// <summary>
/// Class ProductEntry.
/// </summary>
public class ProductEntry
{
/// <summary>
/// Gets or sets the co purchase product identifier.
/// </summary>
/// <value>The co purchase product identifier.</value>
[KeyType(262111)]
public uint CoPurchaseProductId { get; set; }
/// <summary>
/// Gets or sets the label.
/// </summary>
/// <value>The label.</value>
public float Label { get; set; }
/// <summary>
/// Gets or sets the product identifier.
/// </summary>
/// <value>The product identifier.</value>
[KeyType(262111)]
public uint ProductId { get; set; }
}
/// <summary>
/// Class ProductCoPurchasePrediction.
/// </summary>
public class ProductCoPurchasePrediction
{
/// <summary>
/// Gets or sets the co purchase product identifier.
/// </summary>
/// <value>The co purchase product identifier.</value>
public int CoPurchaseProductId { get; set; }
/// <summary>
/// Gets or sets the product identifier.
/// </summary>
/// <value>The product identifier.</value>
public int ProductId { get; set; }
/// <summary>
/// Gets or sets the score.
/// </summary>
/// <value>The score.</value>
public float Score { get; set; }
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment