Skip to content

Instantly share code, notes, and snippets.

@digulla
Last active December 12, 2015 12:09
Show Gist options
  • Save digulla/4770445 to your computer and use it in GitHub Desktop.
Save digulla/4770445 to your computer and use it in GitHub Desktop.
Helper class to reset/clean up Spring beans after a unit test
import java.lang.annotation.Annotation;
import java.lang.reflect.Field;
import java.lang.reflect.Method;
import java.lang.reflect.Modifier;
import java.util.List;
import java.util.Map;
import org.junit.After;
import org.junit.Rule;
import org.junit.runner.RunWith;
import org.springframework.aop.TargetClassAware;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.config.AutowireCapableBeanFactory;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.core.annotation.AnnotationUtils;
import org.springframework.test.context.TestExecutionListeners;
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner;
import org.springframework.test.context.support.DependencyInjectionTestExecutionListener;
import com.avanon.exceptions.ShouldNotHappenException;
import com.avanon.utils.JavaUtils;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@RunWith( SpringJUnit4ClassRunner.class )
@TestExecutionListeners( {
DependencyInjectionTestExecutionListener.class
} )
public abstract class ApplicationContextAwareTestBase implements ApplicationContextAware {
protected ApplicationContext applicationContext;
@Override
public void setApplicationContext( ApplicationContext applicationContext ) throws BeansException {
this.applicationContext = applicationContext;
}
public void autowire( Object bean ) {
AutowireCapableBeanFactory factory = applicationContext.getAutowireCapableBeanFactory();
factory.autowireBean( bean );
factory.initializeBean( bean, "bean" );
}
@After
public void afterTest() {
resetBeans();
setLocalFieldsToNull();
}
private void setLocalFieldsToNull() {
// System.out.println();
// System.out.println();
// System.out.println( "setLocalFieldsToNull():" );
Class<?> current = getClass();
while( null != current ) {
setLocalFieldsToNull( current );
current = current.getSuperclass();
}
}
private void setLocalFieldsToNull( Class<?> current ) {
for( Field field : current.getDeclaredFields() ) {
if( field.getType().isPrimitive() ) {
continue;
}
if( null != field.getAnnotation( Rule.class ) ) {
continue;
}
if( Modifier.isStatic( field.getModifiers() ) ) {
continue;
}
// System.out.println( field );
JavaUtils.setFieldValue( this, field, null );
}
}
private void resetBeans() {
if( ! ( applicationContext instanceof ConfigurableApplicationContext ) ) {
return;
}
// System.out.println();
// System.out.println();
// System.out.println( "resetBeans():" );
ConfigurableListableBeanFactory factory = ( (ConfigurableApplicationContext) applicationContext ).getBeanFactory();
for( String name : applicationContext.getBeanDefinitionNames() ) {
// System.out.println( name );
Object bean = factory.getSingleton( name );
if( null == bean ) {
continue;
}
Method[] methods = methodsToCall( bean );
for( Method method : methods ) {
// System.out.println( method );
try {
method.invoke( bean );
} catch( Exception e ) {
throw new ShouldNotHappenException( "Error invoking " + method + " on " + JavaUtils.instOrTypeToString( bean ), e );
}
}
}
}
private static Map<Class<?>, Method[]> methodCache = Maps.newHashMap();
private Method[] methodsToCall( Object bean ) {
Method[] result = methodCache.get( bean.getClass() );
if( null != result ) {
return result;
}
List<Method> list = Lists.newArrayList();
Class<?> type = ( bean instanceof TargetClassAware )
? ( (TargetClassAware) bean ).getTargetClass()
: bean.getClass();
methodsToCall( type, list, After.class );
result = new Method[ list.size() ];
list.toArray( result );
return result;
}
protected void methodsToCall( Class<?> type, List<Method> list, Class<? extends Annotation> annType ) {
// System.out.println( type );
for( Method method : type.getDeclaredMethods() ) {
// System.out.println( method );
if( null == AnnotationUtils.findAnnotation( method, annType ) ) {
continue;
}
if( ! Void.TYPE.equals( method.getReturnType() ) ) {
throw new IllegalArgumentException( "@After method must return void: " + method );
}
if( 0 != method.getParameterTypes().length ) {
throw new IllegalArgumentException( "@After method must not have any parameters: " + method );
}
list.add( method );
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment