Skip to content

Instantly share code, notes, and snippets.

@lan496
Created March 2, 2021 13:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save lan496/4ac1eba2b658b2582487d1e0f2a4fbbf to your computer and use it in GitHub Desktop.
Save lan496/4ac1eba2b658b2582487d1e0f2a4fbbf to your computer and use it in GitHub Desktop.
Usage of Kokkos::StaticCrsGraph
#include <vector>
#include <cassert>
#include <iostream>
#include "Kokkos_Core.hpp"
#include "Kokkos_StaticCrsGraph.hpp"
// StaticGraphType<DataType, Space, Layout, MemoryTrait, SizeType>
using Space = Kokkos::DefaultExecutionSpace;
using StaticCrsGraphType = Kokkos::StaticCrsGraph<int, Space, void, void, int>;
// Ref: https://github.com/kokkos/kokkos/blob/1fb0c284d458c75370094921d9f202c287502325/containers/src/Kokkos_StaticCrsGraph.hpp
// Ref: https://github.com/kokkos/kokkos/blob/1fb0c284d458c75370094921d9f202c287502325/containers/unit_tests/TestStaticCrsGraph.hpp
int main(int argc, char* argv[]) {
Kokkos::initialize(argc, argv);
{
const int n = 100;
std::vector<std::vector<int>> graph(n);
for (int i = 0; i < n; ++i) {
for (int j = 0; j < i; ++j) {
graph[i].emplace_back(i);
}
}
StaticCrsGraphType d_graph;
StaticCrsGraphType::HostMirror h_graph;
d_graph = Kokkos::create_staticcrsgraph<StaticCrsGraphType>("d_graph", graph);
assert(d_graph.is_allocated());
h_graph = Kokkos::create_mirror(d_graph);
assert(h_graph.numRows() == n);
for (int i = 0; i < n; ++i) {
// StaticCrsGraph.row_map[i] is the position of the first element in the i-th row
const int begin = h_graph.row_map[i];
const int end = h_graph.row_map[i + 1];
assert(end - begin == graph[i].size());
auto rowView = h_graph.rowConst(i); // just use "auto" here
assert((int)rowView.length == end - begin);
for (int j = 0; j < rowView.length; ++j) {
assert(rowView.colidx(j) == graph[i][j]);
assert(rowView(j) == graph[i][j]); // alias for colidx
}
}
int sum = 0;
using team_policy = Kokkos::TeamPolicy<>;
using member_type = team_policy::member_type;
Kokkos::parallel_reduce("sum_graph",
team_policy(n, Kokkos::AUTO),
KOKKOS_LAMBDA(const member_type& teamMember, int& update) {
const int row = teamMember.league_rank();
auto rowView = d_graph.rowConst(row);
int tmpSum = 0;
Kokkos::parallel_reduce(
Kokkos::TeamThreadRange(teamMember, rowView.length),
[=] (const int ii, int& innerTmpSum) { // no need to capture reference here
innerTmpSum += rowView(ii);
},
tmpSum
);
if (teamMember.team_rank() == 0) {
update += tmpSum;
}
},
sum
);
std::cout << "Sum (actual): " << sum << std::endl;
int expect = (n - 1) * n * (2 * n - 1) / 6;
std::cout << "Sum (expect): " << expect << std::endl;
assert(sum == expect);
} // need to deallocate views before Kokkos::finalize is called!
Kokkos::finalize();
return 0;
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment