Skip to content

Instantly share code, notes, and snippets.

@mirasrael
Last active May 8, 2024 19:52
Show Gist options
  • Save mirasrael/2fb953ee95b0c9d16880e4cfb2477e76 to your computer and use it in GitHub Desktop.
Save mirasrael/2fb953ee95b0c9d16880e4cfb2477e76 to your computer and use it in GitHub Desktop.
package ru.waveaccess.hibernate;
import org.hibernate.FetchMode;
import org.hibernate.HibernateException;
import org.hibernate.MappingException;
import org.hibernate.QueryException;
import org.hibernate.engine.spi.LoadQueryInfluencers;
import org.hibernate.engine.spi.SessionFactoryImplementor;
import org.hibernate.engine.spi.SessionImplementor;
import org.hibernate.loader.collection.BasicCollectionLoader;
import org.hibernate.loader.collection.CollectionLoader;
import org.hibernate.loader.collection.OneToManyLoader;
import org.hibernate.persister.collection.QueryableCollection;
import org.hibernate.persister.entity.EntityPersister;
import org.hibernate.type.CollectionType;
import org.hibernate.type.Type;
import java.io.Serializable;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.util.Arrays;
import java.util.List;
/**
* Date: 08/03/2017
* Time: 15:52
*
* @version cld.0.2
*/
@SuppressWarnings("unused")
public class BatchCollectionLoader {
private static final class EagerQueryableCollectionProxy implements InvocationHandler {
private Object queryableCollection;
private EagerQueryableCollectionProxy(QueryableCollection collection) {
this.queryableCollection = collection;
}
static QueryableCollection newInstance(QueryableCollection collection) {
return (QueryableCollection) Proxy.newProxyInstance(
collection.getClass().getClassLoader(),
new Class[] { QueryableCollection.class },
new EagerQueryableCollectionProxy(collection));
}
@Override
public Object invoke(Object proxy, Method method, Object[] args) throws Throwable {
if (method.getName().equals("getFetchMode")) {
return FetchMode.JOIN;
}
return method.invoke(queryableCollection, args);
}
}
private static QueryableCollection getQueryableCollection(
Class entityClass,
String propertyName,
SessionFactoryImplementor factory) throws HibernateException {
String entityName = entityClass.getName();
final EntityPersister entityPersister = factory.getEntityPersister(entityName);
final Type type = entityPersister.getPropertyType(propertyName);
if (!type.isCollectionType()) {
throw new MappingException(
"Property path [" + entityName + "." + propertyName + "] does not reference a collection"
);
}
final String role = ((CollectionType) type).getRole();
try {
return (QueryableCollection) factory.getCollectionPersister(role);
} catch (ClassCastException cce) {
throw new QueryException("collection role is not queryable: " + role, cce);
} catch (Exception e) {
throw new QueryException("collection role not found: " + role, e);
}
}
@SuppressWarnings("WeakerAccess")
protected static void preloadCollectionsInternal(SessionImplementor session, Class entityClass, List entities, String... collectionNames) {
Serializable[] entityIds = new Serializable[entities.size()];
int i = 0;
for (Object entity : entities) {
if (entity != null) {
entityIds[i++] = session.getContextEntityIdentifier(entity);
}
}
if (i != entities.size()) {
entityIds = Arrays.copyOf(entityIds, i);
}
SessionFactoryImplementor sf = session.getFactory();
for (String collectionName : collectionNames) {
QueryableCollection collectionPersister = getQueryableCollection(entityClass, collectionName, sf);
CollectionLoader loader = collectionPersister.isOneToMany() ?
new OneToManyLoader(collectionPersister, entityIds.length, sf, LoadQueryInfluencers.NONE) :
new BasicCollectionLoader(EagerQueryableCollectionProxy.newInstance(collectionPersister), entityIds.length, sf, LoadQueryInfluencers.NONE);
loader.loadCollectionBatch(session, entityIds, collectionPersister.getKeyType());
}
}
@SuppressWarnings("WeakerAccess")
protected static Class getEntityClass(List entities) {
for (Object entity : entities) {
if (entity != null) {
return entity.getClass();
}
}
return null;
}
@SuppressWarnings("unused")
public static void preloadCollections(SessionImplementor session, List entities, String... collectionNames) {
Class entityClass = getEntityClass(entities);
if (entityClass == null) {
return;
}
preloadCollectionsInternal(session, entityClass, entities, collectionNames);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment