Skip to content

Instantly share code, notes, and snippets.

@natanbc
Created November 26, 2017 20:53
Show Gist options
  • Save natanbc/92ce5097d1c67825fff10460c05386b9 to your computer and use it in GitHub Desktop.
Save natanbc/92ce5097d1c67825fff10460c05386b9 to your computer and use it in GitHub Desktop.
package gabrielbot.module
import java.io.{ByteArrayOutputStream, InputStream}
import java.net.URLClassLoader
import java.util
import java.util.jar.JarFile
import com.github.natanbc.optional.{Optional, OptionalInterface}
import org.objectweb.asm.{AnnotationVisitor, ClassReader, ClassVisitor, ClassWriter, FieldVisitor, MethodVisitor, Opcodes}
import scala.collection.mutable.ArrayBuffer
class ModuleClassLoader(val loader: ModuleLoader, parent: ClassLoader, val source: Either[JarFile, URLClassLoader]) extends ClassLoader(parent) {
val dependencySources: ArrayBuffer[ModuleClassLoader] = new ArrayBuffer()
override def loadClass(name: String): Class[_] = {
for(c <- dependencySources) {
try {
return c.loadClass(name)
} catch {
case _: ClassNotFoundException =>
}
}
super.loadClass(name)
}
override def findClass(name: String): Class[_] = {
try {
return super.findClass(name)
} catch {
case _: ClassNotFoundException =>
}
val bytes = getClassBytes(name)
val classReader = new ClassReader(bytes)
val cw = new ClassWriter(classReader, ClassWriter.COMPUTE_FRAMES | ClassWriter.COMPUTE_MAXS)
val methods = new util.LinkedList[String]
val fields = new util.LinkedList[String]
val interfaces = new util.LinkedList[String]
classReader.accept(new ModuleClassLoader.Analyzer(this, methods, fields, interfaces), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES)
classReader.accept(new ModuleClassLoader.Filter(cw, methods, fields, interfaces), ClassReader.SKIP_FRAMES)
val newBytes = cw.toByteArray
super.defineClass(name, newBytes, 0, newBytes.length)
}
def checkModule(name: String): Unit = {
val bytes = getClassBytes(name)
val classReader = new ClassReader(bytes)
classReader.accept(new ModuleClassLoader.ModuleFinder(this, name), ClassReader.SKIP_CODE | ClassReader.SKIP_DEBUG | ClassReader.SKIP_FRAMES)
}
private def getClassBytes(name: String): Array[Byte] = {
val fsName = name.replace('.', '/') + ".class"
val is: InputStream = source match {
case Left(jf) =>
val e = jf.getEntry(fsName)
if(e == null) null else jf.getInputStream(e)
case Right(ucl) =>
ucl.getResourceAsStream(fsName)
}
if(is == null) throw new ClassNotFoundException(name)
try {
val baos = new ByteArrayOutputStream
val buffer = new Array[Byte](1024)
var r: Int = 0
while( {
r = is.read(buffer); r != -1
}) baos.write(buffer, 0, r)
baos.toByteArray
} finally is.close()
}
}
object ModuleClassLoader {
private object ModuleFinder {
private val MODULE_DESC = "L" + classOf[Module].getName.replace('.', '/') + ";"
}
private class ModuleFinder(loader: ModuleClassLoader, className: String) extends ClassVisitor(Opcodes.ASM5) {
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = if(desc == ModuleFinder.MODULE_DESC) new AnnotationVisitor(Opcodes.ASM5) {
private var moduleName: String = _
private var version: String = _
private var dependencies: Seq[ModuleInfo.Dependency] = Seq()
override def visitArray(name: String): AnnotationVisitor = if(name == "dependencies") new AnnotationVisitor(Opcodes.ASM5) {
private val d: ArrayBuffer[ModuleInfo.Dependency] = new ArrayBuffer()
override def visitAnnotation(name: String, desc: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) {
private var n: String = _
private var version: String = "*"
override def visit(name: String, value: scala.Any): Unit = {
name match {
case "name" => n = value.asInstanceOf[String]
case "version" => version = value.asInstanceOf[String]
}
}
override def visitEnd(): Unit = {
d append ModuleInfo.Dependency(n, version)
}
}
override def visitEnd(): Unit = {
dependencies = d
}
}
else super.visitArray(name)
override def visit(name: String, value: scala.Any): Unit = {
name match {
case "name" => moduleName = value.asInstanceOf[String]
case "version" => version = value.asInstanceOf[String]
}
}
override def visitEnd(): Unit = {
val l = loader.loader
val mod = new LoadedModule(l, ModuleInfo(moduleName, version, dependencies, className), loader)
l.addModule(moduleName, mod)
}
}
else super.visitAnnotation(desc, visible)
}
private object Analyzer {
private val NEEDS_DESC = "L" + classOf[Needs].getName.replace('.', '/') + ";"
private val OPTIONAL_DESC = "L" + classOf[Optional].getName.replace('.', '/') + ";"
private val OPTIONAL_INTERFACE_DESC = "L" + classOf[OptionalInterface].getName.replace('.', '/') + ";"
private val OPTIONAL_INTERFACE_CONTAINER_DESC = "L" + classOf[OptionalInterface.Container].getName.replace('.', '/') + ";"
}
private class Analyzer(val loader: ModuleClassLoader, val methods: util.LinkedList[String], val fields: util.LinkedList[String], val interfaces: util.LinkedList[String]) extends ClassVisitor(Opcodes.ASM5) {
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = if(desc == Analyzer.OPTIONAL_INTERFACE_CONTAINER_DESC) new AnnotationVisitor(Opcodes.ASM5) {
override def visitArray(name: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) {
override def visitAnnotation(name: String, desc: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) {
override def visit(name: String, value: Any): Unit = {
try
loader.loadClass(value.asInstanceOf[String])
catch {
case _: ClassNotFoundException =>
interfaces.add(value.asInstanceOf[String].replace('.', '/'))
}
}
}
}
}
else if(desc == Analyzer.OPTIONAL_INTERFACE_DESC) new AnnotationVisitor(Opcodes.ASM5) {
override def visit(name: String, value: Any): Unit = {
try
loader.loadClass(value.asInstanceOf[String])
catch {
case _: ClassNotFoundException =>
interfaces.add(value.asInstanceOf[String].replace('.', '/'))
}
}
}
else super.visitAnnotation(desc, visible)
override def visitMethod(access: Int, methodName: String, methodDesc: String, signature: String, exceptions: Array[String]) = new MethodVisitor(Opcodes.ASM5) {
private var strip = false
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = {
if(desc == Analyzer.OPTIONAL_DESC) new AnnotationVisitor(Opcodes.ASM5) {
override def visitArray(name: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) {
override def visit(name: String, value: Any): Unit = {
if(strip) return
try
loader.loadClass(value.asInstanceOf[String])
catch {
case _: ClassNotFoundException =>
strip = true
methods.add(methodName + methodDesc)
}
}
}
}
else if(desc == Analyzer.NEEDS_DESC) new AnnotationVisitor(Opcodes.ASM5) {
private var moduleName: String = _
private var version: String = _
private var required: Boolean = false
override def visit(name: String, value: scala.Any): Unit = {
if(strip) return
name match {
case "module" => moduleName = value.asInstanceOf[String]
case "version" => version = value.asInstanceOf[String]
case "required" => required = value.asInstanceOf[Boolean]
}
}
override def visitEnd(): Unit = {
val isAvailable = loader.loader.isModulePresent(moduleName, version)
if(!isAvailable) {
if(required) throw new MissingModuleException(moduleName, version)
strip = true
methods.add(methodName + methodDesc)
}
}
}
else super.visitAnnotation(desc, visible)
}
}
override def visitField(access: Int, fieldName: String, fieldDesc: String, signature: String, value: Any) = new FieldVisitor(Opcodes.ASM5) {
private var strip = false
override def visitAnnotation(desc: String, visible: Boolean): AnnotationVisitor = {
if(desc == Analyzer.OPTIONAL_DESC) new AnnotationVisitor(Opcodes.ASM5) {
override def visitArray(name: String): AnnotationVisitor = new AnnotationVisitor(Opcodes.ASM5) {
override def visit(name: String, value: Any): Unit = {
if(strip) return
try
loader.loadClass(value.asInstanceOf[String])
catch {
case _: ClassNotFoundException =>
strip = true
fields.add(fieldName + fieldDesc)
}
}
}
}
else if(desc == Analyzer.NEEDS_DESC) new AnnotationVisitor(Opcodes.ASM5) {
private var moduleName: String = _
private var version: String = _
private var required: Boolean = false
override def visit(name: String, value: scala.Any): Unit = {
if(strip) return
name match {
case "module" => moduleName = value.asInstanceOf[String]
case "version" => version = value.asInstanceOf[String]
case "required" => required = value.asInstanceOf[Boolean]
}
}
override def visitEnd(): Unit = {
val isAvailable = loader.loader.isModulePresent(moduleName, version)
if(!isAvailable) {
if(required) throw new MissingModuleException(moduleName, version)
strip = true
fields.add(fieldName + fieldDesc)
}
}
}
else super.visitAnnotation(desc, visible)
}
}
}
private class Filter(val c: ClassVisitor, val methods: util.LinkedList[String], val fields: util.LinkedList[String], val interfaces: util.LinkedList[String]) extends ClassVisitor(Opcodes.ASM5, c) {
override def visit(version: Int, access: Int, name: String, signature: String, superName: String, interfaces: Array[String]): Unit = {
val actual = new util.LinkedList[String]
for(s <- interfaces) {
if(!this.interfaces.contains(s)) actual.add(s)
}
super.visit(version, access, name, signature, superName, actual.toArray(new Array[String](actual.size)))
}
override def visitMethod(access: Int, name: String, desc: String, signature: String, exceptions: Array[String]): MethodVisitor = {
if(methods.contains(name + desc)) return null
super.visitMethod(access, name, desc, signature, exceptions)
}
override def visitField(access: Int, name: String, desc: String, signature: String, value: Any): FieldVisitor = {
if(fields.contains(name + desc)) return null
super.visitField(access, name, desc, signature, value)
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment