Skip to content

Instantly share code, notes, and snippets.

@anatawa12
Last active February 6, 2021 03:02
Show Gist options
  • Save anatawa12/b01f3978fee7a5a562b59d03fa62cef5 to your computer and use it in GitHub Desktop.
Save anatawa12/b01f3978fee7a5a562b59d03fa62cef5 to your computer and use it in GitHub Desktop.
object-input-stream-classloader-test
mkdir out
javac -d out C0Main.java
java -cp out C0Main > 2-result.txt
============ object input stream subclass ============
OIS: 使用するObjectInputStreamのローダー
called: readObjectを呼ぶクラスのローダー
object: readObjectの戻り値のクラスローダー
==== overloaded object input stream subclass ====
OIS: loader1, called: appLoad, object: loader1
OIS: loader2, called: appLoad, object: loader2
OIS: appLoad, called: appLoad, object: appLoad
==== non-overloaded object input stream subclass ====
OIS: loader1, called: appLoad, object: appLoad
OIS: loader2, called: appLoad, object: appLoad
OIS: appLoad, called: appLoad, object: appLoad
============ readObject ============
classReader: SerializeableのreadObject内でOIS.readObjectを呼ぶためのクラスのローダー
called: readObjectを呼ぶクラスのローダー
object: SerializeableのreadObject内でのOIS.readObjectの戻り値のローダー
========
classReader: loader1, called: appLoad, object: loader1
classReader: loader2, called: appLoad, object: loader2
classReader: appLoad, called: appLoad, object: appLoad
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.InputStream;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.io.ObjectStreamClass;
import java.io.Serializable;
import java.net.URL;
import java.net.URLClassLoader;
import java.util.function.Function;
import java.util.function.Supplier;
public class C0Main {
ClassLoader loader1;
ClassLoader loader2;
ClassLoader appLoad;
byte[] bytesClassToSerialize1;
byte[] bytesClassToSerialize2;
public C0Main() throws Throwable {
// prepare
// 1. classpathの取得
String classPath = getClass().getProtectionDomain().getCodeSource().getLocation().getPath();
// 2. classLoaderの生成
loader1 = new URLClassLoader(new URL[]{ new File(classPath).toURI().toURL() }, null);
loader2 = new URLClassLoader(new URL[]{ new File(classPath).toURI().toURL() }, null);
appLoad = C0Main.class.getClassLoader();
// 3. ちゃんと別のクラスが生成されてることを確認
if (loader1.loadClass(ClassToSerialize1.class.getName())
== loader2.loadClass(ClassToSerialize1.class.getName()))
throw new IllegalStateException("bug");
bytesClassToSerialize1 = serialize(new ClassToSerialize1());
bytesClassToSerialize2 = serialize(new ClassToSerialize2(ClassToSerialize1.class));
}
private void runTest() throws Throwable {
System.out.println("============ object input stream subclass ============");
runTestOISExtend();
System.out.println("============ readObject ============");
runTestReadObject();
}
// OISの継承クラス
private void runTestOISExtend() throws Throwable {
ObjectInputStream stream;
System.out.println("OIS: 使用するObjectInputStreamのローダー");
System.out.println("called: readObjectを呼ぶクラスのローダー");
System.out.println("object: readObjectの戻り値のクラスローダー");
System.out.println("==== overloaded object input stream subclass ====");
stream = createOIS(loader1, ObjectInputStream1.class, new ByteArrayInputStream(bytesClassToSerialize1));
printClassLoader("OIS: loader1, called: appLoad", stream.readObject().getClass());
stream = createOIS(loader2, ObjectInputStream1.class, new ByteArrayInputStream(bytesClassToSerialize1));
printClassLoader("OIS: loader2, called: appLoad", stream.readObject().getClass());
stream = createOIS(appLoad, ObjectInputStream1.class, new ByteArrayInputStream(bytesClassToSerialize1));
printClassLoader("OIS: appLoad, called: appLoad", stream.readObject().getClass());
System.out.println("==== non-overloaded object input stream subclass ====");
stream = createOIS(loader1, ObjectInputStream2.class, new ByteArrayInputStream(bytesClassToSerialize1));
printClassLoader("OIS: loader1, called: appLoad", stream.readObject().getClass());
stream = createOIS(loader2, ObjectInputStream2.class, new ByteArrayInputStream(bytesClassToSerialize1));
printClassLoader("OIS: loader2, called: appLoad", stream.readObject().getClass());
stream = createOIS(appLoad, ObjectInputStream2.class, new ByteArrayInputStream(bytesClassToSerialize1));
printClassLoader("OIS: appLoad, called: appLoad", stream.readObject().getClass());
}
// readObjectのオーバーライド
private void runTestReadObject() throws Throwable {
ObjectInputStream stream;
System.out.println("classReader: SerializeableのreadObject内でOIS.readObjectを呼ぶためのクラスのローダー");
System.out.println("called: readObjectを呼ぶクラスのローダー");
System.out.println("object: SerializeableのreadObject内でのOIS.readObjectの戻り値のローダー");
System.out.println("========");
setClassReader(appLoad, createClassReader(loader1));
stream = new ObjectInputStream(new ByteArrayInputStream(bytesClassToSerialize2));
printClassLoader("classReader: loader1, called: appLoad", getTheClass(stream.readObject()));
setClassReader(appLoad, createClassReader(loader2));
stream = new ObjectInputStream(new ByteArrayInputStream(bytesClassToSerialize2));
printClassLoader("classReader: loader2, called: appLoad", getTheClass(stream.readObject()));
setClassReader(appLoad, createClassReader(appLoad));
stream = new ObjectInputStream(new ByteArrayInputStream(bytesClassToSerialize2));
printClassLoader("classReader: appLoad, called: appLoad", getTheClass(stream.readObject()));
}
public static void main(String[] args) throws Throwable {
new C0Main().runTest();
}
/**
* OISの継承クラスに関する調査用
*/
public static class ClassToSerialize1 implements Serializable { }
/**
* readObjectの動作確認用
*
* デシリアライズ時にtheClassはclassReaderで読んだClassインスタンスになる
*/
public static class ClassToSerialize2 implements Serializable, Supplier<Class<?>> {
public static Function<ObjectInputStream, Class<?>> classReader;
public Class<?> theClass;
public ClassToSerialize2(Class<?> theClass) {
this.theClass = theClass;
}
private void writeObject(java.io.ObjectOutputStream stream) throws IOException {
stream.writeObject(theClass);
}
private void readObject(java.io.ObjectInputStream stream) throws IOException, ClassNotFoundException {
theClass = classReader.apply(stream);
}
@Override
public Class<?> get() {
return theClass;
}
}
/**
* 例外処理を除けば以下のコードと同一な{@link Function}オブジェクト
* <pre>{@code
* (objectInputStream) -> (Class<?>) objectInputStream.readObject()
* }</pre>
*/
public static class ClassReader implements Function<ObjectInputStream, Class<?>> {
@Override
public Class<?> apply(ObjectInputStream objectInputStream) {
try {
return (Class<?>) objectInputStream.readObject();
} catch (IOException | ClassNotFoundException e) {
throw ClassReader.<RuntimeException>thr(e);
}
}
@SuppressWarnings("unchecked")
private static <T extends Throwable> T thr(Exception e) throws T {
throw (T) e;
}
}
/**
* resolveClassをsuperを呼び出すだけの形にオーバーライドしたObjectInputStream
*/
public static class ObjectInputStream1 extends java.io.ObjectInputStream {
public ObjectInputStream1(InputStream in) throws IOException {
super(in);
}
@Override
protected Class<?> resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException {
return super.resolveClass(desc);
}
}
/**
* なにもオーバーライドしていないObjectInputStream
*/
public static class ObjectInputStream2 extends java.io.ObjectInputStream {
public ObjectInputStream2(InputStream in) throws IOException {
super(in);
}
}
// utils
/**
* {@code readObject}の{@link ClassToSerialize2#theClass}を取得する
*/
@SuppressWarnings("unchecked")
private Class<?> getTheClass(Object readObject) {
return ((Supplier<Class<?>>)readObject).get();
}
/**
* readObjectのクラスローダーの情報を出力
*/
private void printClassLoader(String s, Class<?> readObject) {
System.out.print(s);
System.out.print(", object: ");
ClassLoader loader = readObject.getClassLoader();
if (loader == appLoad)
System.out.print("appLoad");
else if (loader == loader1)
System.out.print("loader1");
else if (loader == loader2)
System.out.print("loader2");
else
System.out.print("unknown");
System.out.println();
}
/**
* ObjectOutputStreamでシリアライズ
*/
private byte[] serialize(Object obj) throws Throwable {
ByteArrayOutputStream baos = new ByteArrayOutputStream();
try (ObjectOutputStream stream = new ObjectOutputStream(baos)) {
stream.writeObject(obj);
}
return baos.toByteArray();
}
/**
* {@code loader}で{@code classDesc}をロードし、{@code stream}を引数にインスタンスを生成する
*/
private ObjectInputStream createOIS(ClassLoader loader, Class<?> classDesc, InputStream stream) throws Throwable {
return (ObjectInputStream) loader.loadClass(classDesc.getName())
.getConstructor(InputStream.class)
.newInstance(stream);
}
/**
* {@code loader}で{@link ClassReader}をロードする
*/
@SuppressWarnings("unchecked")
private Function<ObjectInputStream, Class<?>> createClassReader(ClassLoader loader) throws Throwable {
return (Function<ObjectInputStream, Class<?>>) loader.loadClass(ClassReader.class.getName()).newInstance();
}
/**
* {@code loader}の{@link ClassToSerialize2#classReader}をsetする
*/
private void setClassReader(
ClassLoader loader,
Function<ObjectInputStream, Class<?>> reader) throws Throwable {
loader.loadClass(ClassToSerialize2.class.getName())
.getField("classReader")
.set(null, reader);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment