Created
January 9, 2014 12:06
-
-
Save TuomasKiviaho/8333112 to your computer and use it in GitHub Desktop.
Here's a recursive approach on how to implement conversions using a visitor (which is based logic found in CollectionAnyVisitor and others).
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
/* | |
* Copyright 2012, Mysema Ltd | |
* | |
* Licensed under the Apache License, Version 2.0 (the "License"); | |
* you may not use this file except in compliance with the License. | |
* You may obtain a copy of the License at | |
* http://www.apache.org/licenses/LICENSE-2.0 | |
* Unless required by applicable law or agreed to in writing, software | |
* distributed under the License is distributed on an "AS IS" BASIS, | |
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
* See the License for the specific language governing permissions and | |
* limitations under the License. | |
*/ | |
package com.mysema.query.jpa; | |
import java.util.ArrayList; | |
import java.util.List; | |
import javax.annotation.Nullable; | |
import com.google.common.collect.ImmutableList; | |
import com.mysema.query.support.NumberConversion; | |
import com.mysema.query.support.NumberConversions; | |
import com.mysema.query.types.CollectionExpression; | |
import com.mysema.query.types.Constant; | |
import com.mysema.query.types.Expression; | |
import com.mysema.query.types.ExpressionBase; | |
import com.mysema.query.types.ExpressionUtils; | |
import com.mysema.query.types.FactoryExpression; | |
import com.mysema.query.types.Operation; | |
import com.mysema.query.types.OperationImpl; | |
import com.mysema.query.types.Operator; | |
import com.mysema.query.types.Ops; | |
import com.mysema.query.types.ParamExpression; | |
import com.mysema.query.types.Path; | |
import com.mysema.query.types.PathImpl; | |
import com.mysema.query.types.PathMetadata; | |
import com.mysema.query.types.Predicate; | |
import com.mysema.query.types.PredicateOperation; | |
import com.mysema.query.types.PredicateTemplate; | |
import com.mysema.query.types.ProjectionRole; | |
import com.mysema.query.types.SubQueryExpression; | |
import com.mysema.query.types.TemplateExpression; | |
import com.mysema.query.types.TemplateExpressionImpl; | |
import com.mysema.query.types.Visitor; | |
import com.mysema.query.types.path.ListPath; | |
import com.mysema.query.types.path.SimplePath; | |
/** | |
* Conversions provides module specific projection conversion functionality | |
* | |
* @author tiwe | |
* | |
*/ | |
public final class Conversions { | |
public static class Context { | |
public boolean replace; | |
public final List<Expression<?>> expressions = new ArrayList<Expression<?>>(); | |
public final List<Expression<?>> replacements = new ArrayList<Expression<?>>(); | |
public void add(Expression<?> anyExpression, Expression<?> replacement) { | |
replace = true; | |
expressions.add(anyExpression); | |
replacements.add(replacement); | |
} | |
public void add(Context c) { | |
replace |= c.replace; | |
expressions.addAll(c.expressions); | |
replacements.addAll(c.replacements); | |
} | |
public void clear() { | |
expressions.clear(); | |
replacements.clear(); | |
} | |
} | |
public static abstract class ReplacingVisitor implements Visitor<Expression<?>, Context> { | |
static class FactoryExpressionImpl<T> extends ExpressionBase<T> implements FactoryExpression<T> { | |
private static final long serialVersionUID = -1L; | |
private final FactoryExpression<T> inner; | |
private final List<Expression<?>> args; | |
FactoryExpressionImpl(FactoryExpression<T> inner, List<Expression<?>> args) { | |
super(inner.getType()); | |
this.inner = inner; | |
this.args = args; | |
} | |
@Override | |
public List<Expression<?>> getArgs() { | |
return args; | |
} | |
@Override | |
public T newInstance(Object... a) { | |
return inner.newInstance(a); | |
} | |
@Override | |
public <R, C> R accept(Visitor<R, C> v, C context) { | |
return v.visit(this, context); | |
} | |
} | |
static class PredicateFactoryExpression extends FactoryExpressionImpl<Boolean> implements Predicate { | |
private static final long serialVersionUID = -1L; | |
@Nullable | |
private volatile Predicate not; | |
PredicateFactoryExpression(FactoryExpression<Boolean> inner, List<Expression<?>> args) { | |
super(inner, args); | |
} | |
@Override | |
public Predicate not() { | |
if (not == null) { | |
not = PredicateOperation.create(Ops.NOT, this); | |
} | |
return not; | |
} | |
} | |
protected Predicate exists(Context c, Predicate condition) { | |
return condition; | |
} | |
protected Expression<?> replace(Expression<?> expr, Context context) | |
{ | |
return expr; | |
} | |
@SuppressWarnings({ "unchecked", "rawtypes" }) | |
protected <T> Path<T> replaceParent(Path<T> path, Path<?> parent) { | |
PathMetadata<?> metadata = new PathMetadata(parent, path.getMetadata().getElement(), | |
path.getMetadata().getPathType()); | |
if (path instanceof CollectionExpression) { | |
CollectionExpression col = (CollectionExpression)path; | |
return new ListPath(col.getParameter(0), SimplePath.class, metadata); | |
} else { | |
return new PathImpl<T>(path.getType(), metadata); | |
} | |
} | |
@SuppressWarnings({ "unchecked", "rawtypes" }) | |
@Override | |
public Expression<?> visit(FactoryExpression<?> expr, Context context) { | |
Expression<?> replacement = replace(expr,context); | |
if (context.replace) | |
{ | |
return replacement; | |
} else { | |
final Expression<?>[] args = new Expression<?>[expr.getArgs().size()]; | |
for (int i = 0; i < args.length; i++) { | |
Context c = new Context(); | |
args[i] = expr.getArgs().get(i).accept(this, c); | |
context.add(c); | |
} | |
if (context.replace) { | |
if (expr.getType().equals(Boolean.class)) { | |
Predicate predicate = new PredicateFactoryExpression((FactoryExpression) expr, ImmutableList.copyOf(args)); | |
return !context.expressions.isEmpty() ? exists(context, predicate) : predicate; | |
} else { | |
return new FactoryExpressionImpl(expr, ImmutableList.copyOf(args)); | |
} | |
} else { | |
return expr; | |
} | |
} | |
} | |
@SuppressWarnings( { "rawtypes", "unchecked" }) | |
@Override | |
public Expression<?> visit(TemplateExpression<?> expr, Context context) { | |
Expression<?> replacement = replace(expr,context); | |
if (context.replace) | |
{ | |
return replacement; | |
} else { | |
final Object[] args = new Object[expr.getArgs().size()]; | |
for (int i = 0; i < args.length; i++) { | |
Context c = new Context(); | |
if (expr.getArg(i) instanceof Expression) { | |
args[i] = ((Expression<?>)expr.getArg(i)).accept(this, c); | |
} else { | |
args[i] = expr.getArg(i); | |
} | |
context.add(c); | |
} | |
if (context.replace) { | |
if (expr.getType().equals(Boolean.class)) { | |
Predicate predicate = new PredicateTemplate(expr.getTemplate(), args); | |
return !context.expressions.isEmpty() ? exists(context, predicate) : predicate; | |
} else { | |
return new TemplateExpressionImpl(expr.getType(), expr.getTemplate(), ImmutableList.copyOf(args)); | |
} | |
} else { | |
return expr; | |
} | |
} | |
} | |
@SuppressWarnings( { "rawtypes", "unchecked" }) | |
@Override | |
public Expression<?> visit(Operation<?> expr, Context context) { | |
Expression<?> replacement = replace(expr,context); | |
if (context.replace) | |
{ | |
return replacement; | |
} else { | |
final Expression<?>[] args = new Expression<?>[expr.getArgs().size()]; | |
for (int i = 0; i < args.length; i++) { | |
Context c = new Context(); | |
args[i] = expr.getArg(i).accept(this, c); | |
context.add(c); | |
} | |
if (context.replace) { | |
if (expr.getType().equals(Boolean.class)) { | |
Predicate predicate = new PredicateOperation((Operator)expr.getOperator(), ImmutableList.copyOf(args)); | |
return !context.expressions.isEmpty() ? exists(context, predicate) : predicate; | |
} else { | |
return new OperationImpl(expr.getType(), (Operator)expr.getOperator(), ImmutableList.copyOf(args)); | |
} | |
} else { | |
return expr; | |
} | |
} | |
} | |
public Expression<?> visit(Path<?> expr, Context context) { | |
Expression<?> replacement = replace(expr,context); | |
if (context.replace) | |
{ | |
return replacement; | |
} else if (expr.getMetadata().getParent() != null) { | |
Context c = new Context(); | |
Path<?> parent = (Path<?>) expr.getMetadata().getParent().accept(this, c); | |
if (c.replace) { | |
context.add(c); | |
return replaceParent(expr, parent); | |
} | |
} | |
return expr; | |
} | |
@Override | |
public Expression<?> visit(Constant<?> expr, Context context) | |
{ | |
return replace((Expression<?>) expr, context); | |
} | |
@Override | |
public Expression<?> visit(ParamExpression<?> expr, Context context) | |
{ | |
return replace((Expression<?>) expr, context); | |
} | |
@Override | |
public Expression<?> visit(SubQueryExpression<?> expr, Context context) | |
{ | |
// TODO ValidatingVisitor | |
return replace((Expression<?>) expr, context); | |
} | |
} | |
public static final class NumberConversionVisitor extends ReplacingVisitor { | |
@Override | |
public Expression<?> visit(Path<?> expr, Context context) | |
{ | |
if (expr instanceof ProjectionRole<?>) { | |
Context c = new Context(); | |
Expression<?> projection = ((ProjectionRole<?>) expr).getProjection().accept(this, c); | |
context.add(c); | |
if (context.replace) { | |
return projection; | |
} else { | |
return expr; | |
} | |
} else { | |
Expression<?> replacement = replace(expr,context); | |
if (context.replace) | |
{ | |
return replacement; | |
} | |
} | |
return expr; | |
} | |
@SuppressWarnings("unchecked") | |
@Override | |
protected Expression<?> replace(Expression<?> expr, Context context) | |
{ | |
if (Number.class.isAssignableFrom(expr.getType())) { | |
NumberConversion numberConversion = new NumberConversion(expr); | |
context.add(expr, numberConversion); | |
return numberConversion; | |
} else { | |
return expr; | |
} | |
} | |
} | |
public static <RT> Expression<RT> convert(Expression<RT> expr) { | |
if (isAggSumWithConversion(expr) || isCountAggConversion(expr)) { | |
return new NumberConversion<RT>(expr); | |
} else if (expr instanceof FactoryExpression) { | |
FactoryExpression<RT> factorye = (FactoryExpression<RT>)expr; | |
for (Expression<?> e : factorye.getArgs()) { | |
if (isAggSumWithConversion(e) || isCountAggConversion(expr)) { | |
return new NumberConversions<RT>(factorye); | |
} | |
} | |
} | |
return expr; | |
} | |
@SuppressWarnings("unchecked") | |
public static <RT> Expression<RT> convertForNativeQuery(Expression<RT> expr) { | |
Context context = new Context(); | |
NumberConversionVisitor v = new NumberConversionVisitor(); | |
return (Expression<RT>) expr.accept(v, context); | |
// if (Number.class.isAssignableFrom(expr.getType())) { | |
// return new NumberConversion<RT>(expr); | |
// } else if (expr instanceof FactoryExpression) { | |
// FactoryExpression<RT> factorye = (FactoryExpression<RT>)expr; | |
// return new NumberConversions<RT>(factorye); | |
// } | |
// return expr; | |
} | |
private static boolean isAggSumWithConversion(Expression<?> expr) { | |
expr = ExpressionUtils.extract(expr); | |
if (expr instanceof Operation) { | |
Operation<?> operation = (Operation<?>)expr; | |
Class<?> type = operation.getType(); | |
if (type.equals(Float.class) || type.equals(Integer.class) | |
|| type.equals(Short.class) || type.equals(Byte.class)) { | |
if (operation.getOperator() == Ops.AggOps.SUM_AGG) { | |
return true; | |
} else { | |
for (Expression<?> e : operation.getArgs()) { | |
if (isAggSumWithConversion(e)) { | |
return true; | |
} | |
} | |
} | |
} | |
} | |
return false; | |
} | |
private static boolean isCountAggConversion(Expression<?> expr) { | |
expr = ExpressionUtils.extract(expr); | |
if (expr instanceof Operation) { | |
Operation<?> operation = (Operation<?>)expr; | |
return operation.getOperator() == Ops.AggOps.COUNT_AGG; | |
} | |
return false; | |
} | |
private Conversions() {} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment