Read my blog here
Created
May 28, 2019 07:52
-
-
Save jstemerdink/81b1f9245b7d5adf2e1e004968643a07 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
namespace Epi.Libraries.Commerce.Predictions | |
{ | |
using System.Collections.Generic; | |
using System.IO; | |
using System.Linq; | |
using System.Threading; | |
using System.Web; | |
using EPiServer.Commerce.Catalog.ContentTypes; | |
using EPiServer.Commerce.Catalog.Linking; | |
using EPiServer.Commerce.Order; | |
using EPiServer.Core; | |
using EPiServer.ServiceLocation; | |
using Microsoft.ML; | |
using Microsoft.ML.Data; | |
using Microsoft.ML.Trainers; | |
[ServiceConfiguration(typeof(PredictionEngineService), Lifecycle = ServiceInstanceScope.Singleton)] | |
public class PredictionEngineService | |
{ | |
private readonly IAssociationRepository associationRepository; | |
private readonly ReaderWriterLockSlim cacheLock = new ReaderWriterLockSlim(); | |
private readonly ServiceAccessor<HttpContextBase> httpContextAccessor; | |
private readonly IOrderSearchService orderSearchService; | |
/// <summary> | |
/// Create MLContext to be shared across the model creation workflow objects | |
/// </summary> | |
private MLContext mlContext; | |
private string modelPath; | |
private ITransformer modelTransformer; | |
private PredictionEngine<ProductEntry, CoPurchasePrediction> predictionEngine; | |
public PredictionEngineService( | |
ServiceAccessor<HttpContextBase> httpContextAccessor, | |
IOrderSearchService orderSearchService, | |
IAssociationRepository associationRepository) | |
{ | |
this.httpContextAccessor = httpContextAccessor; | |
this.orderSearchService = orderSearchService; | |
this.associationRepository = associationRepository; | |
} | |
public List<ContentReference> GetUpSellItems(ICart cart) | |
{ | |
List<ProductCoPurchasePrediction> predictions = new List<ProductCoPurchasePrediction>(); | |
IEnumerable<ILineItem> lineItems = cart.GetFirstForm().GetAllLineItems(); | |
foreach (ILineItem lineItem in lineItems) | |
{ | |
EntryContentBase entry = lineItem.GetEntryContent(); | |
predictions.AddRange(this.GetUpSellRecommendations(referenceToEntry: entry.ContentLink)); | |
} | |
return predictions.OrderByDescending(i => i.Score).Take(6) | |
.Select(p => new ContentReference(contentID: p.CoPurchaseProductId)).ToList(); | |
} | |
/// <summary> | |
/// Initializes this instance. | |
/// </summary> | |
/// <exception cref="T:System.Exception">A delegate callback throws an exception.</exception> | |
/// <exception cref="T:System.Threading.LockRecursionException">The current thread cannot acquire the write lock when it holds the read lock.-or-The <see cref="P:System.Threading.ReaderWriterLockSlim.RecursionPolicy" /> property is <see cref="F:System.Threading.LockRecursionPolicy.NoRecursion" />, and the current thread has attempted to acquire the read lock when it already holds the read lock. -or-The <see cref="P:System.Threading.ReaderWriterLockSlim.RecursionPolicy" /> property is <see cref="F:System.Threading.LockRecursionPolicy.NoRecursion" />, and the current thread has attempted to acquire the read lock when it already holds the write lock. -or-The recursion number would exceed the capacity of the counter. This limit is so large that applications should never encounter this exception.</exception> | |
/// <exception cref="T:System.Threading.SynchronizationLockException">The current thread has not entered the lock in read mode.</exception> | |
public void Init() | |
{ | |
this.mlContext = new MLContext(); | |
HttpContextBase httpContext = this.httpContextAccessor(); | |
this.modelPath = httpContext.Server.MapPath("~/App_Data/recommendations_model.zip"); | |
if (File.Exists(path: this.modelPath)) | |
{ | |
this.cacheLock.EnterReadLock(); | |
try | |
{ | |
this.modelTransformer = this.mlContext.Model.Load( | |
filePath: this.modelPath, | |
inputSchema: out DataViewSchema _); | |
} | |
finally | |
{ | |
this.cacheLock.ExitReadLock(); | |
} | |
} | |
else | |
{ | |
List<ProductEntry> products = this.GetProductEntries(); | |
this.modelTransformer = this.LoadDataAndTrain(products: products); | |
} | |
this.predictionEngine = | |
this.mlContext.Model.CreatePredictionEngine<ProductEntry, CoPurchasePrediction>( | |
transformer: this.modelTransformer); | |
} | |
/// <summary> | |
/// Updates the model. | |
/// </summary> | |
public void UpdateModel() | |
{ | |
List<ProductEntry> products = this.GetProductEntries(); | |
this.modelTransformer = this.LoadDataAndTrain(products: products); | |
this.predictionEngine = | |
this.mlContext.Model.CreatePredictionEngine<ProductEntry, CoPurchasePrediction>( | |
transformer: this.modelTransformer); | |
} | |
private ProductCoPurchasePrediction GetPrediction(int productId, int coPurchaseProductId) | |
{ | |
CoPurchasePrediction prediction = this.predictionEngine.Predict( | |
new ProductEntry { ProductId = (uint)productId, CoPurchaseProductId = (uint)coPurchaseProductId }); | |
return new ProductCoPurchasePrediction | |
{ | |
ProductId = productId, CoPurchaseProductId = coPurchaseProductId, Score = prediction.Score | |
}; | |
} | |
private List<ProductEntry> GetProductEntries() | |
{ | |
List<ProductEntry> products = new List<ProductEntry>(); | |
IEnumerable<IPurchaseOrder> orders = this.orderSearchService.FindPurchaseOrders(new OrderSearchFilter()) | |
.Orders; | |
foreach (IPurchaseOrder orderSearchResult in orders) | |
{ | |
IEnumerable<ILineItem> lineItems = orderSearchResult.Forms.FirstOrDefault().GetAllLineItems().ToList(); | |
if (lineItems.Count() <= 1) | |
{ | |
continue; | |
} | |
EntryContentBase firstItem = lineItems.First().GetEntryContent(); | |
products.AddRange( | |
lineItems.Skip(1).Select(lineItem => lineItem.GetEntryContent()).Select( | |
entry => new ProductEntry | |
{ | |
ProductId = (uint)firstItem.ContentLink.ID, | |
CoPurchaseProductId = (uint)entry.ContentLink.ID | |
})); | |
} | |
return products; | |
} | |
private IEnumerable<ProductCoPurchasePrediction> GetUpSellRecommendations(ContentReference referenceToEntry) | |
{ | |
IEnumerable<Association> associations = this.ListAssociations(referenceToEntry: referenceToEntry); | |
List<ProductCoPurchasePrediction> predictions = associations.Select( | |
association => this.GetPrediction( | |
productId: association.Source.ID, | |
coPurchaseProductId: association.Target.ID)).ToList(); | |
return predictions; | |
} | |
private IEnumerable<Association> ListAssociations(ContentReference referenceToEntry) | |
{ | |
IEnumerable<Association> associations = | |
this.associationRepository.GetAssociations(contentLink: referenceToEntry); | |
return associations; | |
} | |
private ITransformer LoadDataAndTrain(List<ProductEntry> products) | |
{ | |
// Read the trained data using TextLoader by defining the schema for reading the product co-purchase dataset | |
IDataView traindata = this.mlContext.Data.LoadFromEnumerable(data: products); | |
// Your data is already encoded so all you need to do is specify options for MatrixFactorizationTrainer with a few extra hyper parameters | |
// LossFunction, Alpha, Lambda and a few others like K and C as shown below and call the trainer. | |
MatrixFactorizationTrainer.Options options = new MatrixFactorizationTrainer.Options(); | |
options.MatrixColumnIndexColumnName = nameof(ProductEntry.ProductId); | |
options.MatrixRowIndexColumnName = nameof(ProductEntry.CoPurchaseProductId); | |
options.LabelColumnName = nameof(ProductEntry.Label); | |
options.LossFunction = MatrixFactorizationTrainer.LossFunctionType.SquareLossOneClass; | |
options.Alpha = 0.01; | |
options.Lambda = 0.025; | |
// For better results use the following parameters | |
options.ApproximationRank = 100; | |
options.C = 0.00001; | |
// Step 4: Call the MatrixFactorization trainer by passing options. | |
MatrixFactorizationTrainer est = this.mlContext.Recommendation().Trainers | |
.MatrixFactorization(options: options); | |
// STEP 5: Train the model fitting to the DataSet | |
ITransformer model = est.Fit(input: traindata); | |
this.cacheLock.EnterWriteLock(); | |
try | |
{ | |
this.mlContext.Model.Save(model: model, inputSchema: traindata.Schema, filePath: this.modelPath); | |
} | |
finally | |
{ | |
this.cacheLock.ExitWriteLock(); | |
} | |
return model; | |
} | |
} | |
/// <summary> | |
/// Class CoPurchasePrediction. | |
/// </summary> | |
public class CoPurchasePrediction | |
{ | |
/// <summary> | |
/// Gets or sets the score. | |
/// </summary> | |
/// <value>The score.</value> | |
public float Score { get; set; } | |
} | |
/// <summary> | |
/// Class ProductEntry. | |
/// </summary> | |
public class ProductEntry | |
{ | |
/// <summary> | |
/// Gets or sets the co purchase product identifier. | |
/// </summary> | |
/// <value>The co purchase product identifier.</value> | |
[KeyType(262111)] | |
public uint CoPurchaseProductId { get; set; } | |
/// <summary> | |
/// Gets or sets the label. | |
/// </summary> | |
/// <value>The label.</value> | |
public float Label { get; set; } | |
/// <summary> | |
/// Gets or sets the product identifier. | |
/// </summary> | |
/// <value>The product identifier.</value> | |
[KeyType(262111)] | |
public uint ProductId { get; set; } | |
} | |
/// <summary> | |
/// Class ProductCoPurchasePrediction. | |
/// </summary> | |
public class ProductCoPurchasePrediction | |
{ | |
/// <summary> | |
/// Gets or sets the co purchase product identifier. | |
/// </summary> | |
/// <value>The co purchase product identifier.</value> | |
public int CoPurchaseProductId { get; set; } | |
/// <summary> | |
/// Gets or sets the product identifier. | |
/// </summary> | |
/// <value>The product identifier.</value> | |
public int ProductId { get; set; } | |
/// <summary> | |
/// Gets or sets the score. | |
/// </summary> | |
/// <value>The score.</value> | |
public float Score { get; set; } | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment