Created
February 14, 2017 17:33
-
-
Save tedsta/51146af7822668eaea7de3459801bafd to your computer and use it in GitHub Desktop.
SPIR-V + OpenCL fun times in Rust
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate ocl; | |
extern crate rspirv; | |
extern crate spirv_headers as spirv; | |
use ocl::{Platform, Device, Context, Queue, Buffer, Program, Kernel, Event, EventList}; | |
use rspirv::binary::Disassemble; | |
pub fn find_platform() -> Option<Platform> { | |
let platform_name = "Experimental OpenCL 2.1 CPU Only Platform"; | |
for platform in Platform::list() { | |
if platform.name() == platform_name { | |
return Some(platform); | |
} | |
} | |
None | |
} | |
pub fn build_spirv_module() -> Vec<u32> { | |
use rspirv::binary::Assemble; | |
let mut b = rspirv::mr::Builder::new(); | |
b.capability(spirv::Capability::Kernel); | |
b.capability(spirv::Capability::Addresses); | |
b.memory_model(spirv::AddressingModel::Physical64, spirv::MemoryModel::OpenCL); | |
// Type declarations | |
let void = b.type_void(); | |
let uint64 = b.type_int(64, 0); | |
let float32 = b.type_float(32); | |
let vec3_uint64 = b.type_vector(uint64, 3); | |
let global_ptr_float32 = b.type_pointer(spirv::StorageClass::CrossWorkgroup, float32); | |
let multiply_kernel_type = b.type_function(void, vec![global_ptr_float32, float32]); | |
let vec3_uint64_uniform_constant_ptr = b.type_pointer(spirv::StorageClass::UniformConstant, vec3_uint64); | |
// Variable declarations | |
let const_0 = b.constant_u32(uint64, 0); | |
// Multiply kernel | |
let multiply_kernel = b.begin_function(void, | |
spirv::FUNCTION_CONTROL_NONE, | |
multiply_kernel_type).unwrap(); | |
let input_array = b.function_parameter(global_ptr_float32).unwrap(); | |
let input_num = b.function_parameter(float32).unwrap(); | |
b.begin_basic_block().unwrap(); | |
let input_array_at_index = b.in_bounds_ptr_access_chain(global_ptr_float32, input_array, const_0, vec![]).unwrap(); | |
b.store(input_array_at_index, input_num, Some(spirv::MEMORY_ACCESS_ALIGNED), | |
vec![rspirv::mr::Operand::LiteralInt32(4)]).unwrap(); | |
b.ret().unwrap(); | |
b.end_function().unwrap(); | |
// Entry points | |
b.entry_point(spirv::ExecutionModel::Kernel, multiply_kernel, "test_kernel".to_string(), vec![]); | |
let module = b.module(); | |
println!("{}", module.disassemble()); | |
module.assemble() | |
} | |
#[cfg(test)] | |
mod tests { | |
#[test] | |
fn it_works() { | |
use ::{find_platform, build_spirv_module}; | |
use ocl::{self, Platform, Device, Context, Queue, Buffer, Program, Kernel, Event, EventList}; | |
let platform = find_platform().unwrap(); | |
assert!(platform.name() == "Experimental OpenCL 2.1 CPU Only Platform"); | |
// Get first (and only) device | |
let device = Device::first(platform); | |
// Build context using the first device | |
let context = Context::builder() | |
.platform(platform) | |
.devices(device) | |
.build().expect("Failed to create context"); | |
let il_word_vec = build_spirv_module(); | |
let mut il_byte_vec: Vec<u8> = vec![]; | |
for word in &il_word_vec { | |
il_byte_vec.push((*word >> 24) as u8); | |
il_byte_vec.push((*word >> 16) as u8); | |
il_byte_vec.push((*word >> 8) as u8); | |
il_byte_vec.push(*word as u8); | |
println!("{:X}\t{:X}\t{:X}\t{:X}", (*word >> 24) as u8, (*word >> 16) as u8, | |
(*word >> 8) as u8, *word as u8); | |
} | |
let queue = Queue::new(&context, device, Some(ocl::core::QUEUE_PROFILING_ENABLE)).expect("Failed to create queue"); | |
let dims = [10]; | |
let buffer = Buffer::<f32>::new(queue.clone(), None, &dims, None).expect("Failed to create buffer"); | |
let mut buffer_host = vec![0.0; dims[0]]; | |
let program = ocl::Program::with_il(il_byte_vec, &context).expect("Failed to build program from SPIR-V module"); | |
let kernel = Kernel::new("test_kernel", &program, queue.clone()).expect("Failed to create kernel") | |
.gws(&dims) | |
.arg_buf(&buffer) | |
.arg_scl(42.0f32); | |
let mut event_list = EventList::new(); | |
kernel.cmd().enew(&mut event_list).enq().unwrap(); | |
event_list.wait().unwrap(); | |
let mut event = Event::empty(); | |
buffer.cmd().read(&mut buffer_host).enew(&mut event).enq().unwrap(); | |
event.wait().unwrap(); | |
} | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment