ffi: java interop improvements

This commit is contained in:
Alex Zenla 2023-09-07 18:16:47 -07:00
parent a2f2252965
commit 38efbe1844
Signed by: alex
GPG Key ID: C0780728420EBFE5
8 changed files with 126 additions and 36 deletions

View File

@ -1,9 +1,7 @@
import std ffi.malloc import std ffi.malloc
export func main() { export func main() {
while true { let pointer = malloc(8192)
let pointer = malloc(8192) println(pointer)
println(pointer) free(pointer)
free(pointer)
}
} }

View File

@ -1,8 +1,13 @@
import java java.lang.System import java java.lang.System
import java java.io.PrintStream
func java_io_PrintStream_println(a) native java "java.io.PrintStream:virtual:println:void:String" import java java.io.InputStreamReader
import java java.io.BufferedReader
export func main() { export func main() {
let input = java_lang_System_in_get()
let reader = java_io_InputStreamReader_new_inputstream(input)
let bufferedReader = java_io_BufferedReader_new_reader(reader)
let line = java_io_BufferedReader_readLine(bufferedReader)
let stream = java_lang_System_err_get() let stream = java_lang_System_err_get()
java_io_PrintStream_println(stream, "Hello World") java_io_PrintStream_println_string(stream, line)
} }

View File

