Skip to content

Instantly share code, notes, and snippets.

@farnoy
Created May 21, 2024 11:10
Show Gist options
  • Save farnoy/5ce05c4cb12e2f3189214442e6324713 to your computer and use it in GitHub Desktop.
Save farnoy/5ce05c4cb12e2f3189214442e6324713 to your computer and use it in GitHub Desktop.
import java.lang.foreign.*
import scala.util.{Using, Try}
def handleResult[T](res: Try[T]) =
res match
case scala.util.Success(r) => println(s"Result: $r")
case scala.util.Failure(e) => e.printStackTrace()
sealed case class MyData(x: Double, y: Byte, z: Short, last: Double)
given NativeType[MyData] with
def memoryLayout = MemoryLayout.structLayout(
implicitly[NativeType[Double]].memoryLayout.withName("x"),
implicitly[NativeType[Byte]].memoryLayout.withName("y"),
MemoryLayout.paddingLayout(1),
implicitly[NativeType[Short]].memoryLayout.withName("z"),
MemoryLayout.paddingLayout(4),
implicitly[NativeType[Double]].memoryLayout.withName("last")
)
val libraryPath =
"C:\\Projects\\scala3-learning/foreign-test/target/release/foreign_test.dll"
val linker = Linker.nativeLinker()
val lookup = SymbolLookup.libraryLookup(libraryPath, Arena.global())
val add =
genNativeFunction[(Long, Long), Long]("add", Arena.global(), linker, lookup)
val process =
genNativeFunction[Tuple1[Pointer[MyData]], Double](
"process",
Arena.global(),
linker,
lookup
)
@main def Foreign(args: String*) =
println(s"args: $args")
val first = args.lift(0).map(_.toLong).getOrElse(1L)
val second = args.lift(1).map(_.toLong).getOrElse(2L)
{
val res = Try(add((first, second)))
handleResult(res)
}
{
// #[repr(C)]
// pub struct MyData {
// x: f64,
// y: u8,
// z: u16,
// last: f64,
// }
// #[no_mangle]
// pub unsafe extern "C" fn process(data: *const MyData) -> f64 {
// let d = data.read();
// d.x + d.y as f64 + d.z as f64 + d.last as f64
// }
val res = Using(Arena.ofConfined()): arena =>
val layout = implicitly[NativeType[MyData]].memoryLayout
val allocated = arena.allocate(layout)
val x = layout.varHandle(MemoryLayout.PathElement.groupElement("x"))
x.set(allocated, 0, 3.14)
layout
.varHandle(MemoryLayout.PathElement.groupElement("z"))
.set(allocated, 0, 6: Short)
layout
.varHandle(MemoryLayout.PathElement.groupElement("last"))
.set(allocated, 0, 8.9)
process(Tuple1(allocated))
// process(Tuple1(MyData(3.14, 1, 384, 8.9))) // Error
handleResult(res)
}
import java.lang.foreign.*
import scala.compiletime.*
import scala.quoted.*
import java.lang.invoke.MethodHandle
trait NativeType[T]:
def memoryLayout: MemoryLayout
sealed trait NativeFunction[Args <: Tuple, R]:
def handle: MethodHandle
def apply(args: Args): R
given NativeType[Int] with
def memoryLayout = ValueLayout.JAVA_INT
given NativeType[Long] with
def memoryLayout = ValueLayout.JAVA_LONG
given NativeType[Float] with
def memoryLayout = ValueLayout.JAVA_FLOAT
given NativeType[Double] with
def memoryLayout = ValueLayout.JAVA_DOUBLE
given NativeType[Short] with
def memoryLayout = ValueLayout.JAVA_SHORT
given NativeType[Byte] with
def memoryLayout = ValueLayout.JAVA_BYTE
sealed case class Pointer[T](t: T) extends AnyVal
given [T: NativeType]: NativeType[Pointer[T]] with
def memoryLayout = summon[NativeType[T]].memoryLayout
type NativeArgs[T <: Tuple] <: Tuple = T match
case EmptyTuple => EmptyTuple
case Pointer[t] *: tail => MemorySegment *: NativeArgs[tail]
case head *: tail => head *: NativeArgs[tail]
inline def genNativeFunction[Args <: Tuple, R](
inline name: String,
arena: Arena,
linker: Linker,
lookup: SymbolLookup
): NativeFunction[NativeArgs[Args], R] = ${
makeNativeFunction('name, 'arena, 'linker, 'lookup)
}
def makeNativeFunction[Args <: Tuple: Type, R: Type](
nameExpr: Expr[String],
arena: Expr[Arena],
linker: Expr[Linker],
lookup: Expr[SymbolLookup]
)(using Quotes): Expr[NativeFunction[NativeArgs[Args], R]] =
import quotes.reflect._
val name = nameExpr.valueOrAbort
def decomposeTuple(tpe: TypeRepr): List[TypeRepr] = tpe match
case AppliedType(_, args) => args
case _ => List(tpe)
def mapTypeToLayout(tpe: TypeRepr): Expr[MemoryLayout] = tpe.asType match
case '[t] =>
Expr.summon[NativeType[t]] match
case Some(instance) => '{ $instance.memoryLayout }
case None =>
val s = s"Unexpected type when generating ValueLayout: ${tpe.show}"
'{ error(${ Expr(s) }) }
val inputTypes = decomposeTuple(TypeRepr.of[Args])
val elementLayouts = inputTypes.map(mapTypeToLayout)
val returnLayout = mapTypeToLayout(TypeRepr.of[R])
'{
new NativeFunction:
val fun = ${ lookup }.find(${ Expr(name) }).orElseThrow
val desc =
FunctionDescriptor.of(${ returnLayout }, ${ Expr.ofSeq(elementLayouts) }*)
val _handle =
${ linker }.downcallHandle(fun, desc, Linker.Option.critical(false))
override def handle = _handle
override def apply(args: NativeArgs[Args]): R =
handle.invokeWithArguments(args.toArray*).asInstanceOf[R]
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment