Forked from MansurAshraf/ProductRecommendationJob.scala
Created
April 2, 2013 15:05
-
-
Save samklr/5292937 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package com.mansur.scalding | |
import com.twitter.scalding._ | |
import org.apache.lucene.search.spell._ | |
import org.apache.mahout.common.distance.TanimotoDistanceMeasure | |
import org.apache.mahout.math.DenseVector | |
import org.apache.commons.math.util.MathUtils | |
/** | |
* @author Muhammad Ashraf | |
* @since 3/29/13 | |
*/ | |
class ProductRecommendationJob(args: Args) extends Job(args) { | |
val tanimotoDistanceMeasure = new TanimotoDistanceMeasure() | |
val ngram = new NGramDistance() | |
val SCALE = 5 | |
/* | |
* Schema of our product catalog | |
*/ | |
val inputSchema = ('DEPARTMENT, 'SUB_DEPARTMENT, 'PRODUCT, 'DESCRIPTION, 'REG_PRICE, 'SALE_PRICE) | |
/* | |
Duplicate schema used for self joining | |
*/ | |
val renameSchema = ('DEPARTMENT1, 'SUB_DEPARTMENT1, 'PRODUCT1, 'DESCRIPTION1, 'REG_PRICE1, 'SALE_PRICE1) | |
/** | |
* output Schema | |
*/ | |
val outputSChema = ('PRODUCT, 'PRODUCT1, 'Distance) | |
/* | |
Read in the catalog | |
*/ | |
val productMatrix = Csv("input", separator = ",", fields = inputSchema, quote = "\"").read | |
/* | |
Read in the catalog a second time for joining | |
*/ | |
val productMatrixDuplicate = Csv("input", separator = ",", fields = inputSchema, quote = "\"").read.rename(inputSchema -> renameSchema) | |
/** | |
* Do a self join based on DEPARTMENT1 and SUB_DEPARTMENT1 | |
*/ | |
productMatrix.joinWithSmaller(('DEPARTMENT, 'SUB_DEPARTMENT) ->('DEPARTMENT1, 'SUB_DEPARTMENT1), productMatrixDuplicate) | |
/** | |
* Map over the grouped fields and calculate distance | |
*/ | |
.mapTo('* -> outputSChema) { | |
in: (String, String, String, String, Double, Double, String, String, String, String, Double, Double) => calculateDistance(in) | |
} | |
/** | |
* Filter the record if two products are the same. | |
*/ | |
.filter(('PRODUCT, 'PRODUCT1)) { | |
input: (String, String) => val (product1, product2) = input | |
product1 != product2 | |
} | |
/** | |
* group the result by product, sort it by distance and take top 3 recommendations | |
*/ | |
.groupBy('PRODUCT) { | |
g => | |
g.sortBy('Distance).take(3) | |
} | |
.write(Csv("output", separator = ",", fields = outputSChema)) | |
/** | |
* Calculates Tanimoto and NGram distance based on different product features and combine them together. | |
* @param in | |
* @return | |
*/ | |
def calculateDistance(in: (String, String, String, String, Double, Double, String, String, String, String, Double, Double)) = { | |
val (_, _, p1_product, p1_description, p1_regPrice, p1_salePrice, _, _, p2_product, p2_description, p2_regPrice, p2_salePrice) = in | |
val ngramDistance = 1 - MathUtils.round(ngram.getDistance(p1_description, p2_description).toDouble, SCALE) | |
val p1_vector = new DenseVector(Array(p1_regPrice, p1_salePrice)) | |
val p2_vector = new DenseVector(Array(p2_regPrice, p2_salePrice)) | |
val tanimotoDistance = MathUtils.round(tanimotoDistanceMeasure.distance(p1_vector, p2_vector), SCALE) | |
val distance = MathUtils.round((tanimotoDistance + ngramDistance), SCALE) | |
val result = (p1_product, p2_product, distance) | |
result | |
} | |
} | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment