diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiFunctionDefinition.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiFunctionDefinition.kt index 2173f3d..92caadf 100644 --- a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiFunctionDefinition.kt +++ b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiFunctionDefinition.kt @@ -3,7 +3,8 @@ package gay.pizza.pork.ffi class FfiFunctionDefinition( val library: String, val function: String, - val returnType: String + val returnType: String, + val parameters: List ) { companion object { fun parse(def: String): FfiFunctionDefinition { @@ -15,7 +16,13 @@ class FfiFunctionDefinition( "but '${def}' was specified") } val (library, function, returnType) = parts - return FfiFunctionDefinition(library, function, returnType) + val parametersString = if (parts.size == 4) parts[3] else "" + return FfiFunctionDefinition( + library, + function, + returnType, + parametersString.split(",") + ) } } } diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt index 97925fd..ed0d5e7 100644 --- a/ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt +++ b/ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt @@ -1,6 +1,7 @@ package gay.pizza.pork.ffi import com.sun.jna.Function +import com.sun.jna.Pointer import gay.pizza.pork.ast.ArgumentSpec import gay.pizza.pork.evaluator.CallableFunction import gay.pizza.pork.evaluator.NativeProvider @@ -12,20 +13,22 @@ class JnaNativeProvider : NativeProvider { return CallableFunction { functionArgs -> val ffiArgs = mutableListOf() for ((index, spec) in arguments.withIndex()) { + val ffiType = functionDefinition.parameters[index] if (spec.multiple) { val variableArguments = functionArgs.values .subList(index, functionArgs.values.size) ffiArgs.addAll(variableArguments) break } else { - ffiArgs.add(functionArgs.values[index]) + val converted = convert(ffiType, functionArgs.values[index]) + ffiArgs.add(converted) } } invoke(function, ffiArgs.toTypedArray(), functionDefinition.returnType) } } - private fun invoke(function: Function, values: Array, type: String): Any = when (type) { + private fun invoke(function: Function, values: Array, type: String): Any = when (rewriteType(type)) { "void*" -> function.invokePointer(values) "int" -> function.invokeInt(values) "long" -> function.invokeLong(values) @@ -35,4 +38,48 @@ class JnaNativeProvider : NativeProvider { "char*" -> function.invokeString(values, false) else -> throw RuntimeException("Unsupported ffi return type: $type") } + + private fun rewriteType(type: String): String = when (type) { + "size_t" -> "long" + else -> type + } + + private fun convert(type: String, value: Any?): Any? = when (rewriteType(type)) { + "short" -> numberConvert(type, value) { toShort() } + "unsigned short" -> numberConvert(type, value) { toShort().toUShort() } + "int" -> numberConvert(type, value) { toInt() } + "unsigned int" -> numberConvert(type, value) { toInt().toUInt() } + "long" -> numberConvert(type, value) { toLong() } + "unsigned long" -> numberConvert(type, value) { toLong().toULong() } + "double" -> numberConvert(type, value) { toDouble() } + "float" -> numberConvert(type, value) { toFloat() } + "char*" -> notNullConvert(type, value) { toString() } + "void*" -> nullableConvert(type, value) { this as Pointer } + else -> throw RuntimeException("Unsupported ffi type: $type") + } + + private fun notNullConvert(type: String, value: Any?, into: Any.() -> T): T { + if (value == null) { + throw RuntimeException("Null values cannot be used for converting to type $type") + } + return into(value) + } + + private fun nullableConvert(type: String, value: Any?, into: Any.() -> T): T? { + if (value == null) { + return null + } + return into(value) + } + + private fun numberConvert(type: String, value: Any?, into: Number.() -> T): T { + if (value == null) { + throw RuntimeException("Null values cannot be used for converting to numeric type $type") + } + + if (value !is Number) { + throw RuntimeException("Cannot convert value '$value' into type $type") + } + return into(value) + } }