@ -7,11 +7,12 @@ class FfiFunctionDefinition(
) { ) {
companion object { companion object {
fun parse(def: String): FfiFunctionDefinition { fun parse(def: String): FfiFunctionDefinition {
val parts = def.split(":", limit = 3) val parts = def.split(":", limit = 4)
if (parts.size != 3 || parts.any { it.trim().isEmpty() }) { if (parts.size !in arrayOf(3, 4) || parts.any { it.trim().isEmpty() }) {
throw RuntimeException( throw RuntimeException(
"FFI function definition is invalid, " + "FFI function definition is invalid, " +
"excepted format is 'library:function:return-type' but '${def}' was specified") "excepted format is 'library:function:return-type:(optional)parameters'" +
" but '${def}' was specified")
} }
val (library, function, returnType) = parts val (library, function, returnType) = parts
return FfiFunctionDefinition(library, function, returnType) return FfiFunctionDefinition(library, function, returnType)

View File

@ -1,8 +1,14 @@
package gay.pizza.pork.ffi package gay.pizza.pork.ffi
import gay.pizza.pork.ast.* import gay.pizza.pork.ast.CompilationUnit
import java.io.PrintStream import gay.pizza.pork.ast.DefinitionModifiers
import gay.pizza.pork.ast.FunctionDefinition
import gay.pizza.pork.ast.Native
import gay.pizza.pork.ast.StringLiteral
import gay.pizza.pork.ast.Symbol
import java.lang.reflect.Method
import java.lang.reflect.Modifier import java.lang.reflect.Modifier
import java.lang.reflect.Parameter
class JavaAutogen(val javaClass: Class<*>) { class JavaAutogen(val javaClass: Class<*>) {
private val prefix = javaClass.name.replace(".", "_") private val prefix = javaClass.name.replace(".", "_")
@ -16,23 +22,62 @@ class JavaAutogen(val javaClass: Class<*>) {
fun generateFunctionDefinitions(): List<FunctionDefinition> { fun generateFunctionDefinitions(): List<FunctionDefinition> {
val definitions = mutableMapOf<String, FunctionDefinition>() val definitions = mutableMapOf<String, FunctionDefinition>()
val methodGroups = mutableMapOf<String, MutableList<Method>>()
for (method in javaClass.methods) { for (method in javaClass.methods) {
if (!Modifier.isPublic(method.modifiers)) { if (!Modifier.isPublic(method.modifiers)) {
continue continue
} }
methodGroups.getOrPut(method.name) { mutableListOf() }.add(method)
}
val name = method.name for ((baseName, methods) in methodGroups) {
val returnTypeName = method.returnType.name for (method in methods) {
val parameterNames = method.parameters.indices.map { ('a' + it).toString() } var name = baseName
val parameterTypeNames = method.parameters.map { it.type.name } if (methods.size > 1 && method.parameters.isNotEmpty()) {
name += "_" + method.parameters.joinToString("_") {
discriminate(it)
}
}
val returnTypeName = method.returnType.name
val parameterNames = method.parameters.indices.map { ('a' + it).toString() }
val parameterTypeNames = method.parameters.map { it.type.name }
fun form(kind: String): JavaFunctionDefinition = fun form(kind: String): JavaFunctionDefinition =
JavaFunctionDefinition(javaClass.name, kind, name, returnTypeName, parameterTypeNames) JavaFunctionDefinition(
javaClass.name,
kind,
method.name,
returnTypeName,
parameterTypeNames
)
if (Modifier.isStatic(method.modifiers)) { if (Modifier.isStatic(method.modifiers)) {
definitions[name] = function(name, parameterNames, form("static")) definitions[name] = function(name, parameterNames, form("static"))
} else { } else {
definitions[name] = function(name, parameterNames, form("virtual")) definitions[name] = function(name, parameterNames, form("virtual"))
}
}
for (constructor in javaClass.constructors) {
val parameterNames = constructor.parameters.indices.map { ('a' + it).toString() }
val parameterTypeNames = constructor.parameters.map { it.type.name }
var name = "new"
if (javaClass.constructors.isNotEmpty()) {
name += "_" + constructor.parameters.joinToString("_") {
discriminate(it)
}
}
val javaFunctionDefinition = JavaFunctionDefinition(
javaClass.name,
"constructor",
"new",
javaClass.name,
parameterTypeNames
)
definitions[name] = function(name, parameterNames, javaFunctionDefinition)
} }
} }
@ -44,8 +89,8 @@ class JavaAutogen(val javaClass: Class<*>) {
val name = field.name val name = field.name
val valueTypeName = field.type.name val valueTypeName = field.type.name
val isStatic = Modifier.isStatic(field.modifiers) val isStatic = Modifier.isStatic(field.modifiers)
fun form(kind: String, getOrSet: Boolean): JavaFunctionDefinition = fun form(kind: String, getOrSet: Boolean): JavaFunctionDefinition {
JavaFunctionDefinition(javaClass.name, kind, name, valueTypeName, if (getOrSet) { val parameters = if (getOrSet) {
if (isStatic) { if (isStatic) {
emptyList() emptyList()
} else { } else {
@ -57,7 +102,15 @@ class JavaAutogen(val javaClass: Class<*>) {
} else { } else {
listOf(javaClass.name, valueTypeName) listOf(javaClass.name, valueTypeName)
} }
}) }
return JavaFunctionDefinition(
javaClass.name,
kind,
name,
valueTypeName,
parameters
)
}
val parametersForGetter = if (isStatic) { val parametersForGetter = if (isStatic) {
emptyList() emptyList()
@ -73,14 +126,26 @@ class JavaAutogen(val javaClass: Class<*>) {
val getterKind = if (isStatic) "static-getter" else "getter" val getterKind = if (isStatic) "static-getter" else "getter"
val setterKind = if (isStatic) "static-setter" else "setter" val setterKind = if (isStatic) "static-setter" else "setter"
definitions[name + "_get"] = function(name + "_get", parametersForGetter, form(getterKind, true)) definitions[name + "_get"] = function(
definitions[name + "_set"] = function(name + "_set", parametersForSetter, form(setterKind, false)) name + "_get",
parametersForGetter,
form(getterKind, true)
)
definitions[name + "_set"] = function(
name + "_set",
parametersForSetter,
form(setterKind, false)
)
} }
return definitions.values.toList() return definitions.values.toList()
} }
private fun function(name: String, parameterNames: List<String>, functionDefinition: JavaFunctionDefinition): FunctionDefinition = private fun function(
name: String,
parameterNames: List<String>,
functionDefinition: JavaFunctionDefinition
): FunctionDefinition =
FunctionDefinition( FunctionDefinition(
modifiers = DefinitionModifiers(true), modifiers = DefinitionModifiers(true),
symbol = Symbol("${prefix}_${name}"), symbol = Symbol("${prefix}_${name}"),
@ -91,4 +156,7 @@ class JavaAutogen(val javaClass: Class<*>) {
private fun asNative(functionDefinition: JavaFunctionDefinition): Native = private fun asNative(functionDefinition: JavaFunctionDefinition): Native =
Native(Symbol("java"), StringLiteral(functionDefinition.encode())) Native(Symbol("java"), StringLiteral(functionDefinition.encode()))
private fun discriminate(parameter: Parameter): String =
parameter.type.simpleName.lowercase().replace("[]", "_array")
} }

View File

@ -35,12 +35,21 @@ class JavaNativeProvider : NativeFunctionProvider {
else -> lookup.findClass(name) else -> lookup.findClass(name)
} }
private fun mapKindToHandle(kind: String, symbol: String, javaClass: Class<*>, returnType: Class<*>, parameterTypes: List<Class<*>>) = when (kind) { private fun mapKindToHandle(
kind: String,
symbol: String,
javaClass: Class<*>,
returnType: Class<*>,
parameterTypes: List<Class<*>>
) = when (kind) {
"getter" -> lookup.findGetter(javaClass, symbol, returnType) "getter" -> lookup.findGetter(javaClass, symbol, returnType)
"setter" -> lookup.findSetter(javaClass, symbol, returnType) "setter" -> lookup.findSetter(javaClass, symbol, returnType)
"constructor" -> lookup.findConstructor(javaClass, MethodType.methodType(returnType, parameterTypes)) "constructor" ->
"static" -> lookup.findStatic(javaClass, symbol, MethodType.methodType(returnType, parameterTypes)) lookup.findConstructor(javaClass, MethodType.methodType(Void.TYPE, parameterTypes))
"virtual" -> lookup.findVirtual(javaClass, symbol, MethodType.methodType(returnType, parameterTypes)) "static" ->
lookup.findStatic(javaClass, symbol, MethodType.methodType(returnType, parameterTypes))
"virtual" ->
lookup.findVirtual(javaClass, symbol, MethodType.methodType(returnType, parameterTypes))
"static-getter" -> lookup.findStaticGetter(javaClass, symbol, returnType) "static-getter" -> lookup.findStaticGetter(javaClass, symbol, returnType)
"static-setter" -> lookup.findStaticSetter(javaClass, symbol, returnType) "static-setter" -> lookup.findStaticSetter(javaClass, symbol, returnType)
else -> throw RuntimeException("Unknown Handle Kind: $kind") else -> throw RuntimeException("Unknown Handle Kind: $kind")

View File

@ -9,7 +9,7 @@ class JnaNativeProvider : NativeFunctionProvider {
val functionDefinition = FfiFunctionDefinition.parse(definition) val functionDefinition = FfiFunctionDefinition.parse(definition)
val function = Function.getFunction(functionDefinition.library, functionDefinition.function) val function = Function.getFunction(functionDefinition.library, functionDefinition.function)
return CallableFunction { return CallableFunction {
return@CallableFunction invoke(function, it.values.toTypedArray(), functionDefinition.returnType) invoke(function, it.values.toTypedArray(), functionDefinition.returnType)
} }
} }

View File

@ -1,5 +1,8 @@
export func malloc(size) export func malloc(size)
native ffi "c:malloc:void*" native ffi "c:malloc:void*:size_t"
export func calloc(size, count)
native ffi "c:calloc:void*:size_t,size_t"
export func free(pointer) export func free(pointer)
native ffi "c:free:void" native ffi "c:free:void:void*"

View File

@ -37,6 +37,12 @@ graalvmNative {
mainClass.set("gay.pizza.pork.tool.MainKt") mainClass.set("gay.pizza.pork.tool.MainKt")
sharedLibrary.set(false) sharedLibrary.set(false)
buildArgs("-march=compatibility") buildArgs("-march=compatibility")
resources {
includedPatterns.addAll(listOf(
".*/*.pork$",
".*/*.manifest$"
))
}
} }
} }
} }