Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Created April 20, 2022 22:59
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save merrymercy/4785e638ce37f44f9b58b551c8d870ed to your computer and use it in GitHub Desktop.
Save merrymercy/4785e638ce37f44f9b58b551c8d870ed to your computer and use it in GitHub Desktop.
// Original : https://github.com/alpa-projects/tensorflow-alpa/blob/d298f84474a04ecce02085332793e6115c0c8e0e/tensorflow/compiler/xla/service/spmd/auto_sharding_strategy.h#L854-L876
if (adj_list.size() > 1) {
// Merge src to dst.
//
// Before:
//
// src ---- adj ---- dst
// | |
// -------------------
//
// After:
//
// adj ---- dst
//
for (int adj : adj_list) {
if (adj == dst) {
continue;
}
Matrix added_edge_cost(node_lens[dst], node_lens[adj]);
for (int i = 0; i < node_lens[dst]; ++i) {
int j = reindexing[i];
Matrix edge_cost_src_adj = GetEdgeCost(src, adj);
for (int k = 0; k < node_lens[adj]; ++k) {
added_edge_cost(i, k) = edge_cost(i, j) + edge_cost_src_adj(j, k);
}
}
AddEdgeCost(dst, adj, added_edge_cost);
}
} else {
// Merge src to dst.
//
// Before:
//
// src ---- dst
//
// After:
//
// dst
//
for (int i = 0; i < node_lens[dst]; ++i) {
int j = reindexing[i];
extra_node_costs[dst][i] += edge_cost(i, j);
}
}
// Updated:
for (int adj : adj_list) {
if (adj == dst) {
for (int i = 0; i < node_lens[dst]; ++i) {
int j = reindexing[i];
extra_node_costs[dst][i] += edge_cost(i, j);
}
} else {
Matrix added_edge_cost(node_lens[dst], node_lens[adj]);
for (int i = 0; i < node_lens[dst]; ++i) {
int j = reindexing[i];
Matrix edge_cost_src_adj = GetEdgeCost(src, adj);
for (int k = 0; k < node_lens[adj]; ++k) {
added_edge_cost(i, k) = edge_cost_src_adj(j, k);
}
}
AddEdgeCost(dst, adj, added_edge_cost);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment