Skip to content

Instantly share code, notes, and snippets.

@benvanik
Created January 6, 2021 00:12
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save benvanik/f5a8ff4928f58572882669dd92caced4 to your computer and use it in GitHub Desktop.
Save benvanik/f5a8ff4928f58572882669dd92caced4 to your computer and use it in GitHub Desktop.
WIP api_interfaces_cc.h example for #iree/4369
// Copyright 2020 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// https://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef IREE_HAL_API_INTERFACES_CC_H_
#define IREE_HAL_API_INTERFACES_CC_H_
#include "iree/base/status.h"
#include "iree/base/time.h"
#include "iree/hal/api.h"
#include "iree/hal/api_interfaces.h"
#include "iree/hal/executable_format.h"
#ifndef __cplusplus
#error "This header is meant for use with C++ HAL implementations."
#endif // __cplusplus
namespace iree {
namespace hal {
//===----------------------------------------------------------------------===//
// iree_hal_resource_t
//===----------------------------------------------------------------------===//
template <typename T>
class ResourceBase {
public:
ResourceBase(const ResourceBase&) = delete;
ResourceBase& operator=(const ResourceBase&) = delete;
// Adds a reference; used by ref_ptr.
friend void ref_ptr_add_ref(T* p) {
volatile iree_atomic_ref_count_t* counter = p->base()->resource.ref_count;
iree_atomic_ref_count_inc(counter);
}
// Releases a reference, potentially deleting the object; used by ref_ptr.
friend void ref_ptr_release_ref(T* p) {
volatile iree_atomic_ref_count_t* counter = p->base()->resource.ref_count;
if (iree_atomic_ref_count_dec(counter) == 1) {
delete p;
}
}
};
//===----------------------------------------------------------------------===//
// iree_hal_allocator_t
//===----------------------------------------------------------------------===//
class AllocatorBase : public ResourceBase<AllocatorBase> {
public:
virtual ~AllocatorBase() = default;
iree_hal_allocator_t* base() const noexcept { return &base_; }
virtual bool CheckBufferCompatibility(
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage,
iree_hal_buffer_usage_t intended_usage) = 0;
virtual StatusOr<ref_ptr<BufferBase>> AllocateBuffer(
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t buffer_usage,
iree_host_size_t allocation_size) = 0;
virtual StatusOr<ref_ptr<BufferBase>> WrapBuffer(
iree_hal_memory_type_t memory_type,
iree_hal_memory_access_t allowed_access,
iree_hal_buffer_usage_t buffer_usage, absl::Span<uint8_t> data) = 0;
protected:
AllocatorBase() {
static const iree_hal_allocator_vtable_t vtable = {
DestroyThunk,
CheckBufferCompatibilityThunk,
AllocateBufferThunk,
WrapBufferThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_allocator_t* allocator) {
delete reinterpret_cast<AllocatorBase*>(allocator);
}
static bool CheckBufferCompatibilityThunk(
iree_hal_allocator_t* source_allocator,
iree_hal_memory_type_t memory_type, iree_hal_buffer_usage_t allowed_usage,
iree_hal_buffer_usage_t intended_usage) {
return reinterpret_cast<AllocatorBase*>(allocator)
->CheckBufferCompatibility(memory_type, allowed_usage, intended_usage);
}
static iree_status_t AllocateBufferThunk(iree_hal_allocator_t* allocator,
iree_hal_memory_type_t memory_type,
iree_hal_buffer_usage_t buffer_usage,
iree_host_size_t allocation_size,
iree_hal_buffer_t** out_buffer) {
IREE_ASSIGN_OR_RETURN(
auto buffer,
reinterpret_cast<AllocatorBase*>(allocator)->AllocateBuffer(
memory_type, buffer_usage, allocation_size));
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
return iree_ok_status();
}
static iree_status_t WrapBufferThunk(iree_hal_allocator_t* allocator,
iree_hal_memory_type_t memory_type,
iree_hal_memory_access_t allowed_access,
iree_hal_buffer_usage_t buffer_usage,
iree_byte_span_t data,
iree_hal_buffer_t** out_buffer) {
IREE_ASSIGN_OR_RETURN(
auto buffer, reinterpret_cast<AllocatorBase*>(allocator)->WrapBuffer(
memory_type, allowed_access, buffer_usage,
absl::MakeSpan(data.data(), data.size())));
*out_buffer = reinterpret_cast<iree_hal_buffer_t*>(buffer.release());
return iree_ok_status();
}
iree_hal_allocator_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_buffer_t
//===----------------------------------------------------------------------===//
class BufferBase : public ResourceBase<BufferBase> {
public:
virtual ~BufferBase() = default;
iree_hal_buffer_t* base() const noexcept { return &base_; }
constexpr iree_hal_allocator_t* allocator() const {
return base()->allocator;
}
iree_hal_memory_type_t memory_type() const {
return static_cast<iree_hal_memory_type_t>(base()->memory_type);
}
iree_hal_memory_access_t allowed_access() const {
return static_cast<iree_hal_memory_access_t>(base()->allowed_access);
}
iree_hal_buffer_usage_t usage() const {
return static_cast<iree_hal_buffer_usage_t>(base()->usage);
}
iree_hal_buffer_t* allocated_buffer() const noexcept {
return base()->allocated_buffer;
}
constexpr iree_device_size_t allocation_size() const {
return base()->allocation_size;
}
constexpr iree_device_size_t byte_offset() const noexcept {
return base()->byte_offset;
}
constexpr iree_device_size_t byte_length() const noexcept {
return base()->byte_length;
}
protected:
BufferBase() {
static const iree_hal_buffer_vtable_t vtable = {
DestroyThunk,
FillThunk,
ReadDataThunk,
WriteDataThunk,
CopyDataThunk,
MapThunk,
UnmapThunk,
InvalidateMappedMemoryThunk,
FlushMappedMemoryThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
virtual Status FillImpl(iree_device_size_t byte_offset,
iree_device_size_t byte_length, const void* pattern,
iree_device_size_t pattern_length) = 0;
virtual Status ReadDataImpl(iree_device_size_t source_offset,
void* target_data,
iree_device_size_t data_length) = 0;
virtual Status WriteDataImpl(iree_device_size_t target_offset,
const void* source_data,
iree_device_size_t data_length) = 0;
virtual Status CopyDataImpl(iree_device_size_t target_offset,
iree_hal_buffer_t* source_buffer,
iree_device_size_t source_offset,
iree_device_size_t data_length) = 0;
virtual Status MapMemoryImpl(MappingMode mapping_mode,
iree_hal_memory_access_t memory_access,
iree_device_size_t local_byte_offset,
iree_device_size_t local_byte_length,
void** out_data) = 0;
virtual Status UnmapMemoryImpl(iree_device_size_t local_byte_offset,
iree_device_size_t local_byte_length,
void* data) = 0;
virtual Status InvalidateMappedMemoryImpl(
iree_device_size_t local_byte_offset,
iree_device_size_t local_byte_length) = 0;
virtual Status FlushMappedMemoryImpl(
iree_device_size_t local_byte_offset,
iree_device_size_t local_byte_length) = 0;
private:
static void DestroyThunk(iree_hal_buffer_t* buffer) {
delete reinterpret_cast<BufferBase*>(buffer);
}
static iree_status_t FillThunk(iree_hal_buffer_t* buffer,
iree_device_size_t byte_offset,
iree_device_size_t byte_length,
const void* pattern,
iree_host_size_t pattern_length) {
return reinterpret_cast<BufferBase*>(buffer)->FillImpl(
byte_offset, byte_length, pattern, pattern_length);
}
static iree_status_t ReadDataThunk(iree_hal_buffer_t* buffer,
iree_device_size_t source_offset,
void* target_buffer,
iree_device_size_t data_length) {
return reinterpret_cast<BufferBase*>(buffer)->ReadDataImpl(
source_offset, target_buffer, data_length);
}
static iree_status_t WriteDataThunk(iree_hal_buffer_t* buffer,
iree_device_size_t target_offset,
const void* source_buffer,
iree_device_size_t data_length) {
return reinterpret_cast<BufferBase*>(buffer)->WriteDataImpl(
target_offset, source_buffer, data_length);
}
static iree_status_t CopyDataThunk(iree_hal_buffer_t* source_buffer,
iree_device_size_t source_offset,
iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset,
iree_device_size_t data_length) {
return reinterpret_cast<BufferBase*>(buffer)->CopyDataImpl(
target_offset, source_buffer, source_offset, data_length);
}
static iree_status_t MapThunk(iree_hal_buffer_t* buffer,
iree_hal_memory_access_t memory_access,
iree_device_size_t byte_offset,
iree_device_size_t byte_length,
iree_hal_buffer_mapping_t* out_mapped_memory) {
// DO NOT SUBMIT
}
static iree_status_t UnmapThunk(iree_hal_buffer_t* buffer,
iree_hal_buffer_mapping_t* mapped_memory) {
// DO NOT SUBMIT
}
static iree_status_t InvalidateMappedMemoryThunk(
iree_hal_buffer_mapping_t* mapped_memory,
iree_device_size_t local_byte_offset,
iree_device_size_t local_byte_length) {
// DO NOT SUBMIT
}
static iree_status_t FlushMappedMemoryThunk(
iree_hal_buffer_mapping_t* mapped_memory,
iree_device_size_t local_byte_offset,
iree_device_size_t local_byte_length) {
// DO NOT SUBMIT
}
iree_hal_buffer_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_command_buffer_t
//===----------------------------------------------------------------------===//
class CommandBufferBase : public ResourceBase<CommandBufferBase> {
public:
virtual ~CommandBufferBase() = default;
iree_hal_command_buffer_t* base() const noexcept { return &base_; }
virtual Status Begin() = 0;
virtual Status End() = 0;
virtual Status ExecutionBarrier(
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
absl::Span<const iree_hal_memory_barrier_t> memory_barriers,
absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) = 0;
virtual Status SignalEvent(iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) = 0;
virtual Status ResetEvent(iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) = 0;
virtual Status WaitEvents(
absl::Span<iree_hal_event_t*> events,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
absl::Span<const iree_hal_memory_barrier_t> memory_barriers,
absl::Span<const iree_hal_buffer_barrier_t> buffer_barriers) = 0;
virtual Status FillBuffer(iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset,
iree_device_size_t length, const void* pattern,
size_t pattern_length) = 0;
virtual Status DiscardBuffer(iree_hal_buffer_t* buffer) = 0;
virtual Status UpdateBuffer(const void* source_buffer,
iree_device_size_t source_offset,
iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset,
iree_device_size_t length) = 0;
virtual Status CopyBuffer(iree_hal_buffer_t* source_buffer,
iree_device_size_t source_offset,
iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset,
iree_device_size_t length) = 0;
virtual Status PushConstants(iree_hal_executable_layout_t* executable_layout,
size_t offset,
absl::Span<const uint32_t> values) = 0;
virtual Status PushDescriptorSet(
iree_hal_executable_layout_t* executable_layout, int32_t set,
absl::Span<const iree_hal_descriptor_set_binding_t> bindings) = 0;
virtual Status BindDescriptorSet(
iree_hal_executable_layout_t* executable_layout, int32_t set,
iree_hal_descriptor_set_t* descriptor_set,
absl::Span<const iree_device_size_t> dynamic_offsets) = 0;
virtual Status Dispatch(iree_hal_executable_t* executable,
int32_t entry_point,
std::array<uint32_t, 3> workgroups) = 0;
virtual Status DispatchIndirect(iree_hal_executable_t* executable,
int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer,
iree_device_size_t workgroups_offset) = 0;
protected:
CommandBufferBase() {
static const iree_hal_command_buffer_vtable_t vtable = {
DestroyThunk, BeginThunk, EndThunk,
ExecutionBarrierThunk, SignalEventThunk, ResetEventThunk,
WaitEventsThunk, FillBufferThunk, UpdateBufferThunk,
CopyBufferThunk, PushConstantsThunk, PushDescriptorSetThunk,
BindDescriptorSetThunk, DispatchThunk, DispatchIndirectThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_command_buffer_t* command_buffer) {
delete reinterpret_cast<CommandBufferBase*>(command_buffer);
}
static iree_status_t BeginThunk(iree_hal_command_buffer_t* command_buffer) {
// DO NOT SUBMIT
}
static iree_status_t EndThunk(iree_hal_command_buffer_t* command_buffer) {
// DO NOT SUBMIT
}
static iree_status_t ExecutionBarrierThunk(
iree_hal_command_buffer_t* command_buffer,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
// DO NOT SUBMIT
}
static iree_status_t SignalEventThunk(
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
// DO NOT SUBMIT
}
static iree_status_t ResetEventThunk(
iree_hal_command_buffer_t* command_buffer, iree_hal_event_t* event,
iree_hal_execution_stage_t source_stage_mask) {
// DO NOT SUBMIT
}
static iree_status_t WaitEventsThunk(
iree_host_size_t event_count, const iree_hal_event_t* events,
iree_hal_execution_stage_t source_stage_mask,
iree_hal_execution_stage_t target_stage_mask,
iree_host_size_t memory_barrier_count,
const iree_hal_memory_barrier_t* memory_barriers,
iree_host_size_t buffer_barrier_count,
const iree_hal_buffer_barrier_t* buffer_barriers) {
// DO NOT SUBMIT
}
static iree_status_t FillBufferThunk(
iree_hal_command_buffer_t* command_buffer,
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
iree_device_size_t length, const void* pattern,
iree_host_size_t pattern_length) {
// DO NOT SUBMIT
}
static iree_status_t UpdateBufferThunk(
iree_hal_command_buffer_t* command_buffer, const void* source_buffer,
iree_host_size_t source_offset, iree_hal_buffer_t* target_buffer,
iree_device_size_t target_offset, iree_device_size_t length) {
// DO NOT SUBMIT
}
static iree_status_t CopyBufferThunk(
iree_hal_command_buffer_t* command_buffer,
iree_hal_buffer_t* source_buffer, iree_device_size_t source_offset,
iree_hal_buffer_t* target_buffer, iree_device_size_t target_offset,
iree_device_size_t length) {
// DO NOT SUBMIT
}
static iree_status_t PushConstantsThunk(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_layout_t* executable_layout, iree_host_size_t offset,
const void* values, iree_host_size_t values_length) {
// DO NOT SUBMIT
}
static iree_status_t PushDescriptorSetThunk(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_layout_t* executable_layout, int32_t set,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings) {
// DO NOT SUBMIT
}
static iree_status_t BindDescriptorSetThunk(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_layout_t* executable_layout, int32_t set,
iree_hal_descriptor_set_t* descriptor_set,
iree_host_size_t dynamic_offset_count,
const iree_device_size_t* dynamic_offsets) {
// DO NOT SUBMIT
}
static iree_status_t DispatchThunk(iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable,
int32_t entry_point, uint32_t workgroup_x,
uint32_t workgroup_y,
uint32_t workgroup_z) {
// DO NOT SUBMIT
}
static iree_status_t DispatchIndirectThunk(
iree_hal_command_buffer_t* command_buffer,
iree_hal_executable_t* executable, int32_t entry_point,
iree_hal_buffer_t* workgroups_buffer,
iree_device_size_t workgroups_offset) {
// DO NOT SUBMIT
}
iree_hal_command_buffer_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_t
//===----------------------------------------------------------------------===//
class DescriptorSetBase : public ResourceBase<DescriptorSetBase> {
public:
virtual ~DescriptorSetBase() = default;
iree_hal_descriptor_set_t* base() const noexcept { return &base_; }
protected:
DescriptorSetBase() {
static const iree_hal_descriptor_set_vtable_t vtable = {
DestroyThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_descriptor_set_t* descriptor_set) {
delete reinterpret_cast<DescriptorSetBase*>(descriptor_set);
}
iree_hal_descriptor_set_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_descriptor_set_layout_t
//===----------------------------------------------------------------------===//
class DescriptorSetLayoutBase : public ResourceBase<DescriptorSetLayoutBase> {
public:
virtual ~DescriptorSetLayoutBase() = default;
iree_hal_descriptor_set_layout_t* base() const noexcept { return &base_; }
protected:
DescriptorSetBase() {
static const iree_hal_descriptor_set_layout_vtable_t vtable = {
DestroyThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(
iree_hal_descriptor_set_layout_t* descriptor_set_layout) {
delete reinterpret_cast<DescriptorSetLayoutBase*>(descriptor_set_layout);
}
iree_hal_descriptor_set_layout_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_device_t
//===----------------------------------------------------------------------===//
class DeviceBase : public ResourceBase<DeviceBase> {
public:
virtual ~DeviceBase() = default;
iree_hal_device_t* base() const noexcept { return &base_; }
protected:
DeviceBase(absl::string_view id) {
static const iree_hal_device_vtable_t vtable = {
DestroyThunk,
CreateBufferThunk,
CreateCommandBufferThunk,
CreateDescriptorSetThunk,
CreateDescriptorSetLayoutThunk,
CreateEventThunk,
CreateExecutableCacheThunk,
CreateExecutableLayoutThunk,
CreateSemaphoreThunk,
QueueSubmitThunk,
WaitSemaphoresWithDeadlineThunk,
WaitSemaphoresWithTimeoutThunk,
WaitIdleWithDeadlineThunk,
WaitIdleWithTimeoutThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
base_.id = iree_make_string_view(id.data(), id.size());
}
private:
static void DestroyThunk(iree_hal_device_t* device) {
delete reinterpret_cast<DeviceBase*>(device);
}
static iree_status_t CreateBufferThunk(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_allocator_t allocator,
iree_hal_command_buffer_t** out_command_buffer);
static iree_status_t CreateCommandBufferThunk(
iree_hal_device_t* device, iree_hal_command_buffer_mode_t mode,
iree_hal_command_category_t command_categories,
iree_allocator_t allocator,
iree_hal_command_buffer_t** out_command_buffer);
static iree_status_t CreateDescriptorSetThunk(
iree_hal_device_t* device, iree_hal_descriptor_set_layout_t* set_layout,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_binding_t* bindings,
iree_allocator_t allocator,
iree_hal_descriptor_set_t** out_descriptor_set);
static iree_status_t CreateDescriptorSetLayoutThunk(
iree_hal_device_t* device,
iree_hal_descriptor_set_layout_usage_type_t usage_type,
iree_host_size_t binding_count,
const iree_hal_descriptor_set_layout_binding_t* bindings,
iree_allocator_t allocator,
iree_hal_descriptor_set_layout_t** out_descriptor_set_layout);
static iree_status_t CreateEventThunk(iree_hal_device_t* device,
iree_allocator_t allocator,
iree_hal_event_t** out_event);
static iree_status_t CreateExecutableCacheThunk(
iree_hal_device_t* device, iree_string_view_t identifier,
iree_allocator_t allocator,
iree_hal_executable_cache_t** out_executable_cache);
static iree_status_t CreateExecutableLayoutThunk(
iree_hal_device_t* device, iree_host_size_t set_layout_count,
iree_hal_descriptor_set_layout_t** set_layouts,
iree_host_size_t push_constants, iree_allocator_t allocator,
iree_hal_executable_layout_t** out_executable_layout);
static iree_status_t CreateSemaphoreThunk(
iree_hal_device_t* device, uint64_t initial_value,
iree_allocator_t allocator, iree_hal_semaphore_t** out_semaphore);
static iree_status_t QueueSubmitThunk(
iree_hal_device_t* device, iree_hal_command_category_t command_categories,
uint64_t queue_affinity, iree_host_size_t batch_count,
const iree_hal_submission_batch_t* batches);
static iree_status_t WaitSemaphoresWithDeadlineThunk(
iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t* semaphore_list, iree_time_t deadline_ns);
static iree_status_t WaitSemaphoresWithTimeoutThunk(
iree_hal_device_t* device, iree_hal_wait_mode_t wait_mode,
const iree_hal_semaphore_list_t* semaphore_list,
iree_duration_t timeout_ns);
static iree_status_t WaitIdleWithDeadlineThunk(iree_hal_device_t* device,
iree_time_t deadline_ns);
static iree_status_t WaitIdleWithTimeoutThunk(iree_hal_device_t* device,
iree_duration_t timeout_ns);
iree_hal_device_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_driver_t
//===----------------------------------------------------------------------===//
class DriverBase : public ResourceBase<DriverBase> {
public:
virtual ~DriverBase() = default;
iree_hal_driver_t* base() const noexcept { return &base_; }
protected:
DriverBase() {
static const iree_hal_driver_vtable_t vtable = {
DestroyThunk,
QueryAvailableDevicesThunk,
CreateDeviceThunk,
CreateDefaultDeviceThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_driver_t* driver) {
delete reinterpret_cast<DriverBase*>(driver);
}
static iree_status_t QueryAvailableDevicesThunk(
iree_hal_driver_t* driver, iree_allocator_t allocator,
iree_hal_device_info_t** out_device_infos,
iree_host_size_t* out_device_info_count) {
//
}
static iree_status_t CreateDeviceThunk(iree_hal_driver_t* driver,
iree_hal_device_id_t device_id,
iree_allocator_t allocator,
iree_hal_device_t** out_device) {
//
}
static iree_status_t CreateDefaultDeviceThunk(
iree_hal_driver_t* driver, iree_allocator_t allocator,
iree_hal_device_t** out_device) {
//
}
iree_hal_driver_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_event_t
//===----------------------------------------------------------------------===//
class EventBase : public ResourceBase<EventBase> {
public:
virtual ~EventBase() = default;
iree_hal_event_t* base() const noexcept { return &base_; }
protected:
EventBase() {
static const iree_hal_event_vtable_t vtable = {
DestroyThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_event_t* event) {
delete reinterpret_cast<EventBase*>(event);
}
iree_hal_event_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_executable_t
//===----------------------------------------------------------------------===//
class ExecutableBase : public ResourceBase<ExecutableBase> {
public:
virtual ~ExecutableBase() = default;
iree_hal_executable_t* base() const noexcept { return &base_; }
protected:
ExecutableBase() {
static const iree_hal_executable_vtable_t vtable = {
DestroyThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_executable_t* executable) {
delete reinterpret_cast<ExecutableBase*>(executable);
}
iree_hal_executable_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_executable_cache_t
//===----------------------------------------------------------------------===//
class ExecutableCacheBase : public ResourceBase<ExecutableCacheBase> {
public:
virtual ~ExecutableCacheBase() = default;
iree_hal_executable_cache_t* base() const noexcept { return &base_; }
virtual bool CanPrepareFormat(ExecutableFormat format) const = 0;
virtual StatusOr<ref_ptr<ExecutableBase>> PrepareExecutable(
iree_hal_executable_layout_t* executable_layout,
iree_hal_executable_caching_mode_t caching_mode,
absl::Span<const uint8_t> executable_data,
iree_allocator_t allocator) = 0;
protected:
ExecutableCacheBase() {
static const iree_hal_executable_cache_vtable_t vtable = {
DestroyThunk,
CanPrepareFormatThunk,
PrepareExecutableThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_executable_cache_t* executable_cache) {
delete reinterpret_cast<ExecutableCacheBase*>(executable_cache);
}
static bool CanPrepareFormatThunk(
iree_hal_executable_cache_t* executable_cache,
iree_hal_executable_format_t format) {
return reinterpret_cast<ExecutableCacheBase*>(executable_cache)
->CanPrepareFormat(static_cast<ExecutableFormat>(format));
}
static iree_status_t PrepareExecutableThunk(
iree_hal_executable_cache_t* executable_cache,
iree_hal_executable_layout_t* executable_layout,
iree_hal_executable_caching_mode_t caching_mode,
iree_const_byte_span_t executable_data, iree_allocator_t allocator,
iree_hal_executable_t** out_executable) {
IREE_ASSIGN_OR_RETURN(
auto executable,
reinterpret_cast<ExecutableCacheBase*>(executable_cache)
->PrepareExecutable(
reinterpret_cast<ExecutableLayoutBase*>(executable_layout),
static_cast<iree_hal_executable_caching_mode_t>(caching_mode),
absl::MakeConstSpan(executable_data.data, executable_data.size),
allocator));
*out_executable =
reinterpret_cast<iree_hal_executable_t*>(executable.release());
return iree_ok_status();
}
iree_hal_executable_cache_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_executable_layout_t
//===----------------------------------------------------------------------===//
class ExecutableLayoutBase : public ResourceBase<ExecutableLayoutBase> {
public:
virtual ~ExecutableLayoutBase() = default;
iree_hal_executable_layout_t* base() const noexcept { return &base_; }
protected:
ExecutableLayoutBase() {
static const iree_hal_executable_layout_vtable_t vtable = {
DestroyThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_executable_layout_t* executable_layout) {
delete reinterpret_cast<ExecutableLayoutBase*>(executable_layout);
}
iree_hal_executable_layout_t base_;
};
//===----------------------------------------------------------------------===//
// iree_hal_semaphore_t
//===----------------------------------------------------------------------===//
class SemaphoreBase : public ResourceBase<SemaphoreBase> {
public:
virtual ~SemaphoreBase() = default;
iree_hal_semaphore_t* base() const noexcept { return &base_; }
virtual StatusOr<uint64_t> Query() = 0;
virtual Status Signal(uint64_t value) = 0;
virtual void Fail(Status status) = 0;
virtual Status Wait(uint64_t value, Time deadline_ns) = 0;
virtual Status Wait(uint64_t value, Duration timeout_ns) = 0;
protected:
SemaphoreBase() {
static const iree_hal_semaphore_vtable_t vtable = {
DestroyThunk, QueryThunk, SignalThunk, FailThunk,
WaitWithDeadlineThunk, WaitWithTimeoutThunk,
};
memset(&base_, 0, sizeof(base_));
base_.vtable = &vtable;
iree_atomic_ref_count_init(&base_.resource.ref_count);
}
private:
static void DestroyThunk(iree_hal_semaphore_t* semaphore) {
delete reinterpret_cast<SemaphoreBase*>(semaphore);
}
static iree_status_t QueryThunk(iree_hal_semaphore_t* semaphore,
uint64_t* out_value) {
IREE_ASSIGN_OR_RETURN(uint64_t value,
reinterpret_cast<SemaphoreBase*>(semaphore)->Query());
*out_value = value;
return iree_ok_status();
}
static iree_status_t SignalThunk(iree_hal_semaphore_t* semaphore,
uint64_t new_value) {
return reinterpret_cast<SemaphoreBase*>(semaphore)->Signal(new_value);
}
static void FailThunk(iree_hal_semaphore_t* semaphore, iree_status_t status) {
reinterpret_cast<SemaphoreBase*>(semaphore)->Fail(Status(status));
}
static iree_status_t WaitWithDeadlineThunk(iree_hal_semaphore_t* semaphore,
uint64_t value,
iree_time_t deadline_ns) {
return reinterpret_cast<SemaphoreBase*>(semaphore)->WaitWithDeadline(
value, deadline_ns);
}
static iree_status_t WaitWithTimeoutThunk(iree_hal_semaphore_t* semaphore,
uint64_t value,
iree_duration_t timeout_ns) {
return reinterpret_cast<SemaphoreBase*>(semaphore)->WaitWithTimeout(
value, timeout_ns);
}
iree_hal_semaphore_t base_;
};
} // namespace hal
} // namespace iree
#endif // __cplusplus
#endif // IREE_HAL_API_INTERFACES_CC_H_
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment