Skip to content

Instantly share code, notes, and snippets.

Embed
What would you like to do?
UCX: Geta list of memory domain / transport layer / device
// Written by Keisuke Fukuda
// Copyright (C) All rights reserved.
// This code is licensed under the MIT License.
#include <cassert>
#include <cstring>
#include <vector>
#include <iostream>
#include <memory>
#include <uct/api/uct.h>
namespace UCS {
class AsyncContext {
ucs_async_context_t *ctx_;
public:
AsyncContext(ucs_async_mode_t mode) {
ucs_status_t st;
st = ucs_async_context_create(mode, &ctx_);
assert(st == UCS_OK);
}
~AsyncContext() {
ucs_async_context_destroy(ctx_);
}
inline const ucs_async_context_t *get() const {
return ctx_;
}
};
}
namespace UCT {
class TransportLayer;
class TransportLayerResources;
class MemoryDomain;
class MemoryDomainResources;
class Worker {
uct_worker_h handle_;
bool good_;
ucs_status_t st_;
public:
Worker(const UCS::AsyncContext &async, ucs_thread_mode_t tm) {
st_ = uct_worker_create(const_cast<ucs_async_context*>(async.get()),
tm, &handle_);
good_ = (st_ == UCS_OK);
}
~Worker() {
uct_worker_destroy(handle_);
}
ucs_status_t stauts() const {
return st_;
}
bool good() const {
return good_;
}
};
class TransportLayerResources {
uct_tl_resource_desc_t *tl_rc_;
unsigned num_;
public:
TransportLayerResources(uct_md_h md_h) {
ucs_status_t st = uct_md_query_tl_resources(md_h, &tl_rc_, &num_);
assert(st == UCS_OK);
}
~TransportLayerResources() {
uct_release_tl_resource_list(tl_rc_);
}
uct_tl_resource_desc_t* getRaw(size_t i) {
assert(i < num_);
return &tl_rc_[i];
}
size_t size() const {
return num_;
}
};
class TransportLayer {
std::shared_ptr<TransportLayerResources> tl_rc_;
size_t idx_;
public:
TransportLayer(std::shared_ptr<TransportLayerResources> p, size_t idx)
: tl_rc_(p), idx_(idx)
{
assert(tl_rc_->size() > idx);
}
const char *getTlName() const { return tl_rc_->getRaw(idx_)->tl_name; }
const char *getDevName() const { return tl_rc_->getRaw(idx_)->dev_name; }
};
class MemoryDomainResources {
uct_md_resource_desc_t *md_rc_;
unsigned num_md_rc_;
std::vector<std::shared_ptr<MemoryDomain>> md_;
public:
MemoryDomainResources()
: md_rc_(nullptr), num_md_rc_(0), md_()
{
ucs_status_t st = uct_query_md_resources(&md_rc_, &num_md_rc_);
assert(md_rc_ != nullptr);
assert(st == UCS_OK);
for (unsigned i = 0; i < num_md_rc_; i++) {
md_.push_back(std::make_shared<MemoryDomain>(this, i));
}
}
~MemoryDomainResources() {
uct_release_md_resource_list(md_rc_);
}
std::shared_ptr<MemoryDomain> get(size_t idx) {
return md_[idx];
}
uct_md_resource_desc_t* getRaw(size_t idx) {
return &md_rc_[idx];
}
size_t size() const {
return md_.size();
}
private:
MemoryDomainResources(const MemoryDomainResources &rhs) = delete;
MemoryDomainResources& operator=(const MemoryDomainResources &rhs) = delete;
};
class MemoryDomain {
MemoryDomainResources *md_rc_;
size_t idx_;
uct_md_h md_h_;
uct_md_config_t *md_conf_;
std::shared_ptr<TransportLayerResources> tl_rc_;
std::vector<std::shared_ptr<TransportLayer>> tl_;
public:
MemoryDomain(MemoryDomainResources* md_rc, size_t idx) :
md_rc_(md_rc), idx_(idx)
{
// Create a memory domain
ucs_status_t st = uct_md_config_read(md_rc_->getRaw(idx_)->md_name,
NULL, NULL,
&md_conf_);
assert(st == UCS_OK);
st = uct_md_open(md_rc_->getRaw(idx_)->md_name, md_conf_, &md_h_);
assert(st == UCS_OK);
tl_rc_ = std::make_shared<TransportLayerResources>(md_h_);
for (size_t i = 0; i < tl_rc_->size(); i++) {
tl_.push_back(std::make_shared<TransportLayer>(tl_rc_, i));
}
}
~MemoryDomain() {
uct_config_release(md_conf_);
uct_md_close(md_h_);
}
const char *name() const {
return md_rc_->getRaw(idx_)->md_name;
}
// Returns a vector of TransportLayer
std::shared_ptr<TransportLayer> getTL(size_t idx) {
return tl_[idx];
}
size_t getNumTL() const {
return tl_.size();
}
private:
MemoryDomain(const MemoryDomain& rhs) = delete;
MemoryDomain& operator=(const MemoryDomain& rhs) = delete;
};
} // namespace UCT
void dev_tl_lookup(std::shared_ptr<UCT::MemoryDomainResources> &md_rs,
const std::string& dev_name,
const std::string &tl_name) {
std::cout << "Resources:" << std::endl;
for (size_t i = 0; i < md_rs->size(); i++) {
auto md = md_rs->get(i);
std::cout << "\t" << md->name() << std::endl;
for (size_t j = 0; j < md->getNumTL(); j++) {
auto tl = md->getTL(j);
std::cout << "\t\t" << tl->getDevName() << "/" << tl->getTlName()
<< std::endl;
}
}
for (size_t i = 0; i < md_rs->size(); i++) {
auto md = md_rs->get(i);
for (size_t j = 0; j < md->getNumTL(); j++) {
auto tl = md->getTL(j);
if (dev_name == tl->getDevName() &&
tl_name == tl->getTlName()) {
std::cout << "Using "
<< tl->getTlName() << "/"
<< tl->getDevName() << std::endl;
}
}
}
}
int main(int argc, char **argv) {
assert(argc == 3);
UCS::AsyncContext async(UCS_ASYNC_MODE_THREAD);
UCT::Worker worker(async, UCS_THREAD_MODE_SINGLE);
auto md_rs = std::make_shared<UCT::MemoryDomainResources>();
std::string dev_query(argv[1]);
std::string tl_query(argv[2]);
dev_tl_lookup(md_rs, dev_query, tl_query);
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment