Last active
February 26, 2018 02:43
-
-
Save keisukefukuda/bfa28464b10d36ae2321b22db73f7b38 to your computer and use it in GitHub Desktop.
UCX: Geta list of memory domain / transport layer / device
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Written by Keisuke Fukuda | |
// Copyright (C) All rights reserved. | |
// This code is licensed under the MIT License. | |
#include <cassert> | |
#include <cstring> | |
#include <vector> | |
#include <iostream> | |
#include <memory> | |
#include <uct/api/uct.h> | |
namespace UCS { | |
class AsyncContext { | |
ucs_async_context_t *ctx_; | |
public: | |
AsyncContext(ucs_async_mode_t mode) { | |
ucs_status_t st; | |
st = ucs_async_context_create(mode, &ctx_); | |
assert(st == UCS_OK); | |
} | |
~AsyncContext() { | |
ucs_async_context_destroy(ctx_); | |
} | |
inline const ucs_async_context_t *get() const { | |
return ctx_; | |
} | |
}; | |
} | |
namespace UCT { | |
class TransportLayer; | |
class TransportLayerResources; | |
class MemoryDomain; | |
class MemoryDomainResources; | |
class Worker { | |
uct_worker_h handle_; | |
bool good_; | |
ucs_status_t st_; | |
public: | |
Worker(const UCS::AsyncContext &async, ucs_thread_mode_t tm) { | |
st_ = uct_worker_create(const_cast<ucs_async_context*>(async.get()), | |
tm, &handle_); | |
good_ = (st_ == UCS_OK); | |
} | |
~Worker() { | |
uct_worker_destroy(handle_); | |
} | |
ucs_status_t stauts() const { | |
return st_; | |
} | |
bool good() const { | |
return good_; | |
} | |
}; | |
class TransportLayerResources { | |
uct_tl_resource_desc_t *tl_rc_; | |
unsigned num_; | |
public: | |
TransportLayerResources(uct_md_h md_h) { | |
ucs_status_t st = uct_md_query_tl_resources(md_h, &tl_rc_, &num_); | |
assert(st == UCS_OK); | |
} | |
~TransportLayerResources() { | |
uct_release_tl_resource_list(tl_rc_); | |
} | |
uct_tl_resource_desc_t* getRaw(size_t i) { | |
assert(i < num_); | |
return &tl_rc_[i]; | |
} | |
size_t size() const { | |
return num_; | |
} | |
}; | |
class TransportLayer { | |
std::shared_ptr<TransportLayerResources> tl_rc_; | |
size_t idx_; | |
public: | |
TransportLayer(std::shared_ptr<TransportLayerResources> p, size_t idx) | |
: tl_rc_(p), idx_(idx) | |
{ | |
assert(tl_rc_->size() > idx); | |
} | |
const char *getTlName() const { return tl_rc_->getRaw(idx_)->tl_name; } | |
const char *getDevName() const { return tl_rc_->getRaw(idx_)->dev_name; } | |
}; | |
class MemoryDomainResources { | |
uct_md_resource_desc_t *md_rc_; | |
unsigned num_md_rc_; | |
std::vector<std::shared_ptr<MemoryDomain>> md_; | |
public: | |
MemoryDomainResources() | |
: md_rc_(nullptr), num_md_rc_(0), md_() | |
{ | |
ucs_status_t st = uct_query_md_resources(&md_rc_, &num_md_rc_); | |
assert(md_rc_ != nullptr); | |
assert(st == UCS_OK); | |
for (unsigned i = 0; i < num_md_rc_; i++) { | |
md_.push_back(std::make_shared<MemoryDomain>(this, i)); | |
} | |
} | |
~MemoryDomainResources() { | |
uct_release_md_resource_list(md_rc_); | |
} | |
std::shared_ptr<MemoryDomain> get(size_t idx) { | |
return md_[idx]; | |
} | |
uct_md_resource_desc_t* getRaw(size_t idx) { | |
return &md_rc_[idx]; | |
} | |
size_t size() const { | |
return md_.size(); | |
} | |
private: | |
MemoryDomainResources(const MemoryDomainResources &rhs) = delete; | |
MemoryDomainResources& operator=(const MemoryDomainResources &rhs) = delete; | |
}; | |
class MemoryDomain { | |
MemoryDomainResources *md_rc_; | |
size_t idx_; | |
uct_md_h md_h_; | |
uct_md_config_t *md_conf_; | |
std::shared_ptr<TransportLayerResources> tl_rc_; | |
std::vector<std::shared_ptr<TransportLayer>> tl_; | |
public: | |
MemoryDomain(MemoryDomainResources* md_rc, size_t idx) : | |
md_rc_(md_rc), idx_(idx) | |
{ | |
// Create a memory domain | |
ucs_status_t st = uct_md_config_read(md_rc_->getRaw(idx_)->md_name, | |
NULL, NULL, | |
&md_conf_); | |
assert(st == UCS_OK); | |
st = uct_md_open(md_rc_->getRaw(idx_)->md_name, md_conf_, &md_h_); | |
assert(st == UCS_OK); | |
tl_rc_ = std::make_shared<TransportLayerResources>(md_h_); | |
for (size_t i = 0; i < tl_rc_->size(); i++) { | |
tl_.push_back(std::make_shared<TransportLayer>(tl_rc_, i)); | |
} | |
} | |
~MemoryDomain() { | |
uct_config_release(md_conf_); | |
uct_md_close(md_h_); | |
} | |
const char *name() const { | |
return md_rc_->getRaw(idx_)->md_name; | |
} | |
// Returns a vector of TransportLayer | |
std::shared_ptr<TransportLayer> getTL(size_t idx) { | |
return tl_[idx]; | |
} | |
size_t getNumTL() const { | |
return tl_.size(); | |
} | |
private: | |
MemoryDomain(const MemoryDomain& rhs) = delete; | |
MemoryDomain& operator=(const MemoryDomain& rhs) = delete; | |
}; | |
} // namespace UCT | |
void dev_tl_lookup(std::shared_ptr<UCT::MemoryDomainResources> &md_rs, | |
const std::string& dev_name, | |
const std::string &tl_name) { | |
std::cout << "Resources:" << std::endl; | |
for (size_t i = 0; i < md_rs->size(); i++) { | |
auto md = md_rs->get(i); | |
std::cout << "\t" << md->name() << std::endl; | |
for (size_t j = 0; j < md->getNumTL(); j++) { | |
auto tl = md->getTL(j); | |
std::cout << "\t\t" << tl->getDevName() << "/" << tl->getTlName() | |
<< std::endl; | |
} | |
} | |
for (size_t i = 0; i < md_rs->size(); i++) { | |
auto md = md_rs->get(i); | |
for (size_t j = 0; j < md->getNumTL(); j++) { | |
auto tl = md->getTL(j); | |
if (dev_name == tl->getDevName() && | |
tl_name == tl->getTlName()) { | |
std::cout << "Using " | |
<< tl->getTlName() << "/" | |
<< tl->getDevName() << std::endl; | |
} | |
} | |
} | |
} | |
int main(int argc, char **argv) { | |
assert(argc == 3); | |
UCS::AsyncContext async(UCS_ASYNC_MODE_THREAD); | |
UCT::Worker worker(async, UCS_THREAD_MODE_SINGLE); | |
auto md_rs = std::make_shared<UCT::MemoryDomainResources>(); | |
std::string dev_query(argv[1]); | |
std::string tl_query(argv[2]); | |
dev_tl_lookup(md_rs, dev_query, tl_query); | |
return 0; | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment