Last active
December 12, 2015 12:09
-
-
Save digulla/4770445 to your computer and use it in GitHub Desktop.
Helper class to reset/clean up Spring beans after a unit test
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import java.lang.annotation.Annotation; | |
import java.lang.reflect.Field; | |
import java.lang.reflect.Method; | |
import java.lang.reflect.Modifier; | |
import java.util.List; | |
import java.util.Map; | |
import org.junit.After; | |
import org.junit.Rule; | |
import org.junit.runner.RunWith; | |
import org.springframework.aop.TargetClassAware; | |
import org.springframework.beans.BeansException; | |
import org.springframework.beans.factory.config.AutowireCapableBeanFactory; | |
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; | |
import org.springframework.context.ApplicationContext; | |
import org.springframework.context.ApplicationContextAware; | |
import org.springframework.context.ConfigurableApplicationContext; | |
import org.springframework.core.annotation.AnnotationUtils; | |
import org.springframework.test.context.TestExecutionListeners; | |
import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; | |
import org.springframework.test.context.support.DependencyInjectionTestExecutionListener; | |
import com.avanon.exceptions.ShouldNotHappenException; | |
import com.avanon.utils.JavaUtils; | |
import com.google.common.collect.Lists; | |
import com.google.common.collect.Maps; | |
@RunWith( SpringJUnit4ClassRunner.class ) | |
@TestExecutionListeners( { | |
DependencyInjectionTestExecutionListener.class | |
} ) | |
public abstract class ApplicationContextAwareTestBase implements ApplicationContextAware { | |
protected ApplicationContext applicationContext; | |
@Override | |
public void setApplicationContext( ApplicationContext applicationContext ) throws BeansException { | |
this.applicationContext = applicationContext; | |
} | |
public void autowire( Object bean ) { | |
AutowireCapableBeanFactory factory = applicationContext.getAutowireCapableBeanFactory(); | |
factory.autowireBean( bean ); | |
factory.initializeBean( bean, "bean" ); | |
} | |
@After | |
public void afterTest() { | |
resetBeans(); | |
setLocalFieldsToNull(); | |
} | |
private void setLocalFieldsToNull() { | |
// System.out.println(); | |
// System.out.println(); | |
// System.out.println( "setLocalFieldsToNull():" ); | |
Class<?> current = getClass(); | |
while( null != current ) { | |
setLocalFieldsToNull( current ); | |
current = current.getSuperclass(); | |
} | |
} | |
private void setLocalFieldsToNull( Class<?> current ) { | |
for( Field field : current.getDeclaredFields() ) { | |
if( field.getType().isPrimitive() ) { | |
continue; | |
} | |
if( null != field.getAnnotation( Rule.class ) ) { | |
continue; | |
} | |
if( Modifier.isStatic( field.getModifiers() ) ) { | |
continue; | |
} | |
// System.out.println( field ); | |
JavaUtils.setFieldValue( this, field, null ); | |
} | |
} | |
private void resetBeans() { | |
if( ! ( applicationContext instanceof ConfigurableApplicationContext ) ) { | |
return; | |
} | |
// System.out.println(); | |
// System.out.println(); | |
// System.out.println( "resetBeans():" ); | |
ConfigurableListableBeanFactory factory = ( (ConfigurableApplicationContext) applicationContext ).getBeanFactory(); | |
for( String name : applicationContext.getBeanDefinitionNames() ) { | |
// System.out.println( name ); | |
Object bean = factory.getSingleton( name ); | |
if( null == bean ) { | |
continue; | |
} | |
Method[] methods = methodsToCall( bean ); | |
for( Method method : methods ) { | |
// System.out.println( method ); | |
try { | |
method.invoke( bean ); | |
} catch( Exception e ) { | |
throw new ShouldNotHappenException( "Error invoking " + method + " on " + JavaUtils.instOrTypeToString( bean ), e ); | |
} | |
} | |
} | |
} | |
private static Map<Class<?>, Method[]> methodCache = Maps.newHashMap(); | |
private Method[] methodsToCall( Object bean ) { | |
Method[] result = methodCache.get( bean.getClass() ); | |
if( null != result ) { | |
return result; | |
} | |
List<Method> list = Lists.newArrayList(); | |
Class<?> type = ( bean instanceof TargetClassAware ) | |
? ( (TargetClassAware) bean ).getTargetClass() | |
: bean.getClass(); | |
methodsToCall( type, list, After.class ); | |
result = new Method[ list.size() ]; | |
list.toArray( result ); | |
return result; | |
} | |
protected void methodsToCall( Class<?> type, List<Method> list, Class<? extends Annotation> annType ) { | |
// System.out.println( type ); | |
for( Method method : type.getDeclaredMethods() ) { | |
// System.out.println( method ); | |
if( null == AnnotationUtils.findAnnotation( method, annType ) ) { | |
continue; | |
} | |
if( ! Void.TYPE.equals( method.getReturnType() ) ) { | |
throw new IllegalArgumentException( "@After method must return void: " + method ); | |
} | |
if( 0 != method.getParameterTypes().length ) { | |
throw new IllegalArgumentException( "@After method must not have any parameters: " + method ); | |
} | |
list.add( method ); | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment