Skip to content

Instantly share code, notes, and snippets.

@cgmb
Created March 31, 2023 11:01
Show Gist options
  • Save cgmb/9b2d4cb1db3244502bfb165779f5aa54 to your computer and use it in GitHub Desktop.
Save cgmb/9b2d4cb1db3244502bfb165779f5aa54 to your computer and use it in GitHub Desktop.
Example of C++-style error handling
/* ************************************************************************
* Copyright (c) 2023 Advanced Micro Devices, Inc.
* ************************************************************************ */
#include <new>
#include "rocsolver_rfinfo.hpp"
extern "C" rocblas_status rocsolver_create_rfinfo(rocsolver_rfinfo* rfinfo, rocblas_handle handle)
{
#ifdef ROCSOLVER_WITH_ROCSPARSE
if(!handle)
return rocblas_status_invalid_handle;
if(!rfinfo)
return rocblas_status_invalid_pointer;
try
{
*rfinfo = new rocsolver_rfinfo_(handle);
}
catch(const std::bad_alloc&)
{
return rocblas_status_memory_error;
}
catch(rocblas_status status)
{
return status;
}
catch(...)
{
return rocblas_status_internal_error;
}
return rocblas_status_success;
#else
return rocblas_status_not_implemented;
#endif
}
extern "C" rocblas_status rocsolver_destroy_rfinfo(rocsolver_rfinfo rfinfo)
{
#ifdef ROCSOLVER_WITH_ROCSPARSE
if(!rfinfo)
return rocblas_status_invalid_pointer;
rocblas_status status = rfinfo->reset();
delete rfinfo;
return status;
#else
return rocblas_status_not_implemented;
#endif
}
/* ************************************************************************
* Copyright (c) 2023 Advanced Micro Devices, Inc.
* ************************************************************************ */
#pragma once
#include <memory>
#include <type_traits>
#include "rocblas.hpp"
#include "rocsolver/rocsolver.h"
#include "rocsparse.hpp"
struct sparse_handle_destroyer
{
void operator()(rocsparse_handle handle) const
{
rocsparse_destroy_handle(handle);
}
};
struct sparse_mat_descr_destroyer
{
void operator()(rocsparse_mat_descr descr) const
{
rocsparse_destroy_mat_descr(descr);
}
};
struct sparse_mat_info_destroyer
{
void operator()(rocsparse_mat_info info) const
{
rocsparse_destroy_mat_info(info);
}
};
using unique_rocsparse_handle
= std::unique_ptr<std::remove_pointer_t<rocsparse_handle>, sparse_handle_destroyer>;
using unique_rocsparse_mat_descr
= std::unique_ptr<std::remove_pointer_t<rocsparse_mat_descr>, sparse_mat_descr_destroyer>;
using unique_rocsparse_mat_info
= std::unique_ptr<std::remove_pointer_t<rocsparse_mat_info>, sparse_mat_info_destroyer>;
unique_rocsparse_handle make_rocsparse_handle()
{
rocsparse_handle sphandle;
THROW_IF_ROCSPARSE_ERROR(rocsparse_create_handle(&sphandle));
return unique_rocsparse_handle{sphandle};
}
unique_rocsparse_mat_descr make_rocsparse_mat_descr()
{
rocsparse_mat_descr descr;
THROW_IF_ROCSPARSE_ERROR(rocsparse_create_mat_descr(&descr));
return unique_rocsparse_mat_descr{descr};
}
unique_rocsparse_mat_info make_rocsparse_mat_info()
{
rocsparse_mat_info info;
THROW_IF_ROCSPARSE_ERROR(rocsparse_create_mat_info(&info));
return unique_rocsparse_mat_info{info};
}
struct rocsolver_rfinfo_
{
unique_rocsparse_handle sphandle;
unique_rocsparse_mat_descr descrL;
unique_rocsparse_mat_descr descrU;
unique_rocsparse_mat_descr descrT;
unique_rocsparse_mat_info infoL;
unique_rocsparse_mat_info infoU;
unique_rocsparse_mat_info infoT;
rocsparse_solve_policy solve_policy;
rocsparse_analysis_policy analysis_policy;
// constructor
rocsolver_rfinfo_(rocblas_handle handle)
{
init_policy();
// create sparse handle
sphandle = make_rocsparse_handle();
// use handle->stream to sphandle->stream
hipStream_t stream;
THROW_IF_ROCBLAS_ERROR(rocblas_get_stream(handle, &stream));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_stream(sphandle.get(), stream));
// create and set matrix descriptors
descrL = make_rocsparse_mat_descr();
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_type(descrL.get(), rocsparse_matrix_type_general));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_index_base(descrL.get(), rocsparse_index_base_zero));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_fill_mode(descrL.get(), rocsparse_fill_mode_lower));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_diag_type(descrL.get(), rocsparse_diag_type_unit));
descrU = make_rocsparse_mat_descr();
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_type(descrU.get(), rocsparse_matrix_type_general));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_index_base(descrU.get(), rocsparse_index_base_zero));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_fill_mode(descrU.get(), rocsparse_fill_mode_upper));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_diag_type(descrU.get(), rocsparse_diag_type_non_unit));
descrT = make_rocsparse_mat_descr();
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_type(descrT.get(), rocsparse_matrix_type_general));
THROW_IF_ROCSPARSE_ERROR(rocsparse_set_mat_index_base(descrT.get(), rocsparse_index_base_zero));
// create info holders
infoL = make_rocsparse_mat_info();
infoU = make_rocsparse_mat_info();
infoT = make_rocsparse_mat_info();
}
// acts like destructor but can return status
rocblas_status reset()
{
ROCSPARSE_CHECK(rocsparse_destroy_handle(sphandle.release()));
ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descrL.release()));
ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descrU.release()));
ROCSPARSE_CHECK(rocsparse_destroy_mat_descr(descrT.release()));
ROCSPARSE_CHECK(rocsparse_destroy_mat_info(infoL.release()));
ROCSPARSE_CHECK(rocsparse_destroy_mat_info(infoU.release()));
ROCSPARSE_CHECK(rocsparse_destroy_mat_info(infoT.release()));
init_policy();
return rocblas_status_success;
}
void init_policy()
{
solve_policy = rocsparse_solve_policy_auto;
analysis_policy = rocsparse_analysis_policy_reuse;
}
};
typedef struct rocsolver_rfinfo_* rocsolver_rfinfo;
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment