Skip to content

Instantly share code, notes, and snippets.

@s1monw
Created January 29, 2014 21:26
Show Gist options
  • Save s1monw/8697560 to your computer and use it in GitHub Desktop.
Save s1monw/8697560 to your computer and use it in GitHub Desktop.
/*
* Licensed to Elasticsearch under one or more contributor
* license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright
* ownership. Elasticsearch licenses this file to you under
* the Apache License, Version 2.0 (the "License"); you may
* not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.lucene.queries;
import org.apache.lucene.index.*;
import org.apache.lucene.search.*;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
/**
*/
public abstract class BlendedTermQuery extends Query {
private final String term;
private final String[] fields;
private final float bias;
public BlendedTermQuery(String[] fields, String term) {
this(fields, term, 0.0f);
}
public BlendedTermQuery(String[] fields, String term, float bias) {
if (fields == null) {
throw new IllegalArgumentException("fields must not be null");
}
if (term == null) {
throw new IllegalArgumentException("term must note be null");
}
this.fields = fields;
this.bias = bias;
this.term = term;
}
@Override
public Query rewrite(IndexReader reader) throws IOException {
IndexReaderContext context = reader.getContext();
Term[] terms = new Term[fields.length];
TermContext[] ctx = new TermContext[fields.length];
for (int i = 0; i < terms.length; i++) {
terms[i] = new Term(fields[i], term);
ctx[i] = TermContext.build(context, terms[i]);
}
blend(ctx, reader.maxDoc());
return topLevelQuery(terms, ctx);
}
protected abstract Query topLevelQuery(Term[] terms, TermContext[] ctx);
protected void blend(TermContext[] contexts, int maxDoc) {
long sum = 0;
int numZeroDF = 0;
for (TermContext ctx : contexts) {
int df = ctx.docFreq();
sum += df;
if (df == 0) {
numZeroDF++;
}
}
if (sum == 0) {
return; // we are done that term doesn't exist at all
}
final long avg = sum / (contexts.length - numZeroDF);
for (int i = 0; i < contexts.length; i++) {
int df = contexts[i].docFreq();
if (df == 0) {
continue;
}
long blendedIDF = avg + ((long) (bias * (df - avg)));
contexts[i].setDocFreq((int) Math.min(maxDoc, blendedIDF));
contexts[i] = adjustTTF(contexts[i]);
}
}
private TermContext adjustTTF(TermContext termContext) {
if (termContext.docFreq() > termContext.totalTermFreq()) {
TermContext newTermContext = new TermContext(termContext.topReaderContext);
List<AtomicReaderContext> leaves = termContext.topReaderContext.leaves();
final int len;
if (leaves == null) {
len = 1;
} else {
len = leaves.size();
}
int df = termContext.docFreq();
long ttf = Math.max(df, termContext.totalTermFreq());
for (int i = 0; i < len; i++) {
TermState termState = termContext.get(i);
if (termState == null) {
continue;
}
newTermContext.register(termState, i, df, ttf);
df = 0;
ttf = 0;
}
return newTermContext;
}
return termContext;
}
@Override
public String toString(String field) {
return "blended(\"" + term + "\", fields: " + Arrays.toString(fields) + ")";
}
@Override
public boolean equals(Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
if (!super.equals(o)) return false;
BlendedTermQuery that = (BlendedTermQuery) o;
if (Float.compare(that.bias, bias) != 0) return false;
if (!Arrays.equals(fields, that.fields)) return false;
if (!term.equals(that.term)) return false;
return true;
}
@Override
public int hashCode() {
int result = super.hashCode();
result = 31 * result + term.hashCode();
result = 31 * result + Arrays.hashCode(fields);
result = 31 * result + Float.floatToIntBits(bias);
return result;
}
public static BlendedTermQuery booleanBlendedQuery(String[] fields, String term, float bias, final boolean disableCoord) {
return new BlendedTermQuery(fields, term, bias) {
protected Query topLevelQuery(Term[] terms, TermContext[] ctx) {
BooleanQuery query = new BooleanQuery(disableCoord);
for (int i = 0; i < terms.length; i++) {
query.add(new TermQuery(terms[i], ctx[i]), BooleanClause.Occur.SHOULD);
}
return query;
}
};
}
public static BlendedTermQuery dismaxBlendedQuery(String[] fields, String term, float bias, final float tieBreakerMultiplier ) {
return new BlendedTermQuery(fields, term, bias) {
protected Query topLevelQuery(Term[] terms, TermContext[] ctx) {
DisjunctionMaxQuery query = new DisjunctionMaxQuery(tieBreakerMultiplier);
for (int i = 0; i < terms.length; i++) {
query.add(new TermQuery(terms[i], ctx[i]));
}
return query;
}
};
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment