Created
November 26, 2017 20:53
-
-
Save natanbc/92ce5097d1c67825fff10460c05386b9 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package gabrielbot.module | |
import java.io.{ByteArrayOutputStream, InputStream} | |
import java.net.URLClassLoader | |
import java.util | |
import java.util.jar.JarFile | |
import com.github.natanbc.optional.{Optional, OptionalInterface} | |
import org.objectweb.asm.{AnnotationVisitor, ClassReader, ClassVisitor, ClassWriter, FieldVisitor, MethodVisitor, Opcodes} | |
import scala.collection.mutable.ArrayBuffer | |
class ModuleClassLoader(val loader: ModuleLoader, parent: ClassLoader, val source: Either[JarFile, URLClassLoader]) extends ClassLoader(parent) { | |
val dependencySources: ArrayBuffer[ModuleClassLoader] = new ArrayBuffer() | |
override def loadClass(name: String): Class[_] = { | |
for(c <- dependencySources) { | |
try { | |
return c.loadClass(name) | |
} catch { | |
case _: ClassNotFoundException => | |
} | |
} | |
super.loadClass(name) | |
} | |
override def findClass(name: String): Class[_] = { | |
try { | |
return super.findClass(name) | |
} catch { | |
case _: ClassNotFoundException => | |
} | |
val bytes = getClassBytes(name) | |
val classReader = new ClassReader(bytes) | |
val cw = new ClassWriter(classReader, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS) | |
val methods = new util.LinkedList[String] | |
val fields = new util.LinkedList[String] | |
val interfaces = new util.LinkedList[String] | |
classReader.accept(new ModuleClassLoader.Analyzer(this, methods, fields, interfaces), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES) | |
classReader.accept(new ModuleClassLoader.Filter(cw, methods, fields, interfaces), ClassReader.SKIP_FRAMES) | |
val newBytes = cw.toByteArray | |
super.defineClass(name, newBytes, 0, newBytes.length) | |
} | |
def checkModule(name: String): Unit = { | |
val bytes = getClassBytes(name) | |
val classReader = new ClassReader(bytes) | |
classReader.accept(new ModuleClassLoader.ModuleFinder(this, name), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES) | |
} | |
private def getClassBytes(name: String): Array[Byte] = { | |
val fsName = name.replace('.', '/') + ".class" | |
val is: InputStream = source match { | |
case Left(jf) => | |
val e = jf.getEntry(fsName) | |
if(e == null) null else jf.getInputStream(e) | |
case Right(ucl) => | |
ucl.getResourceAsStream(fsName) | |
} | |
if(is == null) throw new ClassNotFoundException(name) | |
try { | |
val baos = new ByteArrayOutputStream | |
val buffer = new Array[Byte](1024) | |
var r: Int = 0 | |
while( { | |
r = is.read(buffer); r != -1 | |
}) baos.write(buffer, 0, r) | |
baos.toByteArray | |
} finally is.close() | |
} | |
} | |
object ModuleClassLoader { | |
private object ModuleFinder { | |
private val MODULE_DESC = "L" + classOf[Module].getName.replace('.', '/') + ";" | |
} | |
private class ModuleFinder(loader: ModuleClassLoader, className: String) extends ClassVisitor(Opcodes.ASM5) { | |
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = if(desc == ModuleFinder.MODULE_DESC) new AnnotationVisitor(Opcodes.ASM5) { | |
private var moduleName: String = _ | |
private var version: String = _ | |
private var dependencies: Seq[ModuleInfo.Dependency] = Seq() | |
override def visitArray(name: String): AnnotationVisitor = if(name == "dependencies") new AnnotationVisitor(Opcodes.ASM5) { | |
private val d: ArrayBuffer[ModuleInfo.Dependency] = new ArrayBuffer() | |
override def visitAnnotation(name: String, desc: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) { | |
private var n: String = _ | |
private var version: String = "*" | |
override def visit(name: String, value: scala.Any): Unit = { | |
name match { | |
case "name" => n = value.asInstanceOf[String] | |
case "version" => version = value.asInstanceOf[String] | |
} | |
} | |
override def visitEnd(): Unit = { | |
d append ModuleInfo.Dependency(n, version) | |
} | |
} | |
override def visitEnd(): Unit = { | |
dependencies = d | |
} | |
} | |
else super.visitArray(name) | |
override def visit(name: String, value: scala.Any): Unit = { | |
name match { | |
case "name" => moduleName = value.asInstanceOf[String] | |
case "version" => version = value.asInstanceOf[String] | |
} | |
} | |
override def visitEnd(): Unit = { | |
val l = loader.loader | |
val mod = new LoadedModule(l, ModuleInfo(moduleName, version, dependencies, className), loader) | |
l.addModule(moduleName, mod) | |
} | |
} | |
else super.visitAnnotation(desc, visible) | |
} | |
private object Analyzer { | |
private val NEEDS_DESC = "L" + classOf[Needs].getName.replace('.', '/') + ";" | |
private val OPTIONAL_DESC = "L" + classOf[Optional].getName.replace('.', '/') + ";" | |
private val OPTIONAL_INTERFACE_DESC = "L" + classOf[OptionalInterface].getName.replace('.', '/') + ";" | |
private val OPTIONAL_INTERFACE_CONTAINER_DESC = "L" + classOf[OptionalInterface.Container].getName.replace('.', '/') + ";" | |
} | |
private class Analyzer(val loader: ModuleClassLoader, val methods: util.LinkedList[String], val fields: util.LinkedList[String], val interfaces: util.LinkedList[String]) extends ClassVisitor(Opcodes.ASM5) { | |
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = if(desc == Analyzer.OPTIONAL_INTERFACE_CONTAINER_DESC) new AnnotationVisitor(Opcodes.ASM5) { | |
override def visitArray(name: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) { | |
override def visitAnnotation(name: String, desc: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) { | |
override def visit(name: String, value: Any): Unit = { | |
try | |
loader.loadClass(value.asInstanceOf[String]) | |
catch { | |
case _: ClassNotFoundException => | |
interfaces.add(value.asInstanceOf[String].replace('.', '/')) | |
} | |
} | |
} | |
} | |
} | |
else if(desc == Analyzer.OPTIONAL_INTERFACE_DESC) new AnnotationVisitor(Opcodes.ASM5) { | |
override def visit(name: String, value: Any): Unit = { | |
try | |
loader.loadClass(value.asInstanceOf[String]) | |
catch { | |
case _: ClassNotFoundException => | |
interfaces.add(value.asInstanceOf[String].replace('.', '/')) | |
} | |
} | |
} | |
else super.visitAnnotation(desc, visible) | |
override def visitMethod(access: Int, methodName: String, methodDesc: String, signature: String, exceptions: Array[String]) = new MethodVisitor(Opcodes.ASM5) { | |
private var strip = false | |
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = { | |
if(desc == Analyzer.OPTIONAL_DESC) new AnnotationVisitor(Opcodes.ASM5) { | |
override def visitArray(name: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) { | |
override def visit(name: String, value: Any): Unit = { | |
if(strip) return | |
try | |
loader.loadClass(value.asInstanceOf[String]) | |
catch { | |
case _: ClassNotFoundException => | |
strip = true | |
methods.add(methodName + methodDesc) | |
} | |
} | |
} | |
} | |
else if(desc == Analyzer.NEEDS_DESC) new AnnotationVisitor(Opcodes.ASM5) { | |
private var moduleName: String = _ | |
private var version: String = _ | |
private var required: Boolean = false | |
override def visit(name: String, value: scala.Any): Unit = { | |
if(strip) return | |
name match { | |
case "module" => moduleName = value.asInstanceOf[String] | |
case "version" => version = value.asInstanceOf[String] | |
case "required" => required = value.asInstanceOf[Boolean] | |
} | |
} | |
override def visitEnd(): Unit = { | |
val isAvailable = loader.loader.isModulePresent(moduleName, version) | |
if(!isAvailable) { | |
if(required) throw new MissingModuleException(moduleName, version) | |
strip = true | |
methods.add(methodName + methodDesc) | |
} | |
} | |
} | |
else super.visitAnnotation(desc, visible) | |
} | |
} | |
override def visitField(access: Int, fieldName: String, fieldDesc: String, signature: String, value: Any) = new FieldVisitor(Opcodes.ASM5) { | |
private var strip = false | |
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = { | |
if(desc == Analyzer.OPTIONAL_DESC) new AnnotationVisitor(Opcodes.ASM5) { | |
override def visitArray(name: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) { | |
override def visit(name: String, value: Any): Unit = { | |
if(strip) return | |
try | |
loader.loadClass(value.asInstanceOf[String]) | |
catch { | |
case _: ClassNotFoundException => | |
strip = true | |
fields.add(fieldName + fieldDesc) | |
} | |
} | |
} | |
} | |
else if(desc == Analyzer.NEEDS_DESC) new AnnotationVisitor(Opcodes.ASM5) { | |
private var moduleName: String = _ | |
private var version: String = _ | |
private var required: Boolean = false | |
override def visit(name: String, value: scala.Any): Unit = { | |
if(strip) return | |
name match { | |
case "module" => moduleName = value.asInstanceOf[String] | |
case "version" => version = value.asInstanceOf[String] | |
case "required" => required = value.asInstanceOf[Boolean] | |
} | |
} | |
override def visitEnd(): Unit = { | |
val isAvailable = loader.loader.isModulePresent(moduleName, version) | |
if(!isAvailable) { | |
if(required) throw new MissingModuleException(moduleName, version) | |
strip = true | |
fields.add(fieldName + fieldDesc) | |
} | |
} | |
} | |
else super.visitAnnotation(desc, visible) | |
} | |
} | |
} | |
private class Filter(val c: ClassVisitor, val methods: util.LinkedList[String], val fields: util.LinkedList[String], val interfaces: util.LinkedList[String]) extends ClassVisitor(Opcodes.ASM5, c) { | |
override def visit(version: Int, access: Int, name: String, signature: String, superName: String, interfaces: Array[String]): Unit = { | |
val actual = new util.LinkedList[String] | |
for(s <- interfaces) { | |
if(!this.interfaces.contains(s)) actual.add(s) | |
} | |
super.visit(version, access, name, signature, superName, actual.toArray(new Array[String](actual.size))) | |
} | |
override def visitMethod(access: Int, name: String, desc: String, signature: String, exceptions: Array[String]): MethodVisitor = { | |
if(methods.contains(name + desc)) return null | |
super.visitMethod(access, name, desc, signature, exceptions) | |
} | |
override def visitField(access: Int, name: String, desc: String, signature: String, value: Any): FieldVisitor = { | |
if(fields.contains(name + desc)) return null | |
super.visitField(access, name, desc, signature, value) | |
} | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment