Created
April 1, 2023 22:56
-
-
Save nariaki3551/d63b99166e586b62c8a969cb6de2fd34 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#!/usr/bin/env python | |
# coding: utf-8 | |
import enum | |
import itertools | |
import collections | |
import argparse | |
import numpy as np | |
import pulp | |
def main( | |
num_hosts, | |
num_waves, | |
num_divides, | |
num_muxes, | |
num_splitters, | |
num_switches, | |
num_fxcs, | |
num_pipelines, | |
all_reduce, | |
all_gather, | |
all_to_all, | |
reduce_scatter, | |
ring, | |
time_limit, | |
): | |
symmetry = True | |
ineq_to_eq = True | |
assert (all_reduce + all_gather + all_to_all + reduce_scatter) == 1 | |
assert ( | |
not all_to_all or num_hosts == num_divides | |
) # all_to_all = 1 then num_hosts == num_divides | |
assert ( | |
not reduce_scatter or num_hosts == num_divides | |
) # all_to_all = 1 then num_hosts == num_divides | |
assert num_fxcs >= 1 | |
assert num_pipelines >= 1 | |
# constant values | |
alpha = 2 | |
beta = 0.1 | |
loss_FXC = 0 | |
pipline_cost = 0.01 | |
# preprocess constants for formulation | |
Data = collections.namedtuple("Data", ["host", "id"]) | |
class DArray(np.ndarray): | |
def __new__(cls, H, P): | |
shape = (len(H), len(P)) | |
return super().__new__(cls, shape, dtype=Data) | |
def __init__(self, H, P): | |
for r, h in enumerate(H): | |
for c, p in enumerate(P): | |
self[r, c] = Data(h, p) | |
# variables for formulation | |
H = set(range(num_hosts)) | |
W = set(range(num_waves)) # set of wavelengths | |
P = set(range(num_divides)) | |
D = set(Data(h, p) for h, p in itertools.product(H, P)) | |
D_array = DArray(H, P) | |
T = set() # set of input ports from transceiver (source) | |
R = set() # set of exit (output) ports to transceiver (target) | |
I = set() | |
O = set() | |
N = set(range(num_fxcs)) # set of FXC | |
M = set(range(num_pipelines)) | |
T_host = dict() # T_host[h] ports connected from host h | |
R_host = dict() # R_host[h] ports connected to host h | |
muxs = [] | |
splitters = [] | |
switches = [] | |
offset = 0 | |
# source | |
for h in range(num_hosts): | |
T_host[h] = set(range(offset, offset + 1)) | |
T |= T_host[h] | |
offset += 1 | |
# target | |
for h in range(num_hosts): | |
R_host[h] = set(range(offset, offset + 1)) | |
R |= R_host[h] | |
offset += 1 | |
# muxs | |
for _ in range(num_muxes): | |
muxs.append(list(range(offset, offset + 5))) | |
I |= set(range(offset, offset + 1)) | |
O |= set(range(offset + 1, offset + 5)) | |
offset += 5 | |
# splitters | |
for _ in range(num_splitters): | |
splitters.append(list(range(offset, offset + 5))) | |
I |= set(range(offset, offset + 4)) | |
O |= set(range(offset + 4, offset + 5)) | |
offset += 5 | |
# switches | |
for _ in range(num_switches): | |
switches.append(list(range(offset, offset + 4))) | |
I |= set(range(offset, offset + 2)) | |
O |= set(range(offset + 2, offset + 4)) | |
offset += 4 | |
print("transceiver --> FXC ports", T) | |
print("FXC ports --> transceiver", R) | |
print("transceiver --> FXC ports", T_host) | |
print("FXC ports --> transceiver", R_host) | |
print("mux ports --> FXC ports", set(mux[0] for mux in muxs)) | |
print("FXC ports --> mux ports", set(e for mux in muxs for e in mux[1:5])) | |
print( | |
"splitter ports --> FXC ports", | |
set(e for splitter in splitters for e in splitter[0:4]), | |
) | |
print("FXC ports --> splitter ports", set(splitter[4] for splitter in splitters)) | |
print( | |
"switch ports --> FXC ports", set(e for switch in switches for e in switch[0:2]) | |
) | |
print( | |
"FXC ports --> switch ports", set(e for switch in switches for e in switch[2:4]) | |
) | |
print("Data", D) | |
# formulation | |
prob = pulp.LpProblem() | |
e = pulp.LpVariable.dicts( | |
name="e", indices=itertools.product(T | I, R | O), cat="Binary" | |
) | |
x = pulp.LpVariable.dicts( | |
name="x", | |
indices=itertools.product(N, M, T | I, R | O, D), | |
lowBound=0, | |
upBound=alpha, | |
cat="Continuous", | |
) | |
y = pulp.LpVariable.dicts( | |
name="y", | |
indices=itertools.product(N, M, T | I, R | O, W), | |
lowBound=0, | |
upBound=1, | |
cat="Continuous", | |
) | |
# p = pulp.LpVariable.dicts(name="p", indices=itertools.product(N, M, H, D), cat="Binary") | |
p = pulp.LpVariable.dicts( | |
name="p", | |
indices=itertools.product(N, M, H, D), | |
lowBound=0, | |
upBound=1, | |
cat="Continuous", | |
) | |
lp = pulp.LpVariable.dicts( | |
name="lp", | |
indices=itertools.product(H, D), | |
lowBound=0, | |
upBound=1, | |
cat="Continuous", | |
) | |
s = pulp.LpVariable.dicts( | |
name="s", indices=itertools.product(N, range(num_switches)), cat="Binary" | |
) | |
v = pulp.LpVariable.dicts( | |
name="v", indices=itertools.product(N, M, T | I, R | O, D), cat="Binary" | |
) | |
for n, m, i, j, w in itertools.product(N, M, T, R | O, W): | |
y[n, m, i, j, w] = pulp.LpVariable(f"y_{n}_{m}_{i}_{j}_{w}", cat="Binary") | |
y_base = pulp.LpVariable.dicts( | |
name="y_base", | |
indices=itertools.product(T, R | O, W), | |
lowBound=0, | |
upBound=1, | |
cat="Continuous", | |
) | |
u = pulp.LpVariable.dicts(name="u", indices=itertools.product(N, M), cat="Binary") | |
if all_reduce or reduce_scatter: | |
vv = pulp.LpVariable.dicts( | |
name="_vv", | |
indices=itertools.product(N, M, T | I, R | O, P), | |
lowBound=0, | |
upBound=1, | |
) | |
# cost = pulp.LpVariable.dicts("cost", itertools.product(N, M, T|R), lowBound=0, upBound=len(D)) | |
prod_cost_y = pulp.LpVariable.dicts( | |
"prod_cost_y", itertools.product(N, M, T | I, R | O, W), lowBound=0 | |
) | |
max_cost = pulp.LpVariable.dicts("max_cost", N, lowBound=0, upBound=len(D)) | |
is_zero = lambda x: isinstance(x, (int, float)) and x == 0 | |
# ======================================================== | |
# | |
# Symmetry | |
# | |
# ======================================================== | |
if symmetry: | |
for h in H: | |
assert len(T_host[h]) == 1 | |
assert len(R_host[h]) == 1 | |
t_host = sorted([list(T_host[h])[0] for h in H]) | |
r_host = sorted([list(R_host[h])[0] for h in H]) | |
for _ in range(len(H)): | |
base_i, base_j = t_host[0], r_host[0] | |
for ix in range(1, len(H)): | |
r_host = [r_host[-1]] + r_host[0:-1] | |
base_i, base_j = t_host[0], r_host[0] | |
i, j = t_host[ix], r_host[ix] | |
e[i, j] = e[base_i, base_j] | |
for n, m, w in itertools.product(N, M, W): | |
y[n, m, i, j, w] = y[n, m, base_i, base_j, w] | |
prod_cost_y[n, m, i, j, w] = prod_cost_y[n, m, base_i, base_j, w] | |
# for n, m in itertools.product(N, M): | |
# cost[n, m, i] = cost[n, m, base_i] | |
# cost[n, m, j] = cost[n, m, base_j] | |
for w in W: | |
y_base[i, j, w] = y_base[base_i, base_j, w] | |
for n, m, hix, pix in itertools.product(N, M, H, P): | |
diff_i = i - base_i | |
x[ | |
n, | |
m, | |
i, | |
j, | |
D_array[(hix + diff_i) % len(H), (pix + diff_i) % len(P)], | |
] = x[n, m, base_i, base_j, D_array[hix, pix]] | |
v[ | |
n, | |
m, | |
i, | |
j, | |
D_array[(hix + diff_i) % len(H), (pix + diff_i) % len(P)], | |
] = v[n, m, base_i, base_j, D_array[hix, pix]] | |
# x[n, m, i, j, D_array[(hix+diff_i) % len(H), pix]] = x[n, m, base_i, base_j, D_array[hix, pix]] | |
# v[n, m, i, j, D_array[(hix+diff_i) % len(H), pix]] = v[n, m, base_i, base_j, D_array[hix, pix]] | |
if all_reduce or reduce_scatter: | |
for n, m, pix in itertools.product(N, M, P): | |
vv[n, m, i, j, (pix + diff_i) % len(P)] = vv[ | |
n, m, base_i, base_j, pix | |
] | |
# vv[n, m, i, j, pix] = vv[n, m, base_i, base_j, pix] | |
if ring: | |
t_host = sorted([list(T_host[h])[0] for h in H]) | |
r_host = sorted([list(R_host[h])[0] for h in H]) | |
for key in e: | |
e[key] = 0 | |
for t, r in zip(t_host, r_host[1:] + [r_host[0]]): | |
e[t, r] = 1 | |
# ======================================================== | |
# | |
# fix variables | |
# | |
# ======================================================== | |
# host h cannot transmit data host does not have in the first round | |
for h in H: | |
for d in D - set(D_array[h]): | |
p[0, 0, h, d] = 0 | |
# host h can always transmit data host initially has | |
for n, m, h in itertools.product(N, M, H): | |
for d in D_array[h]: | |
p[n, m, h, d] = 1 | |
for h in H: | |
for d in D_array[h]: | |
lp[h, d] = 1 | |
# host i cannot transmit any data host i donot have in the first round | |
for h in H: | |
for i in T_host[h]: | |
for j, d in itertools.product(R | O, D - set(D_array[h])): | |
x[0, 0, i, j, d] = 0 | |
for h in H: | |
for n, m, i, j, d in itertools.product(N, M, T, R_host[h], D_array[h]): | |
x[n, m, i, j, d] = 0 | |
# mux | |
for mux in muxs: | |
out_port = mux[0] # out_port in I | |
in_ports = mux[1:5] # in_ports subset of O | |
for n, m, i, (in_port, w) in itertools.product( | |
N, M, T | I, zip(in_ports, sorted(list(W))) | |
): | |
for ww in W - {w}: | |
y[n, m, i, in_port, ww] = 0 | |
for in_port in in_ports: | |
e[out_port, in_port] = 0 | |
for n, m, in_port, d in itertools.product(N, M, in_ports, D): | |
x[n, m, out_port, in_port, d] = 0 | |
for n, m, in_port, w in itertools.product(N, M, in_ports, W): | |
y[n, m, out_port, in_port, w] = 0 | |
# splitter | |
for splitter in splitters: | |
out_ports = splitter[0:4] | |
in_port = splitter[4] | |
for out_port in out_ports: | |
e[out_port, in_port] = 0 | |
for n, m, out_port, d in itertools.product(N, M, out_ports, D): | |
x[n, m, out_port, in_port, d] = 0 | |
for n, m, out_port, w in itertools.product(N, M, out_ports, W): | |
y[n, m, out_port, in_port, w] = 0 | |
# switch | |
for switch in switches: | |
out_ports = switch[0:2] | |
in_ports = switch[2:4] | |
for out_port, in_port in itertools.product(out_ports, in_ports): | |
e[out_port, in_port] = 0 | |
for n, m, out_port, in_port, d in itertools.product( | |
N, M, out_ports, in_ports, D | |
): | |
x[n, m, out_port, in_port, d] = 0 | |
for n, m, out_port, in_port, w in itertools.product( | |
N, M, out_ports, in_ports, W | |
): | |
y[n, m, out_port, in_port, w] = 0 | |
# ======================================================== | |
# | |
# Constraints | |
# | |
# ======================================================== | |
# -------------------------------------------------------- | |
# mux constraints | |
# -------------------------------------------------------- | |
for mux_ix, mux in enumerate(muxs): | |
out_port = mux[0] | |
in_ports = mux[1:5] | |
for n, m, d in itertools.product(N, M, D): | |
out_flows = pulp.lpSum(x[n, m, out_port, j, d] for j in R | O) | |
in_flows = pulp.lpSum(x[n, m, j, i, d] for j in T | I for i in in_ports) | |
prob += out_flows <= in_flows - loss_FXC | |
for n, m, (in_port, w) in itertools.product( | |
N, M, zip(in_ports, sorted(list(W))) | |
): | |
out_waves = pulp.lpSum(y[n, m, out_port, i, w] for i in R | O) | |
in_waves = pulp.lpSum(y[n, m, i, in_port, w] for i in T | I) | |
prob += out_waves >= in_waves | |
for n, m, (in_port, w) in itertools.product( | |
N, M, zip(in_ports, sorted(list(W))) | |
): | |
in_waves = pulp.lpSum(y[n, m, i, in_port, w] for i in T | I) | |
in_flows = pulp.lpSum(x[n, m, i, in_port, d] for i in T | I for d in D) | |
prob += alpha * in_waves >= in_flows | |
# -------------------------------------------------------- | |
# splitter constraints | |
# -------------------------------------------------------- | |
for splitter_ix, splitter in enumerate(splitters): | |
out_ports = splitter[0:4] | |
in_port = splitter[4] | |
for n, m, out_port, d in itertools.product(N, M, out_ports, D): | |
in_flows = pulp.lpSum(x[n, m, i, in_port, d] for i in T | I) | |
out_flows = pulp.lpSum(x[n, m, out_port, j, d] for j in R | O) | |
prob += 4 * out_flows == in_flows # under loss_FXC = 0 | |
for n, m, out_port, w in itertools.product(N, M, out_ports, W): | |
in_waves = pulp.lpSum(y[n, m, i, in_port, w] for i in T | I) | |
out_waves = pulp.lpSum(y[n, m, out_port, i, w] for i in R | O) | |
prob += out_waves == in_waves | |
# -------------------------------------------------------- | |
# switch constraints | |
# -------------------------------------------------------- | |
for switch_ix, switch in enumerate(switches): | |
out_ports = switch[0:2] | |
in_ports = switch[2:4] | |
# output constraints | |
a, b = in_ports | |
f, g = out_ports | |
for n, m, d in itertools.product(N, M, D): | |
out_flows_f = pulp.lpSum(x[n, m, f, i, d] for i in R | O) | |
out_flows_g = pulp.lpSum(x[n, m, g, i, d] for i in R | O) | |
in_flows_a = pulp.lpSum(x[n, m, i, a, d] for i in T | I) | |
in_flows_b = pulp.lpSum(x[n, m, i, b, d] for i in T | I) | |
prob += out_flows_f <= in_flows_a + alpha * s[n, switch_ix] | |
prob += out_flows_f <= in_flows_b + alpha * (1 - s[n, switch_ix]) | |
prob += out_flows_g <= in_flows_b + alpha * s[n, switch_ix] | |
prob += out_flows_g <= in_flows_a + alpha * (1 - s[n, switch_ix]) | |
for n, m, w in itertools.product(N, M, W): | |
out_waves_f = pulp.lpSum(y[n, m, f, i, w] for i in R | O) | |
out_waves_g = pulp.lpSum(y[n, m, g, i, w] for i in R | O) | |
in_waves_a = pulp.lpSum(y[n, m, i, a, w] for i in T | I) | |
in_waves_b = pulp.lpSum(y[n, m, i, b, w] for i in T | I) | |
prob += out_waves_f >= in_waves_a - s[n, switch_ix] | |
prob += out_waves_f >= in_waves_b - (1 - s[n, switch_ix]) | |
prob += out_waves_g >= in_waves_b - s[n, switch_ix] | |
prob += out_waves_g >= in_waves_a - (1 - s[n, switch_ix]) | |
# -------------------------------------------------------- | |
# Transceiver power | |
# -------------------------------------------------------- | |
for n, m, h, d in itertools.product(N, M, H, D): | |
for i, j in itertools.product(T_host[h], R | O): | |
if not is_zero(x[n, m, i, j, d]): | |
prob += x[n, m, i, j, d] <= (alpha - loss_FXC) * p[n, m, h, d] | |
# -------------------------------------------------------- | |
# FXC Connection | |
# -------------------------------------------------------- | |
# number of connected port is less than or equal to 1 | |
for i in I: | |
prob += pulp.lpSum(e[i, j] for j in R | O) <= 1 | |
# signals do not flow between unconnected ports | |
for n, m, i, j, d in itertools.product(N, M, T | I, R | O, D): | |
if not is_zero(x[n, m, i, j, d]): | |
prob += x[n, m, i, j, d] <= alpha * e[i, j] | |
# signals do not flow between unconnected ports | |
for n, m, i, j, w in itertools.product(N, M, T | I, R | O, W): | |
if not is_zero(y[n, m, i, j, w]): | |
prob += y[n, m, i, j, w] <= e[i, j] | |
# forbid to enter the multiple save waves in port | |
for n, j, w in itertools.product(N, R | O, W): | |
num_in_wave_w = pulp.lpSum(y[n, m, i, j, w] for m in M for i in T) | |
if not (num_in_wave_w.isNumericalConstant() and num_in_wave_w.constant <= 1): | |
prob += num_in_wave_w <= 1 | |
for i, j in itertools.product(T | I, R | O): | |
sum_waves = pulp.lpSum( | |
y[n, m, i, j, w] for n, m, w in itertools.product(N, M, W) | |
) | |
if not (sum_waves.isNumericalConstant() and sum_waves.constant >= 1): | |
prob += sum_waves >= e[i, j] | |
# signals with same wavelength do not flow on multiple edges from the transceiver› | |
for n, m, i, w in itertools.product(N, M, T, W): | |
prob += pulp.lpSum(y[n, m, i, j, w] for j in R | O) <= 1 | |
for i in T | I: | |
prob += pulp.lpSum(e[i, j] for j in R) <= len(W) | |
# for j in R|O: | |
# prob += pulp.lpSum(e[i, j] for i in T) <= len(W) | |
# -------------------------------------------------------- | |
# definition of p and lp | |
# -------------------------------------------------------- | |
for n, m, h, d in itertools.product(N, M, H, D): | |
if m < num_pipelines - 1: | |
in_flows = pulp.lpSum(x[n, m, i, r, d] for i in T | I for r in R_host[h]) | |
prob += beta * p[n, m + 1, h, d] <= beta * p[n, m, h, d] + in_flows | |
for n, m, h, d in itertools.product(N, M, H, D): | |
if n < num_fxcs - 1: | |
in_flows = pulp.lpSum( | |
x[n, num_pipelines - 1, i, r, d] for i in T | I for r in R_host[h] | |
) | |
prob += ( | |
beta * p[n + 1, 0, h, d] | |
<= beta * p[n, num_pipelines - 1, h, d] + in_flows | |
) | |
for h, d in itertools.product(H, D): | |
in_flows = pulp.lpSum( | |
x[num_fxcs - 1, num_pipelines - 1, i, r, d] | |
for i in T | I | |
for r in R_host[h] | |
) | |
prob += ( | |
beta * lp[h, d] | |
<= beta * p[num_fxcs - 1, num_pipelines - 1, h, d] + in_flows | |
) | |
# -------------------------------------------------------- | |
# for pipeline | |
# -------------------------------------------------------- | |
for n, m, i, j, d in itertools.product(N, M, T, R | O, D): | |
prob += pulp.lpSum(y[n, m, i, j, w] for w in W) >= v[n, m, i, j, d] | |
# y_base = pulp.LpVariable.dicts(name="y_base", indices=itertools.product(T, R|O, W), lowBound=0, upBound=1, cat="Binary") | |
# fixed wavelength between i and j | |
for n, i, j, w in itertools.product(N, T, R | O, W): | |
if ineq_to_eq: | |
prob += y_base[i, j, w] == pulp.lpSum(y[n, m, i, j, w] for m in M) | |
else: | |
prob += y_base[i, j, w] >= pulp.lpSum(y[n, m, i, j, w] for m in M) | |
for i in T: | |
if ineq_to_eq: | |
prob += pulp.lpSum(y_base[i, j, w] for j in R | O for w in W) == len(W) | |
else: | |
prob += pulp.lpSum(y_base[i, j, w] for j in R | O for w in W) <= len(W) | |
# pipline | |
for n, m in itertools.product(N, M): | |
if m == 0: | |
continue | |
for i, j, d in itertools.product(T | I, R | O, D): | |
prob += v[n, m, i, j, d] <= u[n, m] | |
prob += x[n, m, i, j, d] <= u[n, m] | |
prob += u[n, m - 1] >= u[n, m] | |
# -------------------------------------------------------- | |
# In same round, port i can transmit signal with wavelength w only once | |
# -------------------------------------------------------- | |
for n, i, w in itertools.product(N, T, W): | |
prob += pulp.lpSum(y[n, m, i, j, w] for m in M for j in R | O) == 1 | |
# -------------------------------------------------------- | |
# communication data size (receive) | |
# -------------------------------------------------------- | |
for n, m, i, j, d in itertools.product(N, M, T, R | O, D): | |
if is_zero(x[n, m, i, j, d]): | |
v[n, m, i, j, d] = 0 | |
else: | |
prob += alpha * v[n, m, i, j, d] >= x[n, m, i, j, d] | |
for n, m, i, j, d in itertools.product(N, M, T | I, R, D): | |
if is_zero(x[n, m, i, j, d]): | |
v[n, m, i, j, d] = 0 | |
else: | |
prob += alpha * v[n, m, i, j, d] >= x[n, m, i, j, d] | |
# prod_cost_y = pulp.LpVariable.dicts("prod_cost_y", itertools.product(N, M, T|I, R|O, W), lowBound=0) | |
for n, m, i, j, w in itertools.product(N, M, T, R | O, W): | |
if is_zero(y[n, m, i, j, w]): | |
prod_cost_y[y, m, i, j, w] = 0 | |
else: | |
prob += prod_cost_y[n, m, i, j, w] <= len(D) * y[n, m, i, j, w] | |
prob += prod_cost_y[n, m, i, j, w] <= max_cost[n] | |
prob += prod_cost_y[n, m, i, j, w] >= max_cost[n] - len(D) * ( | |
1 - y[n, m, i, j, w] | |
) | |
for n, m, i, j, w in itertools.product(N, M, T | I, R, W): | |
if is_zero(y[n, m, i, j, w]): | |
prod_cost_y[y, m, i, j, w] = 0 | |
else: | |
prob += prod_cost_y[n, m, i, j, w] <= len(D) * y[n, m, i, j, w] | |
prob += prod_cost_y[n, m, i, j, w] <= max_cost[n] | |
prob += prod_cost_y[n, m, i, j, w] >= max_cost[n] - len(D) * ( | |
1 - y[n, m, i, j, w] | |
) | |
if all_reduce or reduce_scatter: | |
# vv[n, i, j, k] = max(v[n, i, j, d(h, k)] for h in H) | |
for n, m, i, j, h, k in itertools.product(N, M, T, R | O, H, P): | |
prob += vv[n, m, i, j, k] >= v[n, m, i, j, D_array[h, k]] | |
for n, m, i, j, h, k in itertools.product(N, M, T | I, R, H, P): | |
prob += vv[n, m, i, j, k] >= v[n, m, i, j, D_array[h, k]] | |
for n, m, i, j in itertools.product(N, M, T, R | O): | |
prob += pulp.lpSum(prod_cost_y[n, m, i, j, w] for w in W) >= pulp.lpSum( | |
vv[n, m, i, j, p] for p in P | |
) | |
for n, m, i, j in itertools.product(N, M, T | I, R): | |
prob += pulp.lpSum(prod_cost_y[n, m, i, j, w] for w in W) >= pulp.lpSum( | |
vv[n, m, i, j, p] for p in P | |
) | |
else: # all_gather | |
for n, m, i, j in itertools.product(N, M, T, R | O): | |
prob += pulp.lpSum(prod_cost_y[n, m, i, j, w] for w in W) >= pulp.lpSum( | |
v[n, m, i, j, d] for d in D | |
) | |
for n, m, i, j in itertools.product(N, M, T | I, R): | |
prob += pulp.lpSum(prod_cost_y[n, m, i, j, w] for w in W) >= pulp.lpSum( | |
v[n, m, i, j, d] for d in D | |
) | |
# max_cost[n] = max(cost[n, i, j] for i in T|I for j in R|O if (i, j) in (I, O)) | |
# for n, m, i in itertools.product(N, M, T): | |
# prob += max_cost[n] >= max_cost[n] * (4 / num_divides) | |
# for n, m, j in itertools.product(N, M, R): | |
# prob += max_cost[n] >= max_cost[n] * (4 / num_divides) | |
if all_reduce or all_gather: | |
for h, d in itertools.product(H, D): | |
prob += lp[h, d] == 1 | |
elif all_to_all: | |
for h in H: | |
for hh in H: | |
prob += lp[h, D_array[hh, h]] == 1 | |
elif reduce_scatter: | |
for h in H: | |
for hh in H: | |
prob += lp[h, D_array[hh, h]] == 1 | |
# prob += pulp.lpSum(e[i, j] for i, j in itertools.product(T|I, R|O)) | |
appendix = ( | |
( | |
0.0001 | |
* pulp.lpSum( | |
vv[n, m, i, j, k] for n, m, i, j, k in itertools.product(N, M, T, R, P) | |
) | |
if all_reduce or reduce_scatter | |
else 0 | |
) | |
# + 0.0001 * pulp.lpSum(e[i, j] for i, j in itertools.product(T, R)) | |
# + 0.0001 * pulp.lpSum(e[i, j] for i, j in itertools.product(sorted(list(T))[:1], sorted(list(R))[len(H):])) | |
- 0.0000001 | |
* pulp.lpSum( | |
x[n, m, i, j, d] | |
for n, m, i, j, d in itertools.product(N, M, T | I, R | O, D) | |
) | |
) | |
prob += ( | |
pulp.lpSum(max_cost[n] for n in N) | |
+ pipline_cost * pulp.lpSum(u[n, m] for n in N for m in M) | |
+ appendix | |
) | |
print("#x-variables", sum(var.name.startswith("x") for var in prob.variables())) | |
print("#y-variables", sum(var.name.startswith("y") for var in prob.variables())) | |
print("#e-variables", sum(var.name.startswith("e") for var in prob.variables())) | |
print("#p-variables", sum(var.name.startswith("p") for var in prob.variables())) | |
print("#lp-variables", sum(var.name.startswith("lp") for var in prob.variables())) | |
print("#s-variables", sum(var.name.startswith("s") for var in prob.variables())) | |
print("#v-variables", sum(var.name.startswith("v") for var in prob.variables())) | |
print("#vv-variables", sum(var.name.startswith("_vv") for var in prob.variables())) | |
print( | |
"#cost-variables", sum(var.name.startswith("cost") for var in prob.variables()) | |
) | |
print( | |
"#max_cost-variables", | |
sum(var.name.startswith("max_cost") for var in prob.variables()), | |
) | |
print( | |
"#prod_cost_y-variables", | |
sum(var.name.startswith("prod_cost_y") for var in prob.variables()), | |
) | |
print("#constraints", len(prob.constraints)) | |
# solve | |
# prob.writeLP("test.lp") | |
prob.writeMPS("test.mps") | |
solver = pulp.HiGHS_CMD(timeLimit=time_limit) | |
prob.solve(solver=solver) | |
# results | |
epsilon = 0.001 | |
# for key, value in x.items(): | |
# if pulp.value(value) >= beta - epsilon: | |
# print("x", key, pulp.value(value)) | |
# for key, value in y.items(): | |
# if pulp.value(value) > epsilon: | |
# print("y", key, pulp.value(value)) | |
# for key, value in y_base.items(): | |
# if pulp.value(value) > epsilon: | |
# print("y_base", key, pulp.value(value)) | |
for key, value in e.items(): | |
if pulp.value(value) > epsilon: | |
print("e", key, pulp.value(value)) | |
# for key, value in p.items(): | |
# if pulp.value(value) > 0: | |
# print("p", key, pulp.value(value)) | |
for key, value in max_cost.items(): | |
if pulp.value(value) > 0: | |
print("max_cost", key, pulp.value(value) * (4 / num_divides)) | |
print( | |
"obj value", pulp.value(pulp.lpSum(max_cost[n] for n in N)) * (4 / num_divides) | |
) | |
# for key, value in num_data_v.items(): | |
# if pulp.value(value) > 0: | |
# print("num_data_var", key, pulp.value(value)) | |
# for key, value in lp.items(): | |
# if pulp.value(value) > 0: | |
# print("lp", key, pulp.value(value)) | |
# <img src="../figures/FXC_sample8.png" width=800> | |
from openpyxl import Workbook | |
from openpyxl.styles.borders import Border, Side | |
from openpyxl.styles import PatternFill, Alignment | |
wb = Workbook() | |
ws = wb.active | |
ws.title = "Experiment Results" | |
thin_border = Border( | |
left=Side(style="thin"), | |
right=Side(style="thin"), | |
top=Side(style="thin"), | |
bottom=Side(style="thin"), | |
) | |
head_row = 2 | |
head_column = 2 | |
ws.cell(head_row, 2).value = "0" | |
ws.cell(head_row, 2).alignment = Alignment(horizontal="center", vertical="center") | |
head_row += 2 | |
gray_fill = PatternFill( | |
start_color="D0CECE", | |
end_color="D0CECE", | |
fill_type="solid", | |
) | |
epsilon = 0.001 | |
for n in N: | |
for m in M: | |
for h in H: | |
for hix, pix in itertools.product(H, P): | |
if pulp.value(p[n, m, h, D_array[hix, pix]]) >= epsilon: | |
ws.cell(head_row + hix, head_column + pix).fill = gray_fill | |
else: | |
for i in T: | |
if ( | |
pulp.value( | |
pulp.lpSum( | |
x[n, m, i, r, D_array[hix, pix]] | |
for r in R_host[h] | |
) | |
) | |
>= beta - epsilon | |
): | |
ws.cell(head_row + hix, head_column + pix).value = i | |
ws.cell( | |
head_row + hix, head_column + pix | |
).alignment = Alignment( | |
horizontal="center", vertical="center" | |
) | |
for row in ws.iter_rows( | |
min_row=head_row, | |
max_row=head_row + num_hosts - 1, | |
min_col=head_column, | |
max_col=head_column + num_divides - 1, | |
): | |
for cell in row: | |
cell.border = thin_border | |
head_column += num_divides + 2 | |
head_row += num_hosts + 2 | |
head_column = 2 | |
# 全ての列の幅を変更 | |
for col in ws.columns: | |
col_letter = col[0].column_letter | |
width = 5 | |
ws.column_dimensions[col_letter].width = width | |
# 全ての行の高さを変更 | |
for row in ws.rows: | |
row_height = 30 | |
ws.row_dimensions[row[0].row].height = row_height | |
wb.save("sample.xlsx") | |
import sh | |
sh.open("sample.xlsx") | |
def argparser(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"--num_hosts", | |
type=int, | |
) | |
parser.add_argument( | |
"--num_waves", | |
type=int, | |
) | |
parser.add_argument( | |
"--num_divides", | |
type=int, | |
) | |
parser.add_argument( | |
"--num_muxes", | |
type=int, | |
default=0, | |
help="number of muxiplexers (default is 0)", | |
) | |
parser.add_argument( | |
"--num_splitters", | |
type=int, | |
default=0, | |
help="number of splitters (default is 0)", | |
) | |
parser.add_argument( | |
"--num_switches", | |
type=int, | |
default=0, | |
help="number of switches (default is 0)", | |
) | |
parser.add_argument( | |
"--num_fxcs", | |
type=int, | |
) | |
parser.add_argument( | |
"--num_pipelines", | |
type=int, | |
) | |
parser.add_argument( | |
"--all_reduce", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--all_gather", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--all_to_all", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--reduce_scatter", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--ring", | |
action="store_true", | |
) | |
parser.add_argument( | |
"--time_limit", | |
type=float, | |
default=3600, | |
help="solver time limit (default is 3600)", | |
) | |
return parser | |
if __name__ == "__main__": | |
parser = argparser() | |
args = parser.parse_args() | |
main( | |
args.num_hosts, | |
args.num_waves, | |
args.num_divides, | |
args.num_muxes, | |
args.num_splitters, | |
args.num_switches, | |
args.num_fxcs, | |
args.num_pipelines, | |
args.all_reduce, | |
args.all_gather, | |
args.all_to_all, | |
args.reduce_scatter, | |
args.ring, | |
args.time_limit, | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment