Created
February 26, 2018 14:51
-
-
Save keisukefukuda/cf55bf8aaf9f343f529c25ff46487ce0 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Written by Keisuke Fukuda | |
// All rights reserved. | |
// Released under the MIT License. | |
#include <cassert> | |
#include <cstring> | |
#include <vector> | |
#include <iostream> | |
#include <memory> | |
#include <uct/api/uct.h> | |
#define IDENT_TO_STR_(ident) case ident: return #ident | |
namespace { | |
const char * enumToStr(ucs_status_t s) { | |
switch(s) { | |
IDENT_TO_STR_(UCS_INPROGRESS); | |
IDENT_TO_STR_(UCS_ERR_NO_MESSAGE); | |
IDENT_TO_STR_(UCS_ERR_NO_RESOURCE); | |
IDENT_TO_STR_(UCS_ERR_IO_ERROR); | |
IDENT_TO_STR_(UCS_ERR_NO_MEMORY); | |
IDENT_TO_STR_(UCS_ERR_INVALID_PARAM); | |
IDENT_TO_STR_(UCS_ERR_UNREACHABLE); | |
IDENT_TO_STR_(UCS_ERR_INVALID_ADDR); | |
IDENT_TO_STR_(UCS_ERR_NOT_IMPLEMENTED); | |
IDENT_TO_STR_(UCS_ERR_MESSAGE_TRUNCATED); | |
IDENT_TO_STR_(UCS_ERR_NO_PROGRESS); | |
IDENT_TO_STR_(UCS_ERR_BUFFER_TOO_SMALL); | |
IDENT_TO_STR_(UCS_ERR_NO_ELEM); | |
IDENT_TO_STR_(UCS_ERR_SOME_CONNECTS_FAILED); | |
IDENT_TO_STR_(UCS_ERR_NO_DEVICE); | |
IDENT_TO_STR_(UCS_ERR_BUSY); | |
IDENT_TO_STR_(UCS_ERR_CANCELED); | |
IDENT_TO_STR_(UCS_ERR_SHMEM_SEGMENT); | |
IDENT_TO_STR_(UCS_ERR_ALREADY_EXISTS); | |
IDENT_TO_STR_(UCS_ERR_OUT_OF_RANGE); | |
IDENT_TO_STR_(UCS_ERR_TIMED_OUT); | |
IDENT_TO_STR_(UCS_ERR_EXCEEDS_LIMIT); | |
IDENT_TO_STR_(UCS_ERR_UNSUPPORTED); | |
IDENT_TO_STR_(UCS_ERR_FIRST_LINK_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_LAST_LINK_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_FIRST_ENDPOINT_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_LAST_ENDPOINT_FAILURE); | |
IDENT_TO_STR_(UCS_ERR_ENDPOINT_TIMEOUT); | |
default: | |
return "Unknown ucs_status_t"; | |
} | |
} | |
} | |
#define UCS_SAFE_CALL(expr) do { \ | |
ucs_status_t st = (expr); \ | |
if (st != UCS_OK) { \ | |
std::cerr << "Error: " \ | |
<< __FILE__ << ":" \ | |
<< __LINE__ << " " \ | |
<< #expr << " failed with " \ | |
<< enumToStr(st) << std::endl; \ | |
exit(-1); \ | |
} \ | |
} while(0) | |
namespace UCS { | |
class AsyncContext { | |
ucs_async_context_t *ctx_; | |
public: | |
AsyncContext(ucs_async_mode_t mode) { | |
UCS_SAFE_CALL(ucs_async_context_create(mode, &ctx_)); | |
} | |
~AsyncContext() { | |
ucs_async_context_destroy(ctx_); | |
} | |
inline const ucs_async_context_t *get() const { | |
return ctx_; | |
} | |
}; | |
} | |
namespace UCT { | |
class TransportLayer; | |
class TransportLayerResources; | |
class MemoryDomain; | |
class MemoryDomainResources; | |
class Worker { | |
uct_worker_h handle_; | |
public: | |
Worker(const UCS::AsyncContext &async, ucs_thread_mode_t tm) { | |
UCS_SAFE_CALL(uct_worker_create(const_cast<ucs_async_context*>(async.get()), | |
tm, &handle_)); | |
} | |
~Worker() { | |
uct_worker_destroy(handle_); | |
} | |
}; | |
class TransportLayerResources { | |
uct_tl_resource_desc_t *tl_rc_; | |
unsigned num_; | |
public: | |
TransportLayerResources(uct_md_h md_h) { | |
UCS_SAFE_CALL(uct_md_query_tl_resources(md_h, &tl_rc_, &num_)); | |
} | |
~TransportLayerResources() { | |
uct_release_tl_resource_list(tl_rc_); | |
} | |
uct_tl_resource_desc_t* getRaw(size_t i) { | |
assert(i < num_); | |
return &tl_rc_[i]; | |
} | |
size_t size() const { | |
return num_; | |
} | |
}; | |
class TransportLayer { | |
std::shared_ptr<TransportLayerResources> tl_rc_; | |
size_t idx_; | |
public: | |
TransportLayer(std::shared_ptr<TransportLayerResources> p, size_t idx) | |
: tl_rc_(p), idx_(idx) | |
{ | |
assert(tl_rc_->size() > idx); | |
} | |
const char *getTlName() const { return tl_rc_->getRaw(idx_)->tl_name; } | |
const char *getDevName() const { return tl_rc_->getRaw(idx_)->dev_name; } | |
}; | |
class MemoryDomainResources { | |
uct_md_resource_desc_t *md_rc_; | |
unsigned num_md_rc_; | |
std::vector<std::shared_ptr<MemoryDomain>> md_; | |
public: | |
MemoryDomainResources() | |
: md_rc_(nullptr), num_md_rc_(0), md_() | |
{ | |
UCS_SAFE_CALL(uct_query_md_resources(&md_rc_, &num_md_rc_)); | |
assert(md_rc_ != nullptr); | |
for (unsigned i = 0; i < num_md_rc_; i++) { | |
md_.push_back(std::make_shared<MemoryDomain>(this, i)); | |
} | |
} | |
~MemoryDomainResources() { | |
uct_release_md_resource_list(md_rc_); | |
} | |
std::shared_ptr<MemoryDomain> get(size_t idx) { | |
return md_[idx]; | |
} | |
uct_md_resource_desc_t* getRaw(size_t idx) { | |
return &md_rc_[idx]; | |
} | |
size_t size() const { | |
return md_.size(); | |
} | |
private: | |
MemoryDomainResources(const MemoryDomainResources &rhs) = delete; | |
MemoryDomainResources& operator=(const MemoryDomainResources &rhs) = delete; | |
}; | |
class MemoryDomain { | |
MemoryDomainResources *md_rc_; | |
size_t idx_; | |
uct_md_h md_h_; | |
uct_md_config_t *md_conf_; | |
uct_md_attr_t attr_; | |
std::shared_ptr<TransportLayerResources> tl_rc_; | |
std::vector<std::shared_ptr<TransportLayer>> tl_; | |
public: | |
MemoryDomain(MemoryDomainResources* md_rc, size_t idx) : | |
md_rc_(md_rc), idx_(idx) | |
{ | |
// Create a memory domain | |
UCS_SAFE_CALL(uct_md_config_read(md_rc_->getRaw(idx_)->md_name, | |
NULL, NULL, | |
&md_conf_)); | |
UCS_SAFE_CALL(uct_md_open(md_rc_->getRaw(idx_)->md_name, md_conf_, &md_h_)); | |
tl_rc_ = std::make_shared<TransportLayerResources>(md_h_); | |
for (size_t i = 0; i < tl_rc_->size(); i++) { | |
tl_.push_back(std::make_shared<TransportLayer>(tl_rc_, i)); | |
} | |
query(); | |
} | |
~MemoryDomain() { | |
uct_config_release(md_conf_); | |
uct_md_close(md_h_); | |
} | |
const char *name() const { | |
return md_rc_->getRaw(idx_)->md_name; | |
} | |
// Returns a vector of TransportLayer | |
std::shared_ptr<TransportLayer> getTL(size_t idx) { | |
return tl_[idx]; | |
} | |
size_t getNumTL() const { | |
return tl_.size(); | |
} | |
std::string componentName() const { | |
return std::string(attr_.component_name); | |
} | |
size_t maxAlloc() const { return attr_.cap.max_alloc; } | |
size_t maxReg() const {return attr_.cap.max_reg; } | |
enum class Flag { | |
ALLOC = UCT_MD_FLAG_ALLOC, | |
REG = UCT_MD_FLAG_REG, | |
NEED_MEMH = UCT_MD_FLAG_NEED_MEMH, | |
NEED_RKEY = UCT_MD_FLAG_NEED_RKEY, | |
ADVISE = UCT_MD_FLAG_ADVISE, | |
FIXED = UCT_MD_FLAG_FIXED, | |
RKEY_PTR = UCT_MD_FLAG_RKEY_PTR, | |
SOCKADDR = UCT_MD_FLAG_SOCKADDR, | |
}; | |
Flag flag() const { return static_cast<Flag>(attr_.cap.flags); } | |
// TODO | |
// memoryType() const { ... } | |
size_t rKeyPackedSize() const { return attr_.rkey_packed_size; } | |
// TODO | |
// cpuset_t localCpus() const { return attr_.local_cpus; } | |
private: | |
MemoryDomain(const MemoryDomain& rhs) = delete; | |
MemoryDomain& operator=(const MemoryDomain& rhs) = delete; | |
void query() { | |
UCS_SAFE_CALL(uct_md_query(md_h_, &attr_)); | |
} | |
}; | |
} // namespace UCT | |
void dev_tl_lookup(std::shared_ptr<UCT::MemoryDomainResources> &md_rs, | |
const std::string& dev_name, | |
const std::string &tl_name) { | |
std::cout << "Resources:" << std::endl; | |
for (size_t i = 0; i < md_rs->size(); i++) { | |
auto md = md_rs->get(i); | |
std::cout << "\t" << md->name() << std::endl; | |
std::cout << "\t\t" << "component name: " << md->componentName() << "\n"; | |
std::cout << "\t\t" << "max alloc: " << md->maxAlloc() << "\n"; | |
std::cout << "\t\t" << " (" | |
<< (size_t)(md->maxAlloc()/1024./1024/1024) << "[GiB])" << "\n"; | |
std::cout << "\t\t" << "max reg: " << md->maxReg() << "\n"; | |
std::cout << "\t\t" << " (" | |
<< (size_t)(md->maxReg()/1024./1024/1024) << "[GiB])\n"; | |
std::cout << "\t\t" << "rkey packed size: " << md->rKeyPackedSize() << "\n"; | |
std::cout << "\t\t" << "Devices:" << std::endl; | |
for (size_t j = 0; j < md->getNumTL(); j++) { | |
auto tl = md->getTL(j); | |
std::cout << "\t\t\t" << tl->getDevName() << "/" << tl->getTlName() | |
<< std::endl; | |
} | |
} | |
// Look for device/tl that matches the query (dev_name, tl_name) | |
for (size_t i = 0; i < md_rs->size(); i++) { | |
auto md = md_rs->get(i); | |
for (size_t j = 0; j < md->getNumTL(); j++) { | |
auto tl = md->getTL(j); | |
if (dev_name == tl->getDevName() && | |
tl_name == tl->getTlName()) { | |
std::cout << "Using " | |
<< tl->getTlName() << "/" | |
<< tl->getDevName() << std::endl; | |
} | |
} | |
} | |
} | |
int main(int argc, char **argv) { | |
assert(argc == 3); | |
UCS::AsyncContext async(UCS_ASYNC_MODE_THREAD); | |
UCT::Worker worker(async, UCS_THREAD_MODE_SINGLE); | |
auto md_rs = std::make_shared<UCT::MemoryDomainResources>(); | |
std::string dev_query(argv[1]); | |
std::string tl_query(argv[2]); | |
dev_tl_lookup(md_rs, dev_query, tl_query); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment