Skip to content

Instantly share code, notes, and snippets.

@kozak127
Created May 5, 2020 22:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kozak127/9e44be2d3c8775e9ff7bbda47bab7223 to your computer and use it in GitHub Desktop.
Save kozak127/9e44be2d3c8775e9ff7bbda47bab7223 to your computer and use it in GitHub Desktop.
RSQL JPA Specification SpringBoot nestedProperties
import cz.jirutka.rsql.parser.ast.AndNode;
import cz.jirutka.rsql.parser.ast.ComparisonNode;
import cz.jirutka.rsql.parser.ast.OrNode;
import cz.jirutka.rsql.parser.ast.RSQLVisitor;
import org.springframework.data.jpa.domain.Specification;
public class CustomRsqlVisitor<T> implements RSQLVisitor<Specification<T>, Void> {
private RsqlSpecificationBuilder<T> builder;
public CustomRsqlVisitor() {
builder = new RsqlSpecificationBuilder<>();
}
@Override
public Specification<T> visit(AndNode node, Void param) {
return builder.createSpecification(node);
}
@Override
public Specification<T> visit(OrNode node, Void param) {
return builder.createSpecification(node);
}
@Override
public Specification<T> visit(ComparisonNode node, Void params) {
return builder.createSpecification(node);
}
}
import org.springframework.data.jpa.repository.JpaSpecificationExecutor;
import org.springframework.data.repository.CrudRepository;
public interface WhateverRepository extends
CrudRepository<Whatever, Long>,
JpaSpecificationExecutor<Whatever> {
}
import cz.jirutka.rsql.parser.ast.ComparisonOperator;
import cz.jirutka.rsql.parser.ast.RSQLOperators;
import java.util.Arrays;
public enum RsqlSearchOperation {
EQUAL(RSQLOperators.EQUAL),
NOT_EQUAL(RSQLOperators.NOT_EQUAL),
GREATER_THAN(RSQLOperators.GREATER_THAN),
GREATER_THAN_OR_EQUAL(RSQLOperators.GREATER_THAN_OR_EQUAL),
LESS_THAN(RSQLOperators.LESS_THAN),
LESS_THAN_OR_EQUAL(RSQLOperators.LESS_THAN_OR_EQUAL),
IN(RSQLOperators.IN),
NOT_IN(RSQLOperators.NOT_IN);
private ComparisonOperator operator;
RsqlSearchOperation(final ComparisonOperator operator) {
this.operator = operator;
}
public static RsqlSearchOperation getSimpleOperator(final ComparisonOperator operator) {
return Arrays.stream(values())
.filter(operation -> operation.getOperator().equals(operator))
.findAny().orElse(null);
}
public ComparisonOperator getOperator() {
return operator;
}
}
import cz.jirutka.rsql.parser.ast.ComparisonOperator;
import java.lang.reflect.InvocationTargetException;
import java.math.BigDecimal;
import java.time.ZonedDateTime;
import java.time.format.DateTimeFormatter;
import java.util.List;
import java.util.stream.Collectors;
import javax.persistence.criteria.CriteriaBuilder;
import javax.persistence.criteria.CriteriaQuery;
import javax.persistence.criteria.Path;
import javax.persistence.criteria.Predicate;
import javax.persistence.criteria.Root;
import org.springframework.data.jpa.domain.Specification;
public class RsqlSpecification<T> implements Specification<T> {
private String property;
private transient ComparisonOperator operator;
private List<String> arguments;
public RsqlSpecification(final String property, final ComparisonOperator operator,
final List<String> arguments) {
super();
this.property = property;
this.operator = operator;
this.arguments = arguments;
}
@Override
public Predicate toPredicate(final Root<T> root, final CriteriaQuery<?> query, final CriteriaBuilder builder) {
Path<String> propertyExpression = parseProperty(root);
final List<Object> args = castArguments(propertyExpression);
final Object argument = args.get(0);
RsqlSearchOperation operation = RsqlSearchOperation.getSimpleOperator(operator);
if (operation == null) {
return null;
}
switch (operation) {
case EQUAL: {
return createEqualsPredicate(builder, propertyExpression, argument);
}
case NOT_EQUAL: {
return createNotEqualsPredicate(builder, propertyExpression, argument);
}
case GREATER_THAN: {
return builder.greaterThan(propertyExpression, argument.toString());
}
case GREATER_THAN_OR_EQUAL: {
return builder.greaterThanOrEqualTo(propertyExpression, argument.toString());
}
case LESS_THAN: {
return builder.lessThan(propertyExpression, argument.toString());
}
case LESS_THAN_OR_EQUAL: {
return builder.lessThanOrEqualTo(propertyExpression, argument.toString());
}
case IN:
return propertyExpression.in(args);
case NOT_IN:
return builder.not(propertyExpression.in(args));
default:
return null;
}
}
private Predicate createNotEqualsPredicate(CriteriaBuilder builder, Path<String> propertyExpression, Object argument) {
if (argument instanceof String) {
return builder.notLike(propertyExpression, argument.toString().replace('*', '%'));
} else if (argument == null) {
return builder.isNotNull(propertyExpression);
} else {
return builder.notEqual(propertyExpression, argument);
}
}
private Predicate createEqualsPredicate(CriteriaBuilder builder, Path<String> propertyExpression, Object argument) {
if (argument instanceof String) {
return builder.like(propertyExpression, argument.toString().replace('*', '%'));
} else if (argument == null) {
return builder.isNull(propertyExpression);
} else {
return builder.equal(propertyExpression, argument);
}
}
private Path<String> parseProperty(Root<T> root) {
Path<String> path;
if (property.contains(".")) {
// Nested properties
String[] pathSteps = property.split("\\.");
String step = pathSteps[0];
path = root.get(step);
for (int i = 1; i <= pathSteps.length - 1; i++) {
path = path.get(pathSteps[i]);
}
} else {
path = root.get(property);
}
return path;
}
private List<Object> castArguments(Path<?> propertyExpression) {
final Class<?> type = propertyExpression.getJavaType();
return arguments.stream().map(arg -> {
if (type.equals(Integer.class)) {
return Integer.parseInt(arg);
} else if (type.equals(Long.class)) {
return Long.parseLong(arg);
} else if (type.equals(Boolean.class)) {
return Boolean.parseBoolean(arg);
} else if (type.equals(BigDecimal.class)) {
return new BigDecimal(arg);
} else if (type.equals(ZonedDateTime.class)) {
return ZonedDateTime.parse(arg, DateTimeFormatter.ISO_DATE_TIME);
// ENUMS
} else if (type.isEnum()) {
try {
return type.getMethod("valueOf", String.class).invoke(null, arg);
} catch (IllegalAccessException | NoSuchMethodException | InvocationTargetException e) {
throw new IllegalArgumentException();
}
} else {
return arg;
}
}).collect(Collectors.toList());
}
}
import cz.jirutka.rsql.parser.ast.ComparisonNode;
import cz.jirutka.rsql.parser.ast.LogicalNode;
import cz.jirutka.rsql.parser.ast.LogicalOperator;
import cz.jirutka.rsql.parser.ast.Node;
import java.util.List;
import java.util.Objects;
import java.util.stream.Collectors;
import org.springframework.data.jpa.domain.Specification;
public class RsqlSpecificationBuilder<T> {
public Specification<T> createSpecification(final Node node) {
if (node instanceof LogicalNode) {
return createSpecification((LogicalNode) node);
}
if (node instanceof ComparisonNode) {
return createSpecification((ComparisonNode) node);
}
return null;
}
public Specification<T> createSpecification(final LogicalNode logicalNode) {
List<Specification<T>> specs = logicalNode.getChildren()
.stream()
.map(this::createSpecification)
.filter(Objects::nonNull)
.collect(Collectors.toList());
Specification<T> result = specs.get(0);
if (logicalNode.getOperator() == LogicalOperator.AND) {
for (int i = 1; i < specs.size(); i++) {
result = Objects.requireNonNull(Specification.where(result)).and(specs.get(i));
}
} else if (logicalNode.getOperator() == LogicalOperator.OR) {
for (int i = 1; i < specs.size(); i++) {
result = Objects.requireNonNull(Specification.where(result)).or(specs.get(i));
}
}
return result;
}
public Specification<T> createSpecification(final ComparisonNode comparisonNode) {
return Specification.where(
new RsqlSpecification<T>(comparisonNode.getSelector(), comparisonNode.getOperator(),
comparisonNode.getArguments()));
}
}
public List<Whatever> searchWhatever(
String query, Integer pageNumber, Integer pageSize, String orderBy, Direction direction) {
Node rootNode = new RSQLParser().parse(query);
Specification<Whatever> spec = rootNode.accept(new CustomRsqlVisitor<Whatever>());
Sort sort = Sort.by(direction, orderBy);
Pageable pageable = PageRequest.of(pageNumber, pageSize, sort);
return repository.findAll(spec, pageable).getContent();
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment