Skip to content

Instantly share code, notes, and snippets.

@TuomasKiviaho
Created January 9, 2014 12:06
Show Gist options
  • Save TuomasKiviaho/8333112 to your computer and use it in GitHub Desktop.
Save TuomasKiviaho/8333112 to your computer and use it in GitHub Desktop.
Here's a recursive approach on how to implement conversions using a visitor (which is based logic found in CollectionAnyVisitor and others).
/*
* Copyright 2012, Mysema Ltd
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.mysema.query.jpa;
import java.util.ArrayList;
import java.util.List;
import javax.annotation.Nullable;
import com.google.common.collect.ImmutableList;
import com.mysema.query.support.NumberConversion;
import com.mysema.query.support.NumberConversions;
import com.mysema.query.types.CollectionExpression;
import com.mysema.query.types.Constant;
import com.mysema.query.types.Expression;
import com.mysema.query.types.ExpressionBase;
import com.mysema.query.types.ExpressionUtils;
import com.mysema.query.types.FactoryExpression;
import com.mysema.query.types.Operation;
import com.mysema.query.types.OperationImpl;
import com.mysema.query.types.Operator;
import com.mysema.query.types.Ops;
import com.mysema.query.types.ParamExpression;
import com.mysema.query.types.Path;
import com.mysema.query.types.PathImpl;
import com.mysema.query.types.PathMetadata;
import com.mysema.query.types.Predicate;
import com.mysema.query.types.PredicateOperation;
import com.mysema.query.types.PredicateTemplate;
import com.mysema.query.types.ProjectionRole;
import com.mysema.query.types.SubQueryExpression;
import com.mysema.query.types.TemplateExpression;
import com.mysema.query.types.TemplateExpressionImpl;
import com.mysema.query.types.Visitor;
import com.mysema.query.types.path.ListPath;
import com.mysema.query.types.path.SimplePath;
/**
* Conversions provides module specific projection conversion functionality
*
* @author tiwe
*
*/
public final class Conversions {
public static class Context {
public boolean replace;
public final List<Expression<?>> expressions = new ArrayList<Expression<?>>();
public final List<Expression<?>> replacements = new ArrayList<Expression<?>>();
public void add(Expression<?> anyExpression, Expression<?> replacement) {
replace = true;
expressions.add(anyExpression);
replacements.add(replacement);
}
public void add(Context c) {
replace |= c.replace;
expressions.addAll(c.expressions);
replacements.addAll(c.replacements);
}
public void clear() {
expressions.clear();
replacements.clear();
}
}
public static abstract class ReplacingVisitor implements Visitor<Expression<?>, Context> {
static class FactoryExpressionImpl<T> extends ExpressionBase<T> implements FactoryExpression<T> {
private static final long serialVersionUID = -1L;
private final FactoryExpression<T> inner;
private final List<Expression<?>> args;
FactoryExpressionImpl(FactoryExpression<T> inner, List<Expression<?>> args) {
super(inner.getType());
this.inner = inner;
this.args = args;
}
@Override
public List<Expression<?>> getArgs() {
return args;
}
@Override
public T newInstance(Object... a) {
return inner.newInstance(a);
}
@Override
public <R, C> R accept(Visitor<R, C> v, C context) {
return v.visit(this, context);
}
}
static class PredicateFactoryExpression extends FactoryExpressionImpl<Boolean> implements Predicate {
private static final long serialVersionUID = -1L;
@Nullable
private volatile Predicate not;
PredicateFactoryExpression(FactoryExpression<Boolean> inner, List<Expression<?>> args) {
super(inner, args);
}
@Override
public Predicate not() {
if (not == null) {
not = PredicateOperation.create(Ops.NOT, this);
}
return not;
}
}
protected Predicate exists(Context c, Predicate condition) {
return condition;
}
protected Expression<?> replace(Expression<?> expr, Context context)
{
return expr;
}
@SuppressWarnings({ "unchecked", "rawtypes" })
protected <T> Path<T> replaceParent(Path<T> path, Path<?> parent) {
PathMetadata<?> metadata = new PathMetadata(parent, path.getMetadata().getElement(),
path.getMetadata().getPathType());
if (path instanceof CollectionExpression) {
CollectionExpression col = (CollectionExpression)path;
return new ListPath(col.getParameter(0), SimplePath.class, metadata);
} else {
return new PathImpl<T>(path.getType(), metadata);
}
}
@SuppressWarnings({ "unchecked", "rawtypes" })
@Override
public Expression<?> visit(FactoryExpression<?> expr, Context context) {
Expression<?> replacement = replace(expr,context);
if (context.replace)
{
return replacement;
} else {
final Expression<?>[] args = new Expression<?>[expr.getArgs().size()];
for (int i = 0; i < args.length; i++) {
Context c = new Context();
args[i] = expr.getArgs().get(i).accept(this, c);
context.add(c);
}
if (context.replace) {
if (expr.getType().equals(Boolean.class)) {
Predicate predicate = new PredicateFactoryExpression((FactoryExpression) expr, ImmutableList.copyOf(args));
return !context.expressions.isEmpty() ? exists(context, predicate) : predicate;
} else {
return new FactoryExpressionImpl(expr, ImmutableList.copyOf(args));
}
} else {
return expr;
}
}
}
@SuppressWarnings( { "rawtypes", "unchecked" })
@Override
public Expression<?> visit(TemplateExpression<?> expr, Context context) {
Expression<?> replacement = replace(expr,context);
if (context.replace)
{
return replacement;
} else {
final Object[] args = new Object[expr.getArgs().size()];
for (int i = 0; i < args.length; i++) {
Context c = new Context();
if (expr.getArg(i) instanceof Expression) {
args[i] = ((Expression<?>)expr.getArg(i)).accept(this, c);
} else {
args[i] = expr.getArg(i);
}
context.add(c);
}
if (context.replace) {
if (expr.getType().equals(Boolean.class)) {
Predicate predicate = new PredicateTemplate(expr.getTemplate(), args);
return !context.expressions.isEmpty() ? exists(context, predicate) : predicate;
} else {
return new TemplateExpressionImpl(expr.getType(), expr.getTemplate(), ImmutableList.copyOf(args));
}
} else {
return expr;
}
}
}
@SuppressWarnings( { "rawtypes", "unchecked" })
@Override
public Expression<?> visit(Operation<?> expr, Context context) {
Expression<?> replacement = replace(expr,context);
if (context.replace)
{
return replacement;
} else {
final Expression<?>[] args = new Expression<?>[expr.getArgs().size()];
for (int i = 0; i < args.length; i++) {
Context c = new Context();
args[i] = expr.getArg(i).accept(this, c);
context.add(c);
}
if (context.replace) {
if (expr.getType().equals(Boolean.class)) {
Predicate predicate = new PredicateOperation((Operator)expr.getOperator(), ImmutableList.copyOf(args));
return !context.expressions.isEmpty() ? exists(context, predicate) : predicate;
} else {
return new OperationImpl(expr.getType(), (Operator)expr.getOperator(), ImmutableList.copyOf(args));
}
} else {
return expr;
}
}
}
public Expression<?> visit(Path<?> expr, Context context) {
Expression<?> replacement = replace(expr,context);
if (context.replace)
{
return replacement;
} else if (expr.getMetadata().getParent() != null) {
Context c = new Context();
Path<?> parent = (Path<?>) expr.getMetadata().getParent().accept(this, c);
if (c.replace) {
context.add(c);
return replaceParent(expr, parent);
}
}
return expr;
}
@Override
public Expression<?> visit(Constant<?> expr, Context context)
{
return replace((Expression<?>) expr, context);
}
@Override
public Expression<?> visit(ParamExpression<?> expr, Context context)
{
return replace((Expression<?>) expr, context);
}
@Override
public Expression<?> visit(SubQueryExpression<?> expr, Context context)
{
// TODO ValidatingVisitor
return replace((Expression<?>) expr, context);
}
}
public static final class NumberConversionVisitor extends ReplacingVisitor {
@Override
public Expression<?> visit(Path<?> expr, Context context)
{
if (expr instanceof ProjectionRole<?>) {
Context c = new Context();
Expression<?> projection = ((ProjectionRole<?>) expr).getProjection().accept(this, c);
context.add(c);
if (context.replace) {
return projection;
} else {
return expr;
}
} else {
Expression<?> replacement = replace(expr,context);
if (context.replace)
{
return replacement;
}
}
return expr;
}
@SuppressWarnings("unchecked")
@Override
protected Expression<?> replace(Expression<?> expr, Context context)
{
if (Number.class.isAssignableFrom(expr.getType())) {
NumberConversion numberConversion = new NumberConversion(expr);
context.add(expr, numberConversion);
return numberConversion;
} else {
return expr;
}
}
}
public static <RT> Expression<RT> convert(Expression<RT> expr) {
if (isAggSumWithConversion(expr) || isCountAggConversion(expr)) {
return new NumberConversion<RT>(expr);
} else if (expr instanceof FactoryExpression) {
FactoryExpression<RT> factorye = (FactoryExpression<RT>)expr;
for (Expression<?> e : factorye.getArgs()) {
if (isAggSumWithConversion(e) || isCountAggConversion(expr)) {
return new NumberConversions<RT>(factorye);
}
}
}
return expr;
}
@SuppressWarnings("unchecked")
public static <RT> Expression<RT> convertForNativeQuery(Expression<RT> expr) {
Context context = new Context();
NumberConversionVisitor v = new NumberConversionVisitor();
return (Expression<RT>) expr.accept(v, context);
// if (Number.class.isAssignableFrom(expr.getType())) {
// return new NumberConversion<RT>(expr);
// } else if (expr instanceof FactoryExpression) {
// FactoryExpression<RT> factorye = (FactoryExpression<RT>)expr;
// return new NumberConversions<RT>(factorye);
// }
// return expr;
}
private static boolean isAggSumWithConversion(Expression<?> expr) {
expr = ExpressionUtils.extract(expr);
if (expr instanceof Operation) {
Operation<?> operation = (Operation<?>)expr;
Class<?> type = operation.getType();
if (type.equals(Float.class) || type.equals(Integer.class)
|| type.equals(Short.class) || type.equals(Byte.class)) {
if (operation.getOperator() == Ops.AggOps.SUM_AGG) {
return true;
} else {
for (Expression<?> e : operation.getArgs()) {
if (isAggSumWithConversion(e)) {
return true;
}
}
}
}
}
return false;
}
private static boolean isCountAggConversion(Expression<?> expr) {
expr = ExpressionUtils.extract(expr);
if (expr instanceof Operation) {
Operation<?> operation = (Operation<?>)expr;
return operation.getOperator() == Ops.AggOps.COUNT_AGG;
}
return false;
}
private Conversions() {}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment