Skip to content

Instantly share code, notes, and snippets.

@recht
Created December 6, 2011 22:31
Show Gist options
  • Save recht/1440345 to your computer and use it in GitHub Desktop.
Save recht/1440345 to your computer and use it in GitHub Desktop.
ObjectBuilder
package com.tradeshift.commons.groovy;
import groovy.lang.Closure;
import groovy.lang.GroovyClassLoader;
import groovy.util.ConfigObject;
import groovy.util.ConfigSlurper;
import java.beans.PropertyDescriptor;
import java.io.IOException;
import java.lang.reflect.Constructor;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.ParameterizedType;
import java.net.URL;
import java.security.CodeSource;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.commons.beanutils.PropertyUtils;
import org.apache.commons.io.IOUtils;
import org.codehaus.groovy.ast.ASTNode;
import org.codehaus.groovy.ast.ClassNode;
import org.codehaus.groovy.ast.CodeVisitorSupport;
import org.codehaus.groovy.ast.expr.ArrayExpression;
import org.codehaus.groovy.ast.expr.BinaryExpression;
import org.codehaus.groovy.ast.expr.ClassExpression;
import org.codehaus.groovy.ast.expr.ClosureExpression;
import org.codehaus.groovy.ast.expr.ConstantExpression;
import org.codehaus.groovy.ast.expr.Expression;
import org.codehaus.groovy.ast.expr.MapExpression;
import org.codehaus.groovy.ast.expr.MethodCallExpression;
import org.codehaus.groovy.ast.expr.PropertyExpression;
import org.codehaus.groovy.ast.expr.VariableExpression;
import org.codehaus.groovy.ast.stmt.ExpressionStatement;
import org.codehaus.groovy.classgen.GeneratorContext;
import org.codehaus.groovy.control.CompilationFailedException;
import org.codehaus.groovy.control.CompilationUnit;
import org.codehaus.groovy.control.CompilationUnit.PrimaryClassNodeOperation;
import org.codehaus.groovy.control.CompilerConfiguration;
import org.codehaus.groovy.control.MultipleCompilationErrorsException;
import org.codehaus.groovy.control.Phases;
import org.codehaus.groovy.control.SourceUnit;
import org.codehaus.groovy.syntax.Token;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.BeanUtils;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.NoSuchBeanDefinitionException;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.annotation.Scope;
import org.springframework.context.annotation.ScopedProxyMode;
import org.springframework.stereotype.Component;
import org.springframework.util.ClassUtils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.StringUtils;
import com.googlecode.ehcache.annotations.Cacheable;
import com.tradeshift.commons.exception.GenericErrorCode;
import com.tradeshift.commons.exception.ValidationException;
/**
* Use this class to map Groovy DSL to objects.
*
* The Groovy DSL structure must match the classes mapped - either using read bean properties
* or private fields.
*
* Advanced usage: Sometimes, the DSL only contains string references to other objects, which
* must be looked up. Do this by adding a set-method which takes the string list and
* a factory of some kind. Then pass in the factory as a context parameter.
*
* If a property is a closure/block, any typed arguments will be looked up in the
* spring application context and curried automatically.
*
* @author recht
*
*/
@Component
@Scope(proxyMode=ScopedProxyMode.TARGET_CLASS)
public class ObjectBuilder implements ApplicationContextAware {
private static final Logger log = LoggerFactory.getLogger(ObjectBuilder.class);
private ApplicationContext applicationContext;
private ConfigSlurper cs = new ConfigSlurper();
private ConfigSlurper restrictedCs = new ConfigSlurper();
private AutowireCapableBeanFactory beanFactory;
public ObjectBuilder() {
CompilerConfiguration cfg = new CompilerConfiguration();
GroovyClassLoader cl = new GroovyClassLoader(Thread.currentThread().getContextClassLoader(), cfg) {
@Override
protected CompilationUnit createCompilationUnit(CompilerConfiguration config, CodeSource source) {
CompilationUnit cu = super.createCompilationUnit(config, source);
cu.addPhaseOperation(new PrimaryClassNodeOperation() {
@Override
public void call(SourceUnit source, GeneratorContext context, ClassNode classNode) throws CompilationFailedException {
CodeVisitor visitor = new CodeVisitor();
source.getAST().getStatementBlock().visit(visitor);
VariableExpression lines = new VariableExpression("__linenumbers");
MapExpression map = new MapExpression();
for (Map.Entry<String, ASTNode> e : visitor.lineNumbers.entrySet()) {
List<Expression> ex = new ArrayList<Expression>();
ex.add(new ConstantExpression(e.getValue().getLineNumber()));
ex.add(new ConstantExpression(e.getValue().getColumnNumber()));
ex.add(new ConstantExpression(e.getValue().getLastLineNumber()));
ex.add(new ConstantExpression(e.getValue().getLastColumnNumber()));
map.addMapEntryExpression(new ConstantExpression(e.getKey()), new ArrayExpression(new ClassNode(Integer.class), ex));
}
BinaryExpression assign = new BinaryExpression(lines, new Token(100, null, 1, 1), map);
source.getAST().getStatementBlock().addStatement(new ExpressionStatement(assign));
}
}, Phases.SEMANTIC_ANALYSIS);
return cu;
}
};
restrictedCs.setClassLoader(cl);
}
@SuppressWarnings("unchecked")
public <T> Map<String, T> readConfig(URL location, Class<T> target, Object ... context) {
if (location == null)
return null;
try {
ConfigObject cfgs = cs.parse(IOUtils.toString(location.openStream(), "UTF-8"));
Map<String, T> os = new LinkedHashMap<String, T>();
for (Map.Entry<?, ConfigObject> e : (Set<Map.Entry<Object, ConfigObject>>) cfgs.entrySet()) {
T bean = writeBean(target, e.getValue(), e.getKey().toString(), context);
if (beanFactory != null) {
beanFactory.autowireBean(bean);
}
os.put(e.getKey().toString(), bean);
}
return os;
} catch (MultipleCompilationErrorsException e) {
throw handleCompilationError(e);
} catch (IOException e) {
log.debug("Could not read config: {}", location);
throw new RuntimeException();
}
}
@Cacheable(cacheName="scriptCache")
public <T> T readConfig(String config, Class<T> target) {
if (config == null) return null;
try {
ConfigObject cfg = restrictedCs.parse(config);
T bean = writeBean(target, cfg, null, new Object[0]);
if (beanFactory != null) {
beanFactory.autowireBean(bean);
}
return bean;
} catch (MultipleCompilationErrorsException e) {
throw handleCompilationError(e);
}
}
public ConfigObject readConfig(String config) {
try {
ConfigObject cfg = restrictedCs.parse(config);
return cfg;
} catch (MultipleCompilationErrorsException e) {
throw handleCompilationError(e);
}
}
private RuntimeException handleCompilationError(MultipleCompilationErrorsException e) {
int c = e.getErrorCollector().getErrorCount();
for (int i = 0; i < c; i++) {
Exception root = e.getErrorCollector().getException(i);
if (root instanceof ValidationException) {
return (ValidationException)root;
}
}
return e;
}
@SuppressWarnings("unchecked")
private <T> T writeBean(Class<T> type, ConfigObject obj, String name, Object[] context) {
T bean;
try {
Constructor<T> ctor = type.getConstructor(String.class);
bean = BeanUtils.instantiateClass(ctor, name);
} catch (NoSuchMethodException e) {
bean = BeanUtils.instantiate(type);
}
Set<?> keys = obj.keySet();
for (Object key : keys) {
if ("__linenumbers".equals(key.toString())) continue;
Class<?> propertyType = getPropertyType(type, key.toString());
if (propertyType.isAssignableFrom(Closure.class)) {
Closure c = (Closure) obj.get(key);
@SuppressWarnings("rawtypes")
Class[] types = c.getParameterTypes();
int curried = 0;
for (int i = 0; i < types.length; i++) {
if (types[i] == Object.class) continue;
try {
String[] names = applicationContext.getBeanNamesForType(types[i]);
if (names.length == 0) continue;
String beanName = null;
if (names.length == 1) {
beanName = names[0];
}
if (names.length == 2) {
if (names[0].startsWith("scopedTarget.")) {
beanName = names[1];
}
if (names[1].startsWith("scopedTarget.")) {
beanName = names[0];
}
}
if (name != null) {
Object b = applicationContext.getBean(beanName);
c = c.ncurry(i - curried, new Object[] { b });
log.debug("Currying {} to {} at position {}", new Object[] { types[i], b, i });
curried++;
}
} catch (NoSuchBeanDefinitionException e) {
log.debug("Unable to curry {} at position {}", new Object[] { types[i], i });
}
}
setProperty(bean, key.toString(), c);
} else if (BeanUtils.isSimpleProperty(propertyType)) {
setProperty(bean, key.toString(), obj.get(key));
} else if (propertyType.isAssignableFrom(List.class)){
Object value = obj.get(key);
Class<?> elementType = getListType(type, key.toString());
if (value instanceof List<?>) {
if (BeanUtils.isSimpleProperty(elementType)) {
setProperty(bean, key.toString(), value);
} else {
List<?> values = (List<?>) value;
if (values.isEmpty()) {
continue;
}
inner: {
String methodName = "set" + StringUtils.capitalize(key.toString());
for (Object ctx : context) {
Method m = BeanUtils.findMethod(type, methodName, List.class, ctx.getClass());
if (m != null) {
ReflectionUtils.invokeMethod(m, bean, values, ctx);
break inner;
}
}
throw new IllegalArgumentException("Complex lists not supported for " + key + " on " + type);
}
}
} else if (value instanceof ConfigObject) {
List<Object> os = new ArrayList<Object>();
for (Map.Entry<?, ConfigObject> e : (Set<Map.Entry<Object, ConfigObject>>) ((ConfigObject)value).entrySet()) {
os.add(writeBean(elementType, e.getValue(), e.getKey().toString(), context));
}
setProperty(bean, key.toString(), os);
} else {
throw new IllegalArgumentException("Value is not a list: " + value + ", class: " + value.getClass());
}
} else if (propertyType.isAssignableFrom(Map.class) || propertyType.isAssignableFrom(HashMap.class)) {
Map<Object, Object> values = new LinkedHashMap<Object, Object>();
values.putAll((Map<?,?>)obj.get(key));
setProperty(bean, key.toString(), values);
} else {
if (propertyType.isAssignableFrom(obj.get(key).getClass())) {
setProperty(bean, key.toString(), obj.get(key));
} else {
Constructor<?> c = ClassUtils.getConstructorIfAvailable(propertyType, String.class);
if (c != null) {
try {
Object val = c.newInstance(obj.get(key));
setProperty(bean, key.toString(), val);
} catch (Exception e) {}
} else {
Object ib = writeBean(propertyType, (ConfigObject)obj.get(key), key.toString(), context);
setProperty(bean, key.toString(), ib);
}
}
}
}
return bean;
}
private static Class<?> getPropertyType(Class<?> type, String name) {
PropertyDescriptor pd = BeanUtils.getPropertyDescriptor(type, name);
if (pd != null) {
return pd.getPropertyType();
} else {
Field field = ReflectionUtils.findField(type, name);
if (field != null) {
return field.getType();
}
}
throw new IllegalArgumentException("No field " + name + " in " + type);
}
private static void setProperty(Object bean, String name, Object value) {
try {
PropertyUtils.setProperty(bean, name, value);
} catch (NoSuchMethodException e) {
Field field = ReflectionUtils.findField(bean.getClass(), name);
if (!field.isAccessible()) {
ReflectionUtils.makeAccessible(field);
}
ReflectionUtils.setField(field, bean, value);
} catch (Exception e) {
throw new RuntimeException("Tried setting " + name + " with type " + (value != null ? value.getClass() : null), e);
}
}
private static Class<?> getListType(Class<?> type, String name) {
PropertyDescriptor pd = BeanUtils.getPropertyDescriptor(type, name);
if (pd != null && pd.getWriteMethod() != null) {
ParameterizedType pt = (ParameterizedType) pd.getWriteMethod().getGenericParameterTypes()[0];
Class<?> elementType = (Class<?>) pt.getActualTypeArguments()[0];
return elementType;
} else {
Field field = ReflectionUtils.findField(type, name);
return (Class<?>) ((ParameterizedType)field.getGenericType()).getActualTypeArguments()[0];
}
}
@Override
public void setApplicationContext(ApplicationContext applicationContext) throws BeansException {
this.applicationContext = applicationContext;
beanFactory = applicationContext.getAutowireCapableBeanFactory();
}
private static class CodeVisitor extends CodeVisitorSupport {
private Map<String, ASTNode> lineNumbers = new HashMap<String, ASTNode>();
@Override
public void visitClassExpression(ClassExpression expression) {
if (expression.getType().isEnum()) {
super.visitClassExpression(expression);
return;
}
log.warn("Script visits class {}", expression);
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError(expression.getClass().getName()));
}
@Override
public void visitClosureExpression(ClosureExpression expression) {
super.visitClosureExpression(expression);
}
@Override
public void visitMethodCallExpression(MethodCallExpression call) {
if ("getClass".equals(call.getMethodAsString())) {
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError("Do not call getClass()"));
}
if ("getMetaClass".equals(call.getMethodAsString())) {
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError("Do not call getMetaClass()"));
}
if ("invokeMethod".equals(call.getMethodAsString())) {
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError("Do not call invokeMethod()"));
}
if ("sleep".equals(call.getMethodAsString())) {
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError("Do not call sleep()"));
}
if ("execute".equals(call.getMethodAsString())) {
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError("Do not call execute()"));
}
super.visitMethodCallExpression(call);
}
@Override
public void visitPropertyExpression(PropertyExpression expression) {
if ("class".equals(expression.getPropertyAsString())) {
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError("Do not call getClass()"));
}
if ("metaClass".equals(expression.getPropertyAsString())) {
throw new ValidationException(GenericErrorCode.INVALID_INPUT.subError("Do not call getMetaClass()"));
}
super.visitPropertyExpression(expression);
}
@Override
public void visitBinaryExpression(BinaryExpression expression) {
if (expression.getLeftExpression() instanceof VariableExpression) {
if (expression.getRightExpression() instanceof ClosureExpression) {
if (expression.getOperation().getMeaning() == 100) {
VariableExpression var = (VariableExpression) expression.getLeftExpression();
lineNumbers.put(var.getName(), expression);
}
}
}
super.visitBinaryExpression(expression);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment