Skip to content

Instantly share code, notes, and snippets.

@orausch
Created January 23, 2020 10:53
Show Gist options
  • Save orausch/9a42e24b782319447a515e8c29b364a0 to your computer and use it in GitHub Desktop.
Save orausch/9a42e24b782319447a515e8c29b364a0 to your computer and use it in GitHub Desktop.
@Test
public void test() {
List<String> trainData = new ArrayList<>();
trainData.add("1 2 1 2 1 2");
trainData.add("3 4 3 4 3 4");
trainData.add("5 6 5 6 5 6");
trainData.add("7 8 5 6 5 6");
SentenceIterator iter = new CollectionSentenceIterator(trainData);
SentenceIterator iter1 = new CollectionSentenceIterator(trainData);
SkipGram<VocabWord> elementAlgo = new SkipGram<VocabWord>();
SkipGram<VocabWord> elementAlgo1 = new SkipGram<VocabWord>();
int seed = 0; // this fails
// int seed = 1; // this works
Word2Vec vec = new Word2Vec.Builder().elementsLearningAlgorithm(elementAlgo)
.seed(seed)
.workers(1)
.iterate(iter)
.allowParallelTokenization(false)
.build();
vec.fit();
Word2Vec vec1 = new Word2Vec.Builder().elementsLearningAlgorithm(elementAlgo1)
.seed(seed)
.workers(1)
.iterate(iter1)
.allowParallelTokenization(false)
.build();
vec1.fit();
for (String s : new String[] { "1", "2", "3", "4", "5", "6", "7", "8" }) {
double[] A = vec.getWordVector(s);
double[] B = vec1.getWordVector(s);
assertArrayEquals(A, B, 0.00000001);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment