--- /dev/null
+ /*
+The MIT License (MIT)
+
+Copyright (C) 2017 Eric Arnebäck
+Copyright (C) 2019 Michael Zucchi
+
+Permission is hereby granted, free of charge, to any person obtaining a copy
+of this software and associated documentation files (the "Software"), to deal
+in the Software without restriction, including without limitation the rights
+to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
+copies of the Software, and to permit persons to whom the Software is
+furnished to do so, subject to the following conditions:
+
+The above copyright notice and this permission notice shall be included in
+all copies or substantial portions of the Software.
+
+THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
+IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
+FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
+AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
+LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
+OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
+THE SOFTWARE.
+
+ */
+
+/*
+ * This is a Java conversion of a C conversion of this:
+ * https://github.com/Erkaman/vulkan_minimal_compute
+ *
+ * It's been simplified a bit and converted to the 'zvk' api.
+ */
+
+package vulkan.test;
+
+import java.io.InputStream;
+import java.io.FileOutputStream;
+import java.io.IOException;
+import java.nio.channels.Channels;
+import java.nio.ByteBuffer;
+import java.nio.ByteOrder;
+
+import java.awt.Graphics;
+import java.awt.Image;
+import java.awt.Toolkit;
+import java.awt.event.ActionEvent;
+import java.awt.event.KeyEvent;
+import java.awt.image.MemoryImageSource;
+import javax.swing.AbstractAction;
+import javax.swing.JComponent;
+import javax.swing.JFrame;
+import javax.swing.JPanel;
+import javax.swing.KeyStroke;
+
+import java.lang.ref.WeakReference;
+
+import java.lang.invoke.*;
+import jdk.incubator.foreign.*;
+import jdk.incubator.foreign.MemoryLayout.PathElement;
+import au.notzed.nativez.*;
+
+import vulkan.*;
+
+import static vulkan.VkBufferUsageFlagBits.*;
+import static vulkan.VkMemoryPropertyFlagBits.*;
+import static vulkan.VkSharingMode.*;
+import static vulkan.VkDescriptorType.*;
+import static vulkan.VkShaderStageFlagBits.*;
+import static vulkan.VkCommandBufferLevel.*;
+import static vulkan.VkCommandBufferUsageFlagBits.*;
+import static vulkan.VkPipelineBindPoint.*;
+
+import static vulkan.VkDebugUtilsMessageSeverityFlagBitsEXT.*;
+import static vulkan.VkDebugUtilsMessageTypeFlagBitsEXT.*;
+
+public class TestVulkan {
+ static final boolean debug = true;
+ ResourceScope scope = ResourceScope.newSharedScope();
+
+ int WIDTH = 1920*1;
+ int HEIGHT = 1080*1;
+
+ VkInstance instance;
+ VkPhysicalDevice physicalDevice;
+
+ VkDevice device;
+ VkQueue computeQueue;
+
+ long dstBufferSize = WIDTH * HEIGHT * 4;
+ //VkBuffer dstBuffer;
+ //VkDeviceMemory dstMemory;
+ BufferMemory dst;
+
+ VkDescriptorSetLayout descriptorSetLayout;
+ VkDescriptorPool descriptorPool;
+ HandleArray<VkDescriptorSet> descriptorSets = VkDescriptorSet.createArray(1, (SegmentAllocator)scope);
+
+ int computeQueueIndex;
+ VkPhysicalDeviceMemoryProperties deviceMemoryProperties;
+
+ String mandelbrot_entry = "main";
+ IntArray mandelbrot_cs;
+
+ VkShaderModule mandelbrotShader;
+ VkPipelineLayout pipelineLayout;
+ HandleArray<VkPipeline> computePipeline = VkPipeline.createArray(1, (SegmentAllocator)scope);
+
+ VkCommandPool commandPool;
+ HandleArray<VkCommandBuffer> commandBuffers;
+
+ record BufferMemory ( VkBuffer buffer, VkDeviceMemory memory ) {};
+
+ VkDebugUtilsMessengerEXT logger;
+
+ void init_debug() throws Exception {
+ if (!debug)
+ return;
+ try (Frame frame = Frame.frame()) {
+ NativeSymbol cb = PFN_vkDebugUtilsMessengerCallbackEXT.of((severity, flags, data) -> {
+ System.out.printf("Debug: %d: %s\n", severity, data.getMessage());
+ return 0;
+ }, scope);
+ VkDebugUtilsMessengerCreateInfoEXT info = VkDebugUtilsMessengerCreateInfoEXT.create(frame,
+ 0,
+ VK_DEBUG_UTILS_MESSAGE_SEVERITY_INFO_BIT_EXT
+ | VK_DEBUG_UTILS_MESSAGE_SEVERITY_WARNING_BIT_EXT
+ | VK_DEBUG_UTILS_MESSAGE_SEVERITY_ERROR_BIT_EXT,
+ VK_DEBUG_UTILS_MESSAGE_TYPE_GENERAL_BIT_EXT | VK_DEBUG_UTILS_MESSAGE_TYPE_VALIDATION_BIT_EXT,
+ cb.address(),
+ null);
+
+ logger = instance.vkCreateDebugUtilsMessengerEXT(info, null);
+ }
+
+ //typedef VkBool32 (*PFN_vkDebugUtilsMessengerCallbackEXT)(VkDebugUtilsMessageSeverityFlagBitsEXT, VkDebugUtilsMessageTypeFlagsEXT, const VkDebugUtilsMessengerCallbackDataEXT *, void *);
+
+ }
+
+ void init_instance() throws Exception {
+ try (Frame frame = Frame.frame()) {
+ VkInstanceCreateInfo info = VkInstanceCreateInfo.create(frame,
+ 0,
+ VkApplicationInfo.create(frame, "test", 1, "test-engine", 2, VK_MAKE_API_VERSION(0, 1, 0, 0)),
+ new String[] { "VK_LAYER_KHRONOS_validation" },
+ debug ? new String[] { "VK_EXT_debug_utils" } : null
+ );
+
+ instance = VkInstance.vkCreateInstance(info, null);
+ }
+ }
+
+ void init_device() throws Exception {
+ try (Frame frame = Frame.frame()) {
+ IntArray count$h = IntArray.create(frame, 1);
+ HandleArray<VkPhysicalDevice> devs;
+ int count;
+ int res;
+
+ devs = instance.vkEnumeratePhysicalDevices();
+
+ int best = 0;
+ int devid = -1;
+ int queueid = -1;
+
+ for (int i=0;i<devs.length();i++) {
+ VkPhysicalDevice dev = devs.getAtIndex(i);
+ VkQueueFamilyProperties famprops;
+
+ // TODO: change to return the allocated array directly
+ dev.vkGetPhysicalDeviceQueueFamilyProperties(count$h, null);
+ famprops = VkQueueFamilyProperties.createArray(frame, count$h.getAtIndex(0));
+ dev.vkGetPhysicalDeviceQueueFamilyProperties(count$h, famprops);
+
+ int family_count = count$h.getAtIndex(0);
+
+ for (int j=0;j<family_count;j++) {
+ int score = 0;
+
+ if ((famprops.getQueueFlags(j) & VkQueueFlagBits.VK_QUEUE_COMPUTE_BIT) != 0)
+ score += 1;
+ if ((famprops.getQueueFlags(j) & VkQueueFlagBits.VK_QUEUE_GRAPHICS_BIT) == 0)
+ score += 1;
+
+ if (score > best) {
+ score = best;
+ devid = i;
+ queueid = j;
+ }
+ }
+ }
+
+ if (devid == -1)
+ throw new Exception("Cannot find a suitable device");
+
+ computeQueueIndex = queueid;
+ physicalDevice = devs.getAtIndex(devid);
+
+ FloatArray qpri = FloatArray.create(frame, 0.0f);
+ VkDeviceQueueCreateInfo qinfo = VkDeviceQueueCreateInfo.create(
+ frame,
+ 0,
+ queueid,
+ 1,
+ qpri);
+ VkDeviceCreateInfo devinfo = VkDeviceCreateInfo.create(
+ frame,
+ 0,
+ 1,
+ qinfo,
+ null,
+ null,
+ null);
+
+ device = physicalDevice.vkCreateDevice(devinfo, null);
+
+ System.out.printf("device = %s\n", device.address());
+
+ // NOTE: app scope
+ deviceMemoryProperties = VkPhysicalDeviceMemoryProperties.create(scope);
+ physicalDevice.vkGetPhysicalDeviceMemoryProperties(deviceMemoryProperties);
+
+ computeQueue = device.vkGetDeviceQueue(queueid, 0);
+ }
+ }
+
+ /**
+ * Buffers are created in three steps:
+ * 1) create buffer, specifying usage and size
+ * 2) allocate memory based on memory requirements
+ * 3) bind memory
+ *
+ */
+ BufferMemory init_buffer(long dataSize, int usage, int properties) throws Exception {
+ try (Frame frame = Frame.frame()) {
+ VkMemoryRequirements req = VkMemoryRequirements.create(frame);
+ VkBufferCreateInfo buf_info = VkBufferCreateInfo.create(frame,
+ 0,
+ dataSize,
+ usage,
+ VK_SHARING_MODE_EXCLUSIVE,
+ 0,
+ null);
+
+ VkBuffer buffer = device.vkCreateBuffer(buf_info, null);
+
+ device.vkGetBufferMemoryRequirements(buffer, req);
+
+ VkMemoryAllocateInfo alloc = VkMemoryAllocateInfo.create(frame,
+ req.getSize(),
+ find_memory_type(deviceMemoryProperties, req.getMemoryTypeBits(), properties));
+
+ VkDeviceMemory memory = device.vkAllocateMemory(alloc, null);
+
+ device.vkBindBufferMemory(buffer, memory, 0);
+
+ return new BufferMemory(buffer, memory);
+ }
+ }
+
+ /**
+ * Descriptors are used to bind and describe memory blocks
+ * to shaders.
+ *
+ * *Pool is used to allocate descriptors, it is per-device.
+ * *Layout is used to group descriptors for a given pipeline,
+ * The descriptors describe individually-addressable blocks.
+ */
+ void init_descriptor() throws Exception {
+ try (Frame frame = Frame.frame()) {
+ /* Create descriptorset layout */
+ VkDescriptorSetLayoutBinding layout_binding = VkDescriptorSetLayoutBinding.create(frame,
+ 0,
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+ 1,
+ VK_SHADER_STAGE_COMPUTE_BIT,
+ null);
+
+ VkDescriptorSetLayoutCreateInfo descriptor_layout = VkDescriptorSetLayoutCreateInfo.create(frame,
+ 0,
+ 1,
+ layout_binding);
+
+ descriptorSetLayout = device.vkCreateDescriptorSetLayout(descriptor_layout, null);
+
+ /* Create descriptor pool */
+ VkDescriptorPoolSize type_count = VkDescriptorPoolSize.create(frame,
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+ 1);
+
+ VkDescriptorPoolCreateInfo descriptor_pool = VkDescriptorPoolCreateInfo.create(frame,
+ 0,
+ 1,
+ 1,
+ type_count);
+
+ descriptorPool = device.vkCreateDescriptorPool(descriptor_pool, null);
+
+ /* Allocate from pool */
+ HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
+
+ layout_table.setAtIndex(0, descriptorSetLayout);
+
+ VkDescriptorSetAllocateInfo alloc_info = VkDescriptorSetAllocateInfo.create(frame,
+ descriptorPool,
+ 1,
+ layout_table);
+
+ device.vkAllocateDescriptorSets(alloc_info, descriptorSets);
+
+ /* Bind a buffer to the descriptor */
+ VkDescriptorBufferInfo bufferInfo = VkDescriptorBufferInfo.create(frame,
+ dst.buffer,
+ 0,
+ dstBufferSize);
+
+ VkWriteDescriptorSet writeSet = VkWriteDescriptorSet.create(frame,
+ descriptorSets.getAtIndex(0),
+ 0,
+ 0,
+ 1,
+ VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+ null,
+ bufferInfo,
+ null);
+
+ device.vkUpdateDescriptorSets(1, writeSet, 0, null);
+ }
+ }
+
+ /**
+ * Create the compute pipeline. This is the shader and data layouts for it.
+ */
+ void init_pipeline() throws Exception {
+ try (Frame frame = Frame.frame()) {
+ /* Set shader code */
+ VkShaderModuleCreateInfo vsInfo = VkShaderModuleCreateInfo.create(frame,
+ 0,
+ mandelbrot_cs.length() * 4,
+ mandelbrot_cs);
+
+ mandelbrotShader = device.vkCreateShaderModule(vsInfo, null);
+
+ /* Link shader to layout */
+ HandleArray<VkDescriptorSetLayout> layout_table = VkDescriptorSetLayout.createArray(1, frame);
+
+ layout_table.setAtIndex(0, descriptorSetLayout);
+
+ VkPipelineLayoutCreateInfo pipelineinfo = VkPipelineLayoutCreateInfo.create(frame,
+ 0,
+ 1,
+ layout_table,
+ 0,
+ null);
+
+ pipelineLayout = device.vkCreatePipelineLayout(pipelineinfo, null);
+
+ /* Create pipeline */
+ VkComputePipelineCreateInfo pipeline = VkComputePipelineCreateInfo.create(frame,
+ 0,
+ pipelineLayout,
+ null,
+ 0);
+
+ VkPipelineShaderStageCreateInfo stage = pipeline.getStage();
+
+ stage.setStage(VK_SHADER_STAGE_COMPUTE_BIT);
+ stage.setModule(mandelbrotShader);
+ stage.setName(frame, mandelbrot_entry);
+
+ device.vkCreateComputePipelines(null, 1, pipeline, null, computePipeline);
+ }
+ }
+
+ /**
+ * Create a command buffer, this is somewhat like a display list.
+ */
+ void init_command_buffer() throws Exception {
+ try (Frame frame = Frame.frame()) {
+ VkCommandPoolCreateInfo poolinfo = VkCommandPoolCreateInfo.create(frame,
+ 0,
+ computeQueueIndex);
+
+ commandPool = device.vkCreateCommandPool(poolinfo, null);
+
+ VkCommandBufferAllocateInfo cmdinfo = VkCommandBufferAllocateInfo.create(frame,
+ commandPool,
+ VK_COMMAND_BUFFER_LEVEL_PRIMARY,
+ 1);
+
+ // should it take a scope?
+ commandBuffers = device.vkAllocateCommandBuffers(cmdinfo);
+
+ /* Fill command buffer with commands for later operation */
+ VkCommandBufferBeginInfo beginInfo = VkCommandBufferBeginInfo.create(frame,
+ VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
+ null);
+
+ commandBuffers.get(0).vkBeginCommandBuffer(beginInfo);
+
+ /* Bind the compute operation and data */
+ commandBuffers.get(0).vkCmdBindPipeline(VK_PIPELINE_BIND_POINT_COMPUTE, computePipeline.get(0));
+ commandBuffers.get(0).vkCmdBindDescriptorSets(VK_PIPELINE_BIND_POINT_COMPUTE, pipelineLayout, 0, 1, descriptorSets, 0, null);
+
+ /* Run it */
+ commandBuffers.get(0).vkCmdDispatch(WIDTH, HEIGHT, 1);
+
+ commandBuffers.get(0).vkEndCommandBuffer();
+ }
+ }
+
+ /**
+ * Execute the pre-created command buffer.
+ *
+ * A fence is used to wait for completion.
+ */
+ void execute() throws Exception {
+ try (Frame frame = Frame.frame()) {
+ VkSubmitInfo submitInfo = VkSubmitInfo.create(frame);
+
+ submitInfo.setCommandBufferCount(0, 1);
+ submitInfo.setCommandBuffers(0, commandBuffers);
+
+ /* Create fence to mark the task completion */
+ VkFence fence;
+ HandleArray<VkFence> fences = VkFence.createArray(1, frame);
+ VkFenceCreateInfo fenceInfo = VkFenceCreateInfo.create(frame);
+
+ // maybe this should take a HandleArray<Fence> rather than being a constructor
+ fence = device.vkCreateFence(fenceInfo, null);
+ fences.set(0, fence);
+
+ /* Await completion */
+ computeQueue.vkQueueSubmit(1, submitInfo, fence);
+
+ int VK_TRUE = 1;
+ int res;
+ do {
+ res = device.vkWaitForFences(1, fences, VK_TRUE, 1000000);
+ } while (res == VkResult.VK_TIMEOUT);
+
+ device.vkDestroyFence(fence, null);
+ }
+ }
+
+ void shutdown() {
+ device.vkDestroyCommandPool(commandPool, null);
+ device.vkDestroyPipeline(computePipeline.getAtIndex(0), null);
+ device.vkDestroyPipelineLayout(pipelineLayout, null);
+ device.vkDestroyShaderModule(mandelbrotShader, null);
+
+ device.vkDestroyDescriptorPool(descriptorPool, null);
+ device.vkDestroyDescriptorSetLayout(descriptorSetLayout, null);
+
+ device.vkFreeMemory(dst.memory(), null);
+ device.vkDestroyBuffer(dst.buffer(), null);
+
+ device.vkDestroyDevice(null);
+ if (logger != null)
+ instance.vkDestroyDebugUtilsMessengerEXT(logger, null);
+ instance.vkDestroyInstance(null);
+ }
+
+ /**
+ * Accesses the gpu buffer, converts it to RGB byte, and saves it as a pam file.
+ */
+ void save_result() throws Exception {
+ try (ResourceScope scope = ResourceScope.newConfinedScope()) {
+ MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
+ byte[] pixels = new byte[WIDTH * HEIGHT * 3];
+
+ System.out.printf("map %d bytes\n", dstBufferSize);
+
+ for (int i = 0; i < WIDTH * HEIGHT; i++) {
+ pixels[i * 3 + 0] = mem.get(Memory.BYTE, i * 4 + 0);
+ pixels[i * 3 + 1] = mem.get(Memory.BYTE, i * 4 + 1);
+ pixels[i * 3 + 2] = mem.get(Memory.BYTE, i * 4 + 2);
+ }
+
+ device.vkUnmapMemory(dst.memory());
+
+ pam_save("mandelbrot.pam", WIDTH, HEIGHT, 3, pixels);
+ }
+ }
+
+ void show_result() throws Exception {
+ try (ResourceScope scope = ResourceScope.newConfinedScope()) {
+ MemorySegment mem = device.vkMapMemory(dst.memory(), 0, dstBufferSize, 0, scope);
+ int[] pixels = new int[WIDTH * HEIGHT];
+
+ System.out.printf("map %d bytes\n", dstBufferSize);
+
+ MemorySegment.ofArray(pixels).copyFrom(mem);
+
+ device.vkUnmapMemory(dst.memory());
+
+ swing_show(WIDTH, HEIGHT, pixels);
+ }
+ }
+
+ /**
+ * Trivial pnm format image output.
+ */
+ void pam_save(String name, int width, int height, int depth, byte[] pixels) throws IOException {
+ try (FileOutputStream fos = new FileOutputStream(name)) {
+ fos.write(String.format("P6\n%d\n%d\n255\n", width, height).getBytes());
+ fos.write(pixels);
+ System.out.printf("wrote: %s\n", name);
+ }
+ }
+
+ static class DataImage extends JPanel {
+
+ final int w, h, stride;
+ final MemoryImageSource source;
+ final Image image;
+ final int[] pixels;
+
+ public DataImage(int w, int h, int[] pixels) {
+ this.w = w;
+ this.h = h;
+ this.stride = w;
+ this.pixels = pixels;
+ this.source = new MemoryImageSource(w, h, pixels, 0, w);
+ this.source.setAnimated(true);
+ this.source.setFullBufferUpdates(true);
+ this.image = Toolkit.getDefaultToolkit().createImage(source);
+ }
+
+ @Override
+ protected void paintComponent(Graphics g) {
+ super.paintComponent(g);
+ g.drawImage(image, 0, 0, this);
+ }
+ }
+
+ void swing_show(int w, int h, int[] pixels) {
+ JFrame window;
+ DataImage image = new DataImage(w, h, pixels);
+
+ window = new JFrame("mandelbrot");
+ window.setDefaultCloseOperation(JFrame.EXIT_ON_CLOSE);
+ window.setContentPane(image);
+ window.setSize(w, h);
+ window.setVisible(true);
+ }
+
+ IntArray loadSPIRV0(String name) throws IOException {
+ // hmm any way to just load this directly?
+ try (InputStream is = TestVulkan.class.getResourceAsStream(name)) {
+ ByteBuffer bb = ByteBuffer.allocateDirect(8192).order(ByteOrder.nativeOrder());
+ int length = Channels.newChannel(is).read(bb);
+
+ bb.position(0);
+ bb.limit(length);
+
+ return IntArray.create(MemorySegment.ofByteBuffer(bb));
+ }
+ }
+
+ IntArray loadSPIRV(String name) throws IOException {
+ try (InputStream is = TestVulkan.class.getResourceAsStream(name)) {
+ MemorySegment seg = ((SegmentAllocator)scope).allocateArray(Memory.INT, 2048);
+ int length = Channels.newChannel(is).read(seg.asByteBuffer());
+
+ return IntArray.create(seg.asSlice(0, length));
+ }
+ }
+
+ /**
+ * This finds the memory type index for the memory on a specific device.
+ */
+ static int find_memory_type(VkPhysicalDeviceMemoryProperties memory, int typeMask, int query) {
+ VkMemoryType mtypes = memory.getMemoryTypes();
+
+ for (int i = 0; i < memory.getMemoryTypeCount(); i++) {
+ if (((1 << i) & typeMask) != 0 && ((mtypes.getPropertyFlags(i) & query) == query))
+ return i;
+ }
+ return -1;
+ }
+
+ public static int VK_MAKE_API_VERSION(int variant, int major, int minor, int patch) {
+ return (variant << 29) | (major << 22) | (minor << 12) | patch;
+ }
+
+ void demo() throws Exception {
+ mandelbrot_cs = loadSPIRV("mandelbrot.bin");
+
+ init_instance();
+ init_debug();
+ init_device();
+
+ dst = init_buffer(dstBufferSize,
+ VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT);
+
+ init_descriptor();
+
+ init_pipeline();
+ init_command_buffer();
+
+ System.out.printf("Calculating %dx%d\n", WIDTH, HEIGHT);
+ execute();
+ //System.out.println("Saving ...");
+ //save_result();
+ System.out.println("Showing ...");
+ show_result();
+ System.out.println("Done.");
+
+ shutdown();
+ }
+
+
+ public static void main(String[] args) throws Throwable {
+ System.loadLibrary("vulkan");
+
+ new TestVulkan().demo();
+ }
+}