diff --git a/.github/workflows/gradle.yml b/.github/workflows/gradle.yml index 58719e5..c0a2199 100644 --- a/.github/workflows/gradle.yml +++ b/.github/workflows/gradle.yml @@ -30,10 +30,10 @@ jobs: - name: Grant execute permission for gradlew if: ${{ runner.os != 'Windows' }} run: chmod +x gradlew - - name: Build with Gradle - uses: gradle/gradle-build-action@v2 - with: - arguments: build --no-daemon + - name: Setup Gradle + uses: gradle/actions/setup-gradle@v3 + - name: Execute Gradle build + run: ./gradlew build - name: Upload build reports if: ${{ runner.os == 'Linux' && matrix.java == '22-ea' }} uses: actions/upload-artifact@v4 diff --git a/README.md b/README.md index 2a21eb4..999fec4 100644 --- a/README.md +++ b/README.md @@ -4,7 +4,7 @@ Marshal allows you to conveniently create native library bindings with [FFM API](https://openjdk.org/jeps/454). -See [wiki](https://github.com/Over-Run/marshal/wiki) for more information. +~~See [wiki](https://github.com/Over-Run/marshal/wiki) for more information.~~ This library requires JDK 22 or newer. @@ -101,3 +101,7 @@ dependencies { and add this VM argument to enable native access: `--enable-native-access=io.github.overrun.marshal` or this if you don't use modules: `--enable-native-access=ALL-UNNAMED` + +## Additions + +- [OverrunGL](https://github.com/Over-Run/overrungl), which is using Marshal diff --git a/build.gradle.kts b/build.gradle.kts index db9f62b..c889f7b 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -88,7 +88,7 @@ allprojects { dependencies { // add your dependencies compileOnly("org.jetbrains:annotations:24.1.0") - testImplementation(platform("org.junit:junit-bom:5.10.1")) + testImplementation(platform("org.junit:junit-bom:5.10.2")) testImplementation("org.junit.jupiter:junit-jupiter") } diff --git a/demo/src/test/java/overrun/marshal/demo/CrossModuleTest.java b/demo/src/test/java/overrun/marshal/demo/CrossModuleTest.java index fdcce22..af59cf8 100644 --- a/demo/src/test/java/overrun/marshal/demo/CrossModuleTest.java +++ b/demo/src/test/java/overrun/marshal/demo/CrossModuleTest.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import overrun.marshal.Downcall; +import overrun.marshal.DowncallOption; import java.lang.foreign.*; import java.lang.invoke.MethodHandles; @@ -55,6 +56,6 @@ public interface I { @Test void testCrossModule() { - Assertions.assertEquals(1, Downcall.load(MethodHandles.lookup(), I.class, LOOKUP).get()); + Assertions.assertEquals(1, Downcall.load(MethodHandles.lookup(), LOOKUP, DowncallOption.targetClass(I.class)).get()); } } diff --git a/demo/src/test/java/overrun/marshal/demo/CrossModuleWithDirectAccessTest.java b/demo/src/test/java/overrun/marshal/demo/CrossModuleWithDirectAccessTest.java new file mode 100644 index 0000000..0bb1aae --- /dev/null +++ b/demo/src/test/java/overrun/marshal/demo/CrossModuleWithDirectAccessTest.java @@ -0,0 +1,62 @@ +/* + * MIT License + * + * Copyright (c) 2024 Overrun Organization + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + */ + +package overrun.marshal.demo; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import overrun.marshal.DirectAccess; +import overrun.marshal.Downcall; +import overrun.marshal.DowncallOption; + +import java.lang.foreign.*; +import java.lang.invoke.MethodHandles; +import java.lang.invoke.MethodType; +import java.util.Optional; + +/** + * Test cross module + * + * @author squid233 + * @since 0.1.0 + */ +public final class CrossModuleWithDirectAccessTest { + private static final Linker LINKER = Linker.nativeLinker(); + private static final MemorySegment s_get; + + static { + try { + s_get = LINKER.upcallStub(MethodHandles.lookup().findStatic(CrossModuleWithDirectAccessTest.class, "get", MethodType.methodType(int.class)), FunctionDescriptor.of(ValueLayout.JAVA_INT), Arena.ofAuto()); + } catch (NoSuchMethodException | IllegalAccessException e) { + throw new RuntimeException(e); + } + } + + private static final SymbolLookup LOOKUP = name -> "get".equals(name) ? Optional.of(s_get) : Optional.empty(); + + private static int get() { + return 1; + } + + public interface I extends DirectAccess { + int get(); + } + + @Test + void testCrossModule() { + Assertions.assertEquals(1, Downcall.load(MethodHandles.lookup(), LOOKUP, DowncallOption.targetClass(I.class)).get()); + } +} diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties index 1af9e09..a80b22c 100644 --- a/gradle/wrapper/gradle-wrapper.properties +++ b/gradle/wrapper/gradle-wrapper.properties @@ -1,6 +1,6 @@ distributionBase=GRADLE_USER_HOME distributionPath=wrapper/dists -distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-bin.zip +distributionUrl=https\://services.gradle.org/distributions/gradle-8.6-bin.zip networkTimeout=10000 validateDistributionUrl=true zipStoreBase=GRADLE_USER_HOME diff --git a/src/main/java/module-info.java b/src/main/java/module-info.java index aa86c4c..dfbcb36 100644 --- a/src/main/java/module-info.java +++ b/src/main/java/module-info.java @@ -25,5 +25,7 @@ exports overrun.marshal.gen; exports overrun.marshal.struct; + opens overrun.marshal.internal; + requires static org.jetbrains.annotations; } diff --git a/src/main/java/overrun/marshal/DirectAccess.java b/src/main/java/overrun/marshal/DirectAccess.java new file mode 100644 index 0000000..cc17e9c --- /dev/null +++ b/src/main/java/overrun/marshal/DirectAccess.java @@ -0,0 +1,55 @@ +/* + * MIT License + * + * Copyright (c) 2024 Overrun Organization + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + */ + +package overrun.marshal; + +import org.jetbrains.annotations.Unmodifiable; + +import java.lang.foreign.FunctionDescriptor; +import java.lang.invoke.MethodHandle; +import java.util.Map; + +/** + * This interface provides access to function descriptors and method handles + * for each function that is loaded by {@link Downcall}. + * + * @author squid233 + * @see Downcall + * @since 0.1.0 + */ +public interface DirectAccess { + /** + * {@return an unmodifiable map of the function descriptors for each method} + */ + @Unmodifiable + Map functionDescriptors(); + + /** + * {@return an unmodifiable map of the method handles for each method} + */ + @Unmodifiable + Map methodHandles(); + + /** + * Gets the method handle with the given entrypoint name. + * + * @param entrypoint the entrypoint name + * @return the method handle + */ + default MethodHandle methodHandle(String entrypoint) { + return methodHandles().get(entrypoint); + } +} diff --git a/src/main/java/overrun/marshal/Downcall.java b/src/main/java/overrun/marshal/Downcall.java index 0c8dfc6..43a0ea7 100644 --- a/src/main/java/overrun/marshal/Downcall.java +++ b/src/main/java/overrun/marshal/Downcall.java @@ -18,6 +18,8 @@ import overrun.marshal.gen.Type; import overrun.marshal.gen.*; +import overrun.marshal.internal.DowncallData; +import overrun.marshal.internal.DowncallOptions; import overrun.marshal.struct.ByValue; import overrun.marshal.struct.Struct; @@ -47,7 +49,7 @@ /** * Downcall library loader. *

Loading native library

- * You can load native libraries with {@link #load(MethodHandles.Lookup, Class, SymbolLookup)}. + * You can load native libraries with {@link #load(MethodHandles.Lookup, SymbolLookup, DowncallOption...)}. * This method generates a hidden class that loads method handle with the given symbol lookup. *

* The {@code load} methods accept a lookup object for defining hidden class with the caller. @@ -73,23 +75,26 @@ * {@link Critical @Critical} indicates that the annotated method is {@linkplain Linker.Option#critical(boolean) critical}. *

Parameter Annotations

* See {@link Ref @Ref}, {@link Sized @Sized} {@link SizedSeg @SizedSeg} and {@link StrCharset @StrCharset}. - *

Custom Method Handles

- * You can define a method, that accepts no argument and returns a {@link MethodHandle}, in an interface. - *

- * You MUST set the function descriptor of it explicitly. - *

Example

+ *

Direct Access

+ * You can get direct access by implement the class with {@link DirectAccess}, + * which allows you to get the function descriptor and the method handle for a given method. + *

Custom Function Descriptors

+ * You can use a custom function descriptor for each method. *
{@code
  * public interface GL {
- *     GL INSTANCE = Downcall.load(MethodHandles.lookup(), lookup,
- *         Map.of("glClear", FunctionDescriptor.ofVoid(ValueLayout.JAVA_INT)));
- *     MethodHandle glClear();
- *
- *     @Skip
- *     default void clear(int mask) throws Throwable {
- *         glClear().invokeExact(mask);
- *     }
+ *     GL INSTANCE = Downcall.load(lookup(), "libGL.so",
+ *         DowncallOption.descriptors(Map.of("glClear", FunctionDescriptor.ofVoid(JAVA_INT))));
+ *     void glClear(int mask);
  * }
  * }
+ *

Custom Method Handles

+ * You can get a method handle by declaring a method that returns a method handle. + * This requires a custom function descriptor. + *
{@code
+ * @Entrypoint("glClear")
+ * MethodHandle mh_glClear();
+ * default void glClear(int mask) throws Throwable { mh_glClear().invokeExact(mask); }
+ * }
*

Example

*
{@code
  * public interface GL {
@@ -101,6 +106,7 @@
  *
  * @author squid233
  * @see Critical
+ * @see DirectAccess
  * @see Entrypoint
  * @see Ref
  * @see Sized
@@ -119,6 +125,7 @@ public final class Downcall {
     private static final ClassDesc CD_CEnum = ClassDesc.of("overrun.marshal.CEnum");
     private static final ClassDesc CD_Charset = ClassDesc.of("java.nio.charset.Charset");
     private static final ClassDesc CD_Checks = ClassDesc.of("overrun.marshal.Checks");
+    private static final ClassDesc CD_DowncallData = ClassDesc.of("overrun.marshal.internal.DowncallData");
     private static final ClassDesc CD_IllegalStateException = ClassDesc.of("java.lang.IllegalStateException");
     private static final ClassDesc CD_Marshal = ClassDesc.of("overrun.marshal.Marshal");
     private static final ClassDesc CD_MemorySegment = ClassDesc.of("java.lang.foreign.MemorySegment");
@@ -131,6 +138,7 @@ public final class Downcall {
 
     private static final MethodTypeDesc MTD_Charset_String = MethodTypeDesc.of(CD_Charset, CD_String);
     private static final MethodTypeDesc MTD_long = MethodTypeDesc.of(CD_long);
+    private static final MethodTypeDesc MTD_Map = MethodTypeDesc.of(CD_Map);
     private static final MethodTypeDesc MTD_MemorySegment_Arena_Upcall = MethodTypeDesc.of(CD_MemorySegment, CD_Arena, CD_Upcall);
     private static final MethodTypeDesc MTD_MemorySegment_SegmentAllocator_String = MethodTypeDesc.of(CD_MemorySegment,
         CD_SegmentAllocator,
@@ -144,6 +152,7 @@ public final class Downcall {
         CD_StringArray,
         CD_Charset);
     private static final MethodTypeDesc MTD_MemoryStack = MethodTypeDesc.of(CD_MemoryStack);
+    private static final MethodTypeDesc MTD_MethodHandle = MethodTypeDesc.of(CD_MethodHandle);
     private static final MethodTypeDesc MTD_Object_Object = MethodTypeDesc.of(CD_Object, CD_Object);
     private static final MethodTypeDesc MTD_String_MemorySegment = MethodTypeDesc.of(CD_String, CD_MemorySegment);
     private static final MethodTypeDesc MTD_String_MemorySegment_Charset = MethodTypeDesc.of(CD_String, CD_MemorySegment, CD_Charset);
@@ -156,76 +165,35 @@ public final class Downcall {
     private static final MethodTypeDesc MTD_void_MemorySegment_StringArray_Charset = MethodTypeDesc.of(CD_void, CD_MemorySegment, CD_StringArray, CD_Charset);
     private static final MethodTypeDesc MTD_void_String_Throwable = MethodTypeDesc.of(CD_void, CD_String, CD_Throwable);
 
-    private static final DynamicConstantDesc DCD_classData_Map = DynamicConstantDesc.ofNamed(BSM_CLASS_DATA, DEFAULT_NAME, CD_Map);
+    private static final DynamicConstantDesc DCD_classData_DowncallData = DynamicConstantDesc.ofNamed(BSM_CLASS_DATA, DEFAULT_NAME, CD_DowncallData);
 
     private Downcall() {
     }
 
     /**
-     * Loads the given class with the given symbol lookup.
-     *
-     * @param caller        the lookup object for the caller
-     * @param targetClass   the target class
-     * @param lookup        the symbol lookup
-     * @param descriptorMap the custom function descriptors for each method handle
-     * @param            the type of the target class
-     * @return the loaded implementation instance of the target class
-     */
-    public static  T load(MethodHandles.Lookup caller, Class targetClass, SymbolLookup lookup, Map descriptorMap) {
-        return loadBytecode(caller, targetClass, lookup, descriptorMap);
-    }
-
-    /**
-     * Loads the given class with the given symbol lookup.
+     * Loads a class with the given symbol lookup and options.
      *
-     * @param caller      the lookup object for the caller
-     * @param targetClass the target class
-     * @param lookup      the symbol lookup
-     * @param          the type of the target class
+     * @param caller  the lookup object for the caller
+     * @param lookup  the symbol lookup
+     * @param options the options
+     * @param      the type of the target class
      * @return the loaded implementation instance of the target class
      */
-    public static  T load(MethodHandles.Lookup caller, Class targetClass, SymbolLookup lookup) {
-        return load(caller, targetClass, lookup, Map.of());
-    }
-
-    /**
-     * Loads the caller class with the given symbol lookup.
-     *
-     * @param caller        the lookup object for the caller
-     * @param lookup        the symbol lookup
-     * @param descriptorMap the custom function descriptors for each method handle
-     * @param            the type of the caller class
-     * @return the loaded implementation instance of the caller class
-     */
-    @SuppressWarnings("unchecked")
-    public static  T load(MethodHandles.Lookup caller, SymbolLookup lookup, Map descriptorMap) {
-        return load(caller, (Class) caller.lookupClass(), lookup, descriptorMap);
+    public static  T load(MethodHandles.Lookup caller, SymbolLookup lookup, DowncallOption... options) {
+        return loadBytecode(caller, lookup, options);
     }
 
     /**
-     * Loads the caller class with the given symbol lookup.
-     *
-     * @param caller the lookup object for the caller
-     * @param lookup the symbol lookup
-     * @param     the type of the caller class
-     * @return the loaded implementation instance of the caller class
-     */
-    @SuppressWarnings("unchecked")
-    public static  T load(MethodHandles.Lookup caller, SymbolLookup lookup) {
-        return load(caller, (Class) caller.lookupClass(), lookup);
-    }
-
-    /**
-     * Loads the caller class with the given library name.
+     * Loads a class with the given library name and options.
      *
      * @param caller  the lookup object for the caller
-     * @param libname the library name
-     * @param      the type of the caller class
-     * @return the loaded implementation instance of the caller class
+     * @param libPath the path of the library
+     * @param options the options
+     * @param      the type of the target class
+     * @return the loaded implementation instance of the target class
      */
-    @SuppressWarnings("unchecked")
-    public static  T load(MethodHandles.Lookup caller, String libname) {
-        return load(caller, (Class) caller.lookupClass(), SymbolLookup.libraryLookup(libname, Arena.ofAuto()));
+    public static  T load(MethodHandles.Lookup caller, String libPath, DowncallOption... options) {
+        return load(caller, SymbolLookup.libraryLookup(libPath, Arena.ofAuto()), options);
     }
 
     // bytecode
@@ -343,6 +311,7 @@ private static ClassDesc convertToDowncallCD(AnnotatedElement element, Class
         if (aClass.isPrimitive()) return aClass.describeConstable().orElseThrow();
         if (CEnum.class.isAssignableFrom(aClass)) return CD_int;
         if (SegmentAllocator.class.isAssignableFrom(aClass)) return CD_SegmentAllocator;
+        if (aClass == Object.class) return CD_Object;
         return CD_MemorySegment;
     }
 
@@ -378,23 +347,23 @@ private static Method findCEnumWrapper(Class aClass) {
         });
     }
 
-    private static Method findUpcallWrapper(Class aClass) {
-        return findWrapper(aClass, Upcall.Wrapper.class, method -> {
-            final var types = method.getParameterTypes();
-            return types.length == 2 &&
-                   Arena.class.isAssignableFrom(types[0]) &&
-                   types[1] == MemorySegment.class &&
-                   Upcall.class.isAssignableFrom(method.getReturnType());
-        });
-    }
-
     @SuppressWarnings("unchecked")
-    private static  T loadBytecode(MethodHandles.Lookup caller, Class targetClass, SymbolLookup lookup, Map descriptorMap) {
+    private static  T loadBytecode(MethodHandles.Lookup caller, SymbolLookup lookup, DowncallOption... options) {
+        Class _targetClass = null, targetClass;
+        Map _descriptorMap = null, descriptorMap;
+
+        for (DowncallOption option : options) {
+            if (option instanceof DowncallOptions.TargetClass(var aClass)) {
+                _targetClass = aClass;
+            } else if (option instanceof DowncallOptions.Descriptors(var map)) {
+                _descriptorMap = map;
+            }
+        }
+        targetClass = _targetClass != null ? _targetClass : caller.lookupClass();
+        descriptorMap = _descriptorMap != null ? _descriptorMap : Map.of();
+
         final List methodList = Arrays.stream(targetClass.getMethods())
-            .filter(method ->
-                method.getDeclaredAnnotation(Skip.class) == null &&
-                !Modifier.isStatic(method.getModifiers()) &&
-                !method.isSynthetic())
+            .filter(Predicate.not(Downcall::shouldSkip))
             .toList();
         final Map exceptionStringMap = methodList.stream()
             .collect(Collectors.toUnmodifiableMap(Function.identity(), Downcall::createExceptionString));
@@ -408,11 +377,15 @@ private static  T loadBytecode(MethodHandles.Lookup caller, Class targetCl
         final byte[] bytes = cf.build(cd_thisClass, classBuilder -> {
             classBuilder.withFlags(ACC_FINAL | ACC_SUPER);
 
-            // interface
+            // inherit
             final ClassDesc cd_targetClass = targetClass.describeConstable().orElseThrow();
-            classBuilder.withInterfaceSymbols(cd_targetClass);
+            if (targetClass.isInterface()) {
+                classBuilder.withInterfaceSymbols(cd_targetClass);
+            } else {
+                classBuilder.withSuperclass(cd_targetClass);
+            }
 
-            // method handles
+            //region method handles
             final AtomicInteger handleCount = new AtomicInteger();
             methodList.forEach(method -> {
                 final String entrypoint = getMethodEntrypoint(method);
@@ -433,17 +406,23 @@ private static  T loadBytecode(MethodHandles.Lookup caller, Class targetCl
                 classBuilder.withField(handleName, CD_MethodHandle,
                     ACC_PRIVATE | ACC_FINAL | ACC_STATIC);
             });
+            //endregion
 
-            // constructor
+            //region constructor
             classBuilder.withMethod(INIT_NAME, MTD_void, ACC_PUBLIC,
                 methodBuilder ->
                     // super
                     methodBuilder.withCode(codeBuilder -> codeBuilder
                         .aload(codeBuilder.receiverSlot())
-                        .invokespecial(CD_Object, INIT_NAME, MTD_void)
+                        .invokespecial(targetClass.isInterface() ?
+                                CD_Object :
+                                targetClass.getSuperclass().describeConstable().orElseThrow(),
+                            INIT_NAME,
+                            MTD_void)
                         .return_()));
+            //endregion
 
-            // methods
+            //region methods
             methodDataMap.forEach((method, methodData) -> {
                 final var returnType = method.getReturnType();
                 final String methodName = method.getName();
@@ -512,7 +491,6 @@ private static  T loadBytecode(MethodHandles.Lookup caller, Class targetCl
                         blockCodeBuilder -> {
                             final boolean skipFirstParam = methodData.skipFirstParam();
                             final int parameterSize = parameters.size();
-                            final List parameterCDList = new ArrayList<>(skipFirstParam ? parameterSize - 1 : parameterSize);
 
                             final ClassDesc cd_returnTypeDowncall = convertToDowncallCD(method, returnType);
                             final boolean returnVoid = returnType == void.class;
@@ -548,6 +526,7 @@ private static  T loadBytecode(MethodHandles.Lookup caller, Class targetCl
                             }
 
                             // invocation
+                            final List parameterCDList = new ArrayList<>(skipFirstParam ? parameterSize - 1 : parameterSize);
                             blockCodeBuilder.getstatic(cd_thisClass, handleName, CD_MethodHandle);
                             for (int i = skipFirstParam ? 1 : 0; i < parameterSize; i++) {
                                 final Parameter parameter = parameters.get(i);
@@ -708,18 +687,6 @@ private static  T loadBytecode(MethodHandles.Lookup caller, Class targetCl
                                     wrapper.getName(),
                                     MethodTypeDesc.of(ClassDesc.ofDescriptor(wrapper.getReturnType().descriptorString()), CD_int),
                                     wrapper.getDeclaringClass().isInterface());
-                            } else if (Upcall.class.isAssignableFrom(returnType)) {
-                                final Method wrapper = findUpcallWrapper(returnType);
-                                blockCodeBuilder.ifThenElse(Opcode.IFNONNULL,
-                                    blockCodeBuilder1 -> blockCodeBuilder1.aload(allocatorSlot)
-                                        .aload(resultSlot)
-                                        .invokestatic(cd_returnType,
-                                            wrapper.getName(),
-                                            MethodTypeDesc.of(ClassDesc.ofDescriptor(wrapper.getReturnType().descriptorString()),
-                                                ClassDesc.ofDescriptor(wrapper.getParameterTypes()[0].descriptorString()),
-                                                CD_MemorySegment),
-                                            wrapper.getDeclaringClass().isInterface()),
-                                    CodeBuilder::aconst_null);
                             } else if (returnType.isArray()) {
                                 final Class componentType = returnType.getComponentType();
                                 if (componentType == String.class) {
@@ -815,25 +782,53 @@ private static  T loadBytecode(MethodHandles.Lookup caller, Class targetCl
                         }
                     }));
             });
+            //endregion
+
+            //region DirectAccess
+            final boolean hasDirectAccess = DirectAccess.class.isAssignableFrom(targetClass);
+            if (hasDirectAccess) {
+                classBuilder.withMethod("functionDescriptors",
+                    MTD_Map,
+                    ACC_PUBLIC,
+                    methodBuilder -> methodBuilder.withCode(codeBuilder -> codeBuilder
+                        .ldc(DCD_classData_DowncallData)
+                        .invokevirtual(CD_DowncallData, "descriptorMap", MTD_Map)
+                        .areturn()));
+                classBuilder.withMethod("methodHandles",
+                    MTD_Map,
+                    ACC_PUBLIC,
+                    methodBuilder -> methodBuilder.withCode(codeBuilder -> codeBuilder
+                        .ldc(DCD_classData_DowncallData)
+                        .invokevirtual(CD_DowncallData, "handleMap", MTD_Map)
+                        .areturn()));
+            }
+            //endregion
 
-            // class initializer
+            //region class initializer
             classBuilder.withMethod(CLASS_INIT_NAME, MTD_void, ACC_STATIC,
                 methodBuilder -> methodBuilder.withCode(codeBuilder -> {
+                    final int handleMapSlot = codeBuilder.allocateLocal(TypeKind.ReferenceType);
+                    codeBuilder.ldc(DCD_classData_DowncallData)
+                        .invokevirtual(CD_DowncallData, "handleMap", MTD_Map)
+                        .astore(handleMapSlot);
+
                     // method handles
                     methodDataMap.values().forEach(methodData -> codeBuilder
-                        .ldc(DCD_classData_Map)
+                        .aload(handleMapSlot)
                         .ldc(methodData.entrypoint())
                         .invokeinterface(CD_Map, "get", MTD_Object_Object)
                         .checkcast(CD_MethodHandle)
                         .putstatic(cd_thisClass, methodData.handleName(), CD_MethodHandle));
+
                     codeBuilder.return_();
                 }));
+            //endregion
         });
 
         try {
             final MethodHandles.Lookup hiddenClass = caller.defineHiddenClassWithClassData(
                 bytes,
-                generateHandles(methodDataMap, lookup, descriptorMap),
+                generateData(methodDataMap, lookup, descriptorMap),
                 true,
                 MethodHandles.Lookup.ClassOption.STRONG
             );
@@ -844,6 +839,51 @@ private static  T loadBytecode(MethodHandles.Lookup caller, Class targetCl
         }
     }
 
+    private static boolean shouldSkip(Method method) {
+        final Class returnType = method.getReturnType();
+        final boolean b =
+            method.getDeclaredAnnotation(Skip.class) != null ||
+            Modifier.isStatic(method.getModifiers()) ||
+            method.isSynthetic();
+        if (b) {
+            return true;
+        }
+
+        final String methodName = method.getName();
+        final Class[] types = method.getParameterTypes();
+        final int length = types.length;
+
+        // check method declared by Object
+        if (length == 0) {
+            if (returnType == Class.class && "getClass".equals(methodName) ||
+                returnType == int.class && "hashCode".equals(methodName) ||
+                returnType == Object.class && "clone".equals(methodName) ||
+                returnType == String.class && "toString".equals(methodName) ||
+                returnType == void.class && "notify".equals(methodName) ||
+                returnType == void.class && "notifyAll".equals(methodName) ||
+                returnType == void.class && "wait".equals(methodName) ||
+                returnType == void.class && "finalize".equals(methodName)  // TODO: no finalize in the future
+            ) {
+                return true;
+            }
+        } else if (
+            returnType == boolean.class && length == 1 && types[0] == Object.class && "equals".equals(methodName) ||
+            returnType == void.class && length == 1 && types[0] == long.class && "wait".equals(methodName) ||
+            returnType == void.class && length == 2 && types[0] == long.class && types[1] == long.class && "wait".equals(methodName)
+        ) {
+            return true;
+        }
+
+        // check method declared by DirectAccess
+        if (length == 0 && returnType == Map.class) {
+            return "functionDescriptors".equals(methodName) || "methodHandles".equals(methodName);
+        }
+        if (returnType == MethodHandle.class && length == 1 && types[0] == String.class) {
+            return "methodHandle".equals(methodName);
+        }
+        return false;
+    }
+
     private static String createExceptionString(Method method) {
         return STR."""
             \{method.getReturnType().getCanonicalName()} \
@@ -880,118 +920,162 @@ private static void verifyMethods(List list, Map excepti
                 if (!foundLayout) {
                     throw new IllegalStateException(STR."The struct \{returnType} must contain one public static field that is StructLayout");
                 }
+            } else if (!isValidReturnType(returnType)) {
+                throw new IllegalStateException(STR."Invalid return type: \{exceptionStringMap.get(method)}");
             }
 
             // check method parameter
             final Class[] types = method.getParameterTypes();
             final boolean isFirstArena = types.length > 0 && Arena.class.isAssignableFrom(types[0]);
-            if (Upcall.class.isAssignableFrom(returnType) && !isFirstArena) {
-                throw new IllegalStateException(STR."The first parameter of method \{method} is not an arena; however, this method returns an upcall");
-            }
             for (Parameter parameter : method.getParameters()) {
-                if (Upcall.class.isAssignableFrom(parameter.getType()) && !isFirstArena) {
-                    throw new IllegalStateException(STR."The first parameter of method \{method} is not an arena; however, the parameter \{parameter.toString()} is an upcall");
+                final Class type = parameter.getType();
+                if (Upcall.class.isAssignableFrom(type) && !isFirstArena) {
+                    throw new IllegalStateException(STR."The first parameter of method \{method} is not an arena; however, the parameter \{parameter} is an upcall");
+                } else if (!isValidParamType(type)) {
+                    throw new IllegalStateException(STR."Invalid parameter: \{parameter} in \{method}");
                 }
             }
         });
     }
 
-    private static Map generateHandles(Map methodDataMap, SymbolLookup lookup, Map descriptorMap) {
+    private static boolean isValidParamArrayType(Class aClass) {
+        if (!aClass.isArray()) return false;
+        final Class type = aClass.getComponentType();
+        return type.isPrimitive() ||
+               type == MemorySegment.class ||
+               type == String.class ||
+               Addressable.class.isAssignableFrom(type) ||
+               Upcall.class.isAssignableFrom(type) ||
+               CEnum.class.isAssignableFrom(type);
+    }
+
+    private static boolean isValidReturnArrayType(Class aClass) {
+        if (!aClass.isArray()) return false;
+        final Class type = aClass.getComponentType();
+        return type.isPrimitive() ||
+               type == MemorySegment.class ||
+               type == String.class;
+    }
+
+    private static boolean isValidParamType(Class aClass) {
+        return aClass.isPrimitive() ||
+               aClass == MemorySegment.class ||
+               aClass == String.class ||
+               SegmentAllocator.class.isAssignableFrom(aClass) ||
+               Addressable.class.isAssignableFrom(aClass) ||
+               Upcall.class.isAssignableFrom(aClass) ||
+               CEnum.class.isAssignableFrom(aClass) ||
+               isValidParamArrayType(aClass);
+    }
+
+    private static boolean isValidReturnType(Class aClass) {
+        return aClass.isPrimitive() ||
+               aClass == MemorySegment.class ||
+               aClass == String.class ||
+               Struct.class.isAssignableFrom(aClass) ||
+               CEnum.class.isAssignableFrom(aClass) ||
+               isValidReturnArrayType(aClass) ||
+               aClass == MethodHandle.class;
+    }
+
+    private static DowncallData generateData(Map methodDataMap, SymbolLookup lookup, Map descriptorMap) {
+        final Map descriptorMap1 = HashMap.newHashMap(methodDataMap.size());
         final Map map = HashMap.newHashMap(methodDataMap.size());
+
         methodDataMap.forEach((method, methodData) -> {
             final String entrypoint = methodData.entrypoint();
             final Optional optional = lookup.find(entrypoint);
 
-            // function descriptor
-            FunctionDescriptor descriptor = null;
-            final FunctionDescriptor get = descriptorMap.get(entrypoint);
-            if (get != null) {
-                descriptor = get;
-            } else if (optional.isPresent() || !method.isDefault()) {
-                final var returnType = method.getReturnType();
-                final boolean returnVoid = returnType == void.class;
-                final boolean methodByValue = method.getDeclaredAnnotation(ByValue.class) != null;
-
-                // return layout
-                final MemoryLayout retLayout;
-                if (!returnVoid) {
-                    final Convert convert = method.getDeclaredAnnotation(Convert.class);
-                    if (convert != null && returnType == boolean.class) {
-                        retLayout = convert.value().layout();
-                    } else if (returnType.isPrimitive()) {
-                        retLayout = getValueLayout(returnType);
-                    } else {
-                        final SizedSeg sizedSeg = method.getDeclaredAnnotation(SizedSeg.class);
-                        final Sized sized = method.getDeclaredAnnotation(Sized.class);
-                        final boolean isSizedSeg = sizedSeg != null;
-                        final boolean isSized = sized != null;
-                        if (Struct.class.isAssignableFrom(returnType)) {
-                            StructLayout structLayout = null;
-                            for (Field field : returnType.getDeclaredFields()) {
-                                if (Modifier.isStatic(field.getModifiers()) && field.getType() == StructLayout.class) {
-                                    try {
-                                        structLayout = (StructLayout) field.get(null);
-                                        break;
-                                    } catch (IllegalAccessException e) {
-                                        throw new RuntimeException(e);
+            if (optional.isPresent()) {
+                // function descriptor
+                final FunctionDescriptor descriptor;
+                final FunctionDescriptor get = descriptorMap.get(entrypoint);
+                if (get != null) {
+                    descriptor = get;
+                } else {
+                    final var returnType = method.getReturnType();
+                    final boolean returnVoid = returnType == void.class;
+                    final boolean methodByValue = method.getDeclaredAnnotation(ByValue.class) != null;
+
+                    // return layout
+                    final MemoryLayout retLayout;
+                    if (!returnVoid) {
+                        final Convert convert = method.getDeclaredAnnotation(Convert.class);
+                        if (convert != null && returnType == boolean.class) {
+                            retLayout = convert.value().layout();
+                        } else if (returnType.isPrimitive()) {
+                            retLayout = getValueLayout(returnType);
+                        } else {
+                            final SizedSeg sizedSeg = method.getDeclaredAnnotation(SizedSeg.class);
+                            final Sized sized = method.getDeclaredAnnotation(Sized.class);
+                            final boolean isSizedSeg = sizedSeg != null;
+                            final boolean isSized = sized != null;
+                            if (Struct.class.isAssignableFrom(returnType)) {
+                                StructLayout structLayout = null;
+                                for (Field field : returnType.getDeclaredFields()) {
+                                    if (Modifier.isStatic(field.getModifiers()) && field.getType() == StructLayout.class) {
+                                        try {
+                                            structLayout = (StructLayout) field.get(null);
+                                            break;
+                                        } catch (IllegalAccessException e) {
+                                            throw new RuntimeException(e);
+                                        }
                                     }
                                 }
-                            }
-                            Objects.requireNonNull(structLayout);
-                            if (methodByValue) {
-                                retLayout = structLayout;
-                            } else {
-                                final MemoryLayout targetLayout;
-                                if (isSizedSeg) {
-                                    targetLayout = MemoryLayout.sequenceLayout(sizedSeg.value(), structLayout);
-                                } else if (isSized) {
-                                    targetLayout = MemoryLayout.sequenceLayout(sized.value(), structLayout);
+                                Objects.requireNonNull(structLayout);
+                                if (methodByValue) {
+                                    retLayout = structLayout;
                                 } else {
-                                    targetLayout = structLayout;
+                                    final MemoryLayout targetLayout;
+                                    if (isSizedSeg) {
+                                        targetLayout = MemoryLayout.sequenceLayout(sizedSeg.value(), structLayout);
+                                    } else if (isSized) {
+                                        targetLayout = MemoryLayout.sequenceLayout(sized.value(), structLayout);
+                                    } else {
+                                        targetLayout = structLayout;
+                                    }
+                                    retLayout = ValueLayout.ADDRESS.withTargetLayout(targetLayout);
                                 }
-                                retLayout = ValueLayout.ADDRESS.withTargetLayout(targetLayout);
-                            }
-                        } else {
-                            final ValueLayout valueLayout = getValueLayout(returnType);
-                            if ((valueLayout instanceof AddressLayout addressLayout) && (isSizedSeg || isSized)) {
-                                if (isSizedSeg) {
-                                    retLayout = addressLayout.withTargetLayout(MemoryLayout.sequenceLayout(sizedSeg.value(),
-                                        ValueLayout.JAVA_BYTE));
+                            } else {
+                                final ValueLayout valueLayout = getValueLayout(returnType);
+                                if ((valueLayout instanceof AddressLayout addressLayout) && (isSizedSeg || isSized)) {
+                                    if (isSizedSeg) {
+                                        retLayout = addressLayout.withTargetLayout(MemoryLayout.sequenceLayout(sizedSeg.value(),
+                                            ValueLayout.JAVA_BYTE));
+                                    } else {
+                                        retLayout = addressLayout.withTargetLayout(MemoryLayout.sequenceLayout(sized.value(),
+                                            returnType.isArray() ? getValueLayout(returnType.getComponentType()) : ValueLayout.JAVA_BYTE));
+                                    }
                                 } else {
-                                    retLayout = addressLayout.withTargetLayout(MemoryLayout.sequenceLayout(sized.value(),
-                                        returnType.isArray() ? getValueLayout(returnType.getComponentType()) : ValueLayout.JAVA_BYTE));
+                                    retLayout = valueLayout;
                                 }
-                            } else {
-                                retLayout = valueLayout;
                             }
                         }
+                    } else {
+                        retLayout = null;
                     }
-                } else {
-                    retLayout = null;
-                }
 
-                // argument layouts
-                final var parameters = methodData.parameters();
-                final boolean skipFirstParam = methodData.skipFirstParam();
-                final int size = skipFirstParam || methodByValue ?
-                    parameters.size() - 1 :
-                    parameters.size();
-                final MemoryLayout[] argLayouts = new MemoryLayout[size];
-                for (int i = 0; i < size; i++) {
-                    final Parameter parameter = parameters.get(skipFirstParam ? i + 1 : i);
-                    final Class type = parameter.getType();
-                    final Convert convert = parameter.getDeclaredAnnotation(Convert.class);
-                    if (convert != null && type == boolean.class) {
-                        argLayouts[i] = convert.value().layout();
-                    } else {
-                        argLayouts[i] = getValueLayout(type);
+                    // argument layouts
+                    final var parameters = methodData.parameters();
+                    final boolean skipFirstParam = methodData.skipFirstParam();
+                    final int size = skipFirstParam || methodByValue ?
+                        parameters.size() - 1 :
+                        parameters.size();
+                    final MemoryLayout[] argLayouts = new MemoryLayout[size];
+                    for (int i = 0; i < size; i++) {
+                        final Parameter parameter = parameters.get(skipFirstParam ? i + 1 : i);
+                        final Class type = parameter.getType();
+                        final Convert convert = parameter.getDeclaredAnnotation(Convert.class);
+                        if (convert != null && type == boolean.class) {
+                            argLayouts[i] = convert.value().layout();
+                        } else {
+                            argLayouts[i] = getValueLayout(type);
+                        }
                     }
-                }
 
-                descriptor = returnVoid ? FunctionDescriptor.ofVoid(argLayouts) : FunctionDescriptor.of(retLayout, argLayouts);
-            }
+                    descriptor = returnVoid ? FunctionDescriptor.ofVoid(argLayouts) : FunctionDescriptor.of(retLayout, argLayouts);
+                }
 
-            if (optional.isPresent()) {
                 // linker options
                 final Linker.Option[] options;
                 final Critical critical = method.getDeclaredAnnotation(Critical.class);
@@ -1003,13 +1087,15 @@ private static Map generateHandles(Map aClass) {
+        return new DowncallOptions.TargetClass(aClass);
+    }
+
+    /**
+     * Specifies the custom function descriptors.
+     *
+     * @param descriptorMap the custom function descriptors for each method handle
+     * @return the option instance
+     */
+    static DowncallOption descriptors(Map descriptorMap) {
+        return new DowncallOptions.Descriptors(descriptorMap);
+    }
+}
diff --git a/src/main/java/overrun/marshal/Upcall.java b/src/main/java/overrun/marshal/Upcall.java
index 052fd2f..bc4b4e9 100644
--- a/src/main/java/overrun/marshal/Upcall.java
+++ b/src/main/java/overrun/marshal/Upcall.java
@@ -16,8 +16,6 @@
 
 package overrun.marshal;
 
-import overrun.marshal.gen.SizedSeg;
-
 import java.lang.annotation.ElementType;
 import java.lang.annotation.Retention;
 import java.lang.annotation.RetentionPolicy;
@@ -25,10 +23,6 @@
 import java.lang.foreign.*;
 import java.lang.invoke.MethodHandle;
 import java.lang.invoke.MethodHandles;
-import java.lang.reflect.Method;
-import java.util.Arrays;
-import java.util.function.Function;
-import java.util.function.Supplier;
 
 /**
  * An upcall interface.
@@ -39,15 +33,13 @@
  * 

Example

*
{@code
  * // The implementation must be public if you use Type
- * // It is not necessary to mark it as a functional interface. However, you can mark it
+ * // The interface doesn't have to be a functional interface
  * @FunctionalInterface
  * public interface MyCallback extends Upcall {
  *     // Create a type wrapper
- *     Type TYPE = Upcall.type();
+ *     Type TYPE = Upcall.type("invoke", FunctionDescriptor.of(JAVA_INT, JAVA_INT));
  *
  *     // The function to be invoked in C
- *     // Also the stub provider
- *     @Stub
  *     int invoke(int i);
  *
  *     // Create an upcall stub segment with Type
@@ -56,28 +48,24 @@
  *         return TYPE.of(arena, this);
  *     }
  *
- *     // Create an optional wrap method
- *     @Wrapper
- *     static MyCallback wrap(Arena arena, MemorySegment stub) {
- *         return TYPE.wrap(stub, mh -> i -> {
- *             try {
- *                 return (int) mh.invokeExact(i);
- *             } catch (Throwable e) {
- *                 throw new RuntimeException(e);
- *             }
- *         });
+ *     // Create an optional wrap method for others to invoke
+ *     static int invoke(MemorySegment stub, int i) {
+ *         try {
+ *             return (int) TYPE.downcallTarget().invokeExact(stub, i);
+ *         } catch (Throwable e) {
+ *             throw new RuntimeException(e);
+ *         }
  *     }
  * }
  *
  * // C downcall
+ * void setMyCallback(MyCallback cb);
  * setMyCallback(i -> i * 2);
  * }
* * @author squid233 * @see #stub(Arena) - * @see Stub * @see Type - * @see Wrapper * @since 0.1.0 */ public interface Upcall { @@ -93,12 +81,14 @@ public interface Upcall { /** * Creates {@link Type} with the caller class. * - * @param the type of the upcall interface + * @param targetName the name of the target method + * @param descriptor the function descriptor of the target method + * @param the type of the upcall interface * @return the created {@link Type} - * @see #type(Class) + * @see #type(Class, String, FunctionDescriptor) */ @SuppressWarnings("unchecked") - static Type type() { + static Type type(String targetName, FunctionDescriptor descriptor) { final class Walker { private static final StackWalker STACK_WALKER = StackWalker.getInstance(StackWalker.Option.RETAIN_CLASS_REFERENCE); } @@ -106,19 +96,21 @@ final class Walker { if (!Upcall.class.isAssignableFrom(callerClass)) { throw new ClassCastException(callerClass.getName()); } - return type((Class) callerClass); + return type((Class) callerClass, targetName, descriptor); } /** * Creates {@link Type}. * - * @param tClass the class of the upcall interface - * @param the type of the upcall interface + * @param tClass the class of the upcall interface + * @param targetName the name of the target method + * @param descriptor the function descriptor of the target method + * @param the type of the upcall interface * @return the created {@link Type} - * @see #type() + * @see #type(String, FunctionDescriptor) */ - static Type type(Class tClass) { - return new Type<>(tClass); + static Type type(Class tClass, String targetName, FunctionDescriptor descriptor) { + return new Type<>(tClass, targetName, descriptor); } /** @@ -133,23 +125,9 @@ static Type type(Class tClass) { @interface Stub { } - /** - * Marks a static method as an upcall wrapper. - *

- * The parameters of marked method must be only one {@link Arena} and one {@link MemorySegment}. - * - * @author squid233 - * @see Upcall - * @since 0.1.0 - */ - @Target(ElementType.METHOD) - @Retention(RetentionPolicy.RUNTIME) - @interface Wrapper { - } - /** * The type wrapper of an upcall interface. - * The constructor uses heavy reflective, and you should always cache it as a static field. + * The constructor uses heavy reflection, and you should always cache it as a static field. * * @param The type of the upcall interface. * @author squid233 @@ -160,34 +138,16 @@ final class Type { private static final Linker LINKER = Linker.nativeLinker(); private final MethodHandle target; private final FunctionDescriptor descriptor; + private final MethodHandle downcallTarget; - private Type(Class tClass) { - final Method method = Arrays.stream(tClass.getDeclaredMethods()) - .filter(m -> m.getDeclaredAnnotation(Stub.class) != null) - .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Couldn't find any upcall stub provider in " + tClass)); + private Type(Class tClass, String targetName, FunctionDescriptor descriptor) { try { - target = MethodHandles.publicLookup().unreflect(method); - } catch (IllegalAccessException e) { + target = MethodHandles.publicLookup().findVirtual(tClass, targetName, descriptor.toMethodType()); + } catch (IllegalAccessException | NoSuchMethodException e) { throw new RuntimeException(e); } - final var returnType = method.getReturnType(); - final MemoryLayout[] memoryLayouts = Arrays.stream(method.getParameters()) - .map(p -> withSizedSeg(toMemoryLayout(p.getType()), p.getDeclaredAnnotation(SizedSeg.class))) - .toArray(MemoryLayout[]::new); - if (returnType == void.class) { - descriptor = FunctionDescriptor.ofVoid(memoryLayouts); - } else { - descriptor = FunctionDescriptor.of(withSizedSeg(toMemoryLayout(returnType), method.getDeclaredAnnotation(SizedSeg.class)), - memoryLayouts); - } - } - - private static MemoryLayout withSizedSeg(MemoryLayout layout, SizedSeg sizedSeg) { - if (sizedSeg != null && layout instanceof AddressLayout addressLayout) { - return addressLayout.withTargetLayout(MemoryLayout.sequenceLayout(sizedSeg.value(), ValueLayout.JAVA_BYTE)); - } - return layout; + this.descriptor = descriptor; + this.downcallTarget = LINKER.downcallHandle(descriptor); } /** @@ -209,19 +169,7 @@ public MemorySegment of(Arena arena, T upcall) { * @return the downcall method handle */ public MethodHandle downcall(MemorySegment stub) { - return LINKER.downcallHandle(stub, descriptor()); - } - - /** - * Wraps the given upcall stub segment. - * - * @param stub the upcall stub segment - * @param function the function that transforms the given method handle into the downcall type. - * The {@link Arena} is wrapped in a {@link Supplier} and you should store it with a variable - * @return the downcall type - */ - public T wrap(MemorySegment stub, Function function) { - return Unmarshal.isNullPointer(stub) ? null : function.apply(downcall(stub)); + return downcallTarget.bindTo(stub); } /** @@ -231,57 +179,11 @@ public FunctionDescriptor descriptor() { return descriptor; } - private static MemoryLayout toMemoryLayout(Class carrier) { - if (carrier == boolean.class) { - return ValueLayout.JAVA_BOOLEAN; - } else if (carrier == char.class) { - return ValueLayout.JAVA_CHAR; - } else if (carrier == byte.class) { - return ValueLayout.JAVA_BYTE; - } else if (carrier == short.class) { - return ValueLayout.JAVA_SHORT; - } else if (carrier == int.class) { - return ValueLayout.JAVA_INT; - } else if (carrier == float.class) { - return ValueLayout.JAVA_FLOAT; - } else if (carrier == long.class) { - return ValueLayout.JAVA_LONG; - } else if (carrier == double.class) { - return ValueLayout.JAVA_DOUBLE; - } else if (carrier == MemorySegment.class) { - return ValueLayout.ADDRESS; - } else { - throw new IllegalArgumentException("Unsupported carrier: " + carrier.getName()); - } - } - } - - /** - * The base container of {@link Arena} and {@link Upcall}. - * - * @param the type of the upcall - * @author squid233 - * @since 0.1.0 - */ - abstract class BaseContainer implements Upcall { - /** - * The arena. - */ - protected final Arena arena; /** - * The upcall delegate. - */ - protected final T delegate; - - /** - * Creates a base container. - * - * @param arena the arena - * @param delegate the upcall delegate + * {@return the downcall method handle} */ - public BaseContainer(Arena arena, T delegate) { - this.arena = arena; - this.delegate = delegate; + public MethodHandle downcallTarget() { + return downcallTarget; } } } diff --git a/src/main/java/overrun/marshal/internal/DowncallData.java b/src/main/java/overrun/marshal/internal/DowncallData.java new file mode 100644 index 0000000..9681343 --- /dev/null +++ b/src/main/java/overrun/marshal/internal/DowncallData.java @@ -0,0 +1,32 @@ +/* + * MIT License + * + * Copyright (c) 2024 Overrun Organization + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + */ + +package overrun.marshal.internal; + +import java.lang.foreign.FunctionDescriptor; +import java.lang.invoke.MethodHandle; +import java.util.Map; + +/** + * Downcall class data + * + * @param descriptorMap descriptorMap + * @param handleMap handleMap + * @author squid233 + * @since 0.1.0 + */ +public record DowncallData(Map descriptorMap, Map handleMap) { +} diff --git a/src/main/java/overrun/marshal/internal/DowncallOptions.java b/src/main/java/overrun/marshal/internal/DowncallOptions.java new file mode 100644 index 0000000..9d61bdd --- /dev/null +++ b/src/main/java/overrun/marshal/internal/DowncallOptions.java @@ -0,0 +1,50 @@ +/* + * MIT License + * + * Copyright (c) 2024 Overrun Organization + * + * Permission is hereby granted, free of charge, to any person obtaining a copy + * of this software and associated documentation files (the "Software"), to deal + * in the Software without restriction, including without limitation the rights + * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all + * copies or substantial portions of the Software. + */ + +package overrun.marshal.internal; + +import overrun.marshal.DowncallOption; + +import java.lang.foreign.FunctionDescriptor; +import java.util.Map; + +/** + * Downcall options. + * + * @author squid233 + * @since 0.1.0 + */ +public final class DowncallOptions { + private DowncallOptions() { + } + + /** + * specify target class + * + * @param aClass the class + */ + public record TargetClass(Class aClass) implements DowncallOption { + } + + /** + * specify custom function descriptors + * + * @param descriptorMap the custom function descriptors + */ + + public record Descriptors(Map descriptorMap) implements DowncallOption { + } +} diff --git a/src/main/java/overrun/marshal/struct/StructHandle.java b/src/main/java/overrun/marshal/struct/StructHandle.java index efb8a13..ba63b82 100644 --- a/src/main/java/overrun/marshal/struct/StructHandle.java +++ b/src/main/java/overrun/marshal/struct/StructHandle.java @@ -237,7 +237,7 @@ public static Addressable ofAddressab * @param the type of the upcall * @return the struct handle */ - public static Upcall ofUpcall(Struct struct, String name, BiFunction factory) { + public static Upcall ofUpcall(Struct struct, String name, Function factory) { return new Upcall<>(ofValue(struct, name), factory); } @@ -841,9 +841,9 @@ public T get() { * @since 0.1.0 */ public static final class Upcall extends TypeExt { - private final BiFunction factory; + private final Function factory; - private Upcall(VarHandle varHandle, BiFunction factory) { + private Upcall(VarHandle varHandle, Function factory) { super(varHandle); this.factory = factory; } @@ -861,7 +861,7 @@ public void set(Arena userdata, T value) { @Override public T get(long index, Arena userdata) { if (factory == null) throw new UnsupportedOperationException(); - return factory.apply(userdata, (MemorySegment) varHandle.get(0L, index)); + return factory.apply((MemorySegment) varHandle.get(0L, index)); } @Override diff --git a/src/test/java/overrun/marshal/test/ComplexStruct.java b/src/test/java/overrun/marshal/test/ComplexStruct.java index ea536e9..409297c 100644 --- a/src/test/java/overrun/marshal/test/ComplexStruct.java +++ b/src/test/java/overrun/marshal/test/ComplexStruct.java @@ -58,7 +58,7 @@ public final class ComplexStruct extends Struct { public final StructHandle.Str Str = StructHandle.ofString(this, "Str"); public final StructHandle.Str UTF16Str = StructHandle.ofString(this, "UTF16Str", StandardCharsets.UTF_16); public final StructHandle.Addressable Addressable = StructHandle.ofAddressable(this, "Addressable", Vector3::new); - public final StructHandle.Upcall Upcall = StructHandle.ofUpcall(this, "Upcall", SimpleUpcall::wrap); + public final StructHandle.Upcall Upcall = StructHandle.ofUpcall(this, "Upcall", segment -> i -> SimpleUpcall.invoke(segment, i)); public final StructHandle.Array IntArray = StructHandle.ofArray(this, "IntArray", Marshal::marshal, Unmarshal::unmarshalAsIntArray); /** diff --git a/src/test/java/overrun/marshal/test/ComplexUpcall.java b/src/test/java/overrun/marshal/test/ComplexUpcall.java index 9629be1..4e427d5 100644 --- a/src/test/java/overrun/marshal/test/ComplexUpcall.java +++ b/src/test/java/overrun/marshal/test/ComplexUpcall.java @@ -16,13 +16,13 @@ package overrun.marshal.test; +import overrun.marshal.Marshal; +import overrun.marshal.MemoryStack; +import overrun.marshal.Unmarshal; import overrun.marshal.Upcall; -import overrun.marshal.gen.SizedSeg; +import overrun.marshal.gen.Sized; -import java.lang.foreign.Arena; -import java.lang.foreign.MemorySegment; -import java.lang.foreign.ValueLayout; -import java.lang.invoke.MethodHandle; +import java.lang.foreign.*; /** * A complex upcall @@ -32,74 +32,28 @@ */ @FunctionalInterface public interface ComplexUpcall extends Upcall { - int[] invoke(int[] arr); + AddressLayout ARG_LAYOUT = ValueLayout.ADDRESS.withTargetLayout(MemoryLayout.sequenceLayout(2L, ValueLayout.JAVA_INT)); + Type TYPE = Upcall.type("invoke", FunctionDescriptor.of(ARG_LAYOUT, ARG_LAYOUT)); - /** - * the container - * - * @author squid233 - * @since 0.1.0 - */ - sealed class Container extends BaseContainer implements ComplexUpcall { - public static final Type TYPE = Upcall.type(); + @Sized(2) + int[] invoke(@Sized(2) int[] arr); - public Container(Arena arena, ComplexUpcall delegate) { - super(arena, delegate); - } - - @Stub - @SizedSeg(2 * Integer.BYTES) - public MemorySegment invoke(@SizedSeg(2 * Integer.BYTES) MemorySegment arr) { - return arena.allocateFrom(ValueLayout.JAVA_INT, invoke(arr.toArray(ValueLayout.JAVA_INT))); - } - - @Override - public int[] invoke(int[] arr) { - return delegate.invoke(arr); - } - - @Override - public MemorySegment stub(Arena arena) { - return TYPE.of(arena, this); + default MemorySegment invoke(MemorySegment arr) { + try (MemoryStack stack = MemoryStack.stackPush()) { + return Marshal.marshal(stack, invoke(Unmarshal.unmarshalAsIntArray(arr))); } } - /** - * the wrapper container - * - * @author squid233 - * @since 0.1.0 - */ - final class WrapperContainer extends Container { - private final MethodHandle handle; - - public WrapperContainer(Arena arena, MemorySegment stub) { - super(arena, null); - this.handle = TYPE.downcall(stub); + static int[] invoke(MemorySegment stub, int[] arr) { + try (MemoryStack stack = MemoryStack.stackPush()) { + return Unmarshal.unmarshalAsIntArray((MemorySegment) TYPE.downcallTarget().invokeExact(stub, Marshal.marshal(stack, arr))); + } catch (Throwable e) { + throw new RuntimeException(e); } - - @Override - public MemorySegment invoke(MemorySegment arr) { - try { - return (MemorySegment) handle.invokeExact(arr); - } catch (Throwable e) { - throw new RuntimeException(e); - } - } - - @Override - public int[] invoke(int[] arr) { - return invoke(arena.allocateFrom(ValueLayout.JAVA_INT, arr)).toArray(ValueLayout.JAVA_INT); - } - } - - @Wrapper - static Container wrap(Arena arena, MemorySegment stub) { - return new WrapperContainer(arena, stub); } @Override default MemorySegment stub(Arena arena) { - return new Container(arena, this).stub(arena); + return TYPE.of(arena, this); } } diff --git a/src/test/java/overrun/marshal/test/DescriptorMapTest.java b/src/test/java/overrun/marshal/test/DescriptorMapTest.java index 30bafb6..de20f26 100644 --- a/src/test/java/overrun/marshal/test/DescriptorMapTest.java +++ b/src/test/java/overrun/marshal/test/DescriptorMapTest.java @@ -18,6 +18,7 @@ import org.junit.jupiter.api.Test; import overrun.marshal.Downcall; +import overrun.marshal.DowncallOption; import overrun.marshal.gen.Entrypoint; import overrun.marshal.gen.Skip; @@ -89,10 +90,10 @@ static boolean acceptLong(long d) { public interface Interface { static Interface getInstance(ValueLayout returnLayout, ValueLayout acceptLayout) { - return Downcall.load(MethodHandles.lookup(), lookup(returnLayout, acceptLayout), Map.of( + return Downcall.load(MethodHandles.lookup(), lookup(returnLayout, acceptLayout), DowncallOption.descriptors(Map.of( "testReturn", FunctionDescriptor.of(returnLayout), "testAccept", FunctionDescriptor.of(JAVA_BOOLEAN, acceptLayout) - )); + ))); } @Entrypoint("testReturn") diff --git a/src/test/java/overrun/marshal/test/DowncallProvider.java b/src/test/java/overrun/marshal/test/DowncallProvider.java index d8531b5..8930dba 100644 --- a/src/test/java/overrun/marshal/test/DowncallProvider.java +++ b/src/test/java/overrun/marshal/test/DowncallProvider.java @@ -16,8 +16,6 @@ package overrun.marshal.test; -import overrun.marshal.MemoryStack; - import java.lang.foreign.*; import java.lang.invoke.MethodHandle; import java.lang.invoke.MethodHandles; @@ -111,9 +109,7 @@ private static void testCEnum(int i) { } private static int testUpcall(MemorySegment upcall) { - try (MemoryStack stack = MemoryStack.stackPush()) { - return SimpleUpcall.wrap(stack, upcall).invoke(42); - } + return SimpleUpcall.invoke(upcall, 42); } private static void testIntArray(MemorySegment arr) { diff --git a/src/test/java/overrun/marshal/test/DowncallTest.java b/src/test/java/overrun/marshal/test/DowncallTest.java index fb180e0..a0b0f2f 100644 --- a/src/test/java/overrun/marshal/test/DowncallTest.java +++ b/src/test/java/overrun/marshal/test/DowncallTest.java @@ -82,8 +82,8 @@ void testReturnCEnum() { @Test void testReturnUpcall() { try (Arena arena = Arena.ofConfined()) { - final SimpleUpcall upcall = d.testReturnUpcall(arena); - assertEquals(84, upcall.invoke(42)); + final MemorySegment upcall = d.testReturnUpcall(arena); + assertEquals(84, SimpleUpcall.invoke(upcall, 42)); } } diff --git a/src/test/java/overrun/marshal/test/IDowncall.java b/src/test/java/overrun/marshal/test/IDowncall.java index 87390c8..09139a5 100644 --- a/src/test/java/overrun/marshal/test/IDowncall.java +++ b/src/test/java/overrun/marshal/test/IDowncall.java @@ -16,6 +16,7 @@ package overrun.marshal.test; +import overrun.marshal.DowncallOption; import overrun.marshal.struct.ByValue; import overrun.marshal.Downcall; import overrun.marshal.MemoryStack; @@ -36,7 +37,7 @@ public interface IDowncall { Map MAP = Map.of("testDefault", FunctionDescriptor.of(ValueLayout.JAVA_INT)); static IDowncall getInstance(boolean testDefaultNull) { - return Downcall.load(MethodHandles.lookup(), DowncallProvider.lookup(testDefaultNull), MAP); + return Downcall.load(MethodHandles.lookup(), DowncallProvider.lookup(testDefaultNull), DowncallOption.descriptors(MAP)); } void test(); @@ -86,7 +87,7 @@ default int testDefault() { MyEnum testReturnCEnum(); - SimpleUpcall testReturnUpcall(Arena arena); + MemorySegment testReturnUpcall(Arena arena); Vector3 testReturnStruct(); diff --git a/src/test/java/overrun/marshal/test/IndirectInterfaceTest.java b/src/test/java/overrun/marshal/test/IndirectInterfaceTest.java index 9a149a9..d7bbd47 100644 --- a/src/test/java/overrun/marshal/test/IndirectInterfaceTest.java +++ b/src/test/java/overrun/marshal/test/IndirectInterfaceTest.java @@ -19,6 +19,7 @@ import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; import overrun.marshal.Downcall; +import overrun.marshal.DowncallOption; import java.lang.foreign.SymbolLookup; import java.lang.invoke.MethodHandles; @@ -39,7 +40,7 @@ default int fun1() { public interface I2 extends I1 { } - I2 INSTANCE = Downcall.load(MethodHandles.lookup(), I2.class, SymbolLookup.loaderLookup()); + I2 INSTANCE = Downcall.load(MethodHandles.lookup(), SymbolLookup.loaderLookup(), DowncallOption.targetClass(I2.class)); @Test void testIndirectInterface() { diff --git a/src/test/java/overrun/marshal/test/SimpleUpcall.java b/src/test/java/overrun/marshal/test/SimpleUpcall.java index ab97b40..860cc73 100644 --- a/src/test/java/overrun/marshal/test/SimpleUpcall.java +++ b/src/test/java/overrun/marshal/test/SimpleUpcall.java @@ -19,7 +19,9 @@ import overrun.marshal.Upcall; import java.lang.foreign.Arena; +import java.lang.foreign.FunctionDescriptor; import java.lang.foreign.MemorySegment; +import java.lang.foreign.ValueLayout; /** * A simple upcall @@ -29,20 +31,16 @@ */ @FunctionalInterface public interface SimpleUpcall extends Upcall { - Type TYPE = Upcall.type(); + Type TYPE = Upcall.type("invoke", FunctionDescriptor.of(ValueLayout.JAVA_INT, ValueLayout.JAVA_INT)); - @Stub int invoke(int i); - @Wrapper - static SimpleUpcall wrap(Arena arena, MemorySegment stub) { - return TYPE.wrap(stub, methodHandle -> i -> { - try { - return (int) methodHandle.invokeExact(i); - } catch (Throwable e) { - throw new RuntimeException(e); - } - }); + static int invoke(MemorySegment stub, int i) { + try { + return (int) TYPE.downcallTarget().invokeExact(stub, i); + } catch (Throwable e) { + throw new RuntimeException(e); + } } @Override diff --git a/src/test/java/overrun/marshal/test/UpcallTest.java b/src/test/java/overrun/marshal/test/UpcallTest.java index db6b441..9eff7dc 100644 --- a/src/test/java/overrun/marshal/test/UpcallTest.java +++ b/src/test/java/overrun/marshal/test/UpcallTest.java @@ -31,12 +31,12 @@ * @since 0.1.0 */ public final class UpcallTest { - private int invokeSimpleUpcall(Arena arena, MemorySegment upcall) { - return SimpleUpcall.wrap(arena, upcall).invoke(42); + private int invokeSimpleUpcall(MemorySegment upcall) { + return SimpleUpcall.invoke(upcall, 42); } - private int[] invokeComplexUpcall(Arena arena, MemorySegment upcall) { - return ComplexUpcall.wrap(arena, upcall).invoke(new int[]{4, 2}); + private int[] invokeComplexUpcall(MemorySegment upcall) { + return ComplexUpcall.invoke(upcall, new int[]{4, 2}); } @Test @@ -44,7 +44,7 @@ void testSimpleUpcall() { final Arena arena = Arena.ofAuto(); final SimpleUpcall upcall = i -> i * 2; final MemorySegment stub = upcall.stub(arena); - assertEquals(84, invokeSimpleUpcall(arena, stub)); + assertEquals(84, invokeSimpleUpcall(stub)); } @Test @@ -52,6 +52,6 @@ void testComplexUpcall() { final Arena arena = Arena.ofAuto(); final ComplexUpcall upcall = arr -> new int[]{arr[0] * 4, arr[1] * 2}; final MemorySegment stub = upcall.stub(arena); - assertArrayEquals(new int[]{16, 4}, invokeComplexUpcall(arena, stub)); + assertArrayEquals(new int[]{16, 4}, invokeComplexUpcall(stub)); } }