diff --git a/ffi/build.gradle.kts b/ffi/build.gradle.kts index 637c795..4749eed 100644 --- a/ffi/build.gradle.kts +++ b/ffi/build.gradle.kts @@ -7,4 +7,6 @@ dependencies { api(project(":evaluator")) implementation(project(":common")) + implementation("com.github.jnr:jffi:1.3.12") + implementation("com.github.jnr:jffi:1.3.12:native") } diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiAddress.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiAddress.kt new file mode 100644 index 0000000..20eeda2 --- /dev/null +++ b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiAddress.kt @@ -0,0 +1,7 @@ +package gay.pizza.pork.ffi + +data class FfiAddress(val location: Long) { + companion object { + val Null = FfiAddress(0L) + } +} diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiLibraryCache.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiLibraryCache.kt deleted file mode 100644 index edda80e..0000000 --- a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiLibraryCache.kt +++ /dev/null @@ -1,46 +0,0 @@ -package gay.pizza.pork.ffi - -import java.lang.foreign.* - -object FfiLibraryCache { - private val dlopenFunctionDescriptor = FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.JAVA_INT) - private val dlsymFunctionDescriptor = FunctionDescriptor.of(ValueLayout.ADDRESS, ValueLayout.ADDRESS, ValueLayout.ADDRESS) - - private val dlopenMemorySegment = Linker.nativeLinker().defaultLookup().find("dlopen").orElseThrow() - private val dlsymMemorySegment = Linker.nativeLinker().defaultLookup().find("dlsym").orElseThrow() - - private val dlopen = Linker.nativeLinker().downcallHandle( - dlopenMemorySegment, - dlopenFunctionDescriptor - ) - - private val dlsym = Linker.nativeLinker().downcallHandle( - dlsymMemorySegment, - dlsymFunctionDescriptor - ) - - private val libraryHandles = mutableMapOf() - - private fun dlopen(name: String): MemorySegment { - var handle = libraryHandles[name] - if (handle != null) { - return handle - } - return Arena.ofConfined().use { arena -> - val nameStringPointer = arena.allocateUtf8String(name) - handle = dlopen.invokeExact(nameStringPointer, 0) as MemorySegment - if (handle == MemorySegment.NULL) { - throw RuntimeException("Unable to dlopen library: $name") - } - handle!! - } - } - - fun dlsym(name: String, symbol: String): MemorySegment { - val libraryHandle = dlopen(name) - return Arena.ofConfined().use { arena -> - val symbolStringPointer = arena.allocateUtf8String(symbol) - dlsym.invokeExact(libraryHandle, symbolStringPointer) as MemorySegment - } - } -} diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiNativeProvider.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiNativeProvider.kt index 2d17746..b8b683a 100644 --- a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiNativeProvider.kt +++ b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiNativeProvider.kt @@ -1,10 +1,11 @@ package gay.pizza.pork.ffi -import gay.pizza.pork.ast.ArgumentSpec +import com.kenai.jffi.* +import com.kenai.jffi.Function +import gay.pizza.pork.ast.gen.ArgumentSpec import gay.pizza.pork.evaluator.CallableFunction import gay.pizza.pork.evaluator.NativeProvider import gay.pizza.pork.evaluator.None -import java.lang.foreign.* import java.nio.file.Path import kotlin.io.path.Path import kotlin.io.path.absolutePathString @@ -15,7 +16,6 @@ class FfiNativeProvider : NativeProvider { override fun provideNativeFunction(definitions: List, arguments: List): CallableFunction { val functionDefinition = FfiFunctionDefinition.parse(definitions[0], definitions[1]) - val linker = Linker.nativeLinker() val functionAddress = lookupSymbol(functionDefinition) val parameters = functionDefinition.parameters.map { id -> @@ -25,47 +25,72 @@ class FfiNativeProvider : NativeProvider { val returnTypeId = functionDefinition.returnType val returnType = ffiTypeRegistry.lookup(returnTypeId) ?: throw RuntimeException("Unknown ffi return type: $returnTypeId") - val parameterArray = parameters.map { typeAsLayout(it) }.toTypedArray() - val descriptor = if (returnType == FfiPrimitiveType.Void) - FunctionDescriptor.ofVoid(*parameterArray) - else FunctionDescriptor.of(typeAsLayout(returnType), *parameterArray) - val handle = linker.downcallHandle(functionAddress, descriptor) + val returnTypeFfi = typeConversion(returnType) + val parameterArray = parameters.map { typeConversion(it) }.toTypedArray() + val function = Function(functionAddress, returnTypeFfi, *parameterArray) + val context = function.callContext + val invoker = Invoker.getInstance() return CallableFunction { functionArguments, _ -> - Arena.ofConfined().use { arena -> - handle.invokeWithArguments(functionArguments.map { valueAsFfi(it, arena) }) ?: None + val buffer = HeapInvocationBuffer(context) + val freeStringList = mutableListOf() + for ((index, spec) in arguments.withIndex()) { + val ffiType = ffiTypeRegistry.lookup(functionDefinition.parameters[index]) ?: + throw RuntimeException("Unknown ffi type: ${functionDefinition.parameters[index]}") + if (spec.multiple) { + val variableArguments = functionArguments + .subList(index, functionArguments.size) + variableArguments.forEach { + var value = it + if (value is String) { + value = FfiStringWrapper(value) + freeStringList.add(value) + } + put(buffer, value) + } + break + } else { + val converted = convert(ffiType, functionArguments[index]) + if (converted is FfiStringWrapper) { + freeStringList.add(converted) + } + put(buffer, converted) + } + } + + try { + return@CallableFunction invoke(invoker, function, buffer, returnType) + } finally { + freeStringList.forEach { it.free() } } } } - private fun lookupSymbol(functionDefinition: FfiFunctionDefinition): MemorySegment { - if (functionDefinition.library == "c") { - return SymbolLookup.loaderLookup().find(functionDefinition.function).orElseThrow { - RuntimeException("Unknown function: ${functionDefinition.function}") - } - } + private fun lookupSymbol(functionDefinition: FfiFunctionDefinition): Long { val actualLibraryPath = findLibraryPath(functionDefinition.library) - val functionAddress = FfiLibraryCache.dlsym(actualLibraryPath.absolutePathString(), functionDefinition.function) - if (functionAddress.address() == 0L) { - throw RuntimeException("Unknown function: ${functionDefinition.function} in library $actualLibraryPath") + val library = Library.getCachedInstance(actualLibraryPath.absolutePathString(), Library.NOW) + ?: throw RuntimeException("Failed to load library $actualLibraryPath") + val functionAddress = library.getSymbolAddress(functionDefinition.function) + if (functionAddress == 0L) { + throw RuntimeException( + "Failed to find symbol ${functionDefinition.function} in " + + "library ${actualLibraryPath.absolutePathString()}") } return functionAddress } - private fun typeAsLayout(type: FfiType): MemoryLayout = when (type) { - FfiPrimitiveType.UnsignedByte, FfiPrimitiveType.Byte -> ValueLayout.JAVA_BYTE - FfiPrimitiveType.UnsignedInt, FfiPrimitiveType.Int -> ValueLayout.JAVA_INT - FfiPrimitiveType.UnsignedShort, FfiPrimitiveType.Short -> ValueLayout.JAVA_SHORT - FfiPrimitiveType.UnsignedLong, FfiPrimitiveType.Long -> ValueLayout.JAVA_LONG - FfiPrimitiveType.String -> ValueLayout.ADDRESS - FfiPrimitiveType.Pointer -> ValueLayout.ADDRESS - FfiPrimitiveType.Void -> MemoryLayout.sequenceLayout(0, ValueLayout.JAVA_INT) - else -> throw RuntimeException("Unknown ffi type to convert to memory layout: $type") - } - - private fun valueAsFfi(value: Any, allocator: SegmentAllocator): Any = when (value) { - is String -> allocator.allocateUtf8String(value) - None -> MemorySegment.NULL - else -> value + private fun typeConversion(type: FfiType): Type = when (type) { + FfiPrimitiveType.UnsignedByte -> Type.UINT8 + FfiPrimitiveType.Byte -> Type.SINT8 + FfiPrimitiveType.UnsignedInt -> Type.UINT32 + FfiPrimitiveType.Int -> Type.SINT32 + FfiPrimitiveType.UnsignedShort -> Type.UINT16 + FfiPrimitiveType.Short -> Type.SINT16 + FfiPrimitiveType.UnsignedLong -> Type.UINT64 + FfiPrimitiveType.Long -> Type.SINT64 + FfiPrimitiveType.String -> Type.POINTER + FfiPrimitiveType.Pointer -> Type.POINTER + FfiPrimitiveType.Void -> Type.VOID + else -> throw RuntimeException("Unknown ffi type: $type") } private fun findLibraryPath(name: String): Path { @@ -76,4 +101,69 @@ class FfiNativeProvider : NativeProvider { return FfiPlatforms.current.platform.findLibrary(name) ?: throw RuntimeException("Unable to find library: $name") } + + private fun convert(type: FfiType, value: Any?): Any { + if (type !is FfiPrimitiveType) { + return value ?: FfiAddress.Null + } + + if (type.numberConvert != null) { + return numberConvert(type.id, value, type.numberConvert) + } + + if (type.notNullConversion != null) { + return notNullConvert(type.id, value, type.notNullConversion) + } + + if (type.nullableConversion != null) { + return nullableConvert(value, type.nullableConversion) ?: FfiAddress.Null + } + return value ?: FfiAddress.Null + } + + private fun notNullConvert(type: String, value: Any?, into: Any.() -> T): T { + if (value == null) { + throw RuntimeException("Null values cannot be used for converting to type $type") + } + return into(value) + } + + private fun nullableConvert(value: Any?, into: Any.() -> T): T? { + if (value == null || value == None) { + return null + } + return into(value) + } + + private fun numberConvert(type: String, value: Any?, into: Number.() -> T): T { + if (value == null || value == None) { + throw RuntimeException("Null values cannot be used for converting to numeric type $type") + } + + if (value !is Number) { + throw RuntimeException("Cannot convert value '$value' into type $type") + } + return into(value) + } + + private fun put(buffer: InvocationBuffer, value: Any): Unit = when (value) { + is Byte -> buffer.putByte(value.toInt()) + is Short -> buffer.putShort(value.toInt()) + is Int -> buffer.putInt(value) + is Long -> buffer.putLong(value) + is FfiAddress -> buffer.putAddress(value.location) + is FfiStringWrapper -> buffer.putAddress(value.address) + else -> throw RuntimeException("Unknown buffer insertion: $value (${value.javaClass.name})") + } + + private fun invoke(invoker: Invoker, function: Function, buffer: HeapInvocationBuffer, type: FfiType): Any = when (type) { + FfiPrimitiveType.Pointer -> invoker.invokeAddress(function, buffer) + FfiPrimitiveType.UnsignedInt, FfiPrimitiveType.Int -> invoker.invokeInt(function, buffer) + FfiPrimitiveType.Long -> invoker.invokeLong(function, buffer) + FfiPrimitiveType.Void -> invoker.invokeStruct(function, buffer) + FfiPrimitiveType.Double -> invoker.invokeDouble(function, buffer) + FfiPrimitiveType.Float -> invoker.invokeFloat(function, buffer) + FfiPrimitiveType.String -> invoker.invokeAddress(function, buffer) + else -> throw RuntimeException("Unsupported ffi return type: $type") + } ?: None } diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiPrimitiveType.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiPrimitiveType.kt index 5f03d98..9a0f00d 100644 --- a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiPrimitiveType.kt +++ b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiPrimitiveType.kt @@ -1,7 +1,6 @@ package gay.pizza.pork.ffi import gay.pizza.pork.evaluator.None -import java.lang.foreign.MemorySegment enum class FfiPrimitiveType( val id: kotlin.String, @@ -20,13 +19,14 @@ enum class FfiPrimitiveType( Long("long", 8, numberConvert = { toLong() }), UnsignedLong("unsigned long", 8, numberConvert = { toLong() }), Double("double", 8, numberConvert = { toDouble() }), - String("char*", 8, nullableConversion = { toString() }), + String("char*", 8, nullableConversion = { FfiStringWrapper(toString()) }), Pointer("void*", 8, nullableConversion = { - if (this is kotlin.Long) { - MemorySegment.ofAddress(this) - } else if (this == None) { - MemorySegment.NULL - } else this as MemorySegment + when (this) { + is FfiAddress -> this + is None -> FfiAddress.Null + is Number -> FfiAddress(this.toLong()) + else -> FfiAddress.Null + } }), Void("void", 0) } diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiStringWrapper.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiStringWrapper.kt new file mode 100644 index 0000000..729812a --- /dev/null +++ b/ffi/src/main/kotlin/gay/pizza/pork/ffi/FfiStringWrapper.kt @@ -0,0 +1,17 @@ +package gay.pizza.pork.ffi + +import com.kenai.jffi.MemoryIO + +class FfiStringWrapper(input: String) { + val address: Long + + init { + val bytes = input.toByteArray() + address = MemoryIO.getInstance().allocateMemory((bytes.size + 1).toLong(), true) + MemoryIO.getInstance().putZeroTerminatedByteArray(address, bytes, 0, bytes.size) + } + + fun free() { + MemoryIO.getInstance().freeMemory(address) + } +} diff --git a/ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt b/ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt deleted file mode 100644 index 7175ddb..0000000 --- a/ffi/src/main/kotlin/gay/pizza/pork/ffi/JnaNativeProvider.kt +++ /dev/null @@ -1,93 +0,0 @@ -package gay.pizza.pork.ffi - -import com.sun.jna.Function -import com.sun.jna.NativeLibrary -import gay.pizza.pork.ast.gen.ArgumentSpec -import gay.pizza.pork.evaluator.CallableFunction -import gay.pizza.pork.evaluator.NativeProvider -import gay.pizza.pork.evaluator.None - -class JnaNativeProvider : NativeProvider { - override fun provideNativeFunction(definitions: List, arguments: List): CallableFunction { - val functionDefinition = FfiFunctionDefinition.parse(definitions[0], definitions[1]) - val library = NativeLibrary.getInstance(functionDefinition.library) - val function = library.getFunction(functionDefinition.function) - ?: throw RuntimeException("Failed to find function ${functionDefinition.function} in library ${functionDefinition.library}") - return CallableFunction { functionArgs, _ -> - val ffiArgs = mutableListOf() - for ((index, spec) in arguments.withIndex()) { - val ffiType = functionDefinition.parameters[index] - if (spec.multiple) { - val variableArguments = functionArgs - .subList(index, functionArgs.size) - ffiArgs.addAll(variableArguments) - break - } else { - val converted = convert(ffiType, functionArgs[index]) - ffiArgs.add(converted) - } - } - invoke(function, ffiArgs.toTypedArray(), functionDefinition.returnType) - } - } - - private fun invoke(function: Function, values: Array, type: String): Any = when (rewriteType(type)) { - "void*" -> function.invokePointer(values) - "int" -> function.invokeInt(values) - "long" -> function.invokeLong(values) - "float" -> function.invokeFloat(values) - "double" -> function.invokeDouble(values) - "void" -> function.invokeVoid(values) - "char*" -> function.invokeString(values, false) - else -> throw RuntimeException("Unsupported ffi return type: $type") - } ?: None - - private fun rewriteType(type: String): String = when (type) { - "size_t" -> "long" - else -> type - } - - private fun convert(type: String, value: Any?): Any? { - val rewritten = rewriteType(type) - val primitive = FfiPrimitiveType.entries.firstOrNull { it.id == rewritten } - ?: throw RuntimeException("Unsupported ffi type: $type") - if (primitive.numberConvert != null) { - return numberConvert(type, value, primitive.numberConvert) - } - - if (primitive.notNullConversion != null) { - return notNullConvert(type, value, primitive.notNullConversion) - } - - if (primitive.nullableConversion != null) { - return nullableConvert(value, primitive.nullableConversion) - } - - return value - } - - private fun notNullConvert(type: String, value: Any?, into: Any.() -> T): T { - if (value == null) { - throw RuntimeException("Null values cannot be used for converting to type $type") - } - return into(value) - } - - private fun nullableConvert(value: Any?, into: Any.() -> T): T? { - if (value == null || value == None) { - return null - } - return into(value) - } - - private fun numberConvert(type: String, value: Any?, into: Number.() -> T): T { - if (value == null || value == None) { - throw RuntimeException("Null values cannot be used for converting to numeric type $type") - } - - if (value !is Number) { - throw RuntimeException("Cannot convert value '$value' into type $type") - } - return into(value) - } -}