Skip to content

Instantly share code, notes, and snippets.

@Robadob
Last active July 24, 2020 11:48
Show Gist options
  • Save Robadob/34687ef50f09e2b657937071b0fae8b8 to your computer and use it in GitHub Desktop.
Save Robadob/34687ef50f09e2b657937071b0fae8b8 to your computer and use it in GitHub Desktop.
Test of how we can return formatted string from device code.
#include <cstdint>
#include <cstring>
#include "cuda_runtime.h"
#include "device_launch_parameters.h"
#include <stdio.h>
struct FormatBuff {
static const unsigned int MAX_ARGS = 20;
static const unsigned int ARG_BUFF_LEN = 4096;
static const unsigned int FORMAT_BUFF_LEN = 4096;
/**
* The format string passed by the user to the printer
*/
char format_string[FORMAT_BUFF_LEN];
/**
* The type size of each of the args passed to the printer
*/
unsigned int format_args_sizes[MAX_ARGS];
/**
* A compact buffer of each of the args passed to the printer
* Their size corresponds to the matching array above
*/
char format_args[ARG_BUFF_LEN];
/**
* The total number of args passed to the printer
*/
unsigned int arg_count;
/**
* The total space used by the args
*/
unsigned int arg_offset;
};
template<typename T>
__host__ __device__ void subformat(FormatBuff *buff, const T &t) {
if (buff->arg_count < FormatBuff::MAX_ARGS) {
if (buff->arg_offset + sizeof(T) <= FormatBuff::ARG_BUFF_LEN) {
// Copy arg size
buff->format_args_sizes[buff->arg_count] = sizeof(T);
// Copy arg value
memcpy(buff->format_args + buff->arg_offset, &t, sizeof(T));
// Update offsets
++buff->arg_count;
buff->arg_offset += sizeof(T);
}
}
}
template<typename T, typename... Args>
__host__ __device__ void subformat(FormatBuff *buff, const T &t, Args... args) {
// Call subformat with T
subformat(buff, t);
// Recurse with the rest of the list
subformat(buff, args...);
}
template<typename... Args>
__host__ __device__ void format(FormatBuff *buff, const char *format, Args... args) {
// Only output once
if (buff->format_string[0])
return;
// Copy the format string
unsigned int eos = 0;
for (eos = 0; eos < FormatBuff::FORMAT_BUFF_LEN; ++eos)
if (format[eos] == '\0')
break;
memcpy(buff->format_string, format, eos * sizeof(char));
// Process args
subformat(buff, args...);
}
__global__ void test_kernel(FormatBuff *ptr) {
format(ptr, "This is a test (int %d, uint %u, float %f, double %f, (unsupported) char* %s, char %c)\n", -12, 13, 14.0f, 15.0, "s", '7');
}
int main()
{
FormatBuff h_buff;
FormatBuff *d_buff;
memset(&h_buff, 0, sizeof(FormatBuff));
cudaMalloc(&d_buff, sizeof(FormatBuff));
cudaMemset(d_buff, 0, sizeof(FormatBuff));
test_kernel<<<1,1>>>(d_buff);
cudaMemcpy(&h_buff, d_buff, sizeof(FormatBuff), cudaMemcpyDeviceToHost);
// Now actually process the string
char temp_buffer[FormatBuff::FORMAT_BUFF_LEN];
char out_buffer[FormatBuff::FORMAT_BUFF_LEN];
memset(out_buffer, 0, FormatBuff::FORMAT_BUFF_LEN);
unsigned int format_buffer_index = 0;
unsigned int out_index = 0;
unsigned int arg_no = 0;
unsigned int arg_offset = 0;
while (h_buff.format_string[format_buffer_index] != '\0' && format_buffer_index < FormatBuff::FORMAT_BUFF_LEN && out_index < FormatBuff::FORMAT_BUFF_LEN && arg_no < FormatBuff::MAX_ARGS) {
if (h_buff.format_string[format_buffer_index] == '%') {
// We found a format start, now find the next format start, or end of format string
unsigned int format_end = format_buffer_index + 1;
char format_type = '\0';
while(h_buff.format_string[format_end] != '%' && h_buff.format_string[format_end] != '\0' && format_end < FormatBuff::FORMAT_BUFF_LEN) {
// Detect the format type
if (format_type == '\0') {
switch (h_buff.format_string[format_end]) {
case 'd':
case 'i':
case 'u':
case 'o':
case 'x':
case 'X':
case 'f':
case 'e':
case 'g':
case 'G':
case 'a':
case 'A':
case 'c':
case 's':
case 'p':
case 'n':
format_type = h_buff.format_string[format_end];
break;
}
}
++format_end;
}
// Copy the buffer into a temporary buffer
memset(temp_buffer, 0, FormatBuff::FORMAT_BUFF_LEN);
memcpy(temp_buffer, h_buff.format_string + format_buffer_index, format_end - format_buffer_index);
// Now send this substring to the formatter to process
// Cast it to the correct type first
// (This assumes snprintf never returns negative)
switch (format_type) {
case 'd':
case 'i': {
// Signed integer
if (h_buff.format_args_sizes[arg_no] == 4) {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<int32_t*>(h_buff.format_args+arg_offset));
} else {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<int64_t*>(h_buff.format_args+arg_offset));
}
break;
}
case 'u':
case 'o':
case 'x':
case 'X': {
// Unsigned integer
if (h_buff.format_args_sizes[arg_no] == 4) {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<uint32_t*>(h_buff.format_args+arg_offset));
} else {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<uint64_t*>(h_buff.format_args+arg_offset));
}
break;
}
case 'f':
case 'e':
case 'g':
case 'G':
case 'a':
case 'A': {
// Floating point
if (h_buff.format_args_sizes[arg_no] == 4) {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<float*>(h_buff.format_args+arg_offset));
} else {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<double*>(h_buff.format_args+arg_offset));
}
break;
}
case 'c': {
// Char
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<char*>(h_buff.format_args+arg_offset));
break;
}
case 's': {
// Char string
// Not supported, just copy the unchanged format string
memcpy(out_buffer + out_index, temp_buffer, format_end - format_buffer_index);
out_buffer[out_index] = '#'; // Replace the % with #
out_index += format_end - format_buffer_index;
break;
}
case 'p': {
// Pointer
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, reinterpret_cast<void*>(h_buff.format_args+arg_offset));
break;
}
case 'n': {
// No of chars written (signed pointer to have value written back to)
if (h_buff.format_args_sizes[arg_no] == 4) {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, reinterpret_cast<int32_t*>(h_buff.format_args+arg_offset));
} else {
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, reinterpret_cast<int64_t*>(h_buff.format_args+arg_offset));
}
break;
}
}
arg_offset += h_buff.format_args_sizes[arg_no];
++arg_no;
// Continue loop
format_buffer_index = format_end;
} else {
// Copy the single char
out_buffer[out_index] = h_buff.format_string[format_buffer_index];
++out_index;
++format_buffer_index;
}
}
printf(out_buffer);
//getchar();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment