Skip to content

Instantly share code, notes, and snippets.

@philipturner
Created March 15, 2024 20:50
Show Gist options
  • Save philipturner/87c3584f91995081da88a430e651a1d7 to your computer and use it in GitHub Desktop.
Save philipturner/87c3584f91995081da88a430e651a1d7 to your computer and use it in GitHub Desktop.
/* -------------------------------------------------------------------------- *
* OpenMM *
* -------------------------------------------------------------------------- *
* This is part of the OpenMM molecular simulation toolkit originating from *
* Simbios, the NIH National Center for Physics-Based Simulation of *
* Biological Structures at Stanford, funded under the NIH Roadmap for *
* Medical Research, grant U54 GM072970. See https://simtk.org. *
* *
* Portions copyright (c) 2015-2021 Stanford University and the Authors. *
* Portions copyright (c) 2021 Advanced Micro Devices, Inc. All Rights *
* Reserved. *
* Authors: Peter Eastman *
* Contributors: *
* *
* Permission is hereby granted, free of charge, to any person obtaining a *
* copy of this software and associated documentation files (the "Software"), *
* to deal in the Software without restriction, including without limitation *
* the rights to use, copy, modify, merge, publish, distribute, sublicense, *
* and/or sell copies of the Software, and to permit persons to whom the *
* Software is furnished to do so, subject to the following conditions: *
* *
* The above copyright notice and this permission notice shall be included in *
* all copies or substantial portions of the Software. *
* *
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR *
* IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, *
* FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL *
* THE AUTHORS, CONTRIBUTORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, *
* DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR *
* OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE *
* USE OR OTHER DEALINGS IN THE SOFTWARE. *
* -------------------------------------------------------------------------- */
#include "HipCompilerKernels.h"
#include "openmm/OpenMMException.h"
#include <sstream>
#include <hip/hip_common.h>
#include <stdlib.h>
#include <Windows.h>
#include <iostream>
#include <string>
/**
*
* @addtogroup GlobalDefs
* @{
*
*/
/**
* hiprtc error code
*/
typedef enum hiprtcResult {
HIPRTC_SUCCESS = 0, ///< Success
HIPRTC_ERROR_OUT_OF_MEMORY = 1, ///< Out of memory
HIPRTC_ERROR_PROGRAM_CREATION_FAILURE = 2, ///< Failed to create program
HIPRTC_ERROR_INVALID_INPUT = 3, ///< Invalid input
HIPRTC_ERROR_INVALID_PROGRAM = 4, ///< Invalid program
HIPRTC_ERROR_INVALID_OPTION = 5, ///< Invalid option
HIPRTC_ERROR_COMPILATION = 6, ///< Compilation error
HIPRTC_ERROR_BUILTIN_OPERATION_FAILURE = 7, ///< Failed in builtin operation
HIPRTC_ERROR_NO_NAME_EXPRESSIONS_AFTER_COMPILATION = 8, ///< No name expression after compilation
HIPRTC_ERROR_NO_LOWERED_NAMES_BEFORE_COMPILATION = 9, ///< No lowered names before compilation
HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID = 10, ///< Invalid name expression
HIPRTC_ERROR_INTERNAL_ERROR = 11, ///< Internal error
HIPRTC_ERROR_LINKING = 100 ///< Error in linking
} hiprtcResult;
/**
* hiprtc JIT option
*/
typedef enum hiprtcJIT_option {
HIPRTC_JIT_MAX_REGISTERS = 0, ///< Maximum registers
HIPRTC_JIT_THREADS_PER_BLOCK, ///< Thread per block
HIPRTC_JIT_WALL_TIME, ///< Time from aall clock
HIPRTC_JIT_INFO_LOG_BUFFER, ///< Log buffer info
HIPRTC_JIT_INFO_LOG_BUFFER_SIZE_BYTES, ///< Log buffer size in bytes
HIPRTC_JIT_ERROR_LOG_BUFFER, ///< Log buffer error
HIPRTC_JIT_ERROR_LOG_BUFFER_SIZE_BYTES, ///< Log buffer size in bytes
HIPRTC_JIT_OPTIMIZATION_LEVEL, ///< Optimization level
HIPRTC_JIT_TARGET_FROM_HIPCONTEXT, ///<
HIPRTC_JIT_TARGET, ///< JIT target
HIPRTC_JIT_FALLBACK_STRATEGY, ///< Fallback strategy
HIPRTC_JIT_GENERATE_DEBUG_INFO, ///< Generate debug information
HIPRTC_JIT_LOG_VERBOSE, ///< Log verbose
HIPRTC_JIT_GENERATE_LINE_INFO, ///< Generate line information
HIPRTC_JIT_CACHE_MODE, ///< Cache mode
HIPRTC_JIT_NEW_SM3X_OPT, ///< New SM3X option
HIPRTC_JIT_FAST_COMPILE, ///< Fast compile
HIPRTC_JIT_GLOBAL_SYMBOL_NAMES, ///< Global symbol names
HIPRTC_JIT_GLOBAL_SYMBOL_ADDRESS, ///< Global symbol address
HIPRTC_JIT_GLOBAL_SYMBOL_COUNT, ///< Global symbol count
HIPRTC_JIT_LTO, ///< LTO
HIPRTC_JIT_FTZ, ///< FTZ
HIPRTC_JIT_PREC_DIV, ///< Prec_VIV
HIPRTC_JIT_PREC_SQRT, ///< PREC_SQRT
HIPRTC_JIT_FMA, ///< FMA
HIPRTC_JIT_NUM_OPTIONS, ///< Number of options
HIPRTC_JIT_IR_TO_ISA_OPT_EXT = 10000, ///< AMD only. Linker options to be passed on to
HIPRTC_JIT_IR_TO_ISA_OPT_COUNT_EXT, ///< AMD only. Count of linker options
} hiprtcJIT_option;
/**
* hiprtc JIT input type
*/
typedef enum hiprtcJITInputType {
HIPRTC_JIT_INPUT_CUBIN = 0, ///< Input cubin
HIPRTC_JIT_INPUT_PTX, ///< Input PTX
HIPRTC_JIT_INPUT_FATBINARY, ///< Input fat binary
HIPRTC_JIT_INPUT_OBJECT, ///< Input object
HIPRTC_JIT_INPUT_LIBRARY, ///< Input library
HIPRTC_JIT_INPUT_NVVM, ///< Input NVVM
HIPRTC_JIT_NUM_LEGACY_INPUT_TYPES, ///< Number of legacy input type
HIPRTC_JIT_INPUT_LLVM_BITCODE = 100, ///< LLVM bitcode
HIPRTC_JIT_INPUT_LLVM_BUNDLED_BITCODE = 101, ///< LLVM bundled bitcode
HIPRTC_JIT_INPUT_LLVM_ARCHIVES_OF_BUNDLED_BITCODE = 102, ///< LLVM archives of boundled bitcode
HIPRTC_JIT_NUM_INPUT_TYPES = (HIPRTC_JIT_NUM_LEGACY_INPUT_TYPES + 3)
} hiprtcJITInputType;
/**
* @}
*/
// Utility function for loading symbols at runtime.
// TODO: Save this to the USB drive and upload to a GitHub gist, if this works.
void* loadHiprtcSymbol(const char* symbolName) {
const char* libraryPath = "C:/Program Files/AMD/ROCm/5.7/lib/hiprtc.lib";
HMODULE loadedLibrary = LoadLibraryA(libraryPath);
if (loadedLibrary) {
std::cout << "Successfully loaded 'hiprtc' dylib." << std::endl;
}
else {
std::cout << "Could not load 'hiprtc' dylib." << std::endl;
}
FARPROC loadedSymbol = GetProcAddress(loadedLibrary, symbolName);
if (loadedSymbol) {
std::cout << "Successfully loaded '" << std::string(symbolName) << "' symbol." << std::endl;
}
else {
std::cout << "Could not load '" << std::string(symbolName) << "' symbol." << std::endl;
}
return (void*)loadedSymbol;
}
/**
* hiprtc link state
*
*/
typedef struct ihiprtcLinkState* hiprtcLinkState;
/**
* @ingroup Runtime
*
* @brief Returns text string message to explain the error which occurred
*
* @param [in] result code to convert to string.
* @returns const char pointer to the NULL-terminated error string
*
* @warning In HIP, this function returns the name of the error,
* if the hiprtc result is defined, it will return "Invalid HIPRTC error code"
*
* @see hiprtcResult
*/
const char* hiprtcGetErrorString(hiprtcResult result) {
void* pointer = loadHiprtcSymbol("hiprtcGetErrorString");
auto casted = (const char* (*)(hiprtcResult))pointer;
return casted(result);
}
/**
* @ingroup Runtime
* @brief Sets the parameters as major and minor version.
*
* @param [out] major HIP Runtime Compilation major version.
* @param [out] minor HIP Runtime Compilation minor version.
*
* @returns #HIPRTC_ERROR_INVALID_INPUT, #HIPRTC_SUCCESS
*
*/
hiprtcResult hiprtcVersion(int* major, int* minor) {
void* pointer = loadHiprtcSymbol("hiprtcVersion");
auto casted = (hiprtcResult(*)(int*, int*))pointer;
return casted(major, minor);
}
/**
* hiprtc program
*
*/
typedef struct _hiprtcProgram* hiprtcProgram;
/**
* @ingroup Runtime
* @brief Adds the given name exprssion to the runtime compilation program.
*
* @param [in] prog runtime compilation program instance.
* @param [in] name_expression const char pointer to the name expression.
* @returns #HIPRTC_SUCCESS
*
* If const char pointer is NULL, it will return #HIPRTC_ERROR_INVALID_INPUT.
*
* @see hiprtcResult
*/
hiprtcResult hiprtcAddNameExpression(hiprtcProgram prog,
const char* name_expression) {
void* pointer = loadHiprtcSymbol("hiprtcAddNameExpression");
auto casted = (hiprtcResult(*)(hiprtcProgram, const char*))pointer;
return casted(prog, name_expression);
}
/**
* @ingroup Runtime
* @brief Compiles the given runtime compilation program.
*
* @param [in] prog runtime compilation program instance.
* @param [in] numOptions number of compiler options.
* @param [in] options compiler options as const array of strins.
* @returns #HIPRTC_SUCCESS
*
* If the compiler failed to build the runtime compilation program,
* it will return #HIPRTC_ERROR_COMPILATION.
*
* @see hiprtcResult
*/
hiprtcResult hiprtcCompileProgram(hiprtcProgram prog,
int numOptions,
const char** options) {
void* pointer = loadHiprtcSymbol("hiprtcCompileProgram");
auto casted = (hiprtcResult(*)(hiprtcProgram, int, const char**))pointer;
return casted(prog, numOptions, options);
}
/**
* @ingroup Runtime
* @brief Creates an instance of hiprtcProgram with the given input parameters,
* and sets the output hiprtcProgram prog with it.
*
* @param [in, out] prog runtime compilation program instance.
* @param [in] src const char pointer to the program source.
* @param [in] name const char pointer to the program name.
* @param [in] numHeaders number of headers.
* @param [in] headers array of strings pointing to headers.
* @param [in] includeNames array of strings pointing to names included in program source.
* @returns #HIPRTC_SUCCESS
*
* Any invalide input parameter, it will return #HIPRTC_ERROR_INVALID_INPUT
* or #HIPRTC_ERROR_INVALID_PROGRAM.
*
* If failed to create the program, it will return #HIPRTC_ERROR_PROGRAM_CREATION_FAILURE.
*
* @see hiprtcResult
*/
hiprtcResult hiprtcCreateProgram(hiprtcProgram* prog,
const char* src,
const char* name,
int numHeaders,
const char** headers,
const char** includeNames) {
void* pointer = loadHiprtcSymbol("hiprtcCreateProgram");
auto casted = (hiprtcResult(*)(hiprtcProgram*, const char*, const char*, int, const char**, const char**))pointer;
return casted(prog, src, name, numHeaders, headers, includeNames);
}
/**
* @brief Destroys an instance of given hiprtcProgram.
* @ingroup Runtime
* @param [in] prog runtime compilation program instance.
* @returns #HIPRTC_SUCCESS
*
* If prog is NULL, it will return #HIPRTC_ERROR_INVALID_INPUT.
*
* @see hiprtcResult
*/
hiprtcResult hiprtcDestroyProgram(hiprtcProgram* prog) {
void* pointer = loadHiprtcSymbol("hiprtcDestroyProgram");
auto casted = (hiprtcResult(*)(hiprtcProgram*))pointer;
return casted(prog);
}
/**
* @brief Gets the lowered (mangled) name from an instance of hiprtcProgram with the given input parameters,
* and sets the output lowered_name with it.
* @ingroup Runtime
* @param [in] prog runtime compilation program instance.
* @param [in] name_expression const char pointer to the name expression.
* @param [in, out] lowered_name const char array to the lowered (mangled) name.
* @returns #HIPRTC_SUCCESS
*
* If any invalide nullptr input parameters, it will return #HIPRTC_ERROR_INVALID_INPUT
*
* If name_expression is not found, it will return #HIPRTC_ERROR_NAME_EXPRESSION_NOT_VALID
*
* If failed to get lowered_name from the program, it will return #HIPRTC_ERROR_COMPILATION.
*
* @see hiprtcResult
*/
hiprtcResult hiprtcGetLoweredName(hiprtcProgram prog,
const char* name_expression,
const char** lowered_name);
/**
* @brief Gets the log generated by the runtime compilation program instance.
* @ingroup Runtime
* @param [in] prog runtime compilation program instance.
* @param [out] log memory pointer to the generated log.
* @returns #HIPRTC_SUCCESS
*
* @see hiprtcResult
*/
hiprtcResult hiprtcGetProgramLog(hiprtcProgram prog, char* log) {
void* pointer = loadHiprtcSymbol("hiprtcGetProgramLog");
auto casted = (hiprtcResult(*)(hiprtcProgram, char*))pointer;
return casted(prog, log);
}
/**
* @brief Gets the size of log generated by the runtime compilation program instance.
*
* @param [in] prog runtime compilation program instance.
* @param [out] logSizeRet size of generated log.
* @returns #HIPRTC_SUCCESS
*
* @see hiprtcResult
*/
hiprtcResult hiprtcGetProgramLogSize(hiprtcProgram prog,
size_t* logSizeRet) {
void* pointer = loadHiprtcSymbol("hiprtcGetProgramLogSize");
auto casted = (hiprtcResult(*)(hiprtcProgram, size_t*))pointer;
return casted(prog, logSizeRet);
}
/**
* @brief Gets the pointer of compilation binary by the runtime compilation program instance.
* @ingroup Runtime
* @param [in] prog runtime compilation program instance.
* @param [out] code char pointer to binary.
* @returns #HIPRTC_SUCCESS
*
* @see hiprtcResult
*/
hiprtcResult hiprtcGetCode(hiprtcProgram prog, char* code) {
void* pointer = loadHiprtcSymbol("hiprtcGetCode");
auto casted = (hiprtcResult(*)(hiprtcProgram, char*))pointer;
return casted(prog, code);
}
/**
* @brief Gets the size of compilation binary by the runtime compilation program instance.
* @ingroup Runtime
* @param [in] prog runtime compilation program instance.
* @param [out] codeSizeRet the size of binary.
* @returns #HIPRTC_SUCCESS
*
* @see hiprtcResult
*/
hiprtcResult hiprtcGetCodeSize(hiprtcProgram prog, size_t* codeSizeRet) {
void* pointer = loadHiprtcSymbol("hiprtcGetCodeSize");
auto casted = (hiprtcResult(*)(hiprtcProgram, size_t*))pointer;
return casted(prog, codeSizeRet);
}
/**
* @brief Gets the pointer of compiled bitcode by the runtime compilation program instance.
*
* @param [in] prog runtime compilation program instance.
* @param [out] bitcode char pointer to bitcode.
* @return HIPRTC_SUCCESS
*
* @see hiprtcResult
*/
hiprtcResult hiprtcGetBitcode(hiprtcProgram prog, char* bitcode);
/**
* @brief Gets the size of compiled bitcode by the runtime compilation program instance.
* @ingroup Runtime
*
* @param [in] prog runtime compilation program instance.
* @param [out] bitcode_size the size of bitcode.
* @returns #HIPRTC_SUCCESS
*
* @see hiprtcResult
*/
hiprtcResult hiprtcGetBitcodeSize(hiprtcProgram prog, size_t* bitcode_size);
/**
* @brief Creates the link instance via hiprtc APIs.
* @ingroup Runtime
* @param [in] num_options Number of options
* @param [in] option_ptr Array of options
* @param [in] option_vals_pptr Array of option values cast to void*
* @param [out] hip_link_state_ptr hiprtc link state created upon success
*
* @returns #HIPRTC_SUCCESS, #HIPRTC_ERROR_INVALID_INPUT, #HIPRTC_ERROR_INVALID_OPTION
*
* @see hiprtcResult
*/
hiprtcResult hiprtcLinkCreate(unsigned int num_options, hiprtcJIT_option* option_ptr,
void** option_vals_pptr, hiprtcLinkState* hip_link_state_ptr);
/**
* @brief Adds a file with bit code to be linked with options
* @ingroup Runtime
* @param [in] hip_link_state hiprtc link state
* @param [in] input_type Type of the input data or bitcode
* @param [in] file_path Path to the input file where bitcode is present
* @param [in] num_options Size of the options
* @param [in] options_ptr Array of options applied to this input
* @param [in] option_values Array of option values cast to void*
*
* @returns #HIPRTC_SUCCESS
*
* If input values are invalid, it will
* @return #HIPRTC_ERROR_INVALID_INPUT
*
* @see hiprtcResult
*/
hiprtcResult hiprtcLinkAddFile(hiprtcLinkState hip_link_state, hiprtcJITInputType input_type,
const char* file_path, unsigned int num_options,
hiprtcJIT_option* options_ptr, void** option_values);
/**
* @brief Completes the linking of the given program.
* @ingroup Runtime
* @param [in] hip_link_state hiprtc link state
* @param [in] input_type Type of the input data or bitcode
* @param [in] image Input data which is null terminated
* @param [in] image_size Size of the input data
* @param [in] name Optional name for this input
* @param [in] num_options Size of the options
* @param [in] options_ptr Array of options applied to this input
* @param [in] option_values Array of option values cast to void*
*
* @returns #HIPRTC_SUCCESS, #HIPRTC_ERROR_INVALID_INPUT
*
* If adding the file fails, it will
* @return #HIPRTC_ERROR_PROGRAM_CREATION_FAILURE
*
* @see hiprtcResult
*/
hiprtcResult hiprtcLinkAddData(hiprtcLinkState hip_link_state, hiprtcJITInputType input_type,
void* image, size_t image_size, const char* name,
unsigned int num_options, hiprtcJIT_option* options_ptr,
void** option_values);
/**
* @brief Completes the linking of the given program.
* @ingroup Runtime
* @param [in] hip_link_state hiprtc link state
* @param [out] bin_out Upon success, points to the output binary
* @param [out] size_out Size of the binary is stored (optional)
*
* @returns #HIPRTC_SUCCESS
*
* If adding the data fails, it will
* @return #HIPRTC_ERROR_LINKING
*
* @see hiprtcResult
*/
hiprtcResult hiprtcLinkComplete(hiprtcLinkState hip_link_state, void** bin_out, size_t* size_out);
/**
* @brief Deletes the link instance via hiprtc APIs.
* @ingroup Runtime
* @param [in] hip_link_state link state instance
*
* @returns #HIPRTC_SUCCESS
*
* @see hiprtcResult
*/
hiprtcResult hiprtcLinkDestroy(hiprtcLinkState hip_link_state);
using namespace OpenMM;
using namespace std;
#define CHECK_RESULT(result, prefix) \
if (result != HIPRTC_SUCCESS) { \
stringstream m; \
m<<prefix<<": "<<getErrorString(result)<<" ("<<result<<")"<<" at "<<__FILE__<<":"<<__LINE__; \
throw OpenMMException(m.str());\
}
static string getErrorString(hiprtcResult result) {
return hiprtcGetErrorString(result);
}
HipRuntimeCompilerKernel::HipRuntimeCompilerKernel(const std::string& name, const Platform& platform) : HipCompilerKernel(name, platform) {
}
vector<char> HipRuntimeCompilerKernel::createModule(const string& source, const string& flags, HipContext& cu) {
// Split the command line flags into an array of options.
stringstream flagsStream(flags);
string flag;
vector<string> splitFlags;
while (flagsStream >> flag)
splitFlags.push_back(flag);
int numOptions = splitFlags.size();
vector<const char*> options(numOptions);
for (int i = 0; i < numOptions; i++)
options[i] = &splitFlags[i][0];
// Compile the program to HSACO.
hiprtcProgram program;
CHECK_RESULT(hiprtcCreateProgram(&program, source.c_str(), NULL, 0, NULL, NULL), "Error creating program");
try {
hiprtcResult result = hiprtcCompileProgram(program, options.size(), &options[0]);
if (result != HIPRTC_SUCCESS) {
size_t logSize;
hiprtcGetProgramLogSize(program, &logSize);
vector<char> log(logSize);
hiprtcGetProgramLog(program, &log[0]);
throw OpenMMException("Error compiling program: " + string(&log[0]));
}
size_t codeSize;
hiprtcGetCodeSize(program, &codeSize);
vector<char> code(codeSize);
hiprtcGetCode(program, &code[0]);
hiprtcDestroyProgram(&program);
return code;
}
catch (...) {
hiprtcDestroyProgram(&program);
throw;
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment