Last active
November 28, 2023 02:55
-
-
Save antholzer/df8041bb633120411eae82940668c46f to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <complex> | |
#include <iostream> | |
#include <vector> | |
#include <hip/hip_runtime_api.h> | |
#include "rocfft.h" | |
int main(int argc, char* argv[]) | |
{ | |
std::cout << "rocFFT complex 2d batched FFT example\n"; | |
// The problem size | |
const size_t Nx = (argc < 2) ? 4 : atoi(argv[1]); | |
const size_t Ny = (argc < 3) ? 4 : atoi(argv[2]); | |
const size_t Nz = (argc < 4) ? 4 : atoi(argv[3]); | |
std::cout << "Nx: " << Nx << "\tNy: " << Ny << "\tNz: " << Nz | |
<< std::endl; | |
// Initialize data on the host (column major) | |
std::cout << "Input:\n"; | |
std::vector<std::complex<float>> cx(Nx * Ny * Nz); | |
for(size_t i = 0; i < Nz; i++) | |
{ | |
for(size_t j = 0; j < Ny; j++) | |
{ | |
for(size_t k = 0; k < Nx; k++) | |
{ | |
const size_t pos = k + (i * Nz + j) * Ny; | |
cx[pos] = std::complex<float>(i + j + k, 0); | |
} | |
} | |
} | |
for(size_t i = 0; i < Nz; i++) | |
{ | |
for(size_t j = 0; j < Ny; j++) | |
{ | |
for(size_t k = 0; k < Nx; k++) | |
{ | |
const size_t pos = k + (i * Nz + j) * Ny; | |
std::cout << cx[pos] << " "; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
rocfft_setup(); | |
// Create HIP device object and copy data: | |
float2* x = NULL; | |
hipMalloc(&x, cx.size() * sizeof(decltype(cx)::value_type)); | |
hipMemcpy(x, cx.data(), cx.size() * sizeof(decltype(cx)::value_type), hipMemcpyHostToDevice); | |
const size_t offsets[3] = {0, 0, 0}; | |
// Nz batchdim | |
// const size_t lengths[2] = {Nx, Ny}; | |
// const size_t in_strides[2] = {1, Nx}; | |
// const size_t out_strides[2] = {1, Nx}; | |
// const size_t distance = Nx*Ny; | |
// const size_t batch = Nz; | |
// Nx batchdim | |
const size_t lengths[2] = {Ny, Nz}; | |
const size_t in_strides[2] = {Nx, Nx*Ny}; | |
const size_t out_strides[2] = {Nx, Nx*Ny}; | |
const size_t distance = 1; | |
const size_t batch = Nx; | |
rocfft_status status = rocfft_status_success; | |
// Create plans | |
rocfft_plan forward = NULL; | |
rocfft_plan_description desc = NULL; | |
rocfft_plan_description_create(&desc); | |
rocfft_plan_description_set_data_layout(desc, | |
rocfft_array_type_complex_interleaved, rocfft_array_type_complex_interleaved, | |
offsets, offsets, | |
2, in_strides, distance, | |
2, out_strides, distance); | |
status = rocfft_plan_create(&forward, | |
rocfft_placement_inplace, | |
rocfft_transform_type_complex_forward, | |
rocfft_precision_single, | |
2, // Dimensions | |
lengths, // lengths | |
batch, // Number of transforms | |
desc); // Description | |
assert(status == rocfft_status_success); | |
// We may need work memory, which is passed via rocfft_execution_info | |
rocfft_execution_info forwardinfo = NULL; | |
status = rocfft_execution_info_create(&forwardinfo); | |
assert(status == rocfft_status_success); | |
size_t fbuffersize = 0; | |
status = rocfft_plan_get_work_buffer_size(forward, &fbuffersize); | |
assert(status == rocfft_status_success); | |
void* fbuffer = NULL; | |
hipMalloc(&fbuffer, fbuffersize); | |
status = rocfft_execution_info_set_work_buffer(forwardinfo, fbuffer, fbuffersize); | |
assert(status == rocfft_status_success); | |
// Create plans | |
rocfft_plan backward = NULL; | |
status = rocfft_plan_create(&backward, | |
rocfft_placement_inplace, | |
rocfft_transform_type_complex_inverse, | |
rocfft_precision_single, | |
2, // Dimensions | |
lengths, // lengths | |
batch, // Number of transforms | |
desc); // Description | |
assert(status == rocfft_status_success); | |
rocfft_execution_info backwardinfo = NULL; | |
status = rocfft_execution_info_create(&backwardinfo); | |
assert(status == rocfft_status_success); | |
size_t bbuffersize = 0; | |
status = rocfft_plan_get_work_buffer_size(backward, &bbuffersize); | |
assert(status == rocfft_status_success); | |
void* bbuffer = NULL; | |
hipMalloc(&bbuffer, bbuffersize); | |
status = rocfft_execution_info_set_work_buffer(backwardinfo, bbuffer, bbuffersize); | |
assert(status == rocfft_status_success); | |
// Execute the forward transform | |
status = rocfft_execute(forward, | |
(void**)&x, // in_buffer | |
NULL, | |
forwardinfo); // execution info | |
assert(status == rocfft_status_success); | |
// Copy result back to host | |
std::vector<std::complex<float>> cy(cx.size()); | |
hipMemcpy(cy.data(), x, cy.size() * sizeof(decltype(cy)::value_type), hipMemcpyDeviceToHost); | |
std::cout << "Transformed:\n"; | |
for(size_t i = 0; i < Nz; i++) | |
{ | |
for(size_t j = 0; j < Ny; j++) | |
{ | |
for(size_t k = 0; k < Nx; k++) | |
{ | |
const size_t pos = k + (i * Nz + j) * Ny; | |
std::cout << cy[pos] << " "; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
// Execute the backward transform | |
status = rocfft_execute(backward, | |
(void**)&x, // in_buffer | |
NULL, | |
backwardinfo); // execution info | |
assert(status == rocfft_status_success); | |
std::cout << "Transformed back:\n"; | |
hipMemcpy(cy.data(), x, cy.size() * sizeof(decltype(cy)::value_type), hipMemcpyDeviceToHost); | |
for(size_t i = 0; i < Nz; i++) | |
{ | |
for(size_t j = 0; j < Ny; j++) | |
{ | |
for(size_t k = 0; k < Nx; k++) | |
{ | |
const size_t pos = k + (i * Nz + j) * Ny; | |
std::cout << cy[pos] << " "; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
} | |
std::cout << "\n"; | |
const float overN = 1.0f / (lengths[0]*lengths[1]); | |
float error = 0.0f; | |
for(size_t i = 0; i < cx.size(); i++) | |
{ | |
float diff = std::abs(cx[i] - cy[i]*overN); | |
if(diff > error) | |
{ | |
error = diff; | |
} | |
} | |
std::cout << "Maximum error: " << error << "\n"; | |
hipFree(x); | |
hipFree(fbuffer); | |
hipFree(bbuffer); | |
// Destroy plans | |
rocfft_plan_destroy(forward); | |
rocfft_plan_destroy(backward); | |
rocfft_plan_description_destroy(desc); | |
rocfft_cleanup(); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment