From 0549e1ddc30b927358c33f0b676544c218176509 Mon Sep 17 00:00:00 2001 From: Not Zed Date: Sun, 26 Jan 2020 18:09:09 +1030 Subject: [PATCH] Improve CLEventList design and fix CLEvent.setEventCallback. Now uses a MemorySegment directly to read/write events, bypassing Java unless necessary. Added tests. --- .../classes/au/notzed/zcl/CLCommandQueue.java | 19 +- .../classes/au/notzed/zcl/CLEvent.java | 43 +-- .../classes/au/notzed/zcl/CLEventList.java | 107 ++++++-- .../tests/au/notzed/zcl/CLEventTest.java | 248 ++++++++++++++++++ 4 files changed, 372 insertions(+), 45 deletions(-) create mode 100644 src/notzed.zcl/tests/au/notzed/zcl/CLEventTest.java diff --git a/src/notzed.zcl/classes/au/notzed/zcl/CLCommandQueue.java b/src/notzed.zcl/classes/au/notzed/zcl/CLCommandQueue.java index 6de70f5..f878bb8 100644 --- a/src/notzed.zcl/classes/au/notzed/zcl/CLCommandQueue.java +++ b/src/notzed.zcl/classes/au/notzed/zcl/CLCommandQueue.java @@ -1506,6 +1506,13 @@ public class CLCommandQueue extends CLExtendable { } } + /** + * Simplify wait/event handling. + * + * To use, create an EventInfo from the passed in arguments. + * In the enqueue command pass in .wait and .event from this + * structure. If the command succeeds, then call post(). + */ static private class EventInfo { final int nwait; final MemoryAddress wait; @@ -1513,19 +1520,13 @@ public class CLCommandQueue extends CLExtendable { EventInfo(Allocator frame, CLEventList waiters, CLEventList events) { nwait = waiters != null ? waiters.size() : 0; - if (nwait > 0) { - wait = frame.alloca(8 * nwait); - for (int i=0;i 0 ? waiters.slots() : MemoryAddress.NULL; + event = events != null ? events.currentSlot() : MemoryAddress.NULL; } void post(CLEventList events) { if (events != null) - events.add(resolve(getAddr(event), CLEvent::new)); + events.incrementSlot(); } } diff --git a/src/notzed.zcl/classes/au/notzed/zcl/CLEvent.java b/src/notzed.zcl/classes/au/notzed/zcl/CLEvent.java index c378cfe..d921e6e 100644 --- a/src/notzed.zcl/classes/au/notzed/zcl/CLEvent.java +++ b/src/notzed.zcl/classes/au/notzed/zcl/CLEvent.java @@ -19,9 +19,9 @@ package au.notzed.zcl; import jdk.incubator.foreign.*; import static au.notzed.zcl.CL.*; import static au.notzed.zcl.CLLib.*; -import api.Native; -import api.Callback; +import api.*; import java.lang.invoke.MethodHandle; +import java.util.ArrayList; /** * Interface for cl_event. @@ -34,9 +34,10 @@ public class CLEvent extends CLObject { final int apiVersion; /** - * This is used to retain a reference for any callback set + * This is used to retain a reference for any callback set. + * There may be multiple. */ - Callback callback; + ArrayList> callbacks; public CLEvent(MemoryAddress p) { super(p); @@ -58,8 +59,10 @@ public class CLEvent extends CLObject { @Override public void release() { - Native.release(callback); - callback = null; + if (callbacks != null) { + callbacks.forEach((c) -> c.release()); + callbacks = null; + } super.release(); } @@ -94,6 +97,8 @@ public class CLEvent extends CLObject { /** * Call clSetEventCallback(type, notify). * + * Adds a new callback for the given state. + * * @param type * @param notify * @throws CLRuntimeException @@ -102,23 +107,21 @@ public class CLEvent extends CLObject { public void setEventCallback(int type, CLEventNotify notify) throws CLRuntimeException { CLPlatform.requireAPIVersion(apiVersion, CLPlatform.VERSION_1_1); - Native.release(callback); + Callback callback = CLEventNotify.call(notify); - if (notify != null) { - callback = CLEventNotify.call(notify); + try { + int res = clSetEventCallback(addr(), type, callback.addr(), MemoryAddress.NULL); - try { - int res = clSetEventCallback(addr(), type, callback.addr(), MemoryAddress.NULL); + if (res != 0) + throw new CLRuntimeException(res); - if (res != 0) - throw new CLRuntimeException(res); - } catch (RuntimeException | Error t) { - throw t; - } catch (Throwable t) { - throw new RuntimeException(t); - } - } else { - callback = null; + if (callbacks == null) + callbacks = new ArrayList<>(); + callbacks.add(callback); + } catch (RuntimeException | Error t) { + throw t; + } catch (Throwable t) { + throw new RuntimeException(t); } } diff --git a/src/notzed.zcl/classes/au/notzed/zcl/CLEventList.java b/src/notzed.zcl/classes/au/notzed/zcl/CLEventList.java index df0d398..187bc9d 100644 --- a/src/notzed.zcl/classes/au/notzed/zcl/CLEventList.java +++ b/src/notzed.zcl/classes/au/notzed/zcl/CLEventList.java @@ -17,6 +17,7 @@ package au.notzed.zcl; import jdk.incubator.foreign.MemoryAddress; +import jdk.incubator.foreign.MemorySegment; import api.Memory; import api.Allocator; import api.Native; @@ -31,29 +32,40 @@ import static au.notzed.zcl.CLLib.*; * See {@link au.notzed.zcl.CLCommandQueue} for more information on usage. *

Internal Details

*

- * Internally the CLEventList is maintained as an array of long values which hold + * Internally the CLEventList is maintained as a MemorySegment which holds * the cl_event pointers. CLEvent objects are only created to access this pointer * from Java or to transfer the event pointers to new event lists. *

- * The JNI code directly reads from the events list and index in order to build - * a list of wait events, and calls the internal add(long) method to append to - * the events array. + * The list is read or written to directly within the CLCommandQueue methods + * making normal operation very cheap. + *

+ * Currently event lists are not tracked for garbage collection and must + * be released() explicitly. */ -public class CLEventList { +public final class CLEventList implements AutoCloseable { + + /** + * Raw event values. + */ + final MemoryAddress cevents; /** - * Event references. + * Event references? */ final CLEvent[] jevents; int index = 0; + boolean reserve = false; /** * Creates a new event list. * + * The event list MUST be released when done with. + * * @param capacity Sets the event list capacity. */ public CLEventList(int capacity) { this.jevents = new CLEvent[capacity]; + this.cevents = MemorySegment.allocateNative(8 * capacity, 8).baseAddress(); } /** @@ -63,24 +75,89 @@ public class CLEventList { */ public void reset() { for (int i = 0; i < index; i++) { - jevents[i].release(); - jevents[i] = null; + CLEvent ev = jevents[i]; + + if (ev == null) { + try { + clReleaseEvent(Native.getAddr(cevents, i)); + } catch (Throwable t) { + } + } else { + ev.release(); + jevents[i] = null; + } + Native.setAddr(cevents, i, MemoryAddress.NULL); } index = 0; } + /** + * Release all resources. + * + * Currently this must be called on the same thread + * that created the event list. + */ public void release() { reset(); + cevents.segment().close(); + } + + @Override + public void close() { + release(); } - + + /** + * Get the base address for all slots. + * + * This is used internally by CLCommandQueue.EventInfo to write directly to the event list. + */ + MemoryAddress slots() { + return cevents; + } + + /** + * Get current output slot. + * + * This is used internally by CLCommandQueue.EventInfo to write directly to the event list. + * + * @throws IllegalStateException if the CLEventList has been released. + * @throws ArrayIndexOutOfBoundsException if the CLEventList is full. + */ + MemoryAddress currentSlot() { + if (index < jevents.length) { + MemoryAddress addr = cevents.addOffset(index * 8); + + // This should already be null, but this performs a range check + Native.setAddr(addr, MemoryAddress.NULL); + return addr; + } else + throw new ArrayIndexOutOfBoundsException(); + } + + /** + * Indicate slot value is valid. + * + * This is used internally by CLCommandQueue.EventInfo to write directly to the event list. + */ + void incrementSlot() { + index++; + } + /** * Creates an interface to the given event. * * @param index * @return An event interface. */ - public CLEvent get(int index) { - return jevents[index]; + public CLEvent get(int i) { + if (i < index) { + CLEvent ev = jevents[i]; + if (ev == null) + jevents[i] = ev = Native.resolve(Native.getAddr(cevents, i), CLEvent::new); + return ev; + } else + throw new ArrayIndexOutOfBoundsException(); } /** @@ -89,6 +166,7 @@ public class CLEventList { * @param event */ public void add(CLEvent event) { + Native.setAddr(cevents, index, event.addr()); jevents[index++] = event; } @@ -108,11 +186,8 @@ public class CLEventList { */ public void waitForEvents() throws CLException { if (size() > 0) { - try (Allocator frame = Memory.stack()) { - MemoryAddress events = Native.toAddrV(frame, jevents, index); - int res; - - res = clWaitForEvents(size(), events); + try { + int res = clWaitForEvents(size(), cevents); if (res != 0) throw new CLException(res); } catch (CLException | RuntimeException | Error t) { diff --git a/src/notzed.zcl/tests/au/notzed/zcl/CLEventTest.java b/src/notzed.zcl/tests/au/notzed/zcl/CLEventTest.java new file mode 100644 index 0000000..ad2118b --- /dev/null +++ b/src/notzed.zcl/tests/au/notzed/zcl/CLEventTest.java @@ -0,0 +1,248 @@ + +package au.notzed.zcl; + +import org.junit.*; +import static org.junit.Assert.*; +import static au.notzed.zcl.CL.*; +import static au.notzed.zcl.CLLib.*; +import jdk.incubator.foreign.*; +import api.*; + +/* + CLEvent and CLEventList tests + */ +public class CLEventTest { + boolean haveCL() { + CLPlatform[] list = CLPlatform.getPlatforms(); + return list != null && list.length > 0; + } + + CLPlatform plat; + CLDevice devs[]; + CLContext cl; + CLCommandQueue q; + + @Before + public void setup() { + org.junit.Assume.assumeTrue(haveCL()); + + plat = CLPlatform.getPlatforms()[0]; + devs = new CLDevice[] { plat.getDevices(CL_DEVICE_TYPE_ALL)[0] }; + cl = CLContext.createContext(null, devs); + q = cl.createCommandQueue(devs[0], 0); + } + @After + public void shutdown() { + q.release(); + cl.release(); + } + + @Test + public void testUser() throws Exception { + org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1); + System.out.println("createUserEvent"); + + CLEvent ev = cl.createUserEvent(); + + assertNotNull(ev); + } + + @Test + public void testUserFields() throws Exception { + org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1); + System.out.println("UserEvent fields"); + + CLEvent ev = cl.createUserEvent(); + + assertNull(ev.getCommandQueue()); + assertEquals(cl, ev.getContext()); + assertEquals(CL_COMMAND_USER, ev.getCommandType()); + } + + /* + * I think this should pass but it times out on Mesa 19. + */ + @Ignore + @Test(timeout=1000) + public void testUserWait() throws Exception { + org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1); + System.out.println("UserEvent wait"); + + CLEvent ev = cl.createUserEvent(); + + CLBuffer mem = cl.createBuffer(0, 64); + try (MemorySegment seg = MemorySegment.allocateNative(64, 8)) { + CLEventList list = new CLEventList(1); + CLEventList wait = new CLEventList(1); + + wait.add(ev); + assertEquals(ev, wait.get(0)); + + q.enqueueReadBuffer(mem, true, 0, 64, seg, wait, list); + + assertEquals(1, list.size()); + q.flush(); + + assertEquals(CL_QUEUED, list.get(0).getCommandExecutionStatus()); + + ev.setUserEventStatus(CL_COMPLETE); + + assertEquals(CL_COMPLETE, wait.get(0).getCommandExecutionStatus()); + + q.finish(); + } + } + + @Test + public void testCallback() throws Exception { + org.junit.Assume.assumeTrue(plat.getAPIVersion() >= CLPlatform.VERSION_1_1); + int[] count = new int[1]; + + System.out.println("event callback, state"); + + CLBuffer mem = cl.createBuffer(0, 64); + try (MemorySegment seg = MemorySegment.allocateNative(64, 8)) { + CLEventList list = new CLEventList(1); + CLEvent ev; + + q.enqueueReadBuffer(mem, true, 0, 64, seg, null, list); + + ev = list.get(0); + + assertEquals(q, ev.getCommandQueue()); + assertEquals(CL_COMMAND_READ_BUFFER, ev.getCommandType()); + + ev.setEventCallback(CL_SUBMITTED, (e, s) -> { + assertTrue(s <= CL_SUBMITTED); + count[0] += 1; + }); + ev.setEventCallback(CL_RUNNING, (e, s) -> { + assertTrue(s <= CL_RUNNING); + count[0] += 1; + }); + ev.setEventCallback(CL_COMPLETE, (e, s) -> { + assertTrue(s <= CL_COMPLETE); + count[0] += 1; + }); + + q.finish(); + + assertEquals(3, count[0]); + } + } + + void retainEvent(MemoryAddress x) { + try { + clRetainEvent(x); + } catch (Throwable T) { + } + } + + int countEvent(MemoryAddress x) { + int res = -1; + try (Allocator a = Memory.stack()) { + MemoryAddress rc = a.alloca(8); + clGetEventInfo(x, CL_EVENT_REFERENCE_COUNT, 4, rc, MemoryAddress.NULL); + res = Native.getInt(rc); + } catch (Throwable T) { + } + return res; + } + + void releaseEvent(MemoryAddress x) { + try { + clReleaseEvent(x); + } catch (Throwable T) { + } + } + + @Test + public void testClose() throws Exception { + System.out.println("close + oob"); + CLEventList list = new CLEventList(2); + Throwable x; + + try { + x = null; + list.get(0); + } catch (Throwable t) { + x = t; + } + assertEquals(ArrayIndexOutOfBoundsException.class, x.getClass()); + + list.close(); + + try { + x = null; + list.add(cl.createUserEvent()); + } catch (Throwable t) { + x = t; + } + assertEquals(IllegalStateException.class, x.getClass()); + + CLBuffer mem = cl.createBuffer(0, 64); + try (MemorySegment seg = MemorySegment.allocateNative(64, 8)) { + x = null; + q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list); + } catch (Throwable t) { + x = t; + } + assertEquals(IllegalStateException.class, x.getClass()); + } + + @Test + public void testMulti() throws Exception { + System.out.println("multi"); + CLBuffer mem = cl.createBuffer(0, 64); + try (MemorySegment seg = MemorySegment.allocateNative(64, 8); + CLEventList list = new CLEventList(2); + CLEventList last = new CLEventList(1)) { + + q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list); + assertEquals(1, list.size()); + q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list); + assertEquals(2, list.size()); + + q.enqueueReadBuffer(mem, false, 0, 64, seg, list, last); + assertEquals(1, last.size()); + + q.finish(); + + assertEquals(list.get(0).getCommandExecutionStatus(), CL_COMPLETE); + assertEquals(list.get(1).getCommandExecutionStatus(), CL_COMPLETE); + assertEquals(last.get(0).getCommandExecutionStatus(), CL_COMPLETE); + + list.reset(); + assertEquals(0, list.size()); + + last.reset(); + assertEquals(0, last.size()); + } + } + + @Test(timeout=1000) + public void testMultiWait() throws Exception { + System.out.println("multi wait"); + CLBuffer mem = cl.createBuffer(0, 64); + try (MemorySegment seg = MemorySegment.allocateNative(64, 8); + CLEventList list = new CLEventList(3)) { + + q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list); + assertEquals(1, list.size()); + q.enqueueReadBuffer(mem, false, 0, 64, seg, null, list); + assertEquals(2, list.size()); + q.enqueueReadBuffer(mem, false, 0, 64, seg, list, list); + assertEquals(3, list.size()); + + q.flush(); + list.waitForEvents(); + + assertEquals(list.get(0).getCommandExecutionStatus(), CL_COMPLETE); + assertEquals(list.get(1).getCommandExecutionStatus(), CL_COMPLETE); + assertEquals(list.get(2).getCommandExecutionStatus(), CL_COMPLETE); + + list.reset(); + assertEquals(0, list.size()); + } + } +} -- 2.39.5