Skip to content

Instantly share code, notes, and snippets.

@Pratik579
Created October 22, 2023 09:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Pratik579/26fbcb3e81c41d3e9618e131370a40e2 to your computer and use it in GitHub Desktop.
Save Pratik579/26fbcb3e81c41d3e9618e131370a40e2 to your computer and use it in GitHub Desktop.
import ai.djl.MalformedModelException;
import ai.djl.ModelException;
import ai.djl.huggingface.translator.TextEmbeddingTranslatorFactory;
import ai.djl.inference.Predictor;
import ai.djl.repository.zoo.Criteria;
import ai.djl.repository.zoo.ModelNotFoundException;
import ai.djl.repository.zoo.ZooModel;
import ai.djl.training.util.ProgressBar;
import ai.djl.translate.TranslateException;
import java.io.IOException;
/**
* Creates text embedding vectors of two sentences
* and then calculates their cosine similarity.
*
* >>> maven dependencies at end
*/
public class SentenceSimilarity
{
private SentenceSimilarity() {}
public static void main(String[] args) throws IOException, ModelException, TranslateException {
Criteria<String, float[]> criteria =
Criteria.builder()
.setTypes(String.class, float[].class)
.optModelUrls("djl://ai.djl.huggingface.pytorch/sentence-transformers/all-MiniLM-L12-v2")
.optEngine("PyTorch")
.optTranslatorFactory(new TextEmbeddingTranslatorFactory())
.optProgress(new ProgressBar())
.build();
ZooModel<String, float[]> model = criteria.loadModel();
Predictor<String, float[]> predictor = model.newPredictor();
String baseline_completion_text="India has overtaken China in having the largest population in the world with population of 1,425,775,850 at the end of April 2023. Between 1975 and 2010, the population doubled to 1.2 billion, reaching the billion mark in 2000 . The number of children in India peaked more than a decade ago and is now falling. 1,000,000 people in India are Anglo-Indians and 700,000 US citizens are living in India.";
String current_completion_text="India overtook China to become the world's most populous country at the end of April 2023. According to the UN's World Population Dashboard, India's population now stands at slightly over 1.428 billion, edging past China's population of 1.425 billion people. India's population is set to reach 1.7 billion by 2050.";
float [] baselineEmbedding = predictor.predict(baseline_completion_text);
float [] currentEmbedding = predictor.predict(current_completion_text);
double similarity = cosineSimilarity(baselineEmbedding, currentEmbedding);
System.out.println("Sentences similarity is " + similarity);
}
public static double cosineSimilarity(float[] vectorA, float[] vectorB) {
double dotProduct = 0.0;
double normA = 0.0;
double normB = 0.0;
for (int i = 0; i < vectorA.length; i++) {
dotProduct += vectorA[i] * vectorB[i];
normA += Math.pow(vectorA[i], 2);
normB += Math.pow(vectorB[i], 2);
}
return dotProduct / (Math.sqrt(normA) * Math.sqrt(normB));
}
/**
* Calculate cosine distance of two vectors
*
* @param v1, vector 1
* @param v2, vector 2
* @return cosine distance
*/
public static float cosDistance(float[] v1, float[] v2) {
assert v1.length == v2.length;
float dotProduct = dotProduct(v1, v2);
float sumNorm = vectorNorm(v1) * vectorNorm(v2);
return dotProduct / sumNorm;
}
/**
* Calculate the dot-product result of two vectors.
*
* @param v1, vector 1
* @param v2, vector 2
* @return dot product
*/
public static float dotProduct(float[] v1, float[] v2) {
assert v1.length == v2.length;
float result = 0;
for (int i = 0; i < v1.length; i++) {
result += v1[i] * v2[i];
}
return result;
}
/**
* Calculate the norm of a vector
*
* @param v, a float vector
* @return norm
*/
public static float vectorNorm(float[] v) {
float result = 0;
for (float aV : v) {
result += aV * aV;
}
result = (float) Math.sqrt(result);
return result;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment