Last active
July 24, 2020 11:48
-
-
Save Robadob/34687ef50f09e2b657937071b0fae8b8 to your computer and use it in GitHub Desktop.
Test of how we can return formatted string from device code.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cstdint> | |
#include <cstring> | |
#include "cuda_runtime.h" | |
#include "device_launch_parameters.h" | |
#include <stdio.h> | |
struct FormatBuff { | |
static const unsigned int MAX_ARGS = 20; | |
static const unsigned int ARG_BUFF_LEN = 4096; | |
static const unsigned int FORMAT_BUFF_LEN = 4096; | |
/** | |
* The format string passed by the user to the printer | |
*/ | |
char format_string[FORMAT_BUFF_LEN]; | |
/** | |
* The type size of each of the args passed to the printer | |
*/ | |
unsigned int format_args_sizes[MAX_ARGS]; | |
/** | |
* A compact buffer of each of the args passed to the printer | |
* Their size corresponds to the matching array above | |
*/ | |
char format_args[ARG_BUFF_LEN]; | |
/** | |
* The total number of args passed to the printer | |
*/ | |
unsigned int arg_count; | |
/** | |
* The total space used by the args | |
*/ | |
unsigned int arg_offset; | |
}; | |
template<typename T> | |
__host__ __device__ void subformat(FormatBuff *buff, const T &t) { | |
if (buff->arg_count < FormatBuff::MAX_ARGS) { | |
if (buff->arg_offset + sizeof(T) <= FormatBuff::ARG_BUFF_LEN) { | |
// Copy arg size | |
buff->format_args_sizes[buff->arg_count] = sizeof(T); | |
// Copy arg value | |
memcpy(buff->format_args + buff->arg_offset, &t, sizeof(T)); | |
// Update offsets | |
++buff->arg_count; | |
buff->arg_offset += sizeof(T); | |
} | |
} | |
} | |
template<typename T, typename... Args> | |
__host__ __device__ void subformat(FormatBuff *buff, const T &t, Args... args) { | |
// Call subformat with T | |
subformat(buff, t); | |
// Recurse with the rest of the list | |
subformat(buff, args...); | |
} | |
template<typename... Args> | |
__host__ __device__ void format(FormatBuff *buff, const char *format, Args... args) { | |
// Only output once | |
if (buff->format_string[0]) | |
return; | |
// Copy the format string | |
unsigned int eos = 0; | |
for (eos = 0; eos < FormatBuff::FORMAT_BUFF_LEN; ++eos) | |
if (format[eos] == '\0') | |
break; | |
memcpy(buff->format_string, format, eos * sizeof(char)); | |
// Process args | |
subformat(buff, args...); | |
} | |
__global__ void test_kernel(FormatBuff *ptr) { | |
format(ptr, "This is a test (int %d, uint %u, float %f, double %f, (unsupported) char* %s, char %c)\n", -12, 13, 14.0f, 15.0, "s", '7'); | |
} | |
int main() | |
{ | |
FormatBuff h_buff; | |
FormatBuff *d_buff; | |
memset(&h_buff, 0, sizeof(FormatBuff)); | |
cudaMalloc(&d_buff, sizeof(FormatBuff)); | |
cudaMemset(d_buff, 0, sizeof(FormatBuff)); | |
test_kernel<<<1,1>>>(d_buff); | |
cudaMemcpy(&h_buff, d_buff, sizeof(FormatBuff), cudaMemcpyDeviceToHost); | |
// Now actually process the string | |
char temp_buffer[FormatBuff::FORMAT_BUFF_LEN]; | |
char out_buffer[FormatBuff::FORMAT_BUFF_LEN]; | |
memset(out_buffer, 0, FormatBuff::FORMAT_BUFF_LEN); | |
unsigned int format_buffer_index = 0; | |
unsigned int out_index = 0; | |
unsigned int arg_no = 0; | |
unsigned int arg_offset = 0; | |
while (h_buff.format_string[format_buffer_index] != '\0' && format_buffer_index < FormatBuff::FORMAT_BUFF_LEN && out_index < FormatBuff::FORMAT_BUFF_LEN && arg_no < FormatBuff::MAX_ARGS) { | |
if (h_buff.format_string[format_buffer_index] == '%') { | |
// We found a format start, now find the next format start, or end of format string | |
unsigned int format_end = format_buffer_index + 1; | |
char format_type = '\0'; | |
while(h_buff.format_string[format_end] != '%' && h_buff.format_string[format_end] != '\0' && format_end < FormatBuff::FORMAT_BUFF_LEN) { | |
// Detect the format type | |
if (format_type == '\0') { | |
switch (h_buff.format_string[format_end]) { | |
case 'd': | |
case 'i': | |
case 'u': | |
case 'o': | |
case 'x': | |
case 'X': | |
case 'f': | |
case 'e': | |
case 'g': | |
case 'G': | |
case 'a': | |
case 'A': | |
case 'c': | |
case 's': | |
case 'p': | |
case 'n': | |
format_type = h_buff.format_string[format_end]; | |
break; | |
} | |
} | |
++format_end; | |
} | |
// Copy the buffer into a temporary buffer | |
memset(temp_buffer, 0, FormatBuff::FORMAT_BUFF_LEN); | |
memcpy(temp_buffer, h_buff.format_string + format_buffer_index, format_end - format_buffer_index); | |
// Now send this substring to the formatter to process | |
// Cast it to the correct type first | |
// (This assumes snprintf never returns negative) | |
switch (format_type) { | |
case 'd': | |
case 'i': { | |
// Signed integer | |
if (h_buff.format_args_sizes[arg_no] == 4) { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<int32_t*>(h_buff.format_args+arg_offset)); | |
} else { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<int64_t*>(h_buff.format_args+arg_offset)); | |
} | |
break; | |
} | |
case 'u': | |
case 'o': | |
case 'x': | |
case 'X': { | |
// Unsigned integer | |
if (h_buff.format_args_sizes[arg_no] == 4) { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<uint32_t*>(h_buff.format_args+arg_offset)); | |
} else { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<uint64_t*>(h_buff.format_args+arg_offset)); | |
} | |
break; | |
} | |
case 'f': | |
case 'e': | |
case 'g': | |
case 'G': | |
case 'a': | |
case 'A': { | |
// Floating point | |
if (h_buff.format_args_sizes[arg_no] == 4) { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<float*>(h_buff.format_args+arg_offset)); | |
} else { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<double*>(h_buff.format_args+arg_offset)); | |
} | |
break; | |
} | |
case 'c': { | |
// Char | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, *reinterpret_cast<char*>(h_buff.format_args+arg_offset)); | |
break; | |
} | |
case 's': { | |
// Char string | |
// Not supported, just copy the unchanged format string | |
memcpy(out_buffer + out_index, temp_buffer, format_end - format_buffer_index); | |
out_buffer[out_index] = '#'; // Replace the % with # | |
out_index += format_end - format_buffer_index; | |
break; | |
} | |
case 'p': { | |
// Pointer | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, reinterpret_cast<void*>(h_buff.format_args+arg_offset)); | |
break; | |
} | |
case 'n': { | |
// No of chars written (signed pointer to have value written back to) | |
if (h_buff.format_args_sizes[arg_no] == 4) { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, reinterpret_cast<int32_t*>(h_buff.format_args+arg_offset)); | |
} else { | |
out_index += snprintf(out_buffer + out_index, FormatBuff::FORMAT_BUFF_LEN - out_index, temp_buffer, reinterpret_cast<int64_t*>(h_buff.format_args+arg_offset)); | |
} | |
break; | |
} | |
} | |
arg_offset += h_buff.format_args_sizes[arg_no]; | |
++arg_no; | |
// Continue loop | |
format_buffer_index = format_end; | |
} else { | |
// Copy the single char | |
out_buffer[out_index] = h_buff.format_string[format_buffer_index]; | |
++out_index; | |
++format_buffer_index; | |
} | |
} | |
printf(out_buffer); | |
//getchar(); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment