Skip to content

Instantly share code, notes, and snippets.

@likr
Created August 9, 2012 08:41
Show Gist options
  • Save likr/3302401 to your computer and use it in GitHub Desktop.
Save likr/3302401 to your computer and use it in GitHub Desktop.
PyOpenCL FizzBuzz
# coding: utf-8
# vim: filetype=pyopencl.python
from __future__ import print_function
import numpy as np
import pyopencl as cl
from pyopencl import array as clarray
from pyopencl.scan import ExclusiveScanKernel
src = '''//CL//
int len(int x);
int len_n(int x);
void write_fizz(__global char* pos);
void write_buzz(__global char* pos);
void write_fizzbuzz(__global char* pos);
void write_n(__global char* pos, int n);
int len(int x)
{
if (x % 15 == 0) {
return 8; // "fizzbuzz"
} else if (x % 3 == 0 || x % 5 == 0) {
return 4; // "fizz" or "buzz"
} else {
return len_n(x);
}
}
int len_n(int x)
{
int res = 1;
for (; x >= 10; x /= 10) {
++res;
}
return res;
}
void write_fizz(__global char* pos)
{
pos[0] = 'f';
pos[1] = 'i';
pos[2] = 'z';
pos[3] = 'z';
pos[4] = '\\n';
}
void write_buzz(__global char* pos)
{
pos[0] = 'b';
pos[1] = 'u';
pos[2] = 'z';
pos[3] = 'z';
pos[4] = '\\n';
}
void write_fizzbuzz(__global char* pos)
{
pos[0] = 'f';
pos[1] = 'i';
pos[2] = 'z';
pos[3] = 'z';
pos[4] = 'b';
pos[5] = 'u';
pos[6] = 'z';
pos[7] = 'z';
pos[8] = '\\n';
}
void write_n(__global char* pos, int n)
{
const int l = len(n);
for (int i = l - 1; i >= 0; --i) {
pos[i] = '0' + n % 10;
n /= 10;
}
pos[l] = '\\n';
}
__kernel void calc_length(__global int* out)
{
const int i = get_global_id(0);
out[i] = len(i + 1) + 1;
}
__kernel void write_string(
__global const int* indices, __global char* out)
{
const int i = get_global_id(0);
const int n = i + 1;
if (n % 15 == 0) {
write_fizzbuzz(out + indices[i]);
} else if (n % 3 == 0) {
write_fizz(out + indices[i]);
} else if (n % 5 == 0) {
write_buzz(out + indices[i]);
} else {
write_n(out + indices[i], n);
}
}
'''
def main():
ctx = cl.create_some_context()
queue = cl.CommandQueue(ctx)
prg = cl.Program(ctx, src).build()
scan_kernel = ExclusiveScanKernel(ctx, np.int32, 'a + b', '0')
n = 100
sizes = clarray.empty(queue, (n,), dtype=np.int32)
prg.calc_length(queue, sizes.shape, None, sizes.data)
length = int(clarray.sum(sizes).get())
scan_kernel(sizes)
result = clarray.empty(queue, (length,), np.byte)
prg.write_string(queue, sizes.shape, None, sizes.data, result.data)
print(result.get().tostring().decode())
if __name__ == '__main__':
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment