Skip to content

Instantly share code, notes, and snippets.

@metatype
Last active August 29, 2015 14:27
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save metatype/9b1f39a24e52f5c6f3e1 to your computer and use it in GitHub Desktop.
Save metatype/9b1f39a24e52f5c6f3e1 to your computer and use it in GitHub Desktop.
8/2015 HackDay - Dynamic class loading for GemFire functions
package com.gemstone.gemfire.internal.cache.execute;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectStreamClass;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Set;
import com.gemstone.gemfire.InternalGemFireException;
import com.gemstone.gemfire.cache.Cache;
import com.gemstone.gemfire.cache.Region;
import com.gemstone.gemfire.internal.cache.GemFireCacheImpl;
public class Classes {
public static Map<String, byte[]> getDependenciesAsBytes(Class<?> clazz) throws IOException, ClassNotFoundException {
Map<String, byte[]> deps = new HashMap<String, byte[]>();
for (Class<?> cl : findClassDependencies(clazz)) {
InputStream is = Thread.currentThread().getContextClassLoader().getResourceAsStream(cl.getName().replace('.', '/') + ".class");
try {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
int n;
byte[] buf = new byte[1024];
while ((n = is.read(buf, 0, buf.length)) != -1) {
bos.write(buf, 0, n);
}
bos.flush();
deps.put(cl.getName(), bos.toByteArray());
bos.close();
} finally {
is.close();
}
}
return deps;
}
public static Iterable<Class<?>> findClassDependencies(Class<?> clazz) throws ClassNotFoundException {
CapturingClassLoader cl = new CapturingClassLoader(Thread.currentThread().getContextClassLoader());
// this only finds the static dependencies, would need to invoke all behaviors to capture every
// dependency
cl.loadClass(clazz.getName());
return cl.getCapturedClasses();
}
public static class CapturingClassLoader extends ClassLoader {
private final Set<Class<?>> classes;
public CapturingClassLoader(ClassLoader parent) {
super(parent);
classes = new HashSet<Class<?>>();
}
public Iterable<Class<?>> getCapturedClasses() {
return classes;
}
protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
Class<?> clazz = super.loadClass(name, resolve);
classes.add(clazz);
return clazz;
}
}
public static class ResolvingObjectInputStream extends ObjectInputStream {
public ResolvingObjectInputStream(InputStream in) throws IOException {
super(in);
}
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
/* load classes from TCCL instead of from this mess:
*
* if there is a method on the current thread's stack whose declaring class was
* defined by a user-defined class loader (and was not a generated to
* implement reflective invocations), then <code>loader</code> is class
* loader corresponding to the closest such method to the currently
* executing frame; otherwise, <code>loader</code> is
* <code>null</code>
*/
return Thread.currentThread().getContextClassLoader().loadClass(desc.getName());
}
}
public static class RegionClassLoader extends ClassLoader {
private final String token;
public RegionClassLoader(String token, ClassLoader parent) {
super(parent);
this.token = token;
}
protected Class<?> findClass(final String name) throws ClassNotFoundException {
// this is a hack to pull classes from a region, stored by token
Cache c = GemFireCacheImpl.getExisting();
Region<String, Map<String, byte[]>> r = c.getRegion("hackday");
Map<String, byte[]> deps = r.get(token);
if (deps == null) {
throw new InternalGemFireException("Unable to find token " + token);
}
byte[] clazz = (byte[]) deps.get(name);
if (clazz != null) {
return defineClass(name, clazz, 0, clazz.length);
}
return super.findClass(name);
}
}
}
package com.gemstone.gemfire.internal.cache.execute;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Map;
import javax.xml.bind.annotation.adapters.HexBinaryAdapter;
import org.apache.commons.io.output.ByteArrayOutputStream;
import com.gemstone.gemfire.InternalGemFireException;
import com.gemstone.gemfire.cache.execute.Function;
import com.gemstone.gemfire.cache.execute.FunctionContext;
import com.gemstone.gemfire.internal.cache.execute.Classes.RegionClassLoader;
import com.gemstone.gemfire.internal.cache.execute.Classes.ResolvingObjectInputStream;
public class DynamicFunction implements Function {
private static final long serialVersionUID = 1L;
// store the original function in serialized form
private final byte[] serializedFn;
// MD5 hash of the captured classes
private final String token;
// bytecode for each class indexed by classname
private final transient Map<String, byte[]> deps;
private final boolean ha;
private final boolean optimizeForWrite;
private final boolean hasResult;
private final String id;
public DynamicFunction(Function fn) throws IOException, ClassNotFoundException {
// cache the original function attributes
this.ha = fn.isHA();
this.optimizeForWrite = fn.optimizeForWrite();
this.hasResult = fn.hasResult();
this.id = fn.getId();
serializedFn = serialize(fn);
deps = Classes.getDependenciesAsBytes(fn.getClass());
token = tokenize(deps.values());
}
private byte[] serialize(Function fn) throws IOException {
ByteArrayOutputStream bos = new ByteArrayOutputStream();
ObjectOutputStream oos = new ObjectOutputStream(bos);
try {
oos.writeObject(fn);
} finally {
oos.close();
}
return bos.toByteArray();
}
private String tokenize(Iterable<byte[]> classes) {
try {
MessageDigest md = MessageDigest.getInstance("MD5");
for (byte[] clazz : classes) {
md.update(clazz);
}
return (new HexBinaryAdapter()).marshal(md.digest());
} catch (NoSuchAlgorithmException e) {
throw new InternalGemFireException(e);
}
}
public Map<String, byte[]> getDeps() {
return deps;
}
public String getCodeToken() {
return token;
}
@Override
public boolean hasResult() {
return hasResult;
}
@Override
public void execute(FunctionContext context) {
ClassLoader tccl = Thread.currentThread().getContextClassLoader();
ClassLoader client = new RegionClassLoader(token, tccl);
Thread.currentThread().setContextClassLoader(client);
try {
ObjectInputStream ois = new ResolvingObjectInputStream(new ByteArrayInputStream(serializedFn));
try {
((Function) ois.readObject()).execute(context);
} finally {
ois.close();
}
} catch (Exception e) {
throw new InternalGemFireException(e);
} finally {
Thread.currentThread().setContextClassLoader(tccl);
}
}
@Override
public String getId() {
return id;
}
@Override
public boolean optimizeForWrite() {
return optimizeForWrite;
}
@Override
public boolean isHA() {
return ha;
}
}
diff --git a/gemfire-core/src/main/java/com/gemstone/gemfire/internal/cache/execute/ServerFunctionExecutor.java b/gemfire-core/src/main/java/com/gemstone/gemfire/internal/cache/execute/ServerFunctionExecutor.java
index b9a1d36..61de6b3 100644
--- a/gemfire-core/src/main/java/com/gemstone/gemfire/internal/cache/execute/ServerFunctionExecutor.java
+++ b/gemfire-core/src/main/java/com/gemstone/gemfire/internal/cache/execute/ServerFunctionExecutor.java
@@ -23,6 +23,7 @@ import com.gemstone.gemfire.cache.execute.Function;
import com.gemstone.gemfire.cache.execute.FunctionException;
import com.gemstone.gemfire.cache.execute.FunctionService;
import com.gemstone.gemfire.cache.execute.ResultCollector;
+import com.gemstone.gemfire.internal.cache.GemFireCacheImpl;
import com.gemstone.gemfire.internal.cache.TXManagerImpl;
import com.gemstone.gemfire.internal.i18n.LocalizedStrings;
/**
@@ -153,6 +154,16 @@ public class ServerFunctionExecutor extends AbstractExecution {
validateExecution(function, null);
long start = stats.startTime();
stats.startFunctionExecution(true);
+
+ if (function instanceof DynamicFunction) {
+ DynamicFunction fn = (DynamicFunction) function;
+ String token = fn.getCodeToken();
+
+ GemFireCacheImpl gfe = GemFireCacheImpl.getExisting();
+ System.out.println("Caching token " + fn.getCodeToken());
+ gfe.getRegion("hackday").putIfAbsent(token, fn.getDeps());
+ }
+
ExecuteFunctionOp.execute(this.pool, function, this, args, memberMappedArg,
this.allServers, hasResult, rc, this.isFnSerializationReqd,
UserAttributes.userAttributes.get(), groups);
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment