Skip to content

Instantly share code, notes, and snippets.

@vrajat
Created April 15, 2015 05:58
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save vrajat/3314181737ecf28ef443 to your computer and use it in GitHub Desktop.
Save vrajat/3314181737ecf28ef443 to your computer and use it in GitHub Desktop.
A rule to help test DAOs using Hibernate.
package com.qubole.nezha;
import java.io.File;
import java.io.IOException;
import java.net.URISyntaxException;
import java.util.List;
import java.util.Map;
import java.util.SortedSet;
import javax.sql.DataSource;
import javax.validation.Validation;
import javax.validation.Validator;
import com.codahale.metrics.MetricRegistry;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.google.common.io.Resources;
import com.qubole.nezha.NezhaConfiguration;
import io.dropwizard.configuration.ConfigurationFactory;
import io.dropwizard.db.DataSourceFactory;
import io.dropwizard.flyway.FlywayFactory;
import io.dropwizard.jackson.Jackson;
import org.flywaydb.core.Flyway;
import org.hibernate.SessionFactory;
import org.hibernate.boot.registry.StandardServiceRegistryBuilder;
import org.hibernate.cfg.AvailableSettings;
import org.hibernate.cfg.Configuration;
import org.hibernate.context.internal.ManagedSessionContext;
import org.hibernate.service.ServiceRegistry;
import org.hibernate.engine.jdbc.connections.internal.DatasourceConnectionProviderImpl;
import org.hibernate.engine.jdbc.connections.spi.ConnectionProvider;
import org.junit.rules.TestRule;
import org.junit.runner.Description;
import org.junit.runners.model.Statement;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Sets;
import io.dropwizard.configuration.ConfigurationException;
import io.dropwizard.db.ManagedDataSource;
/*
Inspired by https://gist.github.com/Sch3lp/9185192
*/
public class DropwizardHibernateRule implements TestRule {
private SessionFactory sessionFactory;
private NezhaConfiguration nezhaConfiguration;
private final ImmutableList<Class<?>> entities;
private final String configPath;
private final ObjectMapper MAPPER = Jackson.newObjectMapper();
private final Validator validator = Validation.buildDefaultValidatorFactory().getValidator();
public static DropwizardHibernateRule create(String configPath,
ImmutableList<Class<?>> entities) {
return new DropwizardHibernateRule(configPath, entities);
}
private DropwizardHibernateRule(String configPath, ImmutableList<Class<?>> entities) {
this.configPath = configPath;
this.entities = entities;
}
public SessionFactory getSessionFactory() {
return sessionFactory;
}
@Override
public Statement apply(final Statement base, Description description) {
return new Statement() {
@Override
public void evaluate() throws Throwable {
before();
try {
base.evaluate();
} finally {
after();
}
}
};
}
protected void before() throws
IOException, ConfigurationException, ClassNotFoundException, URISyntaxException {
ConfigurationFactory<NezhaConfiguration> factory = new ConfigurationFactory<>(
NezhaConfiguration.class, validator, MAPPER, "");
nezhaConfiguration = factory.build(new File(Resources.getResource(configPath).toURI()));
final MetricRegistry metricRegistry = new MetricRegistry();
DataSourceFactory dataSourceFactory = nezhaConfiguration.getDataSourceFactory();
final ManagedDataSource dataSource = dataSourceFactory.build(metricRegistry, "Rule");
final ConnectionProvider provider = buildConnectionProvider(dataSource,
dataSourceFactory.getProperties());
sessionFactory = buildSessionFactory(dataSourceFactory, provider, ImmutableMap.<String, String> of(), entities);
// open session/transaction
ManagedSessionContext.bind(sessionFactory.openSession());
FlywayFactory flywayFactory = nezhaConfiguration.getFlywayFactory();
final String[] fwLocations = new String[flywayFactory.getLocations().size()];
Flyway flyway = new Flyway();
flyway.setDataSource(nezhaConfiguration.getDataSourceFactory().getUrl(),
nezhaConfiguration.getDataSourceFactory().getUser(),
nezhaConfiguration.getDataSourceFactory().getPassword());
flyway.setLocations(flywayFactory.getLocations().toArray(fwLocations));
flyway.migrate();
}
protected void after() {
// close session/transaction
ManagedSessionContext.unbind(sessionFactory);
}
/**
* From io.dropwizard.hibernate.SessionFactoryFactory
*/
private ConnectionProvider buildConnectionProvider(DataSource dataSource,
Map<String, String> properties) {
final DatasourceConnectionProviderImpl connectionProvider = new DatasourceConnectionProviderImpl();
connectionProvider.setDataSource(dataSource);
connectionProvider.configure(properties);
return connectionProvider;
}
/**
* From com.yammer.dropwizard.hibernate.SessionFactoryFactory
*/
private SessionFactory buildSessionFactory(DataSourceFactory dbConfig,
ConnectionProvider connectionProvider,
ImmutableMap<String, String> properties,
List<Class<?>> entities) {
final Configuration configuration = new Configuration();
configuration.setProperty(AvailableSettings.CURRENT_SESSION_CONTEXT_CLASS, "managed");
configuration.setProperty(AvailableSettings.USE_SQL_COMMENTS, Boolean.toString(dbConfig.isAutoCommentsEnabled()));
configuration.setProperty(AvailableSettings.USE_GET_GENERATED_KEYS, "true");
configuration.setProperty(AvailableSettings.GENERATE_STATISTICS, "true");
configuration.setProperty(AvailableSettings.USE_REFLECTION_OPTIMIZER, "true");
configuration.setProperty(AvailableSettings.ORDER_UPDATES, "true");
configuration.setProperty(AvailableSettings.ORDER_INSERTS, "true");
configuration.setProperty(AvailableSettings.USE_NEW_ID_GENERATOR_MAPPINGS, "true");
configuration.setProperty("jadira.usertype.autoRegisterUserTypes", "true");
for (Map.Entry<String, String> property : properties.entrySet()) {
configuration.setProperty(property.getKey(), property.getValue());
}
addAnnotatedClasses(configuration, entities);
final ServiceRegistry registry = new StandardServiceRegistryBuilder()
.addService(ConnectionProvider.class, connectionProvider)
.applySettings(properties)
.build();
return configuration.buildSessionFactory(registry);
}
/**
* From com.yammer.dropwizard.hibernate.SessionFactoryFactory
*/
private void addAnnotatedClasses(Configuration configuration,
Iterable<Class<?>> entities) {
final SortedSet<String> entityClasses = Sets.newTreeSet();
for (Class<?> klass : entities) {
configuration.addAnnotatedClass(klass);
entityClasses.add(klass.getCanonicalName());
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment