+
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+
+#include "compute.h"
+
+/* ********************************************************************** */
+
+static int mapDescriptorType(int binding) {
+ switch (binding) {
+ case ZVK_BUFFER:
+ case ZVK_POD:
+ return VK_DESCRIPTOR_TYPE_STORAGE_BUFFER;
+ case ZVK_ROIMAGE:
+ return VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
+ case ZVK_WOIMAGE:
+ return VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
+ case ZVK_SAMPLER:
+ return VK_DESCRIPTOR_TYPE_SAMPLER;
+ }
+ return -1;
+}
+
+struct modstate *compute_createModule(struct zvk *zvk, const struct modinfo *mod) {
+ struct modstate *ms = calloc(1, sizeof(*ms));
+
+ ms->zvk = zvk;
+ ms->modinfo = mod;
+
+ VkShaderModuleCreateInfo createInfo = {
+ .sType = VK_STRUCTURE_TYPE_SHADER_MODULE_CREATE_INFO,
+ .codeSize = mod->codeSize,
+ .pCode = mod->code,
+ };
+
+ ZVK_FATAL(vkCreateShaderModule(zvk->device, &createInfo, NULL, &ms->shader));
+
+ return ms;
+}
+
+void compute_destroyModule(struct modstate *mod) {
+ vkDestroyShaderModule(mod->zvk->device, mod->shader, NULL);
+ free(mod);
+}
+
+static int kernelIndex(const struct modinfo *mi, const char * name) {
+ for (int i=0;i<mi->nkernels;i++) {
+ if (strcmp(mi->kernels[i]->name, name) == 0)
+ return i;
+ }
+ return -1;
+}
+
+struct kernstate *compute_createKernel(struct modstate *mod, const char *name) {
+ int ki = kernelIndex(mod->modinfo, name);
+
+ if (ki < 0)
+ return NULL;
+
+ struct zvk *zvk = mod->zvk;
+ struct kernstate *ks = calloc(1, sizeof(*ks));
+ const struct kerninfo *kern = mod->modinfo->kernels[ki];
+
+ ks->mod = mod;
+ ks->kern = kern;
+
+ /* allocate data bindings */
+ VkDescriptorSetLayoutBinding layout_bindings[kern->nbindings];
+ VkDescriptorBindingFlagBitsEXT layout_bindings_flags[kern->nbindings];
+
+ for (int i=0;i<kern->nbindings;i++) {
+ memset(&layout_bindings[i], 0, sizeof(layout_bindings[0]));
+ layout_bindings_flags[i] = VK_DESCRIPTOR_BINDING_UPDATE_AFTER_BIND_BIT_EXT;
+ layout_bindings[i].binding = i;
+ layout_bindings[i].descriptorType = mapDescriptorType(kern->bindings[i].type);
+ layout_bindings[i].descriptorCount = 1;
+ layout_bindings[i].stageFlags = VK_SHADER_STAGE_COMPUTE_BIT;
+ }
+ VkDescriptorSetLayoutBindingFlagsCreateInfoEXT flags = {
+ .sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_BINDING_FLAGS_CREATE_INFO_EXT,
+ .bindingCount = kern->nbindings,
+ .pBindingFlags = layout_bindings_flags,
+ };
+ VkDescriptorSetLayoutCreateInfo descriptor_layout = {
+ .sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_LAYOUT_CREATE_INFO,
+ .pNext = &flags,
+ .flags = VK_DESCRIPTOR_SET_LAYOUT_CREATE_UPDATE_AFTER_BIND_POOL_BIT_EXT,
+ .bindingCount = kern->nbindings,
+ .pBindings = layout_bindings,
+ };
+
+ ZVK_FATAL(vkCreateDescriptorSetLayout(zvk->device, &descriptor_layout, NULL, &ks->descriptorSetLayout));
+
+ VkDescriptorSetAllocateInfo alloc_info[] = {
+ {
+ .sType = VK_STRUCTURE_TYPE_DESCRIPTOR_SET_ALLOCATE_INFO,
+ .descriptorPool = zvk->descriptorPool,
+ .descriptorSetCount = 1,
+ .pSetLayouts = &ks->descriptorSetLayout,
+ },
+ };
+
+ ZVK_FATAL(vkAllocateDescriptorSets(zvk->device, alloc_info, ks->descriptorSets));
+
+ /* Check for the (1) 'pod' binding, allocate memory for it */
+ for (int i=0;i<kern->nbindings;i++) {
+ if (kern->bindings[i].type == ZVK_POD) {
+ zvkAllocBuffer(zvk, kern->bindings[i].size,
+ VK_BUFFER_USAGE_STORAGE_BUFFER_BIT,
+ VK_MEMORY_PROPERTY_HOST_VISIBLE_BIT | VK_MEMORY_PROPERTY_HOST_COHERENT_BIT,
+ &ks->podBuffer, &ks->podMemory);
+
+ VkDescriptorBufferInfo bufferInfo = {
+ .buffer = ks->podBuffer,
+ .offset = 0,
+ .range = kern->bindings[i].size,
+ };
+ VkWriteDescriptorSet writeSet = {
+ .sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
+ .dstSet = ks->descriptorSets[0],
+ .dstBinding = i,
+ .descriptorCount = 1,
+ .descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+ .pBufferInfo = &bufferInfo,
+ };
+ vkUpdateDescriptorSets(zvk->device, 1, &writeSet, 0, NULL);
+ break;
+ }
+ }
+
+ /* Create pipeline */
+ VkPipelineLayoutCreateInfo pipelineinfo = {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_LAYOUT_CREATE_INFO,
+ .setLayoutCount = 1,
+ .pSetLayouts = &ks->descriptorSetLayout,
+ };
+
+ ZVK_FATAL(vkCreatePipelineLayout(zvk->device, &pipelineinfo, NULL, &ks->pipelineLayout));
+
+ VkComputePipelineCreateInfo pipeline = {
+ .sType = VK_STRUCTURE_TYPE_COMPUTE_PIPELINE_CREATE_INFO,
+ .stage =
+ {
+ .sType = VK_STRUCTURE_TYPE_PIPELINE_SHADER_STAGE_CREATE_INFO,
+ .stage = VK_SHADER_STAGE_COMPUTE_BIT,
+ .module = mod->shader,
+ .pName = kern->name,
+ },
+ .layout = ks->pipelineLayout
+
+ };
+ ZVK_FATAL(vkCreateComputePipelines(zvk->device, NULL, 1, &pipeline, NULL, &ks->pipeline));
+
+ return ks;
+}
+
+void compute_destroyKernel(struct kernstate *ks) {
+ struct zvk *zvk = ks->mod->zvk;
+
+ if (ks->podMemory) {
+ vkFreeMemory(zvk->device, ks->podMemory, NULL);
+ vkDestroyBuffer(zvk->device, ks->podBuffer, NULL);
+ }
+
+ vkDestroyPipeline(zvk->device, ks->pipeline, NULL);
+ vkDestroyPipelineLayout(zvk->device, ks->pipelineLayout, NULL);
+
+ vkFreeDescriptorSets(zvk->device, zvk->descriptorPool, 1, ks->descriptorSets);
+ vkDestroyDescriptorSetLayout(zvk->device, ks->descriptorSetLayout, NULL);
+
+ free(ks);
+}
+
+// one time optional?
+VkCommandBuffer compute_createCommand(struct kernstate *ks, uint32_t sizex, uint32_t sizey, uint32_t sizez) {
+ struct zvk *zvk = ks->mod->zvk;
+ VkCommandBuffer commandBuffers[1];
+
+ /* Create a command buffer to run this kernel with it's data set for the given size */
+ VkCommandBufferAllocateInfo cmdinfo = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_ALLOCATE_INFO,
+ .commandPool = zvk->commandPool,
+ .level = VK_COMMAND_BUFFER_LEVEL_PRIMARY,
+ .commandBufferCount = 1,
+ };
+
+ ZVK_FATAL(vkAllocateCommandBuffers(zvk->device, &cmdinfo, commandBuffers));
+
+ VkCommandBufferBeginInfo beginInfo = {
+ .sType = VK_STRUCTURE_TYPE_COMMAND_BUFFER_BEGIN_INFO,
+ //.flags = VK_COMMAND_BUFFER_USAGE_ONE_TIME_SUBMIT_BIT,
+ .flags = 0,
+ };
+ ZVK_FATAL(vkBeginCommandBuffer(commandBuffers[0], &beginInfo));
+
+ vkCmdBindPipeline(commandBuffers[0], VK_PIPELINE_BIND_POINT_COMPUTE, ks->pipeline);
+ vkCmdBindDescriptorSets(commandBuffers[0], VK_PIPELINE_BIND_POINT_COMPUTE, ks->pipelineLayout, 0, 1, ks->descriptorSets, 0, NULL);
+
+ vkCmdDispatch(commandBuffers[0], sizex, sizey, sizez);
+
+ ZVK_FATAL(vkEndCommandBuffer(commandBuffers[0]));
+
+ return commandBuffers[0];
+}
+
+// rather inefficient way to set arguments one at a time
+void compute_setArg(struct kernstate *ks, int index, void *data, size_t size) {
+ struct zvk *zvk = ks->mod->zvk;
+ const struct kerninfo *kern = ks->kern;
+ const struct paraminfo *pi = &kern->params[index];
+ const struct bindinfo *bi = &kern->bindings[pi->binding];
+
+ switch (bi->type) {
+ case ZVK_POD: {
+ void *pod __attribute__ ((aligned(16)));
+
+ ZVK_FATAL(!(size == pi->size));
+ ZVK_FATAL(vkMapMemory(zvk->device, ks->podMemory, 0, VK_WHOLE_SIZE, 0, &pod));
+ memcpy(pod + pi->offset, data, size);
+ vkUnmapMemory(zvk->device, ks->podMemory);
+ break;
+ }
+ case ZVK_BUFFER:
+ ZVK_FATAL(!(size == sizeof(VkDescriptorBufferInfo)));
+ VkWriteDescriptorSet bufferSet = {
+ .sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
+ .dstSet = ks->descriptorSets[0],
+ .dstBinding = pi->binding,
+ .descriptorCount = 1,
+ .descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+ .pBufferInfo = data
+ };
+ vkUpdateDescriptorSets(zvk->device, 1, &bufferSet, 0, NULL);
+ break;
+ case ZVK_WOIMAGE:
+ ZVK_FATAL(!(size == sizeof(VkDescriptorImageInfo)));
+ VkWriteDescriptorSet woSet = {
+ .sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
+ .dstSet = ks->descriptorSets[0],
+ .dstBinding = pi->binding,
+ .descriptorCount = 1,
+ .descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE,
+ .pImageInfo = data
+ };
+ vkUpdateDescriptorSets(zvk->device, 1, &woSet, 0, NULL);
+ break;
+ case ZVK_ROIMAGE:
+ ZVK_FATAL(!(size == sizeof(VkDescriptorImageInfo)));
+ VkWriteDescriptorSet roSet = {
+ .sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
+ .dstSet = ks->descriptorSets[0],
+ .dstBinding = pi->binding,
+ .descriptorCount = 1,
+ .descriptorType = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE,
+ .pImageInfo = data
+ };
+ vkUpdateDescriptorSets(zvk->device, 1, &roSet, 0, NULL);
+ break;
+ // shader?
+ }
+}
+
+void compute_setBuffers(struct kernstate *ks, VkDescriptorBufferInfo *buffers[]) {
+ VkWriteDescriptorSet writeSet = {
+ .sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
+ .dstSet = ks->descriptorSets[0],
+ .descriptorCount = 1,
+ .descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_BUFFER,
+ };
+ const struct kerninfo *kern = ks->kern;
+ struct zvk *zvk = ks->mod->zvk;
+
+ for (int i=0,bindex=0;i<kern->nparams;i++) {
+ const struct paraminfo *pi = &kern->params[i];
+ const struct bindinfo *bi = &kern->bindings[pi->binding];
+
+ if (bi->type == ZVK_BUFFER) {
+ writeSet.pBufferInfo = buffers[bindex];
+ writeSet.dstBinding = pi->binding;
+
+ vkUpdateDescriptorSets(zvk->device, 1, &writeSet, 0, NULL);
+ bindex++;
+ }
+ }
+}
+
+void compute_setImages(struct kernstate *ks, VkDescriptorImageInfo *images[]) {
+ VkWriteDescriptorSet writeSet = {
+ .sType = VK_STRUCTURE_TYPE_WRITE_DESCRIPTOR_SET,
+ .dstSet = ks->descriptorSets[0],
+ .descriptorCount = 1,
+ };
+ const struct kerninfo *kern = ks->kern;
+ struct zvk *zvk = ks->mod->zvk;
+
+ for (int i=0,bindex=0;i<kern->nparams;i++) {
+ const struct paraminfo *pi = &kern->params[i];
+ const struct bindinfo *bi = &kern->bindings[pi->binding];
+
+ if (bi->type == ZVK_WOIMAGE) {
+ writeSet.descriptorType = VK_DESCRIPTOR_TYPE_STORAGE_IMAGE;
+ } else if (bi->type == ZVK_ROIMAGE) {
+ writeSet.descriptorType = VK_DESCRIPTOR_TYPE_SAMPLED_IMAGE;
+ } else
+ continue;
+
+ writeSet.pImageInfo = images[bindex];
+ writeSet.dstBinding = pi->binding;
+
+ vkUpdateDescriptorSets(zvk->device, 1, &writeSet, 0, NULL);
+ bindex++;
+ }
+}