Skip to content

Instantly share code, notes, and snippets.

@keisukefukuda
Created March 10, 2018 14:02
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save keisukefukuda/69e93074f6fc46a55267efeaca2ca093 to your computer and use it in GitHub Desktop.
Save keisukefukuda/69e93074f6fc46a55267efeaca2ca093 to your computer and use it in GitHub Desktop.
#include <cassert>
#include <cstring>
#include <cstdlib>
#include <ctime>
#include <iomanip>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <tuple>
#include <vector>
#include "ucxpp.h"
#include <unistd.h>
#include <mpi.h>
#include "ucxpp.h"
std::string getHostname() {
char hostname[100];
gethostname(hostname, 100);
return std::string(hostname);
}
typedef struct {
int is_uct_desc;
} recv_desc_t;
template<typename ADDR_TYPE>
UCT::Addr<ADDR_TYPE> exchangeAddr(int dest_rank, const UCT::Addr<ADDR_TYPE> &addr) {
std::vector<char> recv_buf(addr.size());
MPI_Request req;
MPI_Status stat;
MPI_Isend(reinterpret_cast<const void*>(addr.get()),
addr.size(), MPI_BYTE, dest_rank, 0, MPI_COMM_WORLD, &req);
MPI_Recv(reinterpret_cast<void*>(recv_buf.data()),
addr.size(), MPI_BYTE, dest_rank, 0, MPI_COMM_WORLD, &stat);
return UCT::Addr<ADDR_TYPE>(recv_buf.data(), addr.size());
}
std::tuple<std::shared_ptr<UCT::MemoryDomain>, std::string, std::string>
dev_tl_lookup(std::shared_ptr<UCT::MemoryDomainResources> &md_rs,
const std::string& md_query,
const std::string &tl_query) {
std::tuple<std::shared_ptr<UCT::MemoryDomain>, std::string, std::string> result;
#if 0
std::cout << "=================================" << std::endl;
std::cout << "List of resources:" << std::endl;
for (size_t i = 0; i < md_rs->size(); i++) {
auto md = md_rs->get(i);
for (size_t j = 0; j < md->getNumTL(); j++) {
auto tl = md->getTL(j);
std::cout << "\t" << tl->getTlName() << " / " << tl->getDevName()
<< std::endl;
}
}
std::cout << "=================================" << std::endl;
std::cout << std::endl;
#endif
// Look for device/tl that matches the query (dev_name, tl_name)
for (size_t i = 0; i < md_rs->size(); i++) {
auto md = md_rs->get(i);
if (md->componentName() == md_query) {
for (size_t j = 0; j < md->getNumTL(); j++) {
auto tl = md->getTL(j);
if (tl->getTlName() == tl_query) {
result = std::make_tuple(md, tl->getTlName(), tl->getDevName());
}
}
}
}
std::cout << "\t" << std::get<1>(result) << " / " << std::get<2>(result)
<< std::endl;
return result;
}
void showIfaceInfo(UCT::MemoryDomain& md, const std::string& dev_name,
const std::string& tl_name, UCT::IFace& iface) {
std::cout << "\tMemory Domain: " << md.name() << std::endl;
std::cout << "\t\t" << "component name: " << md.componentName() << "\n";
std::cout << "\t\t" << "max alloc: " << md.maxAlloc() << "\n";
std::cout << "\t\t" << " ("
<< (size_t)(md.maxAlloc()/1024./1024/1024) << "[GiB])" << "\n";
std::cout << "\t\t" << "max reg: " << md.maxReg() << "\n";
std::cout << "\t\t" << " ("
<< (size_t)(md.maxReg()/1024./1024/1024) << "[GiB])\n";
std::cout << "\t\t" << "rkey packed size: " << md.rKeyPackedSize() << "\n";
// Print devices and transport layers
std::cout << "\t\t" << "Devices:" << std::endl;
for (size_t j = 0; j < md.getNumTL(); j++) {
auto tl = md.getTL(j);
std::cout << "\t\t\t" << tl->getDevName() << "/" << tl->getTlName()
<< std::endl;
}
auto &attr = iface.attr();
auto lat = attr.latency;
std::cout << "iface_config of " << tl_name << "/" << dev_name << std::endl;
std::cout << "\tdevice_addr_len = " << attr.device_addr_len << std::endl;
std::cout << "\tiface_addr_len = " << attr.iface_addr_len << std::endl;
std::cout << "\tep_addr_len = " << attr.ep_addr_len << std::endl;
std::cout << "\tmax_conn_priv = " << attr.max_conn_priv << std::endl;
std::cout << "\toverhead = " << attr.overhead << std::endl;
std::cout << "\tbandwidth = " << attr.bandwidth << std::endl;
std::cout << "\tlatency = " << lat.growth << " * x + "
<< lat.overhead << std::endl;
std::cout << "\tpriority = " << (int)attr.priority << std::endl;
std::cout << std::endl;
auto &put = attr.cap.put;
std::cout << "\tattr.cap.put" << std::endl;
std::cout << "\t\tmax_short = " << put.max_short << std::endl;
std::cout << "\t\tmax_bcopy = " << put.max_bcopy << std::endl;
std::cout << "\t\tmin_zcopy = " << put.min_zcopy << std::endl;
std::cout << "\t\tmax_zcopy = " << put.max_zcopy << " (="
<< (put.max_zcopy/1024/1024) << " [MiB])" << std::endl;
std::cout << "\t\topt_zcopy_align = " << put.opt_zcopy_align << std::endl;
std::cout << "\t\talign_mtu = " << put.align_mtu << std::endl;
std::cout << "\t\tmax_iov = " << put.max_iov << std::endl;
auto &get = attr.cap.get;
std::cout << "\tattr.cap.get" << std::endl;
std::cout << "\t\tmax_bcopy = " << get.max_bcopy << std::endl;
std::cout << "\t\tmin_zcopy = " << get.min_zcopy << std::endl;
std::cout << "\t\tmax_zcopy = " << get.max_zcopy << " (="
<< (get.max_zcopy/1024/1024) << " [MiB])" << std::endl;
std::cout << "\t\topt_zcopy_align = " << get.opt_zcopy_align << std::endl;
std::cout << "\t\talign_mtu = " << get.align_mtu << std::endl;
std::cout << "\t\tmax_iov = " << get.max_iov << std::endl;
auto &am = attr.cap.am;
std::cout << "\tattr.cap.am" << std::endl;
std::cout << "\t\tmax_short = " << am.max_short << std::endl;
std::cout << "\t\tmax_bcopy = " << am.max_bcopy << std::endl;
std::cout << "\t\tmin_zcopy = " << am.min_zcopy << std::endl;
std::cout << "\t\tmax_zcopy = " << am.max_zcopy << " (="
<< (am.max_zcopy/1024/1024) << " [MiB])" << std::endl;
std::cout << "\t\topt_zcopy_align = " << am.opt_zcopy_align << std::endl;
std::cout << "\t\talign_mtu = " << am.align_mtu << std::endl;
std::cout << "\t\tmax_iov = " << am.max_iov << std::endl;
auto &recv = attr.cap.tag.recv;
std::cout << "\tattr.cap.tag.recv" << std::endl;
std::cout << "\t\tmin_recv = " << recv.min_recv << std::endl;
std::cout << "\t\tmax_zcopy = " << recv.max_zcopy << std::endl;
std::cout << "\t\tmax_iov = " << recv.max_iov << std::endl;
std::cout << "\t\tmax_outstanding = " << recv.max_outstanding << std::endl;
auto &eager = attr.cap.tag.eager;
std::cout << "\tattr.cap.tag.eager" << std::endl;
std::cout << "\t\tmax_short = " << eager.max_short << std::endl;
std::cout << "\t\tmax_bcopy = " << eager.max_bcopy << std::endl;
std::cout << "\t\tmax_zcopy = " << eager.max_zcopy << std::endl;
std::cout << "\t\tmax_iov = " << eager.max_iov << std::endl;
auto &rndv = attr.cap.tag.rndv;
std::cout << "\tattr.cap.tag.rndv" << std::endl;
std::cout << "\t\tmax_zcopy = " << rndv.max_zcopy << std::endl;
std::cout << "\t\tmax_hdr = " << rndv.max_hdr << std::endl;
std::cout << "\t\tmax_iov = " << rndv.max_iov << std::endl;
#define SHOW_CAP_FLAG(name) do { \
std::cout << "\t\t" << #name << ": " \
<< ((flags & UCT_IFACE_FLAG_##name) ? "Yes" : "No") \
<< std::endl; \
} while(0)
// check flag
auto flags = attr.cap.flags;
std::cout << "\tflags:" << std::endl;
SHOW_CAP_FLAG(AM_SHORT);
SHOW_CAP_FLAG(AM_BCOPY);
SHOW_CAP_FLAG(AM_ZCOPY);
SHOW_CAP_FLAG(PENDING);
SHOW_CAP_FLAG(PUT_SHORT);
SHOW_CAP_FLAG(PUT_BCOPY);
SHOW_CAP_FLAG(PUT_ZCOPY);
SHOW_CAP_FLAG(GET_SHORT);
SHOW_CAP_FLAG(GET_BCOPY);
SHOW_CAP_FLAG(GET_ZCOPY);
SHOW_CAP_FLAG(ATOMIC_ADD32);
SHOW_CAP_FLAG(ATOMIC_ADD64);
SHOW_CAP_FLAG(ATOMIC_FADD32);
SHOW_CAP_FLAG(ATOMIC_FADD64);
SHOW_CAP_FLAG(ATOMIC_SWAP32);
SHOW_CAP_FLAG(ATOMIC_SWAP64);
SHOW_CAP_FLAG(ATOMIC_CSWAP32);
SHOW_CAP_FLAG(ATOMIC_CPU);
SHOW_CAP_FLAG(ATOMIC_DEVICE);
}
void* desc_holder = nullptr;
ucs_status_t am_handler(void * /* arg */, void *data, size_t length, unsigned flags) {
recv_desc_t *rdesc;
if (flags & UCT_CB_PARAM_FLAG_DESC) {
rdesc = (recv_desc_t*)data - 1;
// Hold descriptor to release later and return UCS_INPROGRESS */
rdesc->is_uct_desc = 1;
desc_holder = rdesc;
return UCS_INPROGRESS;
} else {
rdesc = new recv_desc_t;
rdesc->is_uct_desc = 0;
std::string s(reinterpret_cast<char*>(data), length);
std::cout << "\tmsg = '" << s << "'" << std::endl;
return UCS_OK;
}
}
void send_am(UCT::Endpoint &ep, uint8_t id, size_t msg_size) {
char *msg = new char[msg_size];
for (size_t i = 0; i < msg_size - 1; i++) {
msg[i] = rand() % 26 + 'A';
}
msg[msg_size-1] = '\0';
std::cout << "sent string = " << msg << std::endl;
uint64_t header = *(uint64_t*)msg;
char *payload = nullptr;
size_t rest = (msg_size <= sizeof(uint64_t))
? 0
: msg_size - sizeof(uint64_t);
if (rest > 0) {
payload = new char[rest];
strncpy(payload, msg + sizeof(uint64_t), rest);
}
UCS_SAFE_CALL(uct_ep_am_short(ep.get(), id, header, payload, rest));
delete[] payload;
delete[] msg;
}
void Communicate(int rank,
std::shared_ptr<UCT::Worker> worker,
std::shared_ptr<UCT::MemoryDomain> md,
std::string& tl_name,
std::string& dev_name) {
UCT::IFace iface;
UCT::IFace::params_t params;
params.open_mode = UCT_IFACE_OPEN_MODE_DEVICE;
params.mode.device.tl_name = tl_name.c_str();
params.mode.device.dev_name = dev_name.c_str();
params.stats_root = NULL;
params.rx_headroom = sizeof(recv_desc_t);
UCS_CPU_ZERO(&params.cpu_mask);
iface.open(worker, md, params);
iface.progressEnable(UCT_PROGRESS_SEND | UCT_PROGRESS_RECV);
#if 0
if (rank == 0) {
showIfaceInfo(*md, dev_name, tl_name, iface);
}
#endif
// === get device address
auto own_dev_addr = iface.getDeviceAddress();
auto peer_dev_addr = exchangeAddr(1 - rank, *own_dev_addr);
MPI_Barrier(MPI_COMM_WORLD);
if (iface.isReachable(peer_dev_addr) == 0) {
std::cout << "Rank " << rank
<< " ERROR: peer device address is not reachable!" << std::endl;
exit(-1);
}
std::unique_ptr<UCT::Endpoint> ep;
if (iface.attr().cap.flags & UCT_IFACE_FLAG_CONNECT_TO_IFACE) {
// === get interface address
auto if_addr = iface.getIfaceAddress();
UCT::IfaceAddr peer_iface_addr = exchangeAddr(1 - rank, *if_addr);
ep.reset(new UCT::Endpoint(iface, peer_dev_addr, peer_iface_addr));
} else if (iface.attr().cap.flags & UCT_IFACE_FLAG_CONNECT_TO_EP) {
// === if connect_to_ep
ep.reset(new UCT::Endpoint(iface));
UCT::EpAddr peer_ep_addr = exchangeAddr(1 - rank, ep->addr());
ep->connect(peer_dev_addr, peer_ep_addr);
} else {
std::cerr << "Unsupported" << std::endl;
goto CLOSE_IFACE;
}
// Communicate
{
// Try am_short
// Send AM from rank 0 to 1
size_t msg_size = 8; // iface.attr().cap.am.max_short;
const uint8_t id = 0;
// The sender registers an AM callback function.
iface.setAmHandler(&am_handler, NULL, id, UCT_CB_FLAG_SYNC);
MPI_Barrier(MPI_COMM_WORLD);
if (rank == 0) { // sender
send_am(*ep, id, msg_size);
} else { // receiver
recv_desc_t *rdesc= nullptr;
while (!desc_holder) {
unsigned ret = uct_worker_progress(worker->get());
(void)ret;
//std::cout << "uct_worker_progress() = " << ret << std::endl;
}
rdesc = reinterpret_cast<recv_desc_t*>(desc_holder);
std::cout << "Received: " << (char*)(rdesc + 1) << std::endl;
if (rdesc->is_uct_desc) {
uct_iface_release_desc((recv_desc_t*)rdesc);
} else {
delete rdesc;
}
}
}
MPI_Barrier(MPI_COMM_WORLD);
std::cout << "Rank " << rank << " Done." << std::endl;
ep->destroy(); // ep must be destroyed before iface.
CLOSE_IFACE:
iface.close();
}
int main(int argc, char **argv) {
srand(time(NULL));
MPI_Init(&argc, &argv);
int size;
int rank;
MPI_Comm_size(MPI_COMM_WORLD, &size);
MPI_Comm_rank(MPI_COMM_WORLD, &rank);
UCS::AsyncContext async(UCS_ASYNC_MODE_THREAD);
auto worker = std::make_shared<UCT::Worker>(async, UCS_THREAD_MODE_SINGLE);
auto md_rs = std::make_shared<UCT::MemoryDomainResources>();
std::shared_ptr<UCT::MemoryDomain> md;
std::string tl_name, dev_name;
std::tie(md, tl_name, dev_name) = dev_tl_lookup(md_rs, "ib", "rc");
if (tl_name.size() == 0 || dev_name.size() == 0) {
std::cout << "Couldn't get device or transport layer name" << std::endl;
exit(-1);
}
if (size > 1) {
if (rank <= 2) {
Communicate(rank, worker, md, tl_name, dev_name);
}
}
MPI_Finalize();
return 0;
}
#ifndef UCXPP_H__
#define UCXPP_H__
#include <uct/api/uct.h>
#define IDENT_TO_STR_(ident) case ident: return #ident
const char * enumToStr(ucs_status_t s) {
switch(s) {
IDENT_TO_STR_(UCS_INPROGRESS);
IDENT_TO_STR_(UCS_ERR_NO_MESSAGE);
IDENT_TO_STR_(UCS_ERR_NO_RESOURCE);
IDENT_TO_STR_(UCS_ERR_IO_ERROR);
IDENT_TO_STR_(UCS_ERR_NO_MEMORY);
IDENT_TO_STR_(UCS_ERR_INVALID_PARAM);
IDENT_TO_STR_(UCS_ERR_UNREACHABLE);
IDENT_TO_STR_(UCS_ERR_INVALID_ADDR);
IDENT_TO_STR_(UCS_ERR_NOT_IMPLEMENTED);
IDENT_TO_STR_(UCS_ERR_MESSAGE_TRUNCATED);
IDENT_TO_STR_(UCS_ERR_NO_PROGRESS);
IDENT_TO_STR_(UCS_ERR_BUFFER_TOO_SMALL);
IDENT_TO_STR_(UCS_ERR_NO_ELEM);
IDENT_TO_STR_(UCS_ERR_SOME_CONNECTS_FAILED);
IDENT_TO_STR_(UCS_ERR_NO_DEVICE);
IDENT_TO_STR_(UCS_ERR_BUSY);
IDENT_TO_STR_(UCS_ERR_CANCELED);
IDENT_TO_STR_(UCS_ERR_SHMEM_SEGMENT);
IDENT_TO_STR_(UCS_ERR_ALREADY_EXISTS);
IDENT_TO_STR_(UCS_ERR_OUT_OF_RANGE);
IDENT_TO_STR_(UCS_ERR_TIMED_OUT);
IDENT_TO_STR_(UCS_ERR_EXCEEDS_LIMIT);
IDENT_TO_STR_(UCS_ERR_UNSUPPORTED);
IDENT_TO_STR_(UCS_ERR_FIRST_LINK_FAILURE);
IDENT_TO_STR_(UCS_ERR_LAST_LINK_FAILURE);
IDENT_TO_STR_(UCS_ERR_FIRST_ENDPOINT_FAILURE);
IDENT_TO_STR_(UCS_ERR_LAST_ENDPOINT_FAILURE);
IDENT_TO_STR_(UCS_ERR_ENDPOINT_TIMEOUT);
default:
return "Unknown ucs_status_t";
}
}
#define UCS_SAFE_CALL(expr) do { \
ucs_status_t st = (expr); \
if (st != UCS_OK) { \
std::cerr << "Error: " \
<< __FILE__ << ":" \
<< __LINE__ << " " \
<< #expr << " failed with " \
<< enumToStr(st) << std::endl; \
exit(-1); \
} \
} while(0)
namespace UCS {
class AsyncContext {
ucs_async_context_t *ctx_;
public:
AsyncContext(ucs_async_mode_t mode) {
UCS_SAFE_CALL(ucs_async_context_create(mode, &ctx_));
}
~AsyncContext() {
ucs_async_context_destroy(ctx_);
}
inline const ucs_async_context_t *get() const {
return ctx_;
}
};
}
namespace UCT {
class TransportLayer;
class TransportLayerResources;
class MemoryDomain;
class MemoryDomainResources;
class Worker {
uct_worker_h handle_;
public:
Worker(const UCS::AsyncContext &async, ucs_thread_mode_t tm) {
UCS_SAFE_CALL(uct_worker_create(const_cast<ucs_async_context*>(async.get()),
tm, &handle_));
}
~Worker() {
uct_worker_destroy(handle_);
}
uct_worker_h get() { return handle_; }
};
class TransportLayerResources {
uct_tl_resource_desc_t *tl_rc_;
unsigned num_;
public:
TransportLayerResources(uct_md_h md_h) {
UCS_SAFE_CALL(uct_md_query_tl_resources(md_h, &tl_rc_, &num_));
}
~TransportLayerResources() {
uct_release_tl_resource_list(tl_rc_);
}
uct_tl_resource_desc_t* getRaw(size_t i) {
assert(i < num_);
return &tl_rc_[i];
}
size_t size() const {
return num_;
}
};
class TransportLayer {
std::shared_ptr<TransportLayerResources> tl_rc_;
size_t idx_;
public:
TransportLayer(std::shared_ptr<TransportLayerResources> p, size_t idx)
: tl_rc_(p), idx_(idx)
{
assert(tl_rc_->size() > idx);
}
const char *getTlName() const { return tl_rc_->getRaw(idx_)->tl_name; }
const char *getDevName() const { return tl_rc_->getRaw(idx_)->dev_name; }
};
class MemoryDomainResources {
uct_md_resource_desc_t *md_rc_;
unsigned num_md_rc_;
std::vector<std::shared_ptr<MemoryDomain>> md_;
public:
MemoryDomainResources()
: md_rc_(nullptr), num_md_rc_(0), md_()
{
UCS_SAFE_CALL(uct_query_md_resources(&md_rc_, &num_md_rc_));
assert(md_rc_ != nullptr);
for (unsigned i = 0; i < num_md_rc_; i++) {
md_.push_back(std::make_shared<MemoryDomain>(this, i));
}
}
~MemoryDomainResources() {
uct_release_md_resource_list(md_rc_);
}
std::shared_ptr<MemoryDomain> get(size_t idx) {
return md_[idx];
}
uct_md_resource_desc_t* getRaw(size_t idx) {
return &md_rc_[idx];
}
size_t size() const {
return md_.size();
}
private:
MemoryDomainResources(const MemoryDomainResources &rhs) = delete;
MemoryDomainResources& operator=(const MemoryDomainResources &rhs) = delete;
};
class MemoryDomain {
MemoryDomainResources *md_rc_;
size_t idx_;
uct_md_h md_h_;
uct_md_config_t *md_conf_;
uct_md_attr_t attr_;
std::shared_ptr<TransportLayerResources> tl_rc_;
std::vector<std::shared_ptr<TransportLayer>> tl_;
public:
MemoryDomain(MemoryDomainResources* md_rc, size_t idx) :
md_rc_(md_rc), idx_(idx)
{
// Create a memory domain
UCS_SAFE_CALL(uct_md_config_read(md_rc_->getRaw(idx_)->md_name,
NULL, NULL,
&md_conf_));
UCS_SAFE_CALL(uct_md_open(md_rc_->getRaw(idx_)->md_name, md_conf_, &md_h_));
tl_rc_ = std::make_shared<TransportLayerResources>(md_h_);
for (size_t i = 0; i < tl_rc_->size(); i++) {
tl_.push_back(std::make_shared<TransportLayer>(tl_rc_, i));
}
query();
}
~MemoryDomain() {
uct_config_release(md_conf_);
uct_md_close(md_h_);
}
uct_md_h &get() {
return md_h_;
}
const char *name() const {
return md_rc_->getRaw(idx_)->md_name;
}
// Returns a vector of TransportLayer
std::shared_ptr<TransportLayer> getTL(size_t idx) {
return tl_[idx];
}
size_t getNumTL() const {
return tl_.size();
}
std::string componentName() const {
return std::string(attr_.component_name);
}
size_t maxAlloc() const { return attr_.cap.max_alloc; }
size_t maxReg() const {return attr_.cap.max_reg; }
enum class Flag {
ALLOC = UCT_MD_FLAG_ALLOC,
REG = UCT_MD_FLAG_REG,
NEED_MEMH = UCT_MD_FLAG_NEED_MEMH,
NEED_RKEY = UCT_MD_FLAG_NEED_RKEY,
ADVISE = UCT_MD_FLAG_ADVISE,
FIXED = UCT_MD_FLAG_FIXED,
RKEY_PTR = UCT_MD_FLAG_RKEY_PTR,
SOCKADDR = UCT_MD_FLAG_SOCKADDR,
};
Flag flag() const { return static_cast<Flag>(attr_.cap.flags); }
// TODO
// memoryType() const { ... }
size_t rKeyPackedSize() const { return attr_.rkey_packed_size; }
// TODO
// cpuset_t localCpus() const { return attr_.local_cpus; }
private:
MemoryDomain(const MemoryDomain& rhs) = delete;
MemoryDomain& operator=(const MemoryDomain& rhs) = delete;
void query() {
UCS_SAFE_CALL(uct_md_query(md_h_, &attr_));
}
};
class IFace;
template<class T>
class Addr {
size_t addr_len_;
uint8_t *addr_;
public:
Addr() : addr_len_(0), addr_(nullptr) { }
Addr(const void *addr, size_t len) {
addr_len_ = len;
addr_ = new uint8_t[len];
memcpy(addr_, addr, len);
}
~Addr() {
delete[] addr_;
}
size_t size() const { return addr_len_; }
T* get() { return reinterpret_cast<T*>(addr_); }
const T* get() const { return reinterpret_cast<const T*>(addr_); }
std::string toString() const {
#if 0
std::stringstream ss;
for (int i = 0; i < addr_len_; i++) {
if (i > 0) { ss << ":"; }
ss << std::hex << std::setw(2) << std::setfill('0') << (int)addr_[i];
}
return ss.str();
#else
char buf[100];
char *p = buf;
for (size_t i = 0; i < addr_len_; i++) {
p += sprintf(p, "%s%02x", (i==0 ? "" : ":"), ((unsigned char*)addr_)[i]);
}
return std::string(buf);
#endif
}
};
using DeviceAddr = Addr<uct_device_addr_t>;
using IfaceAddr = Addr<uct_iface_addr_t>;
using EpAddr = Addr<uct_ep_addr_t>;
class Endpoint {
bool destroyed_;
size_t addr_size_;
uct_ep_h ep_;
std::unique_ptr<EpAddr> addr_;
public:
Endpoint(const IFace& iface);
Endpoint(const IFace& iface, const DeviceAddr& peer_dev_addr, const IfaceAddr& peer_if_addr);
~Endpoint();
EpAddr& addr();
const EpAddr& addr() const;
uct_ep_h get() { return ep_; }
uct_ep_h get() const { return ep_; }
void connect(const DeviceAddr& peer_dev_addr, const EpAddr& peer_ep_addr);
void destroy() throw();
};
class IFace {
uct_iface_h iface_;
uct_iface_attr_t attr_;
public:
using params_t = uct_iface_params_t;
IFace() : iface_(nullptr) { }
void open(std::shared_ptr<Worker> worker,
std::shared_ptr<MemoryDomain> md, params_t& params) {
uct_iface_config_t *config;
UCS_SAFE_CALL(uct_md_iface_config_read(md->get(),
params.mode.device.tl_name,
NULL, NULL, &config));
auto status = uct_iface_open(md->get(), worker->get(),
&params, config, &iface_);
uct_config_release(config);
UCS_SAFE_CALL(status);
UCS_SAFE_CALL(uct_iface_query(iface_, &attr_));
}
template<class T>
static ucs_status_t _am_handler(void * arg, void *data, size_t length, unsigned flags) {
T& functor = *reinterpret_cast<T*>(arg);
return functor(data, length, flags);
}
template<class T>
void setAmHandler(T* functor, int id, unsigned flag) {
uct_iface_set_am_handler(iface_, id, &_am_handler<T>,
reinterpret_cast<void*>(functor), flag);
}
void setAmHandler(ucs_status_t (*pf)(void *arg, void *data, size_t length,unsigned flags),
void *data, int id, unsigned flag) {
uct_iface_set_am_handler(iface_, id, pf, data, flag);
}
void progressEnable(uct_progress_types flags) {
uct_iface_progress_enable(iface_, flags);
}
void progressEnable(unsigned flags) {
uct_iface_progress_enable(iface_, flags);
}
void progressDisable(unsigned flags) {
uct_iface_progress_disable(iface_, flags);
}
void progressDisable(uct_progress_types flags) {
uct_iface_progress_disable(iface_, flags);
}
std::shared_ptr<IfaceAddr> getIfaceAddress() {
size_t len = attr_.iface_addr_len;
std::vector<char> addr(len);
auto* p = reinterpret_cast<uct_iface_addr_t*>(addr.data());
UCS_SAFE_CALL(uct_iface_get_address(iface_, p));
return std::make_shared<IfaceAddr>(p, len);
}
std::shared_ptr<DeviceAddr> getDeviceAddress() {
size_t len = attr_.device_addr_len;
std::vector<char> addr(len);
auto* p = reinterpret_cast<uct_device_addr_t*>(addr.data());
UCS_SAFE_CALL(uct_iface_get_device_address(iface_, p));
return std::make_shared<DeviceAddr>(p, len);
}
void close() { uct_iface_close(iface_); }
uct_iface_h get() { return iface_; }
uct_iface_h get() const { return iface_; }
const uct_iface_attr_t &attr() const { return attr_; }
int isReachable(const uct_device_addr_t *addr) {
return uct_iface_is_reachable(iface_, addr, NULL);
}
bool isReachable(const DeviceAddr &addr) {
return this->isReachable(addr.get());
}
bool isReachable(const std::shared_ptr<DeviceAddr> &addr) {
return this->isReachable(addr->get());
}
};
Endpoint::Endpoint(const IFace &iface) : destroyed_(false) {
UCS_SAFE_CALL(uct_ep_create(iface.get(), &ep_));
addr_size_ = iface.attr().ep_addr_len;
std::vector<char> addr(addr_size_);
UCS_SAFE_CALL(uct_ep_get_address(ep_, reinterpret_cast<uct_ep_addr_t*>(addr.data())));
addr_.reset(new EpAddr(addr.data(), addr_size_));
}
Endpoint::Endpoint(const IFace& iface, const DeviceAddr& peer_dev_addr, const IfaceAddr& peer_if_addr)
: destroyed_(false) {
UCS_SAFE_CALL(uct_ep_create_connected(iface.get(),
peer_dev_addr.get(),
peer_if_addr.get(), &ep_));
}
void Endpoint::destroy() throw() {
if (!destroyed_) {
uct_ep_destroy(ep_);
ep_ = nullptr;
}
destroyed_ = true;
}
Endpoint::~Endpoint() {
this->destroy();
}
const EpAddr& Endpoint::addr() const {
return *addr_;
}
void Endpoint::connect(const DeviceAddr& peer_dev_addr, const EpAddr& peer_ep_addr) {
UCS_SAFE_CALL(uct_ep_connect_to_ep(ep_, peer_dev_addr.get(), peer_ep_addr.get()));
}
EpAddr& Endpoint::addr() {
return *addr_;
}
} // namespace UCT
#endif // UCXPP_H__
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment