Created
March 10, 2018 14:02
-
-
Save keisukefukuda/69e93074f6fc46a55267efeaca2ca093 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#include <cassert> | |
#include <cstring> | |
#include <cstdlib> | |
#include <ctime> | |
#include <iomanip> | |
#include <iostream> | |
#include <memory> | |
#include <sstream> | |
#include <string> | |
#include <tuple> | |
#include <vector> | |
#include "ucxpp.h" | |
#include <unistd.h> | |
#include <mpi.h> | |
#include "ucxpp.h" | |
std::string getHostname() { | |
char hostname[100]; | |
gethostname(hostname, 100); | |
return std::string(hostname); | |
} | |
typedef struct { | |
int is_uct_desc; | |
} recv_desc_t; | |
template<typename ADDR_TYPE> | |
UCT::Addr<ADDR_TYPE> exchangeAddr(int dest_rank, const UCT::Addr<ADDR_TYPE> &addr) { | |
std::vector<char> recv_buf(addr.size()); | |
MPI_Request req; | |
MPI_Status stat; | |
MPI_Isend(reinterpret_cast<const void*>(addr.get()), | |
addr.size(), MPI_BYTE, dest_rank, 0, MPI_COMM_WORLD, &req); | |
MPI_Recv(reinterpret_cast<void*>(recv_buf.data()), | |
addr.size(), MPI_BYTE, dest_rank, 0, MPI_COMM_WORLD, &stat); | |
return UCT::Addr<ADDR_TYPE>(recv_buf.data(), addr.size()); | |
} | |
std::tuple<std::shared_ptr<UCT::MemoryDomain>, std::string, std::string> | |
dev_tl_lookup(std::shared_ptr<UCT::MemoryDomainResources> &md_rs, | |
const std::string& md_query, | |
const std::string &tl_query) { | |
std::tuple<std::shared_ptr<UCT::MemoryDomain>, std::string, std::string> result; | |
#if 0 | |
std::cout << "=================================" << std::endl; | |
std::cout << "List of resources:" << std::endl; | |
for (size_t i = 0; i < md_rs->size(); i++) { | |
auto md = md_rs->get(i); | |
for (size_t j = 0; j < md->getNumTL(); j++) { | |
auto tl = md->getTL(j); | |
std::cout << "\t" << tl->getTlName() << " / " << tl->getDevName() | |
<< std::endl; | |
} | |
} | |
std::cout << "=================================" << std::endl; | |
std::cout << std::endl; | |
#endif | |
// Look for device/tl that matches the query (dev_name, tl_name) | |
for (size_t i = 0; i < md_rs->size(); i++) { | |
auto md = md_rs->get(i); | |
if (md->componentName() == md_query) { | |
for (size_t j = 0; j < md->getNumTL(); j++) { | |
auto tl = md->getTL(j); | |
if (tl->getTlName() == tl_query) { | |
result = std::make_tuple(md, tl->getTlName(), tl->getDevName()); | |
} | |
} | |
} | |
} | |
std::cout << "\t" << std::get<1>(result) << " / " << std::get<2>(result) | |
<< std::endl; | |
return result; | |
} | |
void showIfaceInfo(UCT::MemoryDomain& md, const std::string& dev_name, | |
const std::string& tl_name, UCT::IFace& iface) { | |
std::cout << "\tMemory Domain: " << md.name() << std::endl; | |
std::cout << "\t\t" << "component name: " << md.componentName() << "\n"; | |
std::cout << "\t\t" << "max alloc: " << md.maxAlloc() << "\n"; | |
std::cout << "\t\t" << " (" | |
<< (size_t)(md.maxAlloc()/1024./1024/1024) << "[GiB])" << "\n"; | |
std::cout << "\t\t" << "max reg: " << md.maxReg() << "\n"; | |
std::cout << "\t\t" << " (" | |
<< (size_t)(md.maxReg()/1024./1024/1024) << "[GiB])\n"; | |
std::cout << "\t\t" << "rkey packed size: " << md.rKeyPackedSize() << "\n"; | |
// Print devices and transport layers | |
std::cout << "\t\t" << "Devices:" << std::endl; | |
for (size_t j = 0; j < md.getNumTL(); j++) { | |
auto tl = md.getTL(j); | |
std::cout << "\t\t\t" << tl->getDevName() << "/" << tl->getTlName() | |
<< std::endl; | |
} | |
auto &attr = iface.attr(); | |
auto lat = attr.latency; | |
std::cout << "iface_config of " << tl_name << "/" << dev_name << std::endl; | |
std::cout << "\tdevice_addr_len = " << attr.device_addr_len << std::endl; | |
std::cout << "\tiface_addr_len = " << attr.iface_addr_len << std::endl; | |
std::cout << "\tep_addr_len = " << attr.ep_addr_len << std::endl; | |
std::cout << "\tmax_conn_priv = " << attr.max_conn_priv << std::endl; | |
std::cout << "\toverhead = " << attr.overhead << std::endl; | |
std::cout << "\tbandwidth = " << attr.bandwidth << std::endl; | |
std::cout << "\tlatency = " << lat.growth << " * x + " | |
<< lat.overhead << std::endl; | |
std::cout << "\tpriority = " << (int)attr.priority << std::endl; | |
std::cout << std::endl; | |
auto &put = attr.cap.put; | |
std::cout << "\tattr.cap.put" << std::endl; | |
std::cout << "\t\tmax_short = " << put.max_short << std::endl; | |
std::cout << "\t\tmax_bcopy = " << put.max_bcopy << std::endl; | |
std::cout << "\t\tmin_zcopy = " << put.min_zcopy << std::endl; | |
std::cout << "\t\tmax_zcopy = " << put.max_zcopy << " (=" | |
<< (put.max_zcopy/1024/1024) << " [MiB])" << std::endl; | |
std::cout << "\t\topt_zcopy_align = " << put.opt_zcopy_align << std::endl; | |
std::cout << "\t\talign_mtu = " << put.align_mtu << std::endl; | |
std::cout << "\t\tmax_iov = " << put.max_iov << std::endl; | |
auto &get = attr.cap.get; | |
std::cout << "\tattr.cap.get" << std::endl; | |
std::cout << "\t\tmax_bcopy = " << get.max_bcopy << std::endl; | |
std::cout << "\t\tmin_zcopy = " << get.min_zcopy << std::endl; | |
std::cout << "\t\tmax_zcopy = " << get.max_zcopy << " (=" | |
<< (get.max_zcopy/1024/1024) << " [MiB])" << std::endl; | |
std::cout << "\t\topt_zcopy_align = " << get.opt_zcopy_align << std::endl; | |
std::cout << "\t\talign_mtu = " << get.align_mtu << std::endl; | |
std::cout << "\t\tmax_iov = " << get.max_iov << std::endl; | |
auto &am = attr.cap.am; | |
std::cout << "\tattr.cap.am" << std::endl; | |
std::cout << "\t\tmax_short = " << am.max_short << std::endl; | |
std::cout << "\t\tmax_bcopy = " << am.max_bcopy << std::endl; | |
std::cout << "\t\tmin_zcopy = " << am.min_zcopy << std::endl; | |
std::cout << "\t\tmax_zcopy = " << am.max_zcopy << " (=" | |
<< (am.max_zcopy/1024/1024) << " [MiB])" << std::endl; | |
std::cout << "\t\topt_zcopy_align = " << am.opt_zcopy_align << std::endl; | |
std::cout << "\t\talign_mtu = " << am.align_mtu << std::endl; | |
std::cout << "\t\tmax_iov = " << am.max_iov << std::endl; | |
auto &recv = attr.cap.tag.recv; | |
std::cout << "\tattr.cap.tag.recv" << std::endl; | |
std::cout << "\t\tmin_recv = " << recv.min_recv << std::endl; | |
std::cout << "\t\tmax_zcopy = " << recv.max_zcopy << std::endl; | |
std::cout << "\t\tmax_iov = " << recv.max_iov << std::endl; | |
std::cout << "\t\tmax_outstanding = " << recv.max_outstanding << std::endl; | |
auto &eager = attr.cap.tag.eager; | |
std::cout << "\tattr.cap.tag.eager" << std::endl; | |
std::cout << "\t\tmax_short = " << eager.max_short << std::endl; | |
std::cout << "\t\tmax_bcopy = " << eager.max_bcopy << std::endl; | |
std::cout << "\t\tmax_zcopy = " << eager.max_zcopy << std::endl; | |
std::cout << "\t\tmax_iov = " << eager.max_iov << std::endl; | |
auto &rndv = attr.cap.tag.rndv; | |
std::cout << "\tattr.cap.tag.rndv" << std::endl; | |
std::cout << "\t\tmax_zcopy = " << rndv.max_zcopy << std::endl; | |
std::cout << "\t\tmax_hdr = " << rndv.max_hdr << std::endl; | |
std::cout << "\t\tmax_iov = " << rndv.max_iov << std::endl; | |
#define SHOW_CAP_FLAG(name) do { \ | |
std::cout << "\t\t" << #name << ": " \ | |
<< ((flags & UCT_IFACE_FLAG_##name) ? "Yes" : "No") \ | |
<< std::endl; \ | |
} while(0) | |
// check flag | |
auto flags = attr.cap.flags; | |
std::cout << "\tflags:" << std::endl; | |
SHOW_CAP_FLAG(AM_SHORT); | |
SHOW_CAP_FLAG(AM_BCOPY); | |
SHOW_CAP_FLAG(AM_ZCOPY); | |
SHOW_CAP_FLAG(PENDING); | |
SHOW_CAP_FLAG(PUT_SHORT); | |
SHOW_CAP_FLAG(PUT_BCOPY); | |
SHOW_CAP_FLAG(PUT_ZCOPY); | |
SHOW_CAP_FLAG(GET_SHORT); | |
SHOW_CAP_FLAG(GET_BCOPY); | |
SHOW_CAP_FLAG(GET_ZCOPY); | |
SHOW_CAP_FLAG(ATOMIC_ADD32); | |
SHOW_CAP_FLAG(ATOMIC_ADD64); | |
SHOW_CAP_FLAG(ATOMIC_FADD32); | |
SHOW_CAP_FLAG(ATOMIC_FADD64); | |
SHOW_CAP_FLAG(ATOMIC_SWAP32); | |
SHOW_CAP_FLAG(ATOMIC_SWAP64); | |
SHOW_CAP_FLAG(ATOMIC_CSWAP32); | |
SHOW_CAP_FLAG(ATOMIC_CPU); | |
SHOW_CAP_FLAG(ATOMIC_DEVICE); | |
} | |
void* desc_holder = nullptr; | |
ucs_status_t am_handler(void * /* arg */, void *data, size_t length, unsigned flags) { | |
recv_desc_t *rdesc; | |
if (flags & UCT_CB_PARAM_FLAG_DESC) { | |
rdesc = (recv_desc_t*)data - 1; | |
// Hold descriptor to release later and return UCS_INPROGRESS */ | |
rdesc->is_uct_desc = 1; | |
desc_holder = rdesc; | |
return UCS_INPROGRESS; | |
} else { | |
rdesc = new recv_desc_t; | |
rdesc->is_uct_desc = 0; | |
std::string s(reinterpret_cast<char*>(data), length); | |
std::cout << "\tmsg = '" << s << "'" << std::endl; | |
return UCS_OK; | |
} | |
} | |
void send_am(UCT::Endpoint &ep, uint8_t id, size_t msg_size) { | |
char *msg = new char[msg_size]; | |
for (size_t i = 0; i < msg_size - 1; i++) { | |
msg[i] = rand() % 26 + 'A'; | |
} | |
msg[msg_size-1] = '\0'; | |
std::cout << "sent string = " << msg << std::endl; | |
uint64_t header = *(uint64_t*)msg; | |
char *payload = nullptr; | |
size_t rest = (msg_size <= sizeof(uint64_t)) | |
? 0 | |
: msg_size - sizeof(uint64_t); | |
if (rest > 0) { | |
payload = new char[rest]; | |
strncpy(payload, msg + sizeof(uint64_t), rest); | |
} | |
UCS_SAFE_CALL(uct_ep_am_short(ep.get(), id, header, payload, rest)); | |
delete[] payload; | |
delete[] msg; | |
} | |
void Communicate(int rank, | |
std::shared_ptr<UCT::Worker> worker, | |
std::shared_ptr<UCT::MemoryDomain> md, | |
std::string& tl_name, | |
std::string& dev_name) { | |
UCT::IFace iface; | |
UCT::IFace::params_t params; | |
params.open_mode = UCT_IFACE_OPEN_MODE_DEVICE; | |
params.mode.device.tl_name = tl_name.c_str(); | |
params.mode.device.dev_name = dev_name.c_str(); | |
params.stats_root = NULL; | |
params.rx_headroom = sizeof(recv_desc_t); | |
UCS_CPU_ZERO(¶ms.cpu_mask); | |
iface.open(worker, md, params); | |
iface.progressEnable(UCT_PROGRESS_SEND | UCT_PROGRESS_RECV); | |
#if 0 | |
if (rank == 0) { | |
showIfaceInfo(*md, dev_name, tl_name, iface); | |
} | |
#endif | |
// === get device address | |
auto own_dev_addr = iface.getDeviceAddress(); | |
auto peer_dev_addr = exchangeAddr(1 - rank, *own_dev_addr); | |
MPI_Barrier(MPI_COMM_WORLD); | |
if (iface.isReachable(peer_dev_addr) == 0) { | |
std::cout << "Rank " << rank | |
<< " ERROR: peer device address is not reachable!" << std::endl; | |
exit(-1); | |
} | |
std::unique_ptr<UCT::Endpoint> ep; | |
if (iface.attr().cap.flags & UCT_IFACE_FLAG_CONNECT_TO_IFACE) { | |
// === get interface address | |
auto if_addr = iface.getIfaceAddress(); | |
UCT::IfaceAddr peer_iface_addr = exchangeAddr(1 - rank, *if_addr); | |
ep.reset(new UCT::Endpoint(iface, peer_dev_addr, peer_iface_addr)); | |
} else if (iface.attr().cap.flags & UCT_IFACE_FLAG_CONNECT_TO_EP) { | |
// === if connect_to_ep | |
ep.reset(new UCT::Endpoint(iface)); | |
UCT::EpAddr peer_ep_addr = exchangeAddr(1 - rank, ep->addr()); | |
ep->connect(peer_dev_addr, peer_ep_addr); | |
} else { | |
std::cerr << "Unsupported" << std::endl; | |
goto CLOSE_IFACE; | |
} | |
// Communicate | |
{ | |
// Try am_short | |
// Send AM from rank 0 to 1 | |
size_t msg_size = 8; // iface.attr().cap.am.max_short; | |
const uint8_t id = 0; | |
// The sender registers an AM callback function. | |
iface.setAmHandler(&am_handler, NULL, id, UCT_CB_FLAG_SYNC); | |
MPI_Barrier(MPI_COMM_WORLD); | |
if (rank == 0) { // sender | |
send_am(*ep, id, msg_size); | |
} else { // receiver | |
recv_desc_t *rdesc= nullptr; | |
while (!desc_holder) { | |
unsigned ret = uct_worker_progress(worker->get()); | |
(void)ret; | |
//std::cout << "uct_worker_progress() = " << ret << std::endl; | |
} | |
rdesc = reinterpret_cast<recv_desc_t*>(desc_holder); | |
std::cout << "Received: " << (char*)(rdesc + 1) << std::endl; | |
if (rdesc->is_uct_desc) { | |
uct_iface_release_desc((recv_desc_t*)rdesc); | |
} else { | |
delete rdesc; | |
} | |
} | |
} | |
MPI_Barrier(MPI_COMM_WORLD); | |
std::cout << "Rank " << rank << " Done." << std::endl; | |
ep->destroy(); // ep must be destroyed before iface. | |
CLOSE_IFACE: | |
iface.close(); | |
} | |
int main(int argc, char **argv) { | |
srand(time(NULL)); | |
MPI_Init(&argc, &argv); | |
int size; | |
int rank; | |
MPI_Comm_size(MPI_COMM_WORLD, &size); | |
MPI_Comm_rank(MPI_COMM_WORLD, &rank); | |
UCS::AsyncContext async(UCS_ASYNC_MODE_THREAD); | |
auto worker = std::make_shared<UCT::Worker>(async, UCS_THREAD_MODE_SINGLE); | |
auto md_rs = std::make_shared<UCT::MemoryDomainResources>(); | |
std::shared_ptr<UCT::MemoryDomain> md; | |
std::string tl_name, dev_name; | |
std::tie(md, tl_name, dev_name) = dev_tl_lookup(md_rs, "ib", "rc"); | |
if (tl_name.size() == 0 || dev_name.size() == 0) { | |
std::cout << "Couldn't get device or transport layer name" << std::endl; | |
exit(-1); | |
} | |
if (size > 1) { | |
if (rank <= 2) { | |
Communicate(rank, worker, md, tl_name, dev_name); | |
} | |
} | |
MPI_Finalize(); | |
return 0; | |
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#ifndef UCXPP_H__ | |
#define UCXPP_H__ | |
#include <uct/api/uct.h> | |
#define IDENT_TO_STR_(ident) case ident: return #ident | |
const char * enumToStr(ucs_status_t s) { | |
switch(s) { | |
IDENT_TO_STR_(UCS_INPROGRESS); | |
IDENT_TO_STR_(UCS_ERR_NO_MESSAGE); | |
IDENT_TO_STR_(UCS_ERR_NO_RESOURCE); | |
IDENT_TO_STR_(UCS_ERR_IO_ERROR); | |
IDENT_TO_STR_(UCS_ERR_NO_MEMORY); | |
IDENT_TO_STR_(UCS_ERR_INVALID_PARAM); | |
IDENT_TO_STR_(UCS_ERR_UNREACHABLE); | |
IDENT_TO_STR_(UCS_ERR_INVALID_ADDR); | |
IDENT_TO_STR_(UCS_ERR_NOT_IMPLEMENTED); | |
IDENT_TO_STR_(UCS_ERR_MESSAGE_TRUNCATED); | |
IDENT_TO_STR_(UCS_ERR_NO_PROGRESS); | |
IDENT_TO_STR_(UCS_ERR_BUFFER_TOO_SMALL); | |
IDENT_TO_STR_(UCS_ERR_NO_ELEM); | |
IDENT_TO_STR_(UCS_ERR_SOME_CONNECTS_FAILED); | |
IDENT_TO_STR_(UCS_ERR_NO_DEVICE); | |
IDENT_TO_STR_(UCS_ERR_BUSY); | |
IDENT_TO_STR_(UCS_ERR_CANCELED); | |
IDENT_TO_STR_(UCS_ERR_SHMEM_SEGMENT); | |
IDENT_TO_STR_(UCS_ERR_ALREADY_EXISTS); | |
IDENT_TO_STR_(UCS_ERR_OUT_OF_RANGE); | |
IDENT_TO_STR_(UCS_ERR_TIMED_OUT); | |
IDENT_TO_STR_(UCS_ERR_EXCEEDS_LIMIT); | |
IDENT_TO_STR_(UCS_ERR_UNSUPPORTED); | |
IDENT_TO_STR_(UCS_ERR_FIRST_LINK_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_LAST_LINK_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_FIRST_ENDPOINT_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_LAST_ENDPOINT_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_ENDPOINT_TIMEOUT); | |
default: | |
return "Unknown ucs_status_t"; | |
} | |
} | |
#define UCS_SAFE_CALL(expr) do { \ | |
ucs_status_t st = (expr); \ | |
if (st != UCS_OK) { \ | |
std::cerr << "Error: " \ | |
<< __FILE__ << ":" \ | |
<< __LINE__ << " " \ | |
<< #expr << " failed with " \ | |
<< enumToStr(st) << std::endl; \ | |
exit(-1); \ | |
} \ | |
} while(0) | |
namespace UCS { | |
class AsyncContext { | |
ucs_async_context_t *ctx_; | |
public: | |
AsyncContext(ucs_async_mode_t mode) { | |
UCS_SAFE_CALL(ucs_async_context_create(mode, &ctx_)); | |
} | |
~AsyncContext() { | |
ucs_async_context_destroy(ctx_); | |
} | |
inline const ucs_async_context_t *get() const { | |
return ctx_; | |
} | |
}; | |
} | |
namespace UCT { | |
class TransportLayer; | |
class TransportLayerResources; | |
class MemoryDomain; | |
class MemoryDomainResources; | |
class Worker { | |
uct_worker_h handle_; | |
public: | |
Worker(const UCS::AsyncContext &async, ucs_thread_mode_t tm) { | |
UCS_SAFE_CALL(uct_worker_create(const_cast<ucs_async_context*>(async.get()), | |
tm, &handle_)); | |
} | |
~Worker() { | |
uct_worker_destroy(handle_); | |
} | |
uct_worker_h get() { return handle_; } | |
}; | |
class TransportLayerResources { | |
uct_tl_resource_desc_t *tl_rc_; | |
unsigned num_; | |
public: | |
TransportLayerResources(uct_md_h md_h) { | |
UCS_SAFE_CALL(uct_md_query_tl_resources(md_h, &tl_rc_, &num_)); | |
} | |
~TransportLayerResources() { | |
uct_release_tl_resource_list(tl_rc_); | |
} | |
uct_tl_resource_desc_t* getRaw(size_t i) { | |
assert(i < num_); | |
return &tl_rc_[i]; | |
} | |
size_t size() const { | |
return num_; | |
} | |
}; | |
class TransportLayer { | |
std::shared_ptr<TransportLayerResources> tl_rc_; | |
size_t idx_; | |
public: | |
TransportLayer(std::shared_ptr<TransportLayerResources> p, size_t idx) | |
: tl_rc_(p), idx_(idx) | |
{ | |
assert(tl_rc_->size() > idx); | |
} | |
const char *getTlName() const { return tl_rc_->getRaw(idx_)->tl_name; } | |
const char *getDevName() const { return tl_rc_->getRaw(idx_)->dev_name; } | |
}; | |
class MemoryDomainResources { | |
uct_md_resource_desc_t *md_rc_; | |
unsigned num_md_rc_; | |
std::vector<std::shared_ptr<MemoryDomain>> md_; | |
public: | |
MemoryDomainResources() | |
: md_rc_(nullptr), num_md_rc_(0), md_() | |
{ | |
UCS_SAFE_CALL(uct_query_md_resources(&md_rc_, &num_md_rc_)); | |
assert(md_rc_ != nullptr); | |
for (unsigned i = 0; i < num_md_rc_; i++) { | |
md_.push_back(std::make_shared<MemoryDomain>(this, i)); | |
} | |
} | |
~MemoryDomainResources() { | |
uct_release_md_resource_list(md_rc_); | |
} | |
std::shared_ptr<MemoryDomain> get(size_t idx) { | |
return md_[idx]; | |
} | |
uct_md_resource_desc_t* getRaw(size_t idx) { | |
return &md_rc_[idx]; | |
} | |
size_t size() const { | |
return md_.size(); | |
} | |
private: | |
MemoryDomainResources(const MemoryDomainResources &rhs) = delete; | |
MemoryDomainResources& operator=(const MemoryDomainResources &rhs) = delete; | |
}; | |
class MemoryDomain { | |
MemoryDomainResources *md_rc_; | |
size_t idx_; | |
uct_md_h md_h_; | |
uct_md_config_t *md_conf_; | |
uct_md_attr_t attr_; | |
std::shared_ptr<TransportLayerResources> tl_rc_; | |
std::vector<std::shared_ptr<TransportLayer>> tl_; | |
public: | |
MemoryDomain(MemoryDomainResources* md_rc, size_t idx) : | |
md_rc_(md_rc), idx_(idx) | |
{ | |
// Create a memory domain | |
UCS_SAFE_CALL(uct_md_config_read(md_rc_->getRaw(idx_)->md_name, | |
NULL, NULL, | |
&md_conf_)); | |
UCS_SAFE_CALL(uct_md_open(md_rc_->getRaw(idx_)->md_name, md_conf_, &md_h_)); | |
tl_rc_ = std::make_shared<TransportLayerResources>(md_h_); | |
for (size_t i = 0; i < tl_rc_->size(); i++) { | |
tl_.push_back(std::make_shared<TransportLayer>(tl_rc_, i)); | |
} | |
query(); | |
} | |
~MemoryDomain() { | |
uct_config_release(md_conf_); | |
uct_md_close(md_h_); | |
} | |
uct_md_h &get() { | |
return md_h_; | |
} | |
const char *name() const { | |
return md_rc_->getRaw(idx_)->md_name; | |
} | |
// Returns a vector of TransportLayer | |
std::shared_ptr<TransportLayer> getTL(size_t idx) { | |
return tl_[idx]; | |
} | |
size_t getNumTL() const { | |
return tl_.size(); | |
} | |
std::string componentName() const { | |
return std::string(attr_.component_name); | |
} | |
size_t maxAlloc() const { return attr_.cap.max_alloc; } | |
size_t maxReg() const {return attr_.cap.max_reg; } | |
enum class Flag { | |
ALLOC = UCT_MD_FLAG_ALLOC, | |
REG = UCT_MD_FLAG_REG, | |
NEED_MEMH = UCT_MD_FLAG_NEED_MEMH, | |
NEED_RKEY = UCT_MD_FLAG_NEED_RKEY, | |
ADVISE = UCT_MD_FLAG_ADVISE, | |
FIXED = UCT_MD_FLAG_FIXED, | |
RKEY_PTR = UCT_MD_FLAG_RKEY_PTR, | |
SOCKADDR = UCT_MD_FLAG_SOCKADDR, | |
}; | |
Flag flag() const { return static_cast<Flag>(attr_.cap.flags); } | |
// TODO | |
// memoryType() const { ... } | |
size_t rKeyPackedSize() const { return attr_.rkey_packed_size; } | |
// TODO | |
// cpuset_t localCpus() const { return attr_.local_cpus; } | |
private: | |
MemoryDomain(const MemoryDomain& rhs) = delete; | |
MemoryDomain& operator=(const MemoryDomain& rhs) = delete; | |
void query() { | |
UCS_SAFE_CALL(uct_md_query(md_h_, &attr_)); | |
} | |
}; | |
class IFace; | |
template<class T> | |
class Addr { | |
size_t addr_len_; | |
uint8_t *addr_; | |
public: | |
Addr() : addr_len_(0), addr_(nullptr) { } | |
Addr(const void *addr, size_t len) { | |
addr_len_ = len; | |
addr_ = new uint8_t[len]; | |
memcpy(addr_, addr, len); | |
} | |
~Addr() { | |
delete[] addr_; | |
} | |
size_t size() const { return addr_len_; } | |
T* get() { return reinterpret_cast<T*>(addr_); } | |
const T* get() const { return reinterpret_cast<const T*>(addr_); } | |
std::string toString() const { | |
#if 0 | |
std::stringstream ss; | |
for (int i = 0; i < addr_len_; i++) { | |
if (i > 0) { ss << ":"; } | |
ss << std::hex << std::setw(2) << std::setfill('0') << (int)addr_[i]; | |
} | |
return ss.str(); | |
#else | |
char buf[100]; | |
char *p = buf; | |
for (size_t i = 0; i < addr_len_; i++) { | |
p += sprintf(p, "%s%02x", (i==0 ? "" : ":"), ((unsigned char*)addr_)[i]); | |
} | |
return std::string(buf); | |
#endif | |
} | |
}; | |
using DeviceAddr = Addr<uct_device_addr_t>; | |
using IfaceAddr = Addr<uct_iface_addr_t>; | |
using EpAddr = Addr<uct_ep_addr_t>; | |
class Endpoint { | |
bool destroyed_; | |
size_t addr_size_; | |
uct_ep_h ep_; | |
std::unique_ptr<EpAddr> addr_; | |
public: | |
Endpoint(const IFace& iface); | |
Endpoint(const IFace& iface, const DeviceAddr& peer_dev_addr, const IfaceAddr& peer_if_addr); | |
~Endpoint(); | |
EpAddr& addr(); | |
const EpAddr& addr() const; | |
uct_ep_h get() { return ep_; } | |
uct_ep_h get() const { return ep_; } | |
void connect(const DeviceAddr& peer_dev_addr, const EpAddr& peer_ep_addr); | |
void destroy() throw(); | |
}; | |
class IFace { | |
uct_iface_h iface_; | |
uct_iface_attr_t attr_; | |
public: | |
using params_t = uct_iface_params_t; | |
IFace() : iface_(nullptr) { } | |
void open(std::shared_ptr<Worker> worker, | |
std::shared_ptr<MemoryDomain> md, params_t& params) { | |
uct_iface_config_t *config; | |
UCS_SAFE_CALL(uct_md_iface_config_read(md->get(), | |
params.mode.device.tl_name, | |
NULL, NULL, &config)); | |
auto status = uct_iface_open(md->get(), worker->get(), | |
¶ms, config, &iface_); | |
uct_config_release(config); | |
UCS_SAFE_CALL(status); | |
UCS_SAFE_CALL(uct_iface_query(iface_, &attr_)); | |
} | |
template<class T> | |
static ucs_status_t _am_handler(void * arg, void *data, size_t length, unsigned flags) { | |
T& functor = *reinterpret_cast<T*>(arg); | |
return functor(data, length, flags); | |
} | |
template<class T> | |
void setAmHandler(T* functor, int id, unsigned flag) { | |
uct_iface_set_am_handler(iface_, id, &_am_handler<T>, | |
reinterpret_cast<void*>(functor), flag); | |
} | |
void setAmHandler(ucs_status_t (*pf)(void *arg, void *data, size_t length,unsigned flags), | |
void *data, int id, unsigned flag) { | |
uct_iface_set_am_handler(iface_, id, pf, data, flag); | |
} | |
void progressEnable(uct_progress_types flags) { | |
uct_iface_progress_enable(iface_, flags); | |
} | |
void progressEnable(unsigned flags) { | |
uct_iface_progress_enable(iface_, flags); | |
} | |
void progressDisable(unsigned flags) { | |
uct_iface_progress_disable(iface_, flags); | |
} | |
void progressDisable(uct_progress_types flags) { | |
uct_iface_progress_disable(iface_, flags); | |
} | |
std::shared_ptr<IfaceAddr> getIfaceAddress() { | |
size_t len = attr_.iface_addr_len; | |
std::vector<char> addr(len); | |
auto* p = reinterpret_cast<uct_iface_addr_t*>(addr.data()); | |
UCS_SAFE_CALL(uct_iface_get_address(iface_, p)); | |
return std::make_shared<IfaceAddr>(p, len); | |
} | |
std::shared_ptr<DeviceAddr> getDeviceAddress() { | |
size_t len = attr_.device_addr_len; | |
std::vector<char> addr(len); | |
auto* p = reinterpret_cast<uct_device_addr_t*>(addr.data()); | |
UCS_SAFE_CALL(uct_iface_get_device_address(iface_, p)); | |
return std::make_shared<DeviceAddr>(p, len); | |
} | |
void close() { uct_iface_close(iface_); } | |
uct_iface_h get() { return iface_; } | |
uct_iface_h get() const { return iface_; } | |
const uct_iface_attr_t &attr() const { return attr_; } | |
int isReachable(const uct_device_addr_t *addr) { | |
return uct_iface_is_reachable(iface_, addr, NULL); | |
} | |
bool isReachable(const DeviceAddr &addr) { | |
return this->isReachable(addr.get()); | |
} | |
bool isReachable(const std::shared_ptr<DeviceAddr> &addr) { | |
return this->isReachable(addr->get()); | |
} | |
}; | |
Endpoint::Endpoint(const IFace &iface) : destroyed_(false) { | |
UCS_SAFE_CALL(uct_ep_create(iface.get(), &ep_)); | |
addr_size_ = iface.attr().ep_addr_len; | |
std::vector<char> addr(addr_size_); | |
UCS_SAFE_CALL(uct_ep_get_address(ep_, reinterpret_cast<uct_ep_addr_t*>(addr.data()))); | |
addr_.reset(new EpAddr(addr.data(), addr_size_)); | |
} | |
Endpoint::Endpoint(const IFace& iface, const DeviceAddr& peer_dev_addr, const IfaceAddr& peer_if_addr) | |
: destroyed_(false) { | |
UCS_SAFE_CALL(uct_ep_create_connected(iface.get(), | |
peer_dev_addr.get(), | |
peer_if_addr.get(), &ep_)); | |
} | |
void Endpoint::destroy() throw() { | |
if (!destroyed_) { | |
uct_ep_destroy(ep_); | |
ep_ = nullptr; | |
} | |
destroyed_ = true; | |
} | |
Endpoint::~Endpoint() { | |
this->destroy(); | |
} | |
const EpAddr& Endpoint::addr() const { | |
return *addr_; | |
} | |
void Endpoint::connect(const DeviceAddr& peer_dev_addr, const EpAddr& peer_ep_addr) { | |
UCS_SAFE_CALL(uct_ep_connect_to_ep(ep_, peer_dev_addr.get(), peer_ep_addr.get())); | |
} | |
EpAddr& Endpoint::addr() { | |
return *addr_; | |
} | |
} // namespace UCT | |
#endif // UCXPP_H__ |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment