Improve CLEventList design and fix CLEvent.setEventCallback.
authorNot Zed <notzed@gmail.com>
Sun, 26 Jan 2020 07:39:09 +0000 (18:09 +1030)
committerNot Zed <notzed@gmail.com>
Sun, 26 Jan 2020 07:39:09 +0000 (18:09 +1030)
Now uses a MemorySegment directly to read/write events, bypassing Java
unless necessary.
Added tests.

src/notzed.zcl/classes/au/notzed/zcl/CLCommandQueue.java
src/notzed.zcl/classes/au/notzed/zcl/CLEvent.java
src/notzed.zcl/classes/au/notzed/zcl/CLEventList.java
src/notzed.zcl/tests/au/notzed/zcl/CLEventTest.java [new file with mode: 0644]

index 6de70f5..f878bb8 100644 (file)
@@ -1506,6 +1506,13 @@ public class CLCommandQueue extends CLExtendable {
                }
        }
 
+       /**
+        * Simplify wait/event handling.
+        *
+        * To use, create an EventInfo from the passed in arguments.
+        * In the enqueue command pass in .wait and .event from this
+        * structure.  If the command succeeds, then call post().
+        */
        static private class EventInfo {
                final int nwait;
                final MemoryAddress wait;
@@ -1513,19 +1520,13 @@ public class CLCommandQueue extends CLExtendable {
 
                EventInfo(Allocator frame, CLEventList waiters, CLEventList events) {
                        nwait = waiters != null ? waiters.size() : 0;
-                       if (nwait > 0) {
-                               wait = frame.alloca(8 * nwait);
-                               for (int i=0;i<nwait;i++)
-                                       Native.setAddr(wait, i, waiters.get(i).addr());
-                       } else {
-                               wait = MemoryAddress.NULL;
-                       }
-                       event = events != null ? frame.alloca(8) : MemoryAddress.NULL;
+                       wait = nwait > 0 ? waiters.slots() : MemoryAddress.NULL;
+                       event = events != null ? events.currentSlot() : MemoryAddress.NULL;
                }
 
                void post(CLEventList events) {
                        if (events != null)
-                               events.add(resolve(getAddr(event), CLEvent::new));
+                               events.incrementSlot();
                }
        }
 
index c378cfe..d921e6e 100644 (file)
@@ -19,9 +19,9 @@ package au.notzed.zcl;
 import jdk.incubator.foreign.*;
 import static au.notzed.zcl.CL.*;
 import static au.notzed.zcl.CLLib.*;
-import api.Native;
-import api.Callback;
+import api.*;
 import java.lang.invoke.MethodHandle;
+import java.util.ArrayList;
 
 /**
  * Interface for cl_event.
@@ -34,9 +34,10 @@ public class CLEvent extends CLObject {
        final int apiVersion;
 
        /**
-        * This is used to retain a reference for any callback set
+        * This is used to retain a reference for any callback set.
+        * There may be multiple.
         */
-       Callback<CLEventNotify> callback;
+       ArrayList<Callback<CLEventNotify>> callbacks;
 
        public CLEvent(MemoryAddress p) {
                super(p);
@@ -58,8 +59,10 @@ public class CLEvent extends CLObject {
 
        @Override
        public void release() {
-               Native.release(callback);
-               callback = null;
+               if (callbacks != null) {
+                       callbacks.forEach((c) -> c.release());
+                       callbacks = null;
+               }
                super.release();
        }
 
@@ -94,6 +97,8 @@ public class CLEvent extends CLObject {
        /**
         * Call clSetEventCallback(type, notify).
         *
+        * Adds a new callback for the given state.
+        *
         * @param type
         * @param notify
         * @throws CLRuntimeException
@@ -102,23 +107,21 @@ public class CLEvent extends CLObject {
        public void setEventCallback(int type, CLEventNotify notify) throws CLRuntimeException {
                CLPlatform.requireAPIVersion(apiVersion, CLPlatform.VERSION_1_1);
 
-               Native.release(callback);
+               Callback<CLEventNotify> callback = CLEventNotify.call(notify);
 
-               if (notify != null) {
-                       callback = CLEventNotify.call(notify);
+               try {
+                       int res = clSetEventCallback(addr(), type, callback.addr(), MemoryAddress.NULL);
 
-                       try {
-                               int res = clSetEventCallback(addr(), type, callback.addr(), MemoryAddress.NULL);
+                       if (res != 0)
+                               throw new CLRuntimeException(res);
 
-                               if (res != 0)
-                                       throw new CLRuntimeException(res);
-                       } catch (RuntimeException | Error t) {
-                               throw t;
-                       } catch (Throwable t) {
-                               throw new RuntimeException(t);
-                       }
-               } else {
-                       callback = null;
+                       if (callbacks == null)
+                               callbacks = new ArrayList<>();
+                       callbacks.add(callback);
+               } catch (RuntimeException | Error t) {
+                       throw t;
+               } catch (Throwable t) {
+                       throw new RuntimeException(t);
                }
        }
 
index df0d398..187bc9d 100644 (file)
@@ -17,6 +17,7 @@
 package au.notzed.zcl;
 
 import jdk.incubator.foreign.MemoryAddress;
+import jdk.incubator.foreign.MemorySegment;
 import api.Memory;
 import api.Allocator;
 import api.Native;
@@ -31,29 +32,40 @@ import static au.notzed.zcl.CLLib.*;
  * See {@link au.notzed.zcl.CLCommandQueue} for more information on usage.
  * <h2>Internal Details</h2>
  * <p>
- * Internally the CLEventList is maintained as an array of long values which hold
+ * Internally the CLEventList is maintained as a MemorySegment which holds
  * the cl_event pointers. CLEvent objects are only created to access this pointer
  * from Java or to transfer the event pointers to new event lists.
  * <p>
- * The JNI code directly reads from the events list and index in order to build
- * a list of wait events, and calls the internal add(long) method to append to
- * the events array.
+ * The list is read or written to directly within the CLCommandQueue methods
+ * making normal operation very cheap.
+ * <p>
+ * Currently event lists are not tracked for garbage collection and must
+ * be released() explicitly.
  */
-public class CLEventList {
+public final class CLEventList implements AutoCloseable {
+
+       /**
+        * Raw event values.
+        */
+       final MemoryAddress cevents;
 
        /**
-        * Event references.
+        * Event references?
         */
        final CLEvent[] jevents;
        int index = 0;
+       boolean reserve = false;
 
        /**
         * Creates a new event list.
         *
+        * The event list MUST be released when done with.
+        *
         * @param capacity Sets the event list capacity.
         */
        public CLEventList(int capacity) {
                this.jevents = new CLEvent[capacity];
+               this.cevents = MemorySegment.allocateNative(8 * capacity, 8).baseAddress();
        }
 
        /**
@@ -63,24 +75,89 @@ public class CLEventList {
         */
        public void reset() {
                for (int i = 0; i < index; i++) {
-                       jevents[i].release();
-                       jevents[i] = null;
+                       CLEvent ev = jevents[i];
+
+                       if (ev == null) {
+                               try {
+                                       clReleaseEvent(Native.getAddr(cevents, i));
+                               } catch (Throwable t) {
+                               }
+                       } else {
+                               ev.release();
+                               jevents[i] = null;
+                       }
+                       Native.setAddr(cevents, i, MemoryAddress.NULL);
                }
                index = 0;
        }
 
+       /**
+        * Release all resources.
+        *
+        * Currently this must be called on the same thread
+        * that created the event list.
+        */
        public void release() {
                reset();
+               cevents.segment().close();
+       }
+
+       @Override
+       public void close() {
+               release();
        }
-       
+
+       /**
+        * Get the base address for all slots.
+        *
+        * This is used internally by CLCommandQueue.EventInfo to write directly to the event list.
+        */
+       MemoryAddress slots() {
+               return cevents;
+       }
+
+       /**
+        * Get current output slot.
+        *
+        * This is used internally by CLCommandQueue.EventInfo to write directly to the event list.
+        *
+        * @throws IllegalStateException if the CLEventList has been released.
+        * @throws ArrayIndexOutOfBoundsException if the CLEventList is full.
+        */
+       MemoryAddress currentSlot() {
+               if (index < jevents.length) {
+                       MemoryAddress addr = cevents.addOffset(index * 8);
+
+                       // This should already be null, but this performs a range check
+                       Native.setAddr(addr, MemoryAddress.NULL);
+                       return addr;
+               } else
+                       throw new ArrayIndexOutOfBoundsException();
+       }
+
+       /**
+        * Indicate slot value is valid.
+        *
+        * This is used internally by CLCommandQueue.EventInfo to write directly to the event list.
+        */
+       void incrementSlot() {
+               index++;
+       }
+
        /**
         * Creates an interface to the given event.
         *
         * @param index
         * @return An event interface.
         */
-       public CLEvent get(int index) {
-               return jevents[index];
+       public CLEvent get(int i) {
+               if (i < index) {
+                       CLEvent ev = jevents[i];
+                       if (ev == null)
+                               jevents[i] = ev = Native.resolve(Native.getAddr(cevents, i), CLEvent::new);
+                       return ev;
+               } else
+                       throw new ArrayIndexOutOfBoundsException();
        }
 
        /**
@@ -89,6 +166,7 @@ public class CLEventList {
         * @param event
         */
        public void add(CLEvent event) {
+               Native.setAddr(cevents, index, event.addr());
                jevents[index++] = event;
        }
 
@@ -108,11 +186,8 @@ public class CLEventList {
         */
        public void waitForEvents() throws CLException {
                if (size() > 0) {
-                       try (Allocator frame = Memory.stack()) {
-                               MemoryAddress events = Native.toAddrV(frame, jevents, index);
-                               int res;
-
-                               res = clWaitForEvents(size(), events);
+                       try {
+                               int res = clWaitForEvents(size(), cevents);
                                if (res != 0)
                                        throw new CLException(res);
                        } catch (CLException | RuntimeException | Error t) {
diff --git a/src/notzed.zcl/tests/au/notzed/zcl/CLEventTest.java b/src/notzed.zcl/tests/au/notzed/zcl/CLEventTest.java
new file mode 100644 (file)
index 0000000..ad2118b
--- /dev/null
@@ -0,0 +1,248 @@
+
+package au.notzed.zcl;
+
+import org.junit.*;
+import static org.junit.Assert.*;
+import static au.notzed.zcl.CL.*;
+import static au.notzed.zcl.CLLib.*;
+import jdk.incubator.foreign.*;
+import api.*;
+
+/*
+  CLEvent and CLEventList tests
+ */
+public class CLEventTest {
+       boolean haveCL() {
+               CLPlatform[] list = CLPlatform.getPlatforms();
+               return list != null && list.length > 0;
+       }
+
+       CLPlatform plat;
+       CLDevice devs[];
+       CLContext cl;
+       CLCommandQueue q;
+
+       @Before
+       public void setup() {
+               org.junit.Assume.assumeTrue(haveCL());
+
+               plat = CLPlatform.getPlatforms()[0];
+               devs = new CLDevice[] { plat.getDevices(CL_DEVICE_TYPE_ALL)[0] };
+               cl = CLContext.createContext(null, devs);
+               q = cl.createCommandQueue(devs[0], 0);
+       }
+       @After
+       public void shutdown() {
+               q.release();
+               cl.release();
+       }
+
+       @Test
+       public void testUser() throws Exception {
+               org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1);
+               System.out.println("createUserEvent");
+
+               CLEvent ev = cl.createUserEvent();
+
+               assertNotNull(ev);
+       }
+
+       @Test
+       public void testUserFields() throws Exception {
+               org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1);
+               System.out.println("UserEvent fields");
+
+               CLEvent ev = cl.createUserEvent();
+
+               assertNull(ev.getCommandQueue());
+               assertEquals(cl, ev.getContext());
+               assertEquals(CL_COMMAND_USER, ev.getCommandType());
+       }
+
+       /*
+        * I think this should pass but it times out on Mesa 19.
+        */
+       @Ignore
+       @Test(timeout=1000)
+       public void testUserWait() throws Exception {
+               org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1);
+               System.out.println("UserEvent wait");
+
+               CLEvent ev = cl.createUserEvent();
+
+               CLBuffer mem = cl.createBuffer(0, 64);
+               try (MemorySegment seg = MemorySegment.allocateNative(64, 8)) {
+                       CLEventList list = new CLEventList(1);
+                       CLEventList wait = new CLEventList(1);
+
+                       wait.add(ev);
+                       assertEquals(ev, wait.get(0));
+
+                       q.enqueueReadBuffer(mem, true, 0, 64, seg, wait, list);
+
+                       assertEquals(1, list.size());
+                       q.flush();
+
+                       assertEquals(CL_QUEUED, list.get(0).getCommandExecutionStatus());
+
+                       ev.setUserEventStatus(CL_COMPLETE);
+
+                       assertEquals(CL_COMPLETE, wait.get(0).getCommandExecutionStatus());
+
+                       q.finish();
+               }
+       }
+
+       @Test
+       public void testCallback() throws Exception {
+               org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1);
+               int[] count = new int[1];
+
+               System.out.println("event callback, state");
+
+               CLBuffer mem = cl.createBuffer(0, 64);
+               try (MemorySegment seg = MemorySegment.allocateNative(64, 8)) {
+                       CLEventList list = new CLEventList(1);
+                       CLEvent ev;
+
+                       q.enqueueReadBuffer(mem, true, 0, 64, seg, null, list);
+
+                       ev = list.get(0);
+
+                       assertEquals(q, ev.getCommandQueue());
+                       assertEquals(CL_COMMAND_READ_BUFFER, ev.getCommandType());
+
+                       ev.setEventCallback(CL_SUBMITTED, (e, s) -> {
+                                       assertTrue(s <= CL_SUBMITTED);
+                                       count[0] += 1;
+                               });
+                       ev.setEventCallback(CL_RUNNING, (e, s) -> {
+                                       assertTrue(s <= CL_RUNNING);
+                                       count[0] += 1;
+                               });
+                       ev.setEventCallback(CL_COMPLETE, (e, s) -> {
+                                       assertTrue(s <= CL_COMPLETE);
+                                       count[0] += 1;
+                               });
+
+                       q.finish();
+
+                       assertEquals(3, count[0]);
+               }
+       }
+
+       void retainEvent(MemoryAddress x) {
+               try {
+                       clRetainEvent(x);
+               } catch (Throwable T) {
+               }
+       }
+
+       int countEvent(MemoryAddress x) {
+               int res = -1;
+               try (Allocator a = Memory.stack()) {
+                       MemoryAddress rc = a.alloca(8);
+                       clGetEventInfo(x, CL_EVENT_REFERENCE_COUNT, 4, rc, MemoryAddress.NULL);
+                       res = Native.getInt(rc);
+               } catch (Throwable T) {
+               }
+               return res;
+       }
+
+       void releaseEvent(MemoryAddress x) {
+               try {
+                       clReleaseEvent(x);
+               } catch (Throwable T) {
+               }
+       }
+
+       @Test
+       public void testClose() throws Exception {
+               System.out.println("close + oob");
+               CLEventList list = new CLEventList(2);
+               Throwable x;
+
+               try {
+                       x = null;
+                       list.get(0);
+               } catch (Throwable t) {
+                       x = t;
+               }
+               assertEquals(ArrayIndexOutOfBoundsException.class, x.getClass());
+
+               list.close();
+
+               try {
+                       x = null;
+                       list.add(cl.createUserEvent());
+               } catch (Throwable t) {
+                       x = t;
+               }
+               assertEquals(IllegalStateException.class, x.getClass());
+
+               CLBuffer mem = cl.createBuffer(0, 64);
+               try (MemorySegment seg = MemorySegment.allocateNative(64, 8)) {
+                       x = null;
+                       q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list);
+               } catch (Throwable t) {
+                       x = t;
+               }
+               assertEquals(IllegalStateException.class, x.getClass());
+       }
+
+       @Test
+       public void testMulti() throws Exception {
+               System.out.println("multi");
+               CLBuffer mem = cl.createBuffer(0, 64);
+               try (MemorySegment seg = MemorySegment.allocateNative(64, 8);
+                       CLEventList list = new CLEventList(2);
+                       CLEventList last = new CLEventList(1)) {
+
+                       q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list);
+                       assertEquals(1, list.size());
+                       q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list);
+                       assertEquals(2, list.size());
+
+                       q.enqueueReadBuffer(mem, false, 0, 64, seg, list, last);
+                       assertEquals(1, last.size());
+
+                       q.finish();
+
+                       assertEquals(list.get(0).getCommandExecutionStatus(), CL_COMPLETE);
+                       assertEquals(list.get(1).getCommandExecutionStatus(), CL_COMPLETE);
+                       assertEquals(last.get(0).getCommandExecutionStatus(), CL_COMPLETE);
+
+                       list.reset();
+                       assertEquals(0, list.size());
+
+                       last.reset();
+                       assertEquals(0, last.size());
+               }
+       }
+
+       @Test(timeout=1000)
+       public void testMultiWait() throws Exception {
+               System.out.println("multi wait");
+               CLBuffer mem = cl.createBuffer(0, 64);
+               try (MemorySegment seg = MemorySegment.allocateNative(64, 8);
+                       CLEventList list = new CLEventList(3)) {
+
+                       q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list);
+                       assertEquals(1, list.size());
+                       q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list);
+                       assertEquals(2, list.size());
+                       q.enqueueReadBuffer(mem, false, 0, 64, seg, list, list);
+                       assertEquals(3, list.size());
+
+                       q.flush();
+                       list.waitForEvents();
+
+                       assertEquals(list.get(0).getCommandExecutionStatus(), CL_COMPLETE);
+                       assertEquals(list.get(1).getCommandExecutionStatus(), CL_COMPLETE);
+                       assertEquals(list.get(2).getCommandExecutionStatus(), CL_COMPLETE);
+
+                       list.reset();
+                       assertEquals(0, list.size());
+               }
+       }
+}