Skip to content

Instantly share code, notes, and snippets.

@merrymercy
Created April 7, 2022 00:01
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save merrymercy/c3f16938af5112ba37e50d60be3a3be8 to your computer and use it in GitHub Desktop.
Save merrymercy/c3f16938af5112ba37e50d60be3a3be8 to your computer and use it in GitHub Desktop.
HloModule train_step_shard_parallel.3684, input_output_alias={ {0}: (0, {}, may-alias), {1}: (1, {}, may-alias), {2}: (2, {}, may-alias), {3}: (3, {}, may-alias), {4}: (4, {}, may-alias), {5}: (5, {}, may-alias), {6}: (6, {}, may-alias), {7}: (7, {}, may-alias), {8}: (8, {}, may-alias), {9}: (9, {}, may-alias), {10}: (10, {}, may-alias), {11}: (11, {}, may-alias), {12}: (12, {}, may-alias), {13}: (13, {}, may-alias), {14}: (14, {}, may-alias), {15}: (15, {}, may-alias), {16}: (16, {}, may-alias), {17}: (17, {}, may-alias), {18}: (18, {}, may-alias), {19}: (19, {}, may-alias), {20}: (20, {}, may-alias), {21}: (21, {}, may-alias), {22}: (22, {}, may-alias), {23}: (23, {}, may-alias), {24}: (24, {}, may-alias), {25}: (25, {}, may-alias), {26}: (26, {}, may-alias), {27}: (27, {}, may-alias), {28}: (28, {}, may-alias), {29}: (29, {}, may-alias), {30}: (30, {}, may-alias), {31}: (31, {}, may-alias), {32}: (32, {}, may-alias), {33}: (33, {}, may-alias), {34}: (34, {}, may-alias), {35}: (35, {}, may-alias), {36}: (36, {}, may-alias), {37}: (37, {}, may-alias), {38}: (38, {}, may-alias), {39}: (39, {}, may-alias), {40}: (40, {}, may-alias), {41}: (41, {}, may-alias), {42}: (42, {}, may-alias), {43}: (43, {}, may-alias), {44}: (44, {}, may-alias), {45}: (45, {}, may-alias), {46}: (46, {}, may-alias), {47}: (47, {}, may-alias), {48}: (48, {}, may-alias), {49}: (49, {}, may-alias), {50}: (50, {}, may-alias), {51}: (51, {}, may-alias), {52}: (52, {}, may-alias), {53}: (53, {}, may-alias), {54}: (54, {}, may-alias), {55}: (55, {}, may-alias), {56}: (56, {}, may-alias), {57}: (57, {}, may-alias), {58}: (58, {}, may-alias), {59}: (59, {}, may-alias), {60}: (60, {}, may-alias), {61}: (61, {}, may-alias), {62}: (62, {}, may-alias), {63}: (63, {}, may-alias), {64}: (64, {}, may-alias), {65}: (65, {}, may-alias), {66}: (66, {}, may-alias), {67}: (67, {}, may-alias), {68}: (68, {}, may-alias), {69}: (69, {}, may-alias), {70}: (70, {}, may-alias), {71}: (71, {}, may-alias), {72}: (72, {}, may-alias), {73}: (73, {}, may-alias), {74}: (74, {}, may-alias), {75}: (75, {}, may-alias), {76}: (76, {}, may-alias), {77}: (77, {}, may-alias), {78}: (78, {}, may-alias), {79}: (79, {}, may-alias), {80}: (80, {}, may-alias), {81}: (81, {}, may-alias), {82}: (82, {}, may-alias), {83}: (83, {}, may-alias), {84}: (84, {}, may-alias), {85}: (85, {}, may-alias), {86}: (86, {}, may-alias), {87}: (87, {}, may-alias), {88}: (88, {}, may-alias), {89}: (89, {}, may-alias), {90}: (90, {}, may-alias), {91}: (91, {}, may-alias), {92}: (92, {}, may-alias), {93}: (93, {}, may-alias), {94}: (94, {}, may-alias), {95}: (95, {}, may-alias), {96}: (96, {}, may-alias), {97}: (97, {}, may-alias), {98}: (98, {}, may-alias), {99}: (99, {}, may-alias), {100}: (100, {}, may-alias), {101}: (101, {}, may-alias), {102}: (102, {}, may-alias), {103}: (103, {}, may-alias), {104}: (104, {}, may-alias), {105}: (105, {}, may-alias), {106}: (106, {}, may-alias), {107}: (107, {}, may-alias), {108}: (108, {}, may-alias), {109}: (109, {}, may-alias), {110}: (110, {}, may-alias), {111}: (111, {}, may-alias), {112}: (112, {}, may-alias), {113}: (113, {}, may-alias), {114}: (114, {}, may-alias), {115}: (115, {}, may-alias), {116}: (116, {}, may-alias), {117}: (117, {}, may-alias) }
%primitive_computation_add.166 (parameter.167: f32[], parameter.168: f32[]) -> f32[] {
%parameter.167 = f32[] parameter(0)
%parameter.168 = f32[] parameter(1)
ROOT %add.170 = f32[] add(f32[] %parameter.167, f32[] %parameter.168), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__1.181 (parameter.182: f32[], parameter.183: f32[]) -> f32[] {
%parameter.182 = f32[] parameter(0)
%parameter.183 = f32[] parameter(1)
ROOT %add.185 = f32[] add(f32[] %parameter.182, f32[] %parameter.183), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_max.343 (parameter.344: f16[], parameter.345: f16[]) -> f16[] {
%parameter.344 = f16[] parameter(0)
%parameter.345 = f16[] parameter(1)
ROOT %maximum.347 = f16[] maximum(f16[] %parameter.344, f16[] %parameter.345), metadata={op_type="max" op_name="max" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__2.356 (parameter.357: f32[], parameter.358: f32[]) -> f32[] {
%parameter.357 = f32[] parameter(0)
%parameter.358 = f32[] parameter(1)
ROOT %add.360 = f32[] add(f32[] %parameter.357, f32[] %parameter.358), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__3.395 (parameter.396: f32[], parameter.397: f32[]) -> f32[] {
%parameter.396 = f32[] parameter(0)
%parameter.397 = f32[] parameter(1)
ROOT %add.399 = f32[] add(f32[] %parameter.396, f32[] %parameter.397), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__4.409 (parameter.410: f32[], parameter.411: f32[]) -> f32[] {
%parameter.410 = f32[] parameter(0)
%parameter.411 = f32[] parameter(1)
ROOT %add.413 = f32[] add(f32[] %parameter.410, f32[] %parameter.411), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_add__5.523 (parameter.524: f32[], parameter.525: f32[]) -> f32[] {
%parameter.524 = f32[] parameter(0)
%parameter.525 = f32[] parameter(1)
ROOT %add.527 = f32[] add(f32[] %parameter.524, f32[] %parameter.525), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__6.537 (parameter.538: f32[], parameter.539: f32[]) -> f32[] {
%parameter.538 = f32[] parameter(0)
%parameter.539 = f32[] parameter(1)
ROOT %add.541 = f32[] add(f32[] %parameter.538, f32[] %parameter.539), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_max__1.687 (parameter.688: f16[], parameter.689: f16[]) -> f16[] {
%parameter.688 = f16[] parameter(0)
%parameter.689 = f16[] parameter(1)
ROOT %maximum.691 = f16[] maximum(f16[] %parameter.688, f16[] %parameter.689), metadata={op_type="max" op_name="max" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__7.700 (parameter.701: f32[], parameter.702: f32[]) -> f32[] {
%parameter.701 = f32[] parameter(0)
%parameter.702 = f32[] parameter(1)
ROOT %add.704 = f32[] add(f32[] %parameter.701, f32[] %parameter.702), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__8.739 (parameter.740: f32[], parameter.741: f32[]) -> f32[] {
%parameter.740 = f32[] parameter(0)
%parameter.741 = f32[] parameter(1)
ROOT %add.743 = f32[] add(f32[] %parameter.740, f32[] %parameter.741), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__9.753 (parameter.754: f32[], parameter.755: f32[]) -> f32[] {
%parameter.754 = f32[] parameter(0)
%parameter.755 = f32[] parameter(1)
ROOT %add.757 = f32[] add(f32[] %parameter.754, f32[] %parameter.755), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_add__10.867 (parameter.868: f32[], parameter.869: f32[]) -> f32[] {
%parameter.868 = f32[] parameter(0)
%parameter.869 = f32[] parameter(1)
ROOT %add.871 = f32[] add(f32[] %parameter.868, f32[] %parameter.869), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__11.881 (parameter.882: f32[], parameter.883: f32[]) -> f32[] {
%parameter.882 = f32[] parameter(0)
%parameter.883 = f32[] parameter(1)
ROOT %add.885 = f32[] add(f32[] %parameter.882, f32[] %parameter.883), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_max__2.941 (parameter.942: f16[], parameter.943: f16[]) -> f16[] {
%parameter.942 = f16[] parameter(0)
%parameter.943 = f16[] parameter(1)
ROOT %maximum.945 = f16[] maximum(f16[] %parameter.942, f16[] %parameter.943), metadata={op_type="max" op_name="max" source_file="test_export_hlo.py" source_line=79}
}
%primitive_computation_add__13.966 (parameter.967: f32[], parameter.968: f32[]) -> f32[] {
%parameter.967 = f32[] parameter(0)
%parameter.968 = f32[] parameter(1)
ROOT %add.970 = f32[] add(f32[] %parameter.967, f32[] %parameter.968), metadata={op_type="add" op_name="add" source_file="test_export_hlo.py" source_line=79}
}
%primitive_computation_add__16.999 (parameter.1000: f32[], parameter.1001: f32[]) -> f32[] {
%parameter.1000 = f32[] parameter(0)
%parameter.1001 = f32[] parameter(1)
ROOT %add.1003 = f32[] add(f32[] %parameter.1000, f32[] %parameter.1001), metadata={op_type="add" op_name="add" source_file="test_export_hlo.py" source_line=80}
}
%primitive_computation_add__17.1016 (parameter.1017: f16[], parameter.1018: f16[]) -> f16[] {
%parameter.1017 = f16[] parameter(0)
%parameter.1018 = f16[] parameter(1)
ROOT %add.1020 = f16[] add(f16[] %parameter.1017, f16[] %parameter.1018), metadata={op_type="add" op_name="add" source_file="test_export_hlo.py" source_line=79}
}
%primitive_computation_add__19.1037 (parameter.1038: f16[], parameter.1039: f16[]) -> f16[] {
%parameter.1038 = f16[] parameter(0)
%parameter.1039 = f16[] parameter(1)
ROOT %add.1041 = f16[] add(f16[] %parameter.1038, f16[] %parameter.1039), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/model/gpt_model.py" source_line=75}
}
%primitive_computation_max__3.1127 (parameter.1128: f16[], parameter.1129: f16[]) -> f16[] {
%parameter.1128 = f16[] parameter(0)
%parameter.1129 = f16[] parameter(1)
ROOT %maximum.1131 = f16[] maximum(f16[] %parameter.1128, f16[] %parameter.1129), metadata={op_type="max" op_name="max" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__22.1152 (parameter.1153: f32[], parameter.1154: f32[]) -> f32[] {
%parameter.1153 = f32[] parameter(0)
%parameter.1154 = f32[] parameter(1)
ROOT %add.1156 = f32[] add(f32[] %parameter.1153, f32[] %parameter.1154), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__23.1195 (parameter.1196: f32[], parameter.1197: f32[]) -> f32[] {
%parameter.1196 = f32[] parameter(0)
%parameter.1197 = f32[] parameter(1)
ROOT %add.1199 = f32[] add(f32[] %parameter.1196, f32[] %parameter.1197), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__24.1212 (parameter.1213: f32[], parameter.1214: f32[]) -> f32[] {
%parameter.1213 = f32[] parameter(0)
%parameter.1214 = f32[] parameter(1)
ROOT %add.1216 = f32[] add(f32[] %parameter.1213, f32[] %parameter.1214), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_add__25.1336 (parameter.1337: f32[], parameter.1338: f32[]) -> f32[] {
%parameter.1337 = f32[] parameter(0)
%parameter.1338 = f32[] parameter(1)
ROOT %add.1340 = f32[] add(f32[] %parameter.1337, f32[] %parameter.1338), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__26.1353 (parameter.1354: f32[], parameter.1355: f32[]) -> f32[] {
%parameter.1354 = f32[] parameter(0)
%parameter.1355 = f32[] parameter(1)
ROOT %add.1357 = f32[] add(f32[] %parameter.1354, f32[] %parameter.1355), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_add__27.1389 (parameter.1390: f16[], parameter.1391: f16[]) -> f16[] {
%parameter.1390 = f16[] parameter(0)
%parameter.1391 = f16[] parameter(1)
ROOT %add.1393 = f16[] add(f16[] %parameter.1390, f16[] %parameter.1391), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
}
%primitive_computation_add__29.1411 (parameter.1412: f16[], parameter.1413: f16[]) -> f16[] {
%parameter.1412 = f16[] parameter(0)
%parameter.1413 = f16[] parameter(1)
ROOT %add.1415 = f16[] add(f16[] %parameter.1412, f16[] %parameter.1413), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__31.1429 (parameter.1430: f16[], parameter.1431: f16[]) -> f16[] {
%parameter.1430 = f16[] parameter(0)
%parameter.1431 = f16[] parameter(1)
ROOT %add.1433 = f16[] add(f16[] %parameter.1430, f16[] %parameter.1431), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__33.1453 (parameter.1454: f32[], parameter.1455: f32[]) -> f32[] {
%parameter.1454 = f32[] parameter(0)
%parameter.1455 = f32[] parameter(1)
ROOT %add.1457 = f32[] add(f32[] %parameter.1454, f32[] %parameter.1455), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
}
%primitive_computation_add__35.1475 (parameter.1476: f16[], parameter.1477: f16[]) -> f16[] {
%parameter.1476 = f16[] parameter(0)
%parameter.1477 = f16[] parameter(1)
ROOT %add.1479 = f16[] add(f16[] %parameter.1476, f16[] %parameter.1477), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_add__36.1501 (parameter.1502: f16[], parameter.1503: f16[]) -> f16[] {
%parameter.1502 = f16[] parameter(0)
%parameter.1503 = f16[] parameter(1)
ROOT %add.1505 = f16[] add(f16[] %parameter.1502, f16[] %parameter.1503), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_add__37.1514 (parameter.1515: f16[], parameter.1516: f16[]) -> f16[] {
%parameter.1515 = f16[] parameter(0)
%parameter.1516 = f16[] parameter(1)
ROOT %add.1518 = f16[] add(f16[] %parameter.1515, f16[] %parameter.1516), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
}
%primitive_computation_add__39.1536 (parameter.1537: f16[], parameter.1538: f16[]) -> f16[] {
%parameter.1537 = f16[] parameter(0)
%parameter.1538 = f16[] parameter(1)
ROOT %add.1540 = f16[] add(f16[] %parameter.1537, f16[] %parameter.1538), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__41.1554 (parameter.1555: f16[], parameter.1556: f16[]) -> f16[] {
%parameter.1555 = f16[] parameter(0)
%parameter.1556 = f16[] parameter(1)
ROOT %add.1558 = f16[] add(f16[] %parameter.1555, f16[] %parameter.1556), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__43.1578 (parameter.1579: f32[], parameter.1580: f32[]) -> f32[] {
%parameter.1579 = f32[] parameter(0)
%parameter.1580 = f32[] parameter(1)
ROOT %add.1582 = f32[] add(f32[] %parameter.1579, f32[] %parameter.1580), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
}
%primitive_computation_add__45.1600 (parameter.1601: f16[], parameter.1602: f16[]) -> f16[] {
%parameter.1601 = f16[] parameter(0)
%parameter.1602 = f16[] parameter(1)
ROOT %add.1604 = f16[] add(f16[] %parameter.1601, f16[] %parameter.1602), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_add__46.1633 (parameter.1634: f16[], parameter.1635: f16[]) -> f16[] {
%parameter.1634 = f16[] parameter(0)
%parameter.1635 = f16[] parameter(1)
ROOT %add.1637 = f16[] add(f16[] %parameter.1634, f16[] %parameter.1635), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__48.1681 (parameter.1682: f16[], parameter.1683: f16[]) -> f16[] {
%parameter.1682 = f16[] parameter(0)
%parameter.1683 = f16[] parameter(1)
ROOT %add.1685 = f16[] add(f16[] %parameter.1682, f16[] %parameter.1683), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_max__4.1765 (parameter.1766: f16[], parameter.1767: f16[]) -> f16[] {
%parameter.1766 = f16[] parameter(0)
%parameter.1767 = f16[] parameter(1)
ROOT %maximum.1769 = f16[] maximum(f16[] %parameter.1766, f16[] %parameter.1767), metadata={op_type="max" op_name="max" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__50.1790 (parameter.1791: f32[], parameter.1792: f32[]) -> f32[] {
%parameter.1791 = f32[] parameter(0)
%parameter.1792 = f32[] parameter(1)
ROOT %add.1794 = f32[] add(f32[] %parameter.1791, f32[] %parameter.1792), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__51.1833 (parameter.1834: f32[], parameter.1835: f32[]) -> f32[] {
%parameter.1834 = f32[] parameter(0)
%parameter.1835 = f32[] parameter(1)
ROOT %add.1837 = f32[] add(f32[] %parameter.1834, f32[] %parameter.1835), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__52.1850 (parameter.1851: f32[], parameter.1852: f32[]) -> f32[] {
%parameter.1851 = f32[] parameter(0)
%parameter.1852 = f32[] parameter(1)
ROOT %add.1854 = f32[] add(f32[] %parameter.1851, f32[] %parameter.1852), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_add__53.1974 (parameter.1975: f32[], parameter.1976: f32[]) -> f32[] {
%parameter.1975 = f32[] parameter(0)
%parameter.1976 = f32[] parameter(1)
ROOT %add.1978 = f32[] add(f32[] %parameter.1975, f32[] %parameter.1976), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
}
%primitive_computation_add__54.1991 (parameter.1992: f32[], parameter.1993: f32[]) -> f32[] {
%parameter.1992 = f32[] parameter(0)
%parameter.1993 = f32[] parameter(1)
ROOT %add.1995 = f32[] add(f32[] %parameter.1992, f32[] %parameter.1993), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
}
%primitive_computation_add__55.2027 (parameter.2028: f16[], parameter.2029: f16[]) -> f16[] {
%parameter.2028 = f16[] parameter(0)
%parameter.2029 = f16[] parameter(1)
ROOT %add.2031 = f16[] add(f16[] %parameter.2028, f16[] %parameter.2029), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
}
%primitive_computation_add__57.2049 (parameter.2050: f16[], parameter.2051: f16[]) -> f16[] {
%parameter.2050 = f16[] parameter(0)
%parameter.2051 = f16[] parameter(1)
ROOT %add.2053 = f16[] add(f16[] %parameter.2050, f16[] %parameter.2051), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__59.2067 (parameter.2068: f16[], parameter.2069: f16[]) -> f16[] {
%parameter.2068 = f16[] parameter(0)
%parameter.2069 = f16[] parameter(1)
ROOT %add.2071 = f16[] add(f16[] %parameter.2068, f16[] %parameter.2069), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__61.2091 (parameter.2092: f32[], parameter.2093: f32[]) -> f32[] {
%parameter.2092 = f32[] parameter(0)
%parameter.2093 = f32[] parameter(1)
ROOT %add.2095 = f32[] add(f32[] %parameter.2092, f32[] %parameter.2093), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
}
%primitive_computation_add__63.2113 (parameter.2114: f16[], parameter.2115: f16[]) -> f16[] {
%parameter.2114 = f16[] parameter(0)
%parameter.2115 = f16[] parameter(1)
ROOT %add.2117 = f16[] add(f16[] %parameter.2114, f16[] %parameter.2115), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_add__64.2139 (parameter.2140: f16[], parameter.2141: f16[]) -> f16[] {
%parameter.2140 = f16[] parameter(0)
%parameter.2141 = f16[] parameter(1)
ROOT %add.2143 = f16[] add(f16[] %parameter.2140, f16[] %parameter.2141), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_add__65.2152 (parameter.2153: f16[], parameter.2154: f16[]) -> f16[] {
%parameter.2153 = f16[] parameter(0)
%parameter.2154 = f16[] parameter(1)
ROOT %add.2156 = f16[] add(f16[] %parameter.2153, f16[] %parameter.2154), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
}
%primitive_computation_add__67.2174 (parameter.2175: f16[], parameter.2176: f16[]) -> f16[] {
%parameter.2175 = f16[] parameter(0)
%parameter.2176 = f16[] parameter(1)
ROOT %add.2178 = f16[] add(f16[] %parameter.2175, f16[] %parameter.2176), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__69.2192 (parameter.2193: f16[], parameter.2194: f16[]) -> f16[] {
%parameter.2193 = f16[] parameter(0)
%parameter.2194 = f16[] parameter(1)
ROOT %add.2196 = f16[] add(f16[] %parameter.2193, f16[] %parameter.2194), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__71.2216 (parameter.2217: f32[], parameter.2218: f32[]) -> f32[] {
%parameter.2217 = f32[] parameter(0)
%parameter.2218 = f32[] parameter(1)
ROOT %add.2220 = f32[] add(f32[] %parameter.2217, f32[] %parameter.2218), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
}
%primitive_computation_add__73.2238 (parameter.2239: f16[], parameter.2240: f16[]) -> f16[] {
%parameter.2239 = f16[] parameter(0)
%parameter.2240 = f16[] parameter(1)
ROOT %add.2242 = f16[] add(f16[] %parameter.2239, f16[] %parameter.2240), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_add__74.2271 (parameter.2272: f16[], parameter.2273: f16[]) -> f16[] {
%parameter.2272 = f16[] parameter(0)
%parameter.2273 = f16[] parameter(1)
ROOT %add.2275 = f16[] add(f16[] %parameter.2272, f16[] %parameter.2273), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
}
%primitive_computation_add__76.2319 (parameter.2320: f16[], parameter.2321: f16[]) -> f16[] {
%parameter.2320 = f16[] parameter(0)
%parameter.2321 = f16[] parameter(1)
ROOT %add.2323 = f16[] add(f16[] %parameter.2320, f16[] %parameter.2321), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
}
%primitive_computation_add__77.2332 (parameter.2333: f16[], parameter.2334: f16[]) -> f16[] {
%parameter.2333 = f16[] parameter(0)
%parameter.2334 = f16[] parameter(1)
ROOT %add.2336 = f16[] add(f16[] %parameter.2333, f16[] %parameter.2334), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
}
%primitive_computation_add__79.2354 (parameter.2355: f16[], parameter.2356: f16[]) -> f16[] {
%parameter.2355 = f16[] parameter(0)
%parameter.2356 = f16[] parameter(1)
ROOT %add.2358 = f16[] add(f16[] %parameter.2355, f16[] %parameter.2356), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__81.2372 (parameter.2373: f16[], parameter.2374: f16[]) -> f16[] {
%parameter.2373 = f16[] parameter(0)
%parameter.2374 = f16[] parameter(1)
ROOT %add.2376 = f16[] add(f16[] %parameter.2373, f16[] %parameter.2374), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
}
%primitive_computation_add__83.2397 (parameter.2398: f32[], parameter.2399: f32[]) -> f32[] {
%parameter.2398 = f32[] parameter(0)
%parameter.2399 = f32[] parameter(1)
ROOT %add.2401 = f32[] add(f32[] %parameter.2398, f32[] %parameter.2399), metadata={op_type="add" op_name="add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
}
ENTRY %train_step_shard_parallel.3684 (parameter.1: s32[], parameter.2: f16[51200], parameter.3: f16[2048], parameter.4: f16[2048], parameter.5: f16[1024,2048], parameter.6: f16[51200,2048], parameter.7: f16[2048], parameter.8: f16[2048], parameter.9: f16[2048], parameter.10: f16[2048,2048], parameter.11: f16[6144], parameter.12: f16[2048,6144], parameter.13: f16[8192], parameter.14: f16[2048,8192], parameter.15: f16[2048], parameter.16: f16[2048], parameter.17: f16[2048], parameter.18: f16[8192,2048], parameter.19: f16[2048], parameter.20: f16[2048], parameter.21: f16[2048], parameter.22: f16[2048,2048], parameter.23: f16[6144], parameter.24: f16[2048,6144], parameter.25: f16[8192], parameter.26: f16[2048,8192], parameter.27: f16[2048], parameter.28: f16[2048], parameter.29: f16[2048], parameter.30: f16[8192,2048], parameter.31: s32[], parameter.32: f32[51200], parameter.33: f32[2048], parameter.34: f32[2048], parameter.35: f32[1024,2048], parameter.36: f32[51200,2048], parameter.37: f32[2048], parameter.38: f32[2048], parameter.39: f32[2048], parameter.40: f32[2048,2048], parameter.41: f32[6144], parameter.42: f32[2048,6144], parameter.43: f32[8192], parameter.44: f32[2048,8192], parameter.45: f32[2048], parameter.46: f32[2048], parameter.47: f32[2048], parameter.48: f32[8192,2048], parameter.49: f32[2048], parameter.50: f32[2048], parameter.51: f32[2048], parameter.52: f32[2048,2048], parameter.53: f32[6144], parameter.54: f32[2048,6144], parameter.55: f32[8192], parameter.56: f32[2048,8192], parameter.57: f32[2048], parameter.58: f32[2048], parameter.59: f32[2048], parameter.60: f32[8192,2048], parameter.61: f32[51200], parameter.62: f32[2048], parameter.63: f32[2048], parameter.64: f32[1024,2048], parameter.65: f32[51200,2048], parameter.66: f32[2048], parameter.67: f32[2048], parameter.68: f32[2048], parameter.69: f32[2048,2048], parameter.70: f32[6144], parameter.71: f32[2048,6144], parameter.72: f32[8192], parameter.73: f32[2048,8192], parameter.74: f32[2048], parameter.75: f32[2048], parameter.76: f32[2048], parameter.77: f32[8192,2048], parameter.78: f32[2048], parameter.79: f32[2048], parameter.80: f32[2048], parameter.81: f32[2048,2048], parameter.82: f32[6144], parameter.83: f32[2048,6144], parameter.84: f32[8192], parameter.85: f32[2048,8192], parameter.86: f32[2048], parameter.87: f32[2048], parameter.88: f32[2048], parameter.89: f32[8192,2048], parameter.90: f32[51200], parameter.91: f32[2048], parameter.92: f32[2048], parameter.93: f32[1024,2048], parameter.94: f32[51200,2048], parameter.95: f32[2048], parameter.96: f32[2048], parameter.97: f32[2048], parameter.98: f32[2048,2048], parameter.99: f32[6144], parameter.100: f32[2048,6144], parameter.101: f32[8192], parameter.102: f32[2048,8192], parameter.103: f32[2048], parameter.104: f32[2048], parameter.105: f32[2048], parameter.106: f32[8192,2048], parameter.107: f32[2048], parameter.108: f32[2048], parameter.109: f32[2048], parameter.110: f32[2048,2048], parameter.111: f32[6144], parameter.112: f32[2048,6144], parameter.113: f32[8192], parameter.114: f32[2048,8192], parameter.115: f32[2048], parameter.116: f32[2048], parameter.117: f32[2048], parameter.118: f32[8192,2048], parameter.119: s32[8,1024], parameter.120: s32[8,1024], parameter.121: s32[8,1024], parameter.122: s32[8,1024], parameter.123: s32[8,1024], parameter.124: u32[2]) -> (s32[], f16[51200], f16[2048], f16[2048], f16[1024,2048], /*index=5*/f16[51200,2048], f16[2048], f16[2048], f16[2048], f16[2048,2048], /*index=10*/f16[6144], f16[2048,6144], f16[8192], f16[2048,8192], f16[2048], /*index=15*/f16[2048], f16[2048], f16[8192,2048], f16[2048], f16[2048], /*index=20*/f16[2048], f16[2048,2048], f16[6144], f16[2048,6144], f16[8192], /*index=25*/f16[2048,8192], f16[2048], f16[2048], f16[2048], f16[8192,2048], /*index=30*/s32[], f32[51200], f32[2048], f32[2048], f32[1024,2048], /*index=35*/f32[51200,2048], f32[2048], f32[2048], f32[2048], f32[2048,2048], /*index=40*/f32[6144], f32[2048,6144], f32[8192], f32[2048,8192], f32[2048], /*index=45*/f32[2048], f32[2048], f32[8192,2048], f32[2048], f32[2048], /*index=50*/f32[2048], f32[2048,2048], f32[6144], f32[2048,6144], f32[8192], /*index=55*/f32[2048,8192], f32[2048], f32[2048], f32[2048], f32[8192,2048], /*index=60*/f32[51200], f32[2048], f32[2048], f32[1024,2048], f32[51200,2048], /*index=65*/f32[2048], f32[2048], f32[2048], f32[2048,2048], f32[6144], /*index=70*/f32[2048,6144], f32[8192], f32[2048,8192], f32[2048], f32[2048], /*index=75*/f32[2048], f32[8192,2048], f32[2048], f32[2048], f32[2048], /*index=80*/f32[2048,2048], f32[6144], f32[2048,6144], f32[8192], f32[2048,8192], /*index=85*/f32[2048], f32[2048], f32[2048], f32[8192,2048], f32[51200], /*index=90*/f32[2048], f32[2048], f32[1024,2048], f32[51200,2048], f32[2048], /*index=95*/f32[2048], f32[2048], f32[2048,2048], f32[6144], f32[2048,6144], /*index=100*/f32[8192], f32[2048,8192], f32[2048], f32[2048], f32[2048], /*index=105*/f32[8192,2048], f32[2048], f32[2048], f32[2048], f32[2048,2048], /*index=110*/f32[6144], f32[2048,6144], f32[8192], f32[2048,8192], f32[2048], /*index=115*/f32[2048], f32[2048], f32[8192,2048]) {
%parameter.123 = s32[8,1024]{1,0} parameter(122), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%parameter.1 = s32[] parameter(0), sharding={replicated}
%constant.2951 = s32[] constant(1), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/numerics.py" source_line=68}
%add.3682 = s32[] add(s32[] %parameter.1, s32[] %constant.2951), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=339}
%parameter.90 = f32[51200]{0} parameter(89), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%parameter.121 = s32[8,1024]{1,0} parameter(120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.935 = s32[8,1024,51200]{2,1,0} broadcast(s32[8,1024]{1,0} %parameter.121), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="test_export_hlo.py" source_line=78}
%iota.2 = s32[8,1024,51200]{2,1,0} iota(), iota_dimension=2, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="test_export_hlo.py" source_line=78}
%compare.938 = pred[8,1024,51200]{2,1,0} compare(s32[8,1024,51200]{2,1,0} %broadcast.935, s32[8,1024,51200]{2,1,0} %iota.2), direction=EQ, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="test_export_hlo.py" source_line=78}
%constant.302 = s32[] constant(0), sharding={replicated}
%broadcast.916 = s32[8,1024]{1,0} broadcast(s32[] %constant.302), dimensions={}, sharding={replicated}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/gt" source_file="test_export_hlo.py" source_line=77}
%compare.917 = pred[8,1024]{1,0} compare(s32[8,1024]{1,0} %parameter.121, s32[8,1024]{1,0} %broadcast.916), direction=GT, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/gt" source_file="test_export_hlo.py" source_line=77}
%constant.299 = f32[] constant(1), sharding={replicated}
%broadcast = f32[8,1024]{1,0} broadcast(f32[] %constant.299), dimensions={}, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/jit(_where)/broadcast_in_dim[\n broadcast_dimensions=()\n shape=(8, 1024)\n]" source_file="test_export_hlo.py" source_line=77}
%constant.165 = f32[] constant(0), sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.1 = f32[8,1024]{1,0} broadcast(f32[] %constant.165), dimensions={}, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/jit(_where)/broadcast_in_dim[\n broadcast_dimensions=()\n shape=(8, 1024)\n]" source_file="test_export_hlo.py" source_line=77}
%select = f32[8,1024]{1,0} select(pred[8,1024]{1,0} %compare.917, f32[8,1024]{1,0} %broadcast, f32[8,1024]{1,0} %broadcast.1), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="select" op_name="parallelize(train_step_shard_parallel)/jit(_where)/select" source_file="test_export_hlo.py" source_line=77}
%reduce.1004 = f32[] reduce(f32[8,1024]{1,0} %select, f32[] %constant.165), dimensions={0,1}, to_apply=%primitive_computation_add__16.999, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(0, 1)]" source_file="test_export_hlo.py" source_line=80}
%divide.1007 = f32[] divide(f32[] %constant.299, f32[] %reduce.1004), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="test_export_hlo.py" source_line=80}
%broadcast.1008 = f32[8,1024]{1,0} broadcast(f32[] %divide.1007), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=()\n shape=(8, 1024)\n]" source_file="test_export_hlo.py" source_line=80}
%multiply.1009 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %select, f32[8,1024]{1,0} %broadcast.1008), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="test_export_hlo.py" source_line=80}
%negate.1010 = f32[8,1024]{1,0} negate(f32[8,1024]{1,0} %multiply.1009), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/neg" source_file="test_export_hlo.py" source_line=79}
%broadcast.1011 = f32[8,1024,51200]{2,1,0} broadcast(f32[8,1024]{1,0} %negate.1010), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=(0, 1)\n shape=(8, 1024, 51200)\n]" source_file="test_export_hlo.py" source_line=79}
%broadcast.4 = f32[8,1024,51200]{2,1,0} broadcast(f32[] %constant.165), dimensions={}, sharding={replicated}
%select.2 = f32[8,1024,51200]{2,1,0} select(pred[8,1024,51200]{2,1,0} %compare.938, f32[8,1024,51200]{2,1,0} %broadcast.1011, f32[8,1024,51200]{2,1,0} %broadcast.4), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="test_export_hlo.py" source_line=79}
%convert.1013 = f16[8,1024,51200]{2,1,0} convert(f32[8,1024,51200]{2,1,0} %select.2), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="test_export_hlo.py" source_line=79}
%negate.1014 = f16[8,1024,51200]{2,1,0} negate(f16[8,1024,51200]{2,1,0} %convert.1013), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/neg" source_file="test_export_hlo.py" source_line=79}
%constant.1015 = f16[] constant(0), sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="test_export_hlo.py" source_line=79}
%reduce.1021 = f16[8,1024]{1,0} reduce(f16[8,1024,51200]{2,1,0} %negate.1014, f16[] %constant.1015), dimensions={2}, to_apply=%primitive_computation_add__17.1016, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="test_export_hlo.py" source_line=79}
%parameter.120 = s32[8,1024]{1,0} parameter(119), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.137 = s32[8,1024,51200]{2,1,0} broadcast(s32[8,1024]{1,0} %parameter.120), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%compare.140 = pred[8,1024,51200]{2,1,0} compare(s32[8,1024,51200]{2,1,0} %broadcast.137, s32[8,1024,51200]{2,1,0} %iota.2), direction=EQ, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%convert.141 = f16[8,1024,51200]{2,1,0} convert(pred[8,1024,51200]{2,1,0} %compare.140), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%reshape = f16[8192,51200]{1,0} reshape(f16[8,1024,51200]{2,1,0} %convert.141), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%parameter.6 = f16[51200,2048]{1,0} parameter(5), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.16 = f16[8192,2048]{1,0} dot(f16[8192,51200]{1,0} %reshape, f16[51200,2048]{1,0} %parameter.6), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%parameter.122 = s32[8,1024]{1,0} parameter(121), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.154 = s32[8,1024,1024]{2,1,0} broadcast(s32[8,1024]{1,0} %parameter.122), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%iota.3 = s32[8,1024,1024]{2,1,0} iota(), iota_dimension=2, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%compare.157 = pred[8,1024,1024]{2,1,0} compare(s32[8,1024,1024]{2,1,0} %broadcast.154, s32[8,1024,1024]{2,1,0} %iota.3), direction=EQ, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%convert.158 = f16[8,1024,1024]{2,1,0} convert(pred[8,1024,1024]{2,1,0} %compare.157), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%reshape.3 = f16[8192,1024]{1,0} reshape(f16[8,1024,1024]{2,1,0} %convert.158), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%parameter.5 = f16[1024,2048]{1,0} parameter(4), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.17 = f16[8192,2048]{1,0} dot(f16[8192,1024]{1,0} %reshape.3, f16[1024,2048]{1,0} %parameter.5), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%add = f16[8192,2048]{1,0} add(f16[8192,2048]{1,0} %dot.16, f16[8192,2048]{1,0} %dot.17), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=130}
%convert.20 = f32[8192,2048]{1,0} convert(f16[8192,2048]{1,0} %add), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reshape.428 = f32[8,1024,2048]{2,1,0} reshape(f32[8192,2048]{1,0} %convert.20), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.171 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %reshape.428, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add.166, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%constant.35 = f32[] constant(0.00048828125), sharding={replicated}
%broadcast.120 = f32[8,1024]{1,0} broadcast(f32[] %constant.35), dimensions={}, sharding={replicated}
%multiply.77 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.171, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.212 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %multiply.77), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.213 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %reshape.428, f32[8,1024,2048]{2,1,0} %broadcast.212), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.176 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %reshape.428, f32[8,1024,2048]{2,1,0} %reshape.428), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.186 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.176, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__1.181, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.78 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.186, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.96 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %multiply.77, f32[8,1024]{1,0} %multiply.77), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %multiply.78, f32[8,1024]{1,0} %multiply.96), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%constant.196 = f32[] constant(1e-12), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%broadcast.180 = f32[8,1024]{1,0} broadcast(f32[] %constant.196), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%add.1 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract, f32[8,1024]{1,0} %broadcast.180), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%reshape.476 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %add.1), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.199 = f32[8,1024,1]{2,1,0} rsqrt(f32[8,1024,1]{2,1,0} %reshape.476), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%convert.204 = f16[8,1024,1]{2,1,0} convert(f32[8,1024,1]{2,1,0} %rsqrt.199), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.206 = f16[8,1024]{1,0} reshape(f16[8,1024,1]{2,1,0} %convert.204), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%broadcast.207 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %reshape.206), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%parameter.4 = f16[2048]{0} parameter(3), sharding={replicated}
%broadcast.209 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.4), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.210 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.207, f16[8,1024,2048]{2,1,0} %broadcast.209), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.214 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.210), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.215 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.213, f32[8,1024,2048]{2,1,0} %convert.214), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.216 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.215), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%parameter.3 = f16[2048]{0} parameter(2), sharding={replicated}
%broadcast.219 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.3), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%add.220 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.216, f16[8,1024,2048]{2,1,0} %broadcast.219), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%reshape.6 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %add.220), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%parameter.12 = f16[2048,6144]{1,0} parameter(11), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.18 = f16[8192,6144]{1,0} dot(f16[8192,2048]{1,0} %reshape.6, f16[2048,6144]{1,0} %parameter.12), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.8 = f16[8,1024,6144]{2,1,0} reshape(f16[8192,6144]{1,0} %dot.18), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.11 = f16[6144]{0} parameter(10), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.310 = f16[8,1024,6144]{2,1,0} broadcast(f16[6144]{0} %parameter.11), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.311 = f16[8,1024,6144]{2,1,0} add(f16[8,1024,6144]{2,1,0} %reshape.8, f16[8,1024,6144]{2,1,0} %broadcast.310), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%reshape.312 = f16[8,1024,2048,3]{3,2,1,0} reshape(f16[8,1024,6144]{2,1,0} %add.311), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 3)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=162}
%slice.367 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.312), slice={[0:8], [0:1024], [0:2048], [1:2]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/slice[\n limit_indices=(8, 1024, 2048, 2)\n start_indices=(0, 0, 0, 1)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.368 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.367), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=172}
%transpose.92 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.368), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.313 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.312), slice={[0:8], [0:1024], [0:2048], [0:1]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/slice[\n limit_indices=(8, 1024, 2048, 1)\n start_indices=(0, 0, 0, 0)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%constant.63 = f16[] constant(0.125), sharding={replicated}
%broadcast.140 = f16[8,1024,2048,1]{3,2,1,0} broadcast(f16[] %constant.63), dimensions={}, sharding={replicated}
%multiply.92 = f16[8,1024,2048,1]{3,2,1,0} multiply(f16[8,1024,2048,1]{3,2,1,0} %slice.313, f16[8,1024,2048,1]{3,2,1,0} %broadcast.140), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%reshape.338 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %multiply.92), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%transpose.90 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.338), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.317 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.312), slice={[0:8], [0:1024], [0:2048], [2:3]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/slice[\n limit_indices=(8, 1024, 2048, 3)\n start_indices=(0, 0, 0, 2)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.318 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.317), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=175}
%transpose.91 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.318), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.55 = f16[8,32,1024,1024]{3,2,1,0} dot(f16[8,32,1024,64]{3,1,2,0} %transpose.90, f16[8,32,64,1024]{2,1,3,0} %transpose.91), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%parameter.119 = s32[8,1024]{1,0} parameter(118), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%compare = pred[8,1024]{1,0} compare(s32[8,1024]{1,0} %parameter.119, s32[8,1024]{1,0} %broadcast.916), direction=GT, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/gt" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=182}
%broadcast.160 = f16[8,1024]{1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=()\n shape=(8, 1, 1, 1024)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=183}
%constant.342 = f16[] constant(-inf), sharding={replicated}, metadata={op_type="reduce_max" op_name="parallelize(train_step_shard_parallel)/reduce_max[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.161 = f16[8,1024]{1,0} broadcast(f16[] %constant.342), dimensions={}, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=()\n shape=(8, 1, 1, 1024)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=184}
%select.3 = f16[8,1024]{1,0} select(pred[8,1024]{1,0} %compare, f16[8,1024]{1,0} %broadcast.160, f16[8,1024]{1,0} %broadcast.161), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="select" op_name="parallelize(train_step_shard_parallel)/select" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=184}
%broadcast.340 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,1024]{1,0} %select.3), dimensions={0,3}, sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%add.341 = f16[8,32,1024,1024]{3,2,1,0} add(f16[8,32,1024,1024]{3,2,1,0} %dot.55, f16[8,32,1024,1024]{3,2,1,0} %broadcast.340), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%reduce.348 = f16[8,32,1024]{2,1,0} reduce(f16[8,32,1024,1024]{3,2,1,0} %add.341, f16[] %constant.342), dimensions={3}, to_apply=%primitive_computation_max.343, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_max" op_name="parallelize(train_step_shard_parallel)/reduce_max[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.351 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %reduce.348), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%subtract.352 = f16[8,32,1024,1024]{3,2,1,0} subtract(f16[8,32,1024,1024]{3,2,1,0} %add.341, f16[8,32,1024,1024]{3,2,1,0} %broadcast.351), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%exponential.353 = f16[8,32,1024,1024]{3,2,1,0} exponential(f16[8,32,1024,1024]{3,2,1,0} %subtract.352), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="exp" op_name="parallelize(train_step_shard_parallel)/exp" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.354 = f32[8,32,1024,1024]{3,2,1,0} convert(f16[8,32,1024,1024]{3,2,1,0} %exponential.353), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%reduce.361 = f32[8,32,1024]{2,1,0} reduce(f32[8,32,1024,1024]{3,2,1,0} %convert.354, f32[] %constant.165), dimensions={3}, to_apply=%primitive_computation_add__2.356, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.8 = f16[8,32,1024]{2,1,0} convert(f32[8,32,1024]{2,1,0} %reduce.361), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.365 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %convert.8), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%divide.366 = f16[8,32,1024,1024]{3,2,1,0} divide(f16[8,32,1024,1024]{3,2,1,0} %exponential.353, f16[8,32,1024,1024]{3,2,1,0} %broadcast.365), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%dot.56 = f16[8,32,64,1024]{3,2,1,0} dot(f16[8,32,64,1024]{2,1,3,0} %transpose.92, f16[8,32,1024,1024]{3,2,1,0} %divide.366), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((1,), (3,)), ((0, 2), (0, 1)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%transpose = f16[8,1024,32,64]{1,3,2,0} transpose(f16[8,32,64,1024]{3,2,1,0} %dot.56), dimensions={0,3,1,2}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/jit(jvp(_einsum))/transpose[\n permutation=(0, 3, 1, 2)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%reshape.9 = f16[8192,2048]{1,0} reshape(f16[8,1024,32,64]{1,3,2,0} %transpose), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%parameter.10 = f16[2048,2048]{1,0} parameter(9), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.19 = f16[8192,2048]{1,0} dot(f16[8192,2048]{1,0} %reshape.9, f16[2048,2048]{1,0} %parameter.10), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.11 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.19), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.9 = f16[2048]{0} parameter(8), sharding={replicated}
%broadcast.390 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.9), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.391 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.11, f16[8,1024,2048]{2,1,0} %broadcast.390), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.392 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.391, f16[8,1024,2048]{2,1,0} %add.220), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=233}
%convert.393 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.392), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.400 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.393, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__3.395, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%multiply.79 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.400, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.405 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %multiply.79), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.406 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.393, f32[8,1024,2048]{2,1,0} %broadcast.405), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.407 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.393, f32[8,1024,2048]{2,1,0} %convert.393), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.414 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.407, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__4.409, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.80 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.414, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.98 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %multiply.79, f32[8,1024]{1,0} %multiply.79), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.1 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %multiply.80, f32[8,1024]{1,0} %multiply.98), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%add.2 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.1, f32[8,1024]{1,0} %broadcast.180), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt = f32[8,1024]{1,0} rsqrt(f32[8,1024]{1,0} %add.2), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%convert.28 = f16[8,1024]{1,0} convert(f32[8,1024]{1,0} %rsqrt), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%broadcast.426 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %convert.28), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%parameter.8 = f16[2048]{0} parameter(7), sharding={replicated}
%broadcast.428 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.8), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.429 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.426, f16[8,1024,2048]{2,1,0} %broadcast.428), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.430 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.429), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.431 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.406, f32[8,1024,2048]{2,1,0} %convert.430), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.432 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.431), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%parameter.7 = f16[2048]{0} parameter(6), sharding={replicated}
%broadcast.435 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.7), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%add.436 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.432, f16[8,1024,2048]{2,1,0} %broadcast.435), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%reshape.12 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %add.436), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%parameter.14 = f16[2048,8192]{1,0} parameter(13), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.20 = f16[8192,8192]{1,0} dot(f16[8192,2048]{1,0} %reshape.12, f16[2048,8192]{1,0} %parameter.14), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.14 = f16[8,1024,8192]{2,1,0} reshape(f16[8192,8192]{1,0} %dot.20), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.13 = f16[8192]{0} parameter(12), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.440 = f16[8,1024,8192]{2,1,0} broadcast(f16[8192]{0} %parameter.13), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.441 = f16[8,1024,8192]{2,1,0} add(f16[8,1024,8192]{2,1,0} %reshape.14, f16[8,1024,8192]{2,1,0} %broadcast.440), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%convert.507 = f32[8,1024,8192]{2,1,0} convert(f16[8,1024,8192]{2,1,0} %add.441), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.446 = f32[] constant(-4), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.447 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.446), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.39 = f32[] constant(0.707106769), sharding={replicated}
%broadcast.25 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.39), dimensions={}, sharding={replicated}
%multiply.4 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.507, f32[8,1024,8192]{2,1,0} %broadcast.25), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.445 = f32[] constant(4), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.448 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.445), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%clamp.449 = f32[8,1024,8192]{2,1,0} clamp(f32[8,1024,8192]{2,1,0} %broadcast.447, f32[8,1024,8192]{2,1,0} %multiply.4, f32[8,1024,8192]{2,1,0} %broadcast.448), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.475 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.165), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.450 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.449, f32[8,1024,8192]{2,1,0} %clamp.449), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.476 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %broadcast.475, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.474 = f32[] constant(-2.72614237e-10), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.477 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.474), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.478 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.476, f32[8,1024,8192]{2,1,0} %broadcast.477), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.480 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.478, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.479 = f32[] constant(2.77068146e-08), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.481 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.479), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.482 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.480, f32[8,1024,8192]{2,1,0} %broadcast.481), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.484 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.482, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.483 = f32[] constant(-2.10102394e-06), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.485 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.483), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.486 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.484, f32[8,1024,8192]{2,1,0} %broadcast.485), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.488 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.486, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.487 = f32[] constant(-5.69250624e-05), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.489 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.487), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.490 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.488, f32[8,1024,8192]{2,1,0} %broadcast.489), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.492 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.490, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.491 = f32[] constant(-0.000734990637), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.493 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.491), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.494 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.492, f32[8,1024,8192]{2,1,0} %broadcast.493), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.496 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.494, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.495 = f32[] constant(-0.0029546), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.497 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.495), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.498 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.496, f32[8,1024,8192]{2,1,0} %broadcast.497), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.500 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.498, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.499 = f32[] constant(-0.0160960332), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.501 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.499), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.502 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.500, f32[8,1024,8192]{2,1,0} %broadcast.501), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.503 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.449, f32[8,1024,8192]{2,1,0} %add.502), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.452 = f32[] constant(-1.45660715e-05), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.455 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.452), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.456 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.476, f32[8,1024,8192]{2,1,0} %broadcast.455), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.458 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.456, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.457 = f32[] constant(-0.000213374049), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.459 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.457), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.460 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.458, f32[8,1024,8192]{2,1,0} %broadcast.459), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.462 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.460, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.461 = f32[] constant(-0.00168282702), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.463 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.461), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.464 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.462, f32[8,1024,8192]{2,1,0} %broadcast.463), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.466 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.464, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.465 = f32[] constant(-0.00737332925), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.467 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.465), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.468 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.466, f32[8,1024,8192]{2,1,0} %broadcast.467), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.470 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.468, f32[8,1024,8192]{2,1,0} %multiply.450), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.469 = f32[] constant(-0.0142647391), sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.471 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.469), dimensions={}, sharding={replicated}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.472 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.470, f32[8,1024,8192]{2,1,0} %broadcast.471), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.504 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.503, f32[8,1024,8192]{2,1,0} %add.472), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.505 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.299), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.506 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %divide.504, f32[8,1024,8192]{2,1,0} %broadcast.505), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.508 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.507, f32[8,1024,8192]{2,1,0} %add.506), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.40 = f32[] constant(0.5), sharding={replicated}
%broadcast.26 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.40), dimensions={}, sharding={replicated}
%multiply.5 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %multiply.508, f32[8,1024,8192]{2,1,0} %broadcast.26), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.514 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %multiply.5), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%reshape.15 = f16[8192,8192]{1,0} reshape(f16[8,1024,8192]{2,1,0} %convert.514), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%parameter.18 = f16[8192,2048]{1,0} parameter(17), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.21 = f16[8192,2048]{1,0} dot(f16[8192,8192]{1,0} %reshape.15, f16[8192,2048]{1,0} %parameter.18), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.17 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.21), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.17 = f16[2048]{0} parameter(16), sharding={replicated}
%broadcast.518 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.17), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.519 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.17, f16[8,1024,2048]{2,1,0} %broadcast.518), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.520 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.519, f16[8,1024,2048]{2,1,0} %add.436), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=310}
%convert.521 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.520), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.528 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.521, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__5.523, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%multiply.81 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.528, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.533 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %multiply.81), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.534 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.521, f32[8,1024,2048]{2,1,0} %broadcast.533), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.535 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.521, f32[8,1024,2048]{2,1,0} %convert.521), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.542 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.535, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__6.537, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.82 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.542, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.99 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %multiply.81, f32[8,1024]{1,0} %multiply.81), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.2 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %multiply.82, f32[8,1024]{1,0} %multiply.99), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%add.3 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.2, f32[8,1024]{1,0} %broadcast.180), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.1 = f32[8,1024]{1,0} rsqrt(f32[8,1024]{1,0} %add.3), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%convert.29 = f16[8,1024]{1,0} convert(f32[8,1024]{1,0} %rsqrt.1), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%broadcast.554 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %convert.29), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%parameter.16 = f16[2048]{0} parameter(15), sharding={replicated}
%broadcast.556 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.16), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.557 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.554, f16[8,1024,2048]{2,1,0} %broadcast.556), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.558 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.557), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.559 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.534, f32[8,1024,2048]{2,1,0} %convert.558), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.560 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.559), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%parameter.15 = f16[2048]{0} parameter(14), sharding={replicated}
%broadcast.563 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.15), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%add.564 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.560, f16[8,1024,2048]{2,1,0} %broadcast.563), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%reshape.18 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %add.564), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%parameter.24 = f16[2048,6144]{1,0} parameter(23), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.22 = f16[8192,6144]{1,0} dot(f16[8192,2048]{1,0} %reshape.18, f16[2048,6144]{1,0} %parameter.24), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.20 = f16[8,1024,6144]{2,1,0} reshape(f16[8192,6144]{1,0} %dot.22), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.23 = f16[6144]{0} parameter(22), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.654 = f16[8,1024,6144]{2,1,0} broadcast(f16[6144]{0} %parameter.23), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.655 = f16[8,1024,6144]{2,1,0} add(f16[8,1024,6144]{2,1,0} %reshape.20, f16[8,1024,6144]{2,1,0} %broadcast.654), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%reshape.656 = f16[8,1024,2048,3]{3,2,1,0} reshape(f16[8,1024,6144]{2,1,0} %add.655), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 3)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=162}
%slice.711 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.656), slice={[0:8], [0:1024], [0:2048], [1:2]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/slice[\n limit_indices=(8, 1024, 2048, 2)\n start_indices=(0, 0, 0, 1)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.712 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.711), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=172}
%transpose.96 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.712), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.657 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.656), slice={[0:8], [0:1024], [0:2048], [0:1]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/slice[\n limit_indices=(8, 1024, 2048, 1)\n start_indices=(0, 0, 0, 0)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%multiply.93 = f16[8,1024,2048,1]{3,2,1,0} multiply(f16[8,1024,2048,1]{3,2,1,0} %slice.657, f16[8,1024,2048,1]{3,2,1,0} %broadcast.140), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%reshape.342 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %multiply.93), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%transpose.94 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.342), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.661 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.656), slice={[0:8], [0:1024], [0:2048], [2:3]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/slice[\n limit_indices=(8, 1024, 2048, 3)\n start_indices=(0, 0, 0, 2)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.662 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.661), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=175}
%transpose.95 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.662), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.57 = f16[8,32,1024,1024]{3,2,1,0} dot(f16[8,32,1024,64]{3,1,2,0} %transpose.94, f16[8,32,64,1024]{2,1,3,0} %transpose.95), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%add.685 = f16[8,32,1024,1024]{3,2,1,0} add(f16[8,32,1024,1024]{3,2,1,0} %dot.57, f16[8,32,1024,1024]{3,2,1,0} %broadcast.340), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%reduce.692 = f16[8,32,1024]{2,1,0} reduce(f16[8,32,1024,1024]{3,2,1,0} %add.685, f16[] %constant.342), dimensions={3}, to_apply=%primitive_computation_max__1.687, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_max" op_name="parallelize(train_step_shard_parallel)/reduce_max[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.695 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %reduce.692), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%subtract.696 = f16[8,32,1024,1024]{3,2,1,0} subtract(f16[8,32,1024,1024]{3,2,1,0} %add.685, f16[8,32,1024,1024]{3,2,1,0} %broadcast.695), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%exponential.697 = f16[8,32,1024,1024]{3,2,1,0} exponential(f16[8,32,1024,1024]{3,2,1,0} %subtract.696), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="exp" op_name="parallelize(train_step_shard_parallel)/exp" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.698 = f32[8,32,1024,1024]{3,2,1,0} convert(f16[8,32,1024,1024]{3,2,1,0} %exponential.697), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%reduce.705 = f32[8,32,1024]{2,1,0} reduce(f32[8,32,1024,1024]{3,2,1,0} %convert.698, f32[] %constant.165), dimensions={3}, to_apply=%primitive_computation_add__7.700, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.9 = f16[8,32,1024]{2,1,0} convert(f32[8,32,1024]{2,1,0} %reduce.705), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.709 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %convert.9), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%divide.710 = f16[8,32,1024,1024]{3,2,1,0} divide(f16[8,32,1024,1024]{3,2,1,0} %exponential.697, f16[8,32,1024,1024]{3,2,1,0} %broadcast.709), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%dot.58 = f16[8,32,64,1024]{3,2,1,0} dot(f16[8,32,64,1024]{2,1,3,0} %transpose.96, f16[8,32,1024,1024]{3,2,1,0} %divide.710), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((1,), (3,)), ((0, 2), (0, 1)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%transpose.1 = f16[8,1024,32,64]{1,3,2,0} transpose(f16[8,32,64,1024]{3,2,1,0} %dot.58), dimensions={0,3,1,2}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/jit(jvp(_einsum))/transpose[\n permutation=(0, 3, 1, 2)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%reshape.21 = f16[8192,2048]{1,0} reshape(f16[8,1024,32,64]{1,3,2,0} %transpose.1), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%parameter.22 = f16[2048,2048]{1,0} parameter(21), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.23 = f16[8192,2048]{1,0} dot(f16[8192,2048]{1,0} %reshape.21, f16[2048,2048]{1,0} %parameter.22), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.23 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.23), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.21 = f16[2048]{0} parameter(20), sharding={replicated}
%broadcast.734 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.21), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.735 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.23, f16[8,1024,2048]{2,1,0} %broadcast.734), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.736 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.735, f16[8,1024,2048]{2,1,0} %add.564), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=233}
%convert.737 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.736), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.744 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.737, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__8.739, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%multiply.83 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.744, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.749 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %multiply.83), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.750 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.737, f32[8,1024,2048]{2,1,0} %broadcast.749), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.751 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.737, f32[8,1024,2048]{2,1,0} %convert.737), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.758 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.751, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__9.753, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.84 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.758, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.100 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %multiply.83, f32[8,1024]{1,0} %multiply.83), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.3 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %multiply.84, f32[8,1024]{1,0} %multiply.100), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%add.4 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.3, f32[8,1024]{1,0} %broadcast.180), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.2 = f32[8,1024]{1,0} rsqrt(f32[8,1024]{1,0} %add.4), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%convert.30 = f16[8,1024]{1,0} convert(f32[8,1024]{1,0} %rsqrt.2), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%broadcast.770 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %convert.30), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%parameter.20 = f16[2048]{0} parameter(19), sharding={replicated}
%broadcast.772 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.20), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.773 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.770, f16[8,1024,2048]{2,1,0} %broadcast.772), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.774 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.773), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.775 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.750, f32[8,1024,2048]{2,1,0} %convert.774), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.776 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.775), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%parameter.19 = f16[2048]{0} parameter(18), sharding={replicated}
%broadcast.779 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.19), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%add.780 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.776, f16[8,1024,2048]{2,1,0} %broadcast.779), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%reshape.24 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %add.780), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%parameter.26 = f16[2048,8192]{1,0} parameter(25), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.24 = f16[8192,8192]{1,0} dot(f16[8192,2048]{1,0} %reshape.24, f16[2048,8192]{1,0} %parameter.26), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.26 = f16[8,1024,8192]{2,1,0} reshape(f16[8192,8192]{1,0} %dot.24), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.25 = f16[8192]{0} parameter(24), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.784 = f16[8,1024,8192]{2,1,0} broadcast(f16[8192]{0} %parameter.25), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.785 = f16[8,1024,8192]{2,1,0} add(f16[8,1024,8192]{2,1,0} %reshape.26, f16[8,1024,8192]{2,1,0} %broadcast.784), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%convert.851 = f32[8,1024,8192]{2,1,0} convert(f16[8,1024,8192]{2,1,0} %add.785), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.10 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.851, f32[8,1024,8192]{2,1,0} %broadcast.25), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%clamp.793 = f32[8,1024,8192]{2,1,0} clamp(f32[8,1024,8192]{2,1,0} %broadcast.447, f32[8,1024,8192]{2,1,0} %multiply.10, f32[8,1024,8192]{2,1,0} %broadcast.448), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.794 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.793, f32[8,1024,8192]{2,1,0} %clamp.793), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.820 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %broadcast.475, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.822 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.820, f32[8,1024,8192]{2,1,0} %broadcast.477), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.824 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.822, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.826 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.824, f32[8,1024,8192]{2,1,0} %broadcast.481), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.828 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.826, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.830 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.828, f32[8,1024,8192]{2,1,0} %broadcast.485), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.832 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.830, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.834 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.832, f32[8,1024,8192]{2,1,0} %broadcast.489), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.836 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.834, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.838 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.836, f32[8,1024,8192]{2,1,0} %broadcast.493), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.840 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.838, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.842 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.840, f32[8,1024,8192]{2,1,0} %broadcast.497), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.844 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.842, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.846 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.844, f32[8,1024,8192]{2,1,0} %broadcast.501), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.847 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.793, f32[8,1024,8192]{2,1,0} %add.846), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.800 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.820, f32[8,1024,8192]{2,1,0} %broadcast.455), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.802 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.800, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.804 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.802, f32[8,1024,8192]{2,1,0} %broadcast.459), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.806 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.804, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.808 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.806, f32[8,1024,8192]{2,1,0} %broadcast.463), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.810 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.808, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.812 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.810, f32[8,1024,8192]{2,1,0} %broadcast.467), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.814 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.812, f32[8,1024,8192]{2,1,0} %multiply.794), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.816 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.814, f32[8,1024,8192]{2,1,0} %broadcast.471), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.848 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.847, f32[8,1024,8192]{2,1,0} %add.816), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.850 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %divide.848, f32[8,1024,8192]{2,1,0} %broadcast.505), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.852 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.851, f32[8,1024,8192]{2,1,0} %add.850), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.11 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %multiply.852, f32[8,1024,8192]{2,1,0} %broadcast.26), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.858 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %multiply.11), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%reshape.27 = f16[8192,8192]{1,0} reshape(f16[8,1024,8192]{2,1,0} %convert.858), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%parameter.30 = f16[8192,2048]{1,0} parameter(29), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.25 = f16[8192,2048]{1,0} dot(f16[8192,8192]{1,0} %reshape.27, f16[8192,2048]{1,0} %parameter.30), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.29 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.25), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%parameter.29 = f16[2048]{0} parameter(28), sharding={replicated}
%broadcast.862 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.29), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.863 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.29, f16[8,1024,2048]{2,1,0} %broadcast.862), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.864 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.863, f16[8,1024,2048]{2,1,0} %add.780), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=310}
%convert.865 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.864), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.872 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.865, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__10.867, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%multiply.85 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.872, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.877 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %multiply.85), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.878 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.865, f32[8,1024,2048]{2,1,0} %broadcast.877), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.879 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.865, f32[8,1024,2048]{2,1,0} %convert.865), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.886 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.879, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__11.881, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.86 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %reduce.886, f32[8,1024]{1,0} %broadcast.120), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.101 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %multiply.85, f32[8,1024]{1,0} %multiply.85), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.4 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %multiply.86, f32[8,1024]{1,0} %multiply.101), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%add.5 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.4, f32[8,1024]{1,0} %broadcast.180), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.3 = f32[8,1024]{1,0} rsqrt(f32[8,1024]{1,0} %add.5), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%convert.31 = f16[8,1024]{1,0} convert(f32[8,1024]{1,0} %rsqrt.3), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%broadcast.898 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %convert.31), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%parameter.28 = f16[2048]{0} parameter(27), sharding={replicated}
%broadcast.900 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.28), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.901 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.898, f16[8,1024,2048]{2,1,0} %broadcast.900), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.902 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.901), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.903 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.878, f32[8,1024,2048]{2,1,0} %convert.902), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.904 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.903), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%parameter.27 = f16[2048]{0} parameter(26), sharding={replicated}
%broadcast.907 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %parameter.27), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%add.908 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.904, f16[8,1024,2048]{2,1,0} %broadcast.907), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%reshape.30 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %add.908), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.26 = f16[8192,51200]{1,0} dot(f16[8192,2048]{1,0} %reshape.30, f16[51200,2048]{1,0} %parameter.6), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.32 = f16[8,1024,51200]{2,1,0} reshape(f16[8192,51200]{1,0} %dot.26), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/gpt_model.py" source_line=70}
%parameter.2 = f16[51200]{0} parameter(1), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.913 = f16[8,1024,51200]{2,1,0} broadcast(f16[51200]{0} %parameter.2), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/gpt_model.py" source_line=75}
%add.914 = f16[8,1024,51200]{2,1,0} add(f16[8,1024,51200]{2,1,0} %reshape.32, f16[8,1024,51200]{2,1,0} %broadcast.913), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/gpt_model.py" source_line=75}
%reduce.946 = f16[8,1024]{1,0} reduce(f16[8,1024,51200]{2,1,0} %add.914, f16[] %constant.342), dimensions={2}, to_apply=%primitive_computation_max__2.941, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_max" op_name="parallelize(train_step_shard_parallel)/reduce_max[axes=(2,)]" source_file="test_export_hlo.py" source_line=79}
%broadcast.961 = f16[8,1024,51200]{2,1,0} broadcast(f16[8,1024]{1,0} %reduce.946), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="test_export_hlo.py" source_line=79}
%subtract.962 = f16[8,1024,51200]{2,1,0} subtract(f16[8,1024,51200]{2,1,0} %add.914, f16[8,1024,51200]{2,1,0} %broadcast.961), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="test_export_hlo.py" source_line=79}
%exponential.963 = f16[8,1024,51200]{2,1,0} exponential(f16[8,1024,51200]{2,1,0} %subtract.962), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="exp" op_name="parallelize(train_step_shard_parallel)/exp" source_file="test_export_hlo.py" source_line=79}
%convert.964 = f32[8,1024,51200]{2,1,0} convert(f16[8,1024,51200]{2,1,0} %exponential.963), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="test_export_hlo.py" source_line=79}
%reduce.971 = f32[8,1024]{1,0} reduce(f32[8,1024,51200]{2,1,0} %convert.964, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__13.966, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="test_export_hlo.py" source_line=79}
%convert.10 = f16[8,1024]{1,0} convert(f32[8,1024]{1,0} %reduce.971), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="test_export_hlo.py" source_line=79}
%divide.43 = f16[8,1024]{1,0} divide(f16[8,1024]{1,0} %reduce.1021, f16[8,1024]{1,0} %convert.10), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="test_export_hlo.py" source_line=79}
%broadcast.53 = f16[8,1024,51200]{2,1,0} broadcast(f16[8,1024]{1,0} %divide.43), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=(0, 1)\n shape=(8, 1024, 51200)\n]" source_file="test_export_hlo.py" source_line=79}
%multiply.1034 = f16[8,1024,51200]{2,1,0} multiply(f16[8,1024,51200]{2,1,0} %broadcast.53, f16[8,1024,51200]{2,1,0} %exponential.963), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="test_export_hlo.py" source_line=79}
%add.1035 = f16[8,1024,51200]{2,1,0} add(f16[8,1024,51200]{2,1,0} %convert.1013, f16[8,1024,51200]{2,1,0} %multiply.1034), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/add_any" source_file="test_export_hlo.py" source_line=83}
%reduce.1042 = f16[51200]{0} reduce(f16[8,1024,51200]{2,1,0} %add.1035, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__19.1037, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(0, 1)]" source_file="/home/ubuntu/efs/alpa/alpa/model/gpt_model.py" source_line=75}
%constant.2456 = f16[] constant(0.099976), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.109 = f16[51200]{0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.55 = f16[51200]{0} multiply(f16[51200]{0} %reduce.1042, f16[51200]{0} %broadcast.109), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2462 = f32[51200]{0} convert(f16[51200]{0} %multiply.55), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.32 = f32[51200]{0} parameter(31), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%constant.2459 = f32[] constant(0.9), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2460 = f32[51200]{0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2461 = f32[51200]{0} multiply(f32[51200]{0} %parameter.32, f32[51200]{0} %broadcast.2460), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2463 = f32[51200]{0} add(f32[51200]{0} %convert.2462, f32[51200]{0} %multiply.2461), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.31 = s32[] parameter(30), sharding={replicated}
%constant.2949 = s32[] constant(2147483647), sharding={replicated}, metadata={op_type="lt" op_name="parallelize(train_step_shard_parallel)/lt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/numerics.py" source_line=68}
%compare.2950 = pred[] compare(s32[] %parameter.31, s32[] %constant.2949), direction=LT, sharding={replicated}, metadata={op_type="lt" op_name="parallelize(train_step_shard_parallel)/lt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/numerics.py" source_line=68}
%add.2952 = s32[] add(s32[] %parameter.31, s32[] %constant.2951), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/numerics.py" source_line=68}
%select.1 = s32[] select(pred[] %compare.2950, s32[] %add.2952, s32[] %constant.2949), sharding={replicated}, metadata={op_type="select" op_name="parallelize(train_step_shard_parallel)/jit(_where)/select" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/numerics.py" source_line=68}
%convert.2964 = f32[] convert(s32[] %select.1), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=True\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=86}
%power.2966 = f32[] power(f32[] %constant.2459, f32[] %convert.2964), sharding={replicated}, metadata={op_type="pow" op_name="parallelize(train_step_shard_parallel)/pow" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=86}
%subtract.2968 = f32[] subtract(f32[] %constant.299, f32[] %power.2966), sharding={replicated}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=86}
%broadcast.2970 = f32[51200]{0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.66 = f16[51200]{0} multiply(f16[51200]{0} %reduce.1042, f16[51200]{0} %reduce.1042), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%constant.2689 = f16[] constant(0.0010004), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2690 = f16[51200]{0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2691 = f16[51200]{0} multiply(f16[51200]{0} %multiply.66, f16[51200]{0} %broadcast.2690), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2695 = f32[51200]{0} convert(f16[51200]{0} %multiply.2691), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.61 = f32[51200]{0} parameter(60), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%constant.2692 = f32[] constant(0.999), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2693 = f32[51200]{0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2694 = f32[51200]{0} multiply(f32[51200]{0} %parameter.61, f32[51200]{0} %broadcast.2693), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2696 = f32[51200]{0} add(f32[51200]{0} %convert.2695, f32[51200]{0} %multiply.2694), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%power.3058 = f32[] power(f32[] %constant.2692, f32[] %convert.2964), sharding={replicated}, metadata={op_type="pow" op_name="parallelize(train_step_shard_parallel)/pow" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=86}
%subtract.3060 = f32[] subtract(f32[] %constant.299, f32[] %power.3058), sharding={replicated}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=86}
%broadcast.3062 = f32[51200]{0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3063 = f32[51200]{0} divide(f32[51200]{0} %add.2696, f32[51200]{0} %broadcast.3062), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3151 = f32[51200]{0} sqrt(f32[51200]{0} %divide.3063), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%constant.3152 = f32[] constant(1e-08), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3153 = f32[51200]{0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3154 = f32[51200]{0} add(f32[51200]{0} %sqrt.3151, f32[51200]{0} %broadcast.3153), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.14 = f32[51200]{0} multiply(f32[51200]{0} %broadcast.2970, f32[51200]{0} %add.3154), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3155 = f32[51200]{0} divide(f32[51200]{0} %add.2463, f32[51200]{0} %multiply.14), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%constant.3420 = f32[] constant(-0.01), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%broadcast.3421 = f32[51200]{0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3422 = f32[51200]{0} multiply(f32[51200]{0} %divide.3155, f32[51200]{0} %broadcast.3421), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3507 = f32[51200]{0} add(f32[51200]{0} %parameter.90, f32[51200]{0} %multiply.3422), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3536 = f16[51200]{0} convert(f32[51200]{0} %add.3507), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3566 = f16[51200]{0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3567 = f16[51200]{0} multiply(f16[51200]{0} %parameter.2, f16[51200]{0} %broadcast.3566), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3568 = f16[51200]{0} add(f16[51200]{0} %convert.3536, f16[51200]{0} %multiply.3567), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.91 = f32[2048]{0} parameter(90), sharding={replicated}
%constant.297 = f32[] constant(2048), sharding={replicated}
%constant.177 = f32[] constant(2), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%constant.300 = f32[] constant(1.41421354), sharding={replicated}
%constant.52 = f16[] constant(8), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%broadcast.15 = f16[8,1,1,1024]{3,2,1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=()\n shape=(8, 1, 1, 1024)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=183}
%broadcast.17 = f16[8,1,1,1024]{3,2,1,0} broadcast(f16[] %constant.342), dimensions={}, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=()\n shape=(8, 1, 1, 1024)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=184}
%parameter.124 = u32[2]{0} parameter(123), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%constant.1055 = pred[] constant(true), sharding={replicated}, metadata={op_type="remat_call" op_name="parallelize(train_step_shard_parallel)/remat_call[\n concrete=True\n differentiated=True\n name=jvp(core_fn)\n policy=None\n prevent_cse=True\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%constant.1056 = pred[] constant(false), sharding={replicated}, metadata={op_type="remat_call" op_name="parallelize(train_step_shard_parallel)/remat_call[\n concrete=True\n differentiated=True\n name=jvp(core_fn)\n policy=None\n prevent_cse=True\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%reshape.36 = f16[8192,51200]{1,0} reshape(f16[8,1024,51200]{2,1,0} %add.1035), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.28 = f16[8192,2048]{1,0} dot(f16[8192,51200]{1,0} %reshape.36, f16[51200,2048]{1,0} %parameter.6), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.38 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.28), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/dot_general[\n dimension_numbers=(((2,), (1,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/gpt_model.py" source_line=70}
%tuple.1057 = (f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) tuple(f32[] %constant.297, f32[] %constant.177, f32[] %constant.299, f32[] %constant.300, f32[] %constant.297, /*index=5*/f16[] %constant.52, s32[] %constant.302, f16[8,1,1,1024]{3,2,1,0} %broadcast.15, f16[8,1,1,1024]{3,2,1,0} %broadcast.17, f32[] %constant.196, /*index=10*/f32[] %constant.297, f32[] %constant.196, f32[] %constant.297, f16[2048]{0} %parameter.19, f16[2048]{0} %parameter.20, /*index=15*/f16[2048]{0} %parameter.21, f16[2048,2048]{1,0} %parameter.22, f16[6144]{0} %parameter.23, f16[2048,6144]{1,0} %parameter.24, f16[8192]{0} %parameter.25, /*index=20*/f16[2048,8192]{1,0} %parameter.26, f16[2048]{0} %parameter.27, f16[2048]{0} %parameter.28, f16[2048]{0} %parameter.29, f16[8192,2048]{1,0} %parameter.30, /*index=25*/u32[2]{0} %parameter.124, f16[8,1024,2048]{2,1,0} %add.564, s32[8,1024]{1,0} %parameter.119, pred[] %constant.1055, pred[] %constant.1056, /*index=30*/f16[8,1024,2048]{2,1,0} %reshape.38), sharding={{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=5*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=10*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=15*/{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=20*/{devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=25*/{devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {replicated}, {replicated}, /*index=30*/{devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_call" op_name="parallelize(train_step_shard_parallel)/remat_call[\n concrete=True\n differentiated=True\n name=jvp(core_fn)\n policy=None\n prevent_cse=True\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%custom-call.1058 = (f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) custom-call((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %tuple.1057), custom_call_target="identity", output_to_operand_aliasing={{0}: (0, {0}), {1}: (0, {1}), {2}: (0, {2}), {3}: (0, {3}), {4}: (0, {4}), {5}: (0, {5}), {6}: (0, {6}), {7}: (0, {7}), {8}: (0, {8}), {9}: (0, {9}), {10}: (0, {10}), {11}: (0, {11}), {12}: (0, {12}), {13}: (0, {13}), {14}: (0, {14}), {15}: (0, {15}), {16}: (0, {16}), {17}: (0, {17}), {18}: (0, {18}), {19}: (0, {19}), {20}: (0, {20}), {21}: (0, {21}), {22}: (0, {22}), {23}: (0, {23}), {24}: (0, {24}), {25}: (0, {25}), {26}: (0, {26}), {27}: (0, {27}), {28}: (0, {28}), {29}: (0, {29}), {30}: (0, {30})}, sharding={{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=5*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=10*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=15*/{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=20*/{devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=25*/{devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {replicated}, {replicated}, /*index=30*/{devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_begin"}
%get-tuple-element.1085 = f16[8,1024,2048]{2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=26, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.39 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %get-tuple-element.1085), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%get-tuple-element.1077 = f16[2048,6144]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=18, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.29 = f16[8192,6144]{1,0} dot(f16[8192,2048]{1,0} %reshape.39, f16[2048,6144]{1,0} %get-tuple-element.1077), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.41 = f16[8,1024,6144]{2,1,0} reshape(f16[8192,6144]{1,0} %dot.29), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1076 = f16[6144]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=17, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.1094 = f16[8,1024,6144]{2,1,0} broadcast(f16[6144]{0} %get-tuple-element.1076), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1095 = f16[8,1024,6144]{2,1,0} add(f16[8,1024,6144]{2,1,0} %reshape.41, f16[8,1024,6144]{2,1,0} %broadcast.1094), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%reshape.1096 = f16[8,1024,2048,3]{3,2,1,0} reshape(f16[8,1024,6144]{2,1,0} %add.1095), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 3)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=162}
%slice.1167 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.1096), slice={[0:8], [0:1024], [0:2048], [1:2]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/slice[\n limit_indices=(8, 1024, 2048, 2)\n start_indices=(0, 0, 0, 1)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.1168 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.1167), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=172}
%transpose.100 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1168), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.1097 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.1096), slice={[0:8], [0:1024], [0:2048], [0:1]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/slice[\n limit_indices=(8, 1024, 2048, 1)\n start_indices=(0, 0, 0, 0)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%get-tuple-element.1064 = f16[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=5, sharding={replicated}
%broadcast.95 = f16[8,1024,2048,1]{3,2,1,0} broadcast(f16[] %get-tuple-element.1064), dimensions={}, sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%divide.31 = f16[8,1024,2048,1]{3,2,1,0} divide(f16[8,1024,2048,1]{3,2,1,0} %slice.1097, f16[8,1024,2048,1]{3,2,1,0} %broadcast.95), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%reshape.346 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %divide.31), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%transpose.98 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.346), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.1101 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.1096), slice={[0:8], [0:1024], [0:2048], [2:3]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/slice[\n limit_indices=(8, 1024, 2048, 3)\n start_indices=(0, 0, 0, 2)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.1102 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.1101), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=175}
%transpose.99 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1102), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.59 = f16[8,32,1024,1024]{3,2,1,0} dot(f16[8,32,1024,64]{3,1,2,0} %transpose.98, f16[8,32,64,1024]{2,1,3,0} %transpose.99), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%get-tuple-element.1086 = s32[8,1024]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=27, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%get-tuple-element.1065 = s32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=6, sharding={replicated}
%broadcast.96 = s32[8,1024]{1,0} broadcast(s32[] %get-tuple-element.1065), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/gt" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=182}
%compare.2 = pred[8,1024]{1,0} compare(s32[8,1024]{1,0} %get-tuple-element.1086, s32[8,1024]{1,0} %broadcast.96), direction=GT, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/gt" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=182}
%reshape.347 = pred[8,1,1,1024]{3,2,1,0} reshape(pred[8,1024]{1,0} %compare.2), sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/gt" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=182}
%get-tuple-element.1066 = f16[8,1,1,1024]{3,2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=7, sharding={replicated}
%get-tuple-element.1067 = f16[8,1,1,1024]{3,2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=8, sharding={replicated}
%select.1122 = f16[8,1,1,1024]{3,2,1,0} select(pred[8,1,1,1024]{3,2,1,0} %reshape.347, f16[8,1,1,1024]{3,2,1,0} %get-tuple-element.1066, f16[8,1,1,1024]{3,2,1,0} %get-tuple-element.1067), sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="select" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/select" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=184}
%reshape.1123 = f16[8,1024]{1,0} reshape(f16[8,1,1,1024]{3,2,1,0} %select.1122), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%broadcast.1124 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,1024]{1,0} %reshape.1123), dimensions={0,3}, sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%add.1125 = f16[8,32,1024,1024]{3,2,1,0} add(f16[8,32,1024,1024]{3,2,1,0} %dot.59, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1124), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%reduce.1132 = f16[8,32,1024]{2,1,0} reduce(f16[8,32,1024,1024]{3,2,1,0} %add.1125, f16[] %constant.342), dimensions={3}, to_apply=%primitive_computation_max__3.1127, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_max" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_max[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.1147 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %reduce.1132), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%subtract.1148 = f16[8,32,1024,1024]{3,2,1,0} subtract(f16[8,32,1024,1024]{3,2,1,0} %add.1125, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1147), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%exponential.1149 = f16[8,32,1024,1024]{3,2,1,0} exponential(f16[8,32,1024,1024]{3,2,1,0} %subtract.1148), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="exp" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/exp" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.1150 = f32[8,32,1024,1024]{3,2,1,0} convert(f16[8,32,1024,1024]{3,2,1,0} %exponential.1149), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%reduce.1157 = f32[8,32,1024]{2,1,0} reduce(f32[8,32,1024,1024]{3,2,1,0} %convert.1150, f32[] %constant.165), dimensions={3}, to_apply=%primitive_computation_add__22.1152, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.11 = f16[8,32,1024]{2,1,0} convert(f32[8,32,1024]{2,1,0} %reduce.1157), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.1161 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %convert.11), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%divide.1162 = f16[8,32,1024,1024]{3,2,1,0} divide(f16[8,32,1024,1024]{3,2,1,0} %exponential.1149, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1161), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%dot.60 = f16[8,32,64,1024]{3,2,1,0} dot(f16[8,32,64,1024]{2,1,3,0} %transpose.100, f16[8,32,1024,1024]{3,2,1,0} %divide.1162), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((1,), (3,)), ((0, 2), (0, 1)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%transpose.2 = f16[8,1024,32,64]{1,3,2,0} transpose(f16[8,32,64,1024]{3,2,1,0} %dot.60), dimensions={0,3,1,2}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(jvp(_einsum))/transpose[\n permutation=(0, 3, 1, 2)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%reshape.42 = f16[8192,2048]{1,0} reshape(f16[8,1024,32,64]{1,3,2,0} %transpose.2), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%get-tuple-element.1075 = f16[2048,2048]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=16, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.30 = f16[8192,2048]{1,0} dot(f16[8192,2048]{1,0} %reshape.42, f16[2048,2048]{1,0} %get-tuple-element.1075), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.44 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.30), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1074 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=15, sharding={replicated}
%broadcast.1190 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1074), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1191 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.44, f16[8,1024,2048]{2,1,0} %broadcast.1190), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1192 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.1191, f16[8,1024,2048]{2,1,0} %get-tuple-element.1085), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=233}
%convert.1193 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.1192), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.1200 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.1193, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__23.1195, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%get-tuple-element.1063 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=4, sharding={replicated}
%broadcast.97 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1063), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.32 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1200, f32[8,1024]{1,0} %broadcast.97), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.1205 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %divide.32), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.1206 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.1193, f32[8,1024,2048]{2,1,0} %broadcast.1205), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.1207 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1193, f32[8,1024,2048]{2,1,0} %convert.1193), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.1217 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.1207, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__24.1212, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%get-tuple-element.1069 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=10, sharding={replicated}
%broadcast.98 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1069), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.33 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1217, f32[8,1024]{1,0} %broadcast.98), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.103 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %divide.32, f32[8,1024]{1,0} %divide.32), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.5 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %divide.33, f32[8,1024]{1,0} %multiply.103), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%get-tuple-element.1068 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=9, sharding={replicated}
%broadcast.185 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1068), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%add.6 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.5, f32[8,1024]{1,0} %broadcast.185), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%reshape.481 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %add.6), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.1228 = f32[8,1024,1]{2,1,0} rsqrt(f32[8,1024,1]{2,1,0} %reshape.481), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%convert.1229 = f16[8,1024,1]{2,1,0} convert(f32[8,1024,1]{2,1,0} %rsqrt.1228), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.1235 = f16[8,1024]{1,0} reshape(f16[8,1024,1]{2,1,0} %convert.1229), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%broadcast.1236 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %reshape.1235), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%get-tuple-element.1073 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=14, sharding={replicated}
%broadcast.1238 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1073), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.1239 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.1236, f16[8,1024,2048]{2,1,0} %broadcast.1238), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.1240 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.1239), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.1241 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.1206, f32[8,1024,2048]{2,1,0} %convert.1240), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.1242 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.1241), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%get-tuple-element.1072 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=13, sharding={replicated}
%broadcast.1245 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1072), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%add.1246 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.1242, f16[8,1024,2048]{2,1,0} %broadcast.1245), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%reshape.45 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %add.1246), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%get-tuple-element.1079 = f16[2048,8192]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=20, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.31 = f16[8192,8192]{1,0} dot(f16[8192,2048]{1,0} %reshape.45, f16[2048,8192]{1,0} %get-tuple-element.1079), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.47 = f16[8,1024,8192]{2,1,0} reshape(f16[8192,8192]{1,0} %dot.31), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1078 = f16[8192]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=19, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.1250 = f16[8,1024,8192]{2,1,0} broadcast(f16[8192]{0} %get-tuple-element.1078), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1251 = f16[8,1024,8192]{2,1,0} add(f16[8,1024,8192]{2,1,0} %reshape.47, f16[8,1024,8192]{2,1,0} %broadcast.1250), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%convert.1320 = f32[8,1024,8192]{2,1,0} convert(f16[8,1024,8192]{2,1,0} %add.1251), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%get-tuple-element.1062 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=3, sharding={replicated}
%broadcast.1253 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %get-tuple-element.1062), dimensions={}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.1254 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %convert.1320, f32[8,1024,8192]{2,1,0} %broadcast.1253), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%clamp.1259 = f32[8,1024,8192]{2,1,0} clamp(f32[8,1024,8192]{2,1,0} %broadcast.447, f32[8,1024,8192]{2,1,0} %divide.1254, f32[8,1024,8192]{2,1,0} %broadcast.448), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1260 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.1259, f32[8,1024,8192]{2,1,0} %clamp.1259), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1286 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %broadcast.475, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1288 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1286, f32[8,1024,8192]{2,1,0} %broadcast.477), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1290 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1288, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1292 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1290, f32[8,1024,8192]{2,1,0} %broadcast.481), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1294 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1292, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1296 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1294, f32[8,1024,8192]{2,1,0} %broadcast.485), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1298 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1296, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1300 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1298, f32[8,1024,8192]{2,1,0} %broadcast.489), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1302 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1300, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1304 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1302, f32[8,1024,8192]{2,1,0} %broadcast.493), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1306 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1304, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1308 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1306, f32[8,1024,8192]{2,1,0} %broadcast.497), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1310 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1308, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1312 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1310, f32[8,1024,8192]{2,1,0} %broadcast.501), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1313 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.1259, f32[8,1024,8192]{2,1,0} %add.1312), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1266 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1286, f32[8,1024,8192]{2,1,0} %broadcast.455), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1268 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1266, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1270 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1268, f32[8,1024,8192]{2,1,0} %broadcast.459), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1272 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1270, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1274 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1272, f32[8,1024,8192]{2,1,0} %broadcast.463), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1276 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1274, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1278 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1276, f32[8,1024,8192]{2,1,0} %broadcast.467), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1280 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1278, f32[8,1024,8192]{2,1,0} %multiply.1260), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1282 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1280, f32[8,1024,8192]{2,1,0} %broadcast.471), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.1314 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.1313, f32[8,1024,8192]{2,1,0} %add.1282), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%get-tuple-element.1061 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=2, sharding={replicated}
%broadcast.1318 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %get-tuple-element.1061), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1319 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %divide.1314, f32[8,1024,8192]{2,1,0} %broadcast.1318), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1321 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.1320, f32[8,1024,8192]{2,1,0} %add.1319), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%get-tuple-element.1060 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=1, sharding={replicated}
%broadcast.1322 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %get-tuple-element.1060), dimensions={}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.1323 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.1321, f32[8,1024,8192]{2,1,0} %broadcast.1322), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.1327 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %divide.1323), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%reshape.48 = f16[8192,8192]{1,0} reshape(f16[8,1024,8192]{2,1,0} %convert.1327), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%get-tuple-element.1083 = f16[8192,2048]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=24, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.32 = f16[8192,2048]{1,0} dot(f16[8192,8192]{1,0} %reshape.48, f16[8192,2048]{1,0} %get-tuple-element.1083), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.50 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.32), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1082 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=23, sharding={replicated}
%broadcast.1331 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1082), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1332 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.50, f16[8,1024,2048]{2,1,0} %broadcast.1331), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1333 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.1332, f16[8,1024,2048]{2,1,0} %add.1246), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=310}
%convert.1334 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.1333), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.1341 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.1334, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__25.1336, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%get-tuple-element.1059 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=0, sharding={replicated}
%broadcast.99 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1059), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.34 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1341, f32[8,1024]{1,0} %broadcast.99), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.1346 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %divide.34), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.1347 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.1334, f32[8,1024,2048]{2,1,0} %broadcast.1346), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%get-tuple-element.1089 = f16[8,1024,2048]{2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=30, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%convert.1403 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %get-tuple-element.1089), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%multiply.1404 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.1347, f32[8,1024,2048]{2,1,0} %convert.1403), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.1405 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.1404), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%get-tuple-element.1081 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=22, sharding={replicated}
%broadcast.1426 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1081), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.1427 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %convert.1405, f16[8,1024,2048]{2,1,0} %broadcast.1426), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.1434 = f16[8,1024]{1,0} reduce(f16[8,1024,2048]{2,1,0} %multiply.1427, f16[] %constant.1015), dimensions={2}, to_apply=%primitive_computation_add__31.1429, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.12 = f32[8,1024]{1,0} convert(f16[8,1024]{1,0} %reduce.1434), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.354 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %convert.12), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%constant.201 = f32[] constant(-0.5), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%broadcast.1373 = f32[8,1024,1]{2,1,0} broadcast(f32[] %constant.201), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.1348 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1334, f32[8,1024,2048]{2,1,0} %convert.1334), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.1358 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.1348, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__26.1353, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%get-tuple-element.1071 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=12, sharding={replicated}
%broadcast.100 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1071), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.35 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1358, f32[8,1024]{1,0} %broadcast.100), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.105 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %divide.34, f32[8,1024]{1,0} %divide.34), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.6 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %divide.35, f32[8,1024]{1,0} %multiply.105), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%get-tuple-element.1070 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1058), index=11, sharding={replicated}
%broadcast.186 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1070), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%add.7 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.6, f32[8,1024]{1,0} %broadcast.186), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%reshape.482 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %add.7), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.1369 = f32[8,1024,1]{2,1,0} rsqrt(f32[8,1024,1]{2,1,0} %reshape.482), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%divide.1371 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %rsqrt.1369, f32[8,1024,1]{2,1,0} %reshape.482), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.1374 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %broadcast.1373, f32[8,1024,1]{2,1,0} %divide.1371), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.1437 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %reshape.354, f32[8,1024,1]{2,1,0} %multiply.1374), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%broadcast.1440 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1071), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.1441 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %multiply.1437, f32[8,1024,1]{2,1,0} %broadcast.1440), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%broadcast.130 = f32[8,1024,1]{2,1,0} broadcast(f32[] %constant.177), dimensions={}, sharding={replicated}
%multiply.87 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %divide.1441, f32[8,1024,1]{2,1,0} %broadcast.130), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.408 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %multiply.87), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.63 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.408), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%multiply.1450 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1334, f32[8,1024,2048]{2,1,0} %broadcast.63), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%convert.1370 = f16[8,1024,1]{2,1,0} convert(f32[8,1024,1]{2,1,0} %rsqrt.1369), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.1376 = f16[8,1024]{1,0} reshape(f16[8,1024,1]{2,1,0} %convert.1370), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%broadcast.1377 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %reshape.1376), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.1380 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.1377, f16[8,1024,2048]{2,1,0} %broadcast.1426), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.1381 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.1380), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.1406 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1403, f32[8,1024,2048]{2,1,0} %convert.1381), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.1460 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %multiply.1450, f32[8,1024,2048]{2,1,0} %multiply.1406), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%negate.1438 = f32[8,1024,1]{2,1,0} negate(f32[8,1024,1]{2,1,0} %multiply.1437), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%broadcast.164 = f32[8,1024]{1,0} broadcast(f32[] %constant.177), dimensions={}, sharding={replicated}
%multiply.106 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %broadcast.164, f32[8,1024]{1,0} %divide.34), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%reshape.444 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %multiply.106), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.1439 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %negate.1438, f32[8,1024,1]{2,1,0} %reshape.444), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%negate.1451 = f32[8,1024,2048]{2,1,0} negate(f32[8,1024,2048]{2,1,0} %multiply.1406), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reduce.1458 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %negate.1451, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__33.1453, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reshape.1459 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %reduce.1458), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.1461 = f32[8,1024,1]{2,1,0} add(f32[8,1024,1]{2,1,0} %multiply.1439, f32[8,1024,1]{2,1,0} %reshape.1459), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%broadcast.1462 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1059), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.1463 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %add.1461, f32[8,1024,1]{2,1,0} %broadcast.1462), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%reshape.260 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %divide.1463), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.1471 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.260), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/broadcast_in_dim[\n broadcast_dimensions=(0, 1)\n shape=(8, 1024, 2048)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%add.1472 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %add.1460, f32[8,1024,2048]{2,1,0} %broadcast.1471), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%convert.1473 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %add.1472), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reshape.54 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %convert.1473), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.34 = f16[8192,8192]{1,0} dot(f16[8192,2048]{1,0} %reshape.54, f16[8192,2048]{1,0} %get-tuple-element.1083), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%convert.13 = f32[8192,8192]{1,0} convert(f16[8192,8192]{1,0} %dot.34), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.165 = f32[8192,8192]{1,0} broadcast(f32[] %get-tuple-element.1060), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.44 = f32[8192,8192]{1,0} divide(f32[8192,8192]{1,0} %convert.13, f32[8192,8192]{1,0} %broadcast.165), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%reshape.445 = f32[8,1024,8192]{2,1,0} reshape(f32[8192,8192]{1,0} %divide.44), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1490 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %reshape.445, f32[8,1024,8192]{2,1,0} %add.1319), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.1491 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %multiply.1490), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%constant.1492 = f32[] constant(1.12837923), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.1493 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %constant.1492), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1489 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.1320, f32[8,1024,8192]{2,1,0} %reshape.445), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1494 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %broadcast.1493, f32[8,1024,8192]{2,1,0} %multiply.1489), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1315 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %divide.1254, f32[8,1024,8192]{2,1,0} %divide.1254), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%negate.1316 = f32[8,1024,8192]{2,1,0} negate(f32[8,1024,8192]{2,1,0} %multiply.1315), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%exponential.1317 = f32[8,1024,8192]{2,1,0} exponential(f32[8,1024,8192]{2,1,0} %negate.1316), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="exp" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/exp" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1495 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %multiply.1494, f32[8,1024,8192]{2,1,0} %exponential.1317), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.1497 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.1495, f32[8,1024,8192]{2,1,0} %broadcast.1253), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.1498 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %divide.1497), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1499 = f16[8,1024,8192]{2,1,0} add(f16[8,1024,8192]{2,1,0} %convert.1491, f16[8,1024,8192]{2,1,0} %convert.1498), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%reshape.60 = f16[8192,8192]{1,0} reshape(f16[8,1024,8192]{2,1,0} %add.1499), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.36 = f16[8192,2048]{1,0} dot(f16[8192,8192]{1,0} %reshape.60, f16[2048,8192]{1,0} %get-tuple-element.1079), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.62 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.36), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (1,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%add.1512 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.1473, f16[8,1024,2048]{2,1,0} %reshape.62), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%convert.1528 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.1512), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%multiply.1529 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.1206, f32[8,1024,2048]{2,1,0} %convert.1528), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.1530 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.1529), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.1552 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %convert.1530, f16[8,1024,2048]{2,1,0} %broadcast.1238), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.1559 = f16[8,1024]{1,0} reduce(f16[8,1024,2048]{2,1,0} %multiply.1552, f16[] %constant.1015), dimensions={2}, to_apply=%primitive_computation_add__41.1554, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.14 = f32[8,1024]{1,0} convert(f16[8,1024]{1,0} %reduce.1559), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.356 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %convert.14), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%divide.1230 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %rsqrt.1228, f32[8,1024,1]{2,1,0} %reshape.481), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.1233 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %broadcast.1373, f32[8,1024,1]{2,1,0} %divide.1230), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.1562 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %reshape.356, f32[8,1024,1]{2,1,0} %multiply.1233), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%broadcast.1565 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1069), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.1566 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %multiply.1562, f32[8,1024,1]{2,1,0} %broadcast.1565), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.88 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %divide.1566, f32[8,1024,1]{2,1,0} %broadcast.130), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.409 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %multiply.88), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.67 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.409), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%multiply.1575 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1193, f32[8,1024,2048]{2,1,0} %broadcast.67), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.1531 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1528, f32[8,1024,2048]{2,1,0} %convert.1240), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.1585 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %multiply.1575, f32[8,1024,2048]{2,1,0} %multiply.1531), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%negate.1563 = f32[8,1024,1]{2,1,0} negate(f32[8,1024,1]{2,1,0} %multiply.1562), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.104 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %broadcast.164, f32[8,1024]{1,0} %divide.32), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%reshape.442 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %multiply.104), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.1564 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %negate.1563, f32[8,1024,1]{2,1,0} %reshape.442), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%negate.1576 = f32[8,1024,2048]{2,1,0} negate(f32[8,1024,2048]{2,1,0} %multiply.1531), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reduce.1583 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %negate.1576, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__43.1578, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reshape.1584 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %reduce.1583), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.1586 = f32[8,1024,1]{2,1,0} add(f32[8,1024,1]{2,1,0} %multiply.1564, f32[8,1024,1]{2,1,0} %reshape.1584), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%broadcast.1587 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1063), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.1588 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %add.1586, f32[8,1024,1]{2,1,0} %broadcast.1587), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%reshape.263 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %divide.1588), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.1596 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.263), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/broadcast_in_dim[\n broadcast_dimensions=(0, 1)\n shape=(8, 1024, 2048)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%add.1597 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %add.1585, f32[8,1024,2048]{2,1,0} %broadcast.1596), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%convert.1598 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %add.1597), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reshape.66 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %convert.1598), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.38 = f16[8192,2048]{1,0} dot(f16[8192,2048]{1,0} %reshape.66, f16[2048,2048]{1,0} %get-tuple-element.1075), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.1611 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8192,2048]{1,0} %dot.38), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=208}
%transpose.3 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1611), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/transpose[\n permutation=(0, 2, 3, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%dot.7 = f16[8,32,64,1024]{3,2,1,0} dot(f16[8,32,64,1024]{2,1,3,0} %transpose.3, f16[8,32,1024,1024]{3,2,1,0} %divide.1162), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%transpose.4 = f16[8,1024,32,64]{1,3,2,0} transpose(f16[8,32,64,1024]{3,2,1,0} %dot.7), dimensions={0,3,1,2}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/transpose[\n permutation=(0, 3, 1, 2)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%reshape.1625 = f16[8,1024,2048,1]{3,2,1,0} reshape(f16[8,1024,32,64]{1,3,2,0} %transpose.4), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=172}
%pad.1627 = f16[8,1024,2048,3]{3,2,1,0} pad(f16[8,1024,2048,1]{3,2,1,0} %reshape.1625, f16[] %constant.1015), padding=0_0x0_0x0_0x1_1, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="pad" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/pad[\n padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0), (1, 1, 0))\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%transpose.102 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1611), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.61 = f16[8,32,1024,1024]{3,2,1,0} dot(f16[8,32,1024,64]{3,1,2,0} %transpose.102, f16[8,32,64,1024]{2,1,3,0} %transpose.100), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%divide.1643 = f16[8,32,1024,1024]{3,2,1,0} divide(f16[8,32,1024,1024]{3,2,1,0} %dot.61, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1161), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%constant.1164 = f16[] constant(1), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=-2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.171 = f16[8,32,1024]{2,1,0} broadcast(f16[] %constant.1164), dimensions={}, sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=-2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%multiply.102 = f16[8,32,1024]{2,1,0} multiply(f16[8,32,1024]{2,1,0} %convert.11, f16[8,32,1024]{2,1,0} %convert.11), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=-2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%divide.46 = f16[8,32,1024]{2,1,0} divide(f16[8,32,1024]{2,1,0} %broadcast.171, f16[8,32,1024]{2,1,0} %multiply.102), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=-2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.1629 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %divide.46), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%multiply.1630 = f16[8,32,1024,1024]{3,2,1,0} multiply(f16[8,32,1024,1024]{3,2,1,0} %dot.61, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1629), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%multiply.1631 = f16[8,32,1024,1024]{3,2,1,0} multiply(f16[8,32,1024,1024]{3,2,1,0} %multiply.1630, f16[8,32,1024,1024]{3,2,1,0} %exponential.1149), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%reduce.1638 = f16[8,32,1024]{2,1,0} reduce(f16[8,32,1024,1024]{3,2,1,0} %multiply.1631, f16[] %constant.1015), dimensions={3}, to_apply=%primitive_computation_add__46.1633, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%negate = f16[8,32,1024]{2,1,0} negate(f16[8,32,1024]{2,1,0} %reduce.1638), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.68 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %negate), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/broadcast_in_dim[\n broadcast_dimensions=(0, 1, 2)\n shape=(8, 32, 1024, 1024)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%add.1654 = f16[8,32,1024,1024]{3,2,1,0} add(f16[8,32,1024,1024]{3,2,1,0} %divide.1643, f16[8,32,1024,1024]{3,2,1,0} %broadcast.68), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%multiply.1655 = f16[8,32,1024,1024]{3,2,1,0} multiply(f16[8,32,1024,1024]{3,2,1,0} %add.1654, f16[8,32,1024,1024]{3,2,1,0} %exponential.1149), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%dot.63 = f16[8,32,1024,64]{3,2,1,0} dot(f16[8,32,1024,1024]{3,2,1,0} %multiply.1655, f16[8,32,1024,64]{3,1,2,0} %transpose.98), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((2,), (1,)), ((0, 1), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%transpose.6 = f16[8,1024,32,64]{3,1,2,0} transpose(f16[8,32,1024,64]{3,2,1,0} %dot.63), dimensions={0,2,1,3}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/transpose[\n permutation=(0, 2, 1, 3)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%reshape.1669 = f16[8,1024,2048,1]{3,2,1,0} reshape(f16[8,1024,32,64]{3,1,2,0} %transpose.6), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=175}
%pad.1671 = f16[8,1024,2048,3]{3,2,1,0} pad(f16[8,1024,2048,1]{3,2,1,0} %reshape.1669, f16[] %constant.1015), padding=0_0x0_0x0_0x2_0, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="pad" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/pad[\n padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0), (2, 0, 0))\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%add.1672 = f16[8,1024,2048,3]{3,2,1,0} add(f16[8,1024,2048,3]{3,2,1,0} %pad.1627, f16[8,1024,2048,3]{3,2,1,0} %pad.1671), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%transpose.105 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1102), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.62 = f16[8,32,1024,64]{3,2,1,0} dot(f16[8,32,1024,1024]{3,2,1,0} %multiply.1655, f16[8,32,1024,64]{3,1,2,0} %transpose.105), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((3,), (1,)), ((0, 1), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%broadcast.101 = f16[8,32,1024,64]{3,2,1,0} broadcast(f16[] %get-tuple-element.1064), dimensions={}, sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%divide.36 = f16[8,32,1024,64]{3,2,1,0} divide(f16[8,32,1024,64]{3,2,1,0} %dot.62, f16[8,32,1024,64]{3,2,1,0} %broadcast.101), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%transpose.121 = f16[8,1024,32,64]{3,2,1,0} transpose(f16[8,32,1024,64]{3,2,1,0} %divide.36), dimensions={0,2,1,3}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%reshape.1675 = f16[8,1024,2048,1]{3,2,1,0} reshape(f16[8,1024,32,64]{3,2,1,0} %transpose.121), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=169}
%pad.1677 = f16[8,1024,2048,3]{3,2,1,0} pad(f16[8,1024,2048,1]{3,2,1,0} %reshape.1675, f16[] %constant.1015), padding=0_0x0_0x0_0x0_2, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="pad" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/pad[\n padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 2, 0))\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%add.1678 = f16[8,1024,2048,3]{3,2,1,0} add(f16[8,1024,2048,3]{3,2,1,0} %add.1672, f16[8,1024,2048,3]{3,2,1,0} %pad.1677), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%reshape.72 = f16[8192,6144]{1,0} reshape(f16[8,1024,2048,3]{3,2,1,0} %add.1678), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.40 = f16[8192,2048]{1,0} dot(f16[8192,6144]{1,0} %reshape.72, f16[2048,6144]{1,0} %get-tuple-element.1077), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.74 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.40), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (1,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%add.1692 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.1598, f16[8,1024,2048]{2,1,0} %reshape.74), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%tuple.1695 = (f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) tuple(f32[] %constant.297, f32[] %constant.177, f32[] %constant.299, f32[] %constant.300, f32[] %constant.297, /*index=5*/f16[] %constant.52, s32[] %constant.302, f16[8,1,1,1024]{3,2,1,0} %broadcast.15, f16[8,1,1,1024]{3,2,1,0} %broadcast.17, f32[] %constant.196, /*index=10*/f32[] %constant.297, f32[] %constant.196, f32[] %constant.297, f16[2048]{0} %parameter.7, f16[2048]{0} %parameter.8, /*index=15*/f16[2048]{0} %parameter.9, f16[2048,2048]{1,0} %parameter.10, f16[6144]{0} %parameter.11, f16[2048,6144]{1,0} %parameter.12, f16[8192]{0} %parameter.13, /*index=20*/f16[2048,8192]{1,0} %parameter.14, f16[2048]{0} %parameter.15, f16[2048]{0} %parameter.16, f16[2048]{0} %parameter.17, f16[8192,2048]{1,0} %parameter.18, /*index=25*/u32[2]{0} %parameter.124, f16[8,1024,2048]{2,1,0} %add.220, s32[8,1024]{1,0} %parameter.119, pred[] %constant.1055, pred[] %constant.1056, /*index=30*/f16[8,1024,2048]{2,1,0} %add.1692), sharding={{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=5*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=10*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=15*/{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=20*/{devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=25*/{devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {replicated}, {replicated}, /*index=30*/{devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_call" op_name="parallelize(train_step_shard_parallel)/remat_call[\n concrete=True\n differentiated=True\n name=jvp(core_fn)\n policy=None\n prevent_cse=True\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%custom-call.1696 = (f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) custom-call((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %tuple.1695), custom_call_target="identity", output_to_operand_aliasing={{0}: (0, {0}), {1}: (0, {1}), {2}: (0, {2}), {3}: (0, {3}), {4}: (0, {4}), {5}: (0, {5}), {6}: (0, {6}), {7}: (0, {7}), {8}: (0, {8}), {9}: (0, {9}), {10}: (0, {10}), {11}: (0, {11}), {12}: (0, {12}), {13}: (0, {13}), {14}: (0, {14}), {15}: (0, {15}), {16}: (0, {16}), {17}: (0, {17}), {18}: (0, {18}), {19}: (0, {19}), {20}: (0, {20}), {21}: (0, {21}), {22}: (0, {22}), {23}: (0, {23}), {24}: (0, {24}), {25}: (0, {25}), {26}: (0, {26}), {27}: (0, {27}), {28}: (0, {28}), {29}: (0, {29}), {30}: (0, {30})}, sharding={{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=5*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=10*/{replicated}, {replicated}, {replicated}, {replicated}, {replicated}, /*index=15*/{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=20*/{devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=25*/{devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {replicated}, {replicated}, /*index=30*/{devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_begin"}
%get-tuple-element.1723 = f16[8,1024,2048]{2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=26, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.75 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %get-tuple-element.1723), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%get-tuple-element.1715 = f16[2048,6144]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=18, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.41 = f16[8192,6144]{1,0} dot(f16[8192,2048]{1,0} %reshape.75, f16[2048,6144]{1,0} %get-tuple-element.1715), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.77 = f16[8,1024,6144]{2,1,0} reshape(f16[8192,6144]{1,0} %dot.41), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1714 = f16[6144]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=17, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.1732 = f16[8,1024,6144]{2,1,0} broadcast(f16[6144]{0} %get-tuple-element.1714), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1733 = f16[8,1024,6144]{2,1,0} add(f16[8,1024,6144]{2,1,0} %reshape.77, f16[8,1024,6144]{2,1,0} %broadcast.1732), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%reshape.1734 = f16[8,1024,2048,3]{3,2,1,0} reshape(f16[8,1024,6144]{2,1,0} %add.1733), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 3)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=162}
%slice.1805 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.1734), slice={[0:8], [0:1024], [0:2048], [1:2]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/slice[\n limit_indices=(8, 1024, 2048, 2)\n start_indices=(0, 0, 0, 1)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.1806 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.1805), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=172}
%transpose.110 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1806), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.1735 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.1734), slice={[0:8], [0:1024], [0:2048], [0:1]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/slice[\n limit_indices=(8, 1024, 2048, 1)\n start_indices=(0, 0, 0, 0)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%get-tuple-element.1702 = f16[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=5, sharding={replicated}
%broadcast.102 = f16[8,1024,2048,1]{3,2,1,0} broadcast(f16[] %get-tuple-element.1702), dimensions={}, sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%divide.37 = f16[8,1024,2048,1]{3,2,1,0} divide(f16[8,1024,2048,1]{3,2,1,0} %slice.1735, f16[8,1024,2048,1]{3,2,1,0} %broadcast.102), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%reshape.358 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %divide.37), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%transpose.108 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.358), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%slice.1739 = f16[8,1024,2048,1]{3,2,1,0} slice(f16[8,1024,2048,3]{3,2,1,0} %reshape.1734), slice={[0:8], [0:1024], [0:2048], [2:3]}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="slice" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/slice[\n limit_indices=(8, 1024, 2048, 3)\n start_indices=(0, 0, 0, 2)\n strides=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%reshape.1740 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8,1024,2048,1]{3,2,1,0} %slice.1739), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=175}
%transpose.109 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1740), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.64 = f16[8,32,1024,1024]{3,2,1,0} dot(f16[8,32,1024,64]{3,1,2,0} %transpose.108, f16[8,32,64,1024]{2,1,3,0} %transpose.109), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((3,), (3,)), ((0, 2), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%get-tuple-element.1724 = s32[8,1024]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=27, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%get-tuple-element.1703 = s32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=6, sharding={replicated}
%broadcast.103 = s32[8,1024]{1,0} broadcast(s32[] %get-tuple-element.1703), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/gt" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=182}
%compare.3 = pred[8,1024]{1,0} compare(s32[8,1024]{1,0} %get-tuple-element.1724, s32[8,1024]{1,0} %broadcast.103), direction=GT, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/gt" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=182}
%reshape.359 = pred[8,1,1,1024]{3,2,1,0} reshape(pred[8,1024]{1,0} %compare.3), sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="gt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/gt" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=182}
%get-tuple-element.1704 = f16[8,1,1,1024]{3,2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=7, sharding={replicated}
%get-tuple-element.1705 = f16[8,1,1,1024]{3,2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=8, sharding={replicated}
%select.1760 = f16[8,1,1,1024]{3,2,1,0} select(pred[8,1,1,1024]{3,2,1,0} %reshape.359, f16[8,1,1,1024]{3,2,1,0} %get-tuple-element.1704, f16[8,1,1,1024]{3,2,1,0} %get-tuple-element.1705), sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="select" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/select" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=184}
%reshape.1761 = f16[8,1024]{1,0} reshape(f16[8,1,1,1024]{3,2,1,0} %select.1760), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%broadcast.1762 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,1024]{1,0} %reshape.1761), dimensions={0,3}, sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%add.1763 = f16[8,32,1024,1024]{3,2,1,0} add(f16[8,32,1024,1024]{3,2,1,0} %dot.64, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1762), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=94}
%reduce.1770 = f16[8,32,1024]{2,1,0} reduce(f16[8,32,1024,1024]{3,2,1,0} %add.1763, f16[] %constant.342), dimensions={3}, to_apply=%primitive_computation_max__4.1765, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_max" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_max[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.1785 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %reduce.1770), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%subtract.1786 = f16[8,32,1024,1024]{3,2,1,0} subtract(f16[8,32,1024,1024]{3,2,1,0} %add.1763, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1785), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%exponential.1787 = f16[8,32,1024,1024]{3,2,1,0} exponential(f16[8,32,1024,1024]{3,2,1,0} %subtract.1786), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="exp" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/exp" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.1788 = f32[8,32,1024,1024]{3,2,1,0} convert(f16[8,32,1024,1024]{3,2,1,0} %exponential.1787), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%reduce.1795 = f32[8,32,1024]{2,1,0} reduce(f32[8,32,1024,1024]{3,2,1,0} %convert.1788, f32[] %constant.165), dimensions={3}, to_apply=%primitive_computation_add__50.1790, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%convert.15 = f16[8,32,1024]{2,1,0} convert(f32[8,32,1024]{2,1,0} %reduce.1795), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.1799 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %convert.15), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%divide.1800 = f16[8,32,1024,1024]{3,2,1,0} divide(f16[8,32,1024,1024]{3,2,1,0} %exponential.1787, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1799), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%dot.65 = f16[8,32,64,1024]{3,2,1,0} dot(f16[8,32,64,1024]{2,1,3,0} %transpose.110, f16[8,32,1024,1024]{3,2,1,0} %divide.1800), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(jvp(_einsum))/dot_general[\n dimension_numbers=(((1,), (3,)), ((0, 2), (0, 1)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%transpose.7 = f16[8,1024,32,64]{1,3,2,0} transpose(f16[8,32,64,1024]{3,2,1,0} %dot.65), dimensions={0,3,1,2}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(jvp(_einsum))/transpose[\n permutation=(0, 3, 1, 2)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%reshape.78 = f16[8192,2048]{1,0} reshape(f16[8,1024,32,64]{1,3,2,0} %transpose.7), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%get-tuple-element.1713 = f16[2048,2048]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=16, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.42 = f16[8192,2048]{1,0} dot(f16[8192,2048]{1,0} %reshape.78, f16[2048,2048]{1,0} %get-tuple-element.1713), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.80 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.42), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1712 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=15, sharding={replicated}
%broadcast.1828 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1712), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1829 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.80, f16[8,1024,2048]{2,1,0} %broadcast.1828), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1830 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.1829, f16[8,1024,2048]{2,1,0} %get-tuple-element.1723), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=233}
%convert.1831 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.1830), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.1838 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.1831, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__51.1833, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%get-tuple-element.1701 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=4, sharding={replicated}
%broadcast.104 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1701), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.38 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1838, f32[8,1024]{1,0} %broadcast.104), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.1843 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %divide.38), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.1844 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.1831, f32[8,1024,2048]{2,1,0} %broadcast.1843), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.1845 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1831, f32[8,1024,2048]{2,1,0} %convert.1831), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.1855 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.1845, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__52.1850, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%get-tuple-element.1707 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=10, sharding={replicated}
%broadcast.105 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1707), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.39 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1855, f32[8,1024]{1,0} %broadcast.105), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.108 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %divide.38, f32[8,1024]{1,0} %divide.38), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.7 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %divide.39, f32[8,1024]{1,0} %multiply.108), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%get-tuple-element.1706 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=9, sharding={replicated}
%broadcast.188 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1706), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%add.8 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.7, f32[8,1024]{1,0} %broadcast.188), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%reshape.483 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %add.8), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.1866 = f32[8,1024,1]{2,1,0} rsqrt(f32[8,1024,1]{2,1,0} %reshape.483), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%convert.1867 = f16[8,1024,1]{2,1,0} convert(f32[8,1024,1]{2,1,0} %rsqrt.1866), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.1873 = f16[8,1024]{1,0} reshape(f16[8,1024,1]{2,1,0} %convert.1867), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%broadcast.1874 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %reshape.1873), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%get-tuple-element.1711 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=14, sharding={replicated}
%broadcast.1876 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1711), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.1877 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.1874, f16[8,1024,2048]{2,1,0} %broadcast.1876), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.1878 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.1877), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.1879 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.1844, f32[8,1024,2048]{2,1,0} %convert.1878), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.1880 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.1879), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%get-tuple-element.1710 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=13, sharding={replicated}
%broadcast.1883 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1710), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%add.1884 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.1880, f16[8,1024,2048]{2,1,0} %broadcast.1883), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%reshape.81 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %add.1884), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%get-tuple-element.1717 = f16[2048,8192]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=20, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.43 = f16[8192,8192]{1,0} dot(f16[8192,2048]{1,0} %reshape.81, f16[2048,8192]{1,0} %get-tuple-element.1717), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.83 = f16[8,1024,8192]{2,1,0} reshape(f16[8192,8192]{1,0} %dot.43), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1716 = f16[8192]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=19, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.1888 = f16[8,1024,8192]{2,1,0} broadcast(f16[8192]{0} %get-tuple-element.1716), dimensions={2}, sharding={devices=[1,1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1889 = f16[8,1024,8192]{2,1,0} add(f16[8,1024,8192]{2,1,0} %reshape.83, f16[8,1024,8192]{2,1,0} %broadcast.1888), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%convert.1958 = f32[8,1024,8192]{2,1,0} convert(f16[8,1024,8192]{2,1,0} %add.1889), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%get-tuple-element.1700 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=3, sharding={replicated}
%broadcast.1891 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %get-tuple-element.1700), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.1892 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %convert.1958, f32[8,1024,8192]{2,1,0} %broadcast.1891), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%clamp.1897 = f32[8,1024,8192]{2,1,0} clamp(f32[8,1024,8192]{2,1,0} %broadcast.447, f32[8,1024,8192]{2,1,0} %divide.1892, f32[8,1024,8192]{2,1,0} %broadcast.448), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1898 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.1897, f32[8,1024,8192]{2,1,0} %clamp.1897), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1924 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %broadcast.475, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1926 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1924, f32[8,1024,8192]{2,1,0} %broadcast.477), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1928 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1926, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1930 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1928, f32[8,1024,8192]{2,1,0} %broadcast.481), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1932 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1930, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1934 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1932, f32[8,1024,8192]{2,1,0} %broadcast.485), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1936 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1934, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1938 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1936, f32[8,1024,8192]{2,1,0} %broadcast.489), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1940 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1938, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1942 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1940, f32[8,1024,8192]{2,1,0} %broadcast.493), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1944 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1942, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1946 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1944, f32[8,1024,8192]{2,1,0} %broadcast.497), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1948 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1946, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1950 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1948, f32[8,1024,8192]{2,1,0} %broadcast.501), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1951 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %clamp.1897, f32[8,1024,8192]{2,1,0} %add.1950), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1904 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1924, f32[8,1024,8192]{2,1,0} %broadcast.455), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1906 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1904, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1908 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1906, f32[8,1024,8192]{2,1,0} %broadcast.459), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1910 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1908, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1912 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1910, f32[8,1024,8192]{2,1,0} %broadcast.463), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1914 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1912, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1916 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1914, f32[8,1024,8192]{2,1,0} %broadcast.467), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1918 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %add.1916, f32[8,1024,8192]{2,1,0} %multiply.1898), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1920 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %multiply.1918, f32[8,1024,8192]{2,1,0} %broadcast.471), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.1952 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.1951, f32[8,1024,8192]{2,1,0} %add.1920), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="erf" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/erf" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%get-tuple-element.1699 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=2, sharding={replicated}
%broadcast.1956 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %get-tuple-element.1699), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.1957 = f32[8,1024,8192]{2,1,0} add(f32[8,1024,8192]{2,1,0} %divide.1952, f32[8,1024,8192]{2,1,0} %broadcast.1956), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1959 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.1958, f32[8,1024,8192]{2,1,0} %add.1957), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%get-tuple-element.1698 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=1, sharding={replicated}
%broadcast.1960 = f32[8,1024,8192]{2,1,0} broadcast(f32[] %get-tuple-element.1698), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.1961 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.1959, f32[8,1024,8192]{2,1,0} %broadcast.1960), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.1965 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %divide.1961), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%reshape.84 = f16[8192,8192]{1,0} reshape(f16[8,1024,8192]{2,1,0} %convert.1965), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%get-tuple-element.1721 = f16[8192,2048]{1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=24, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%dot.44 = f16[8192,2048]{1,0} dot(f16[8192,8192]{1,0} %reshape.84, f16[8192,2048]{1,0} %get-tuple-element.1721), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.86 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.44), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (0,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%get-tuple-element.1720 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=23, sharding={replicated}
%broadcast.1969 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1720), dimensions={2}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1970 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %reshape.86, f16[8,1024,2048]{2,1,0} %broadcast.1969), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%add.1971 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %add.1970, f16[8,1024,2048]{2,1,0} %add.1884), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=310}
%convert.1972 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.1971), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reduce.1979 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %convert.1972, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__53.1974, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%get-tuple-element.1697 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=0, sharding={replicated}
%broadcast.106 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1697), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.40 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1979, f32[8,1024]{1,0} %broadcast.106), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.1984 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %divide.40), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%subtract.1985 = f32[8,1024,2048]{2,1,0} subtract(f32[8,1024,2048]{2,1,0} %convert.1972, f32[8,1024,2048]{2,1,0} %broadcast.1984), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%get-tuple-element.1727 = f16[8,1024,2048]{2,1,0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=30, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%convert.2041 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %get-tuple-element.1727), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%multiply.2042 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.1985, f32[8,1024,2048]{2,1,0} %convert.2041), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.2043 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.2042), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%get-tuple-element.1719 = f16[2048]{0} get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=22, sharding={replicated}
%broadcast.2064 = f16[8,1024,2048]{2,1,0} broadcast(f16[2048]{0} %get-tuple-element.1719), dimensions={2}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.2065 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %convert.2043, f16[8,1024,2048]{2,1,0} %broadcast.2064), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.2072 = f16[8,1024]{1,0} reduce(f16[8,1024,2048]{2,1,0} %multiply.2065, f16[] %constant.1015), dimensions={2}, to_apply=%primitive_computation_add__59.2067, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.16 = f32[8,1024]{1,0} convert(f16[8,1024]{1,0} %reduce.2072), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.366 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %convert.16), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%multiply.1986 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1972, f32[8,1024,2048]{2,1,0} %convert.1972), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%reduce.1996 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %multiply.1986, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__54.1991, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%get-tuple-element.1709 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=12, sharding={replicated}
%broadcast.107 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1709), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.41 = f32[8,1024]{1,0} divide(f32[8,1024]{1,0} %reduce.1996, f32[8,1024]{1,0} %broadcast.107), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.110 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %divide.40, f32[8,1024]{1,0} %divide.40), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%subtract.8 = f32[8,1024]{1,0} subtract(f32[8,1024]{1,0} %divide.41, f32[8,1024]{1,0} %multiply.110), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="sub" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/sub" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%get-tuple-element.1708 = f32[] get-tuple-element((f32[], f32[], f32[], f32[], f32[], /*index=5*/f16[], s32[], f16[8,1,1,1024]{3,2,1,0}, f16[8,1,1,1024]{3,2,1,0}, f32[], /*index=10*/f32[], f32[], f32[], f16[2048]{0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=20*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=25*/u32[2]{0}, f16[8,1024,2048]{2,1,0}, s32[8,1024]{1,0}, pred[], pred[], /*index=30*/f16[8,1024,2048]{2,1,0}) %custom-call.1696), index=11, sharding={replicated}
%broadcast.190 = f32[8,1024]{1,0} broadcast(f32[] %get-tuple-element.1708), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%add.9 = f32[8,1024]{1,0} add(f32[8,1024]{1,0} %subtract.8, f32[8,1024]{1,0} %broadcast.190), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%reshape.484 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %add.9), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%rsqrt.2007 = f32[8,1024,1]{2,1,0} rsqrt(f32[8,1024,1]{2,1,0} %reshape.484), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="rsqrt" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/rsqrt" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%divide.2009 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %rsqrt.2007, f32[8,1024,1]{2,1,0} %reshape.484), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.2012 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %broadcast.1373, f32[8,1024,1]{2,1,0} %divide.2009), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.2075 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %reshape.366, f32[8,1024,1]{2,1,0} %multiply.2012), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%broadcast.2078 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1709), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.2079 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %multiply.2075, f32[8,1024,1]{2,1,0} %broadcast.2078), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.89 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %divide.2079, f32[8,1024,1]{2,1,0} %broadcast.130), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.410 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %multiply.89), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.78 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.410), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%multiply.2088 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1972, f32[8,1024,2048]{2,1,0} %broadcast.78), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%convert.2008 = f16[8,1024,1]{2,1,0} convert(f32[8,1024,1]{2,1,0} %rsqrt.2007), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.2014 = f16[8,1024]{1,0} reshape(f16[8,1024,1]{2,1,0} %convert.2008), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%broadcast.2015 = f16[8,1024,2048]{2,1,0} broadcast(f16[8,1024]{1,0} %reshape.2014), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.2018 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.2015, f16[8,1024,2048]{2,1,0} %broadcast.2064), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.2019 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %multiply.2018), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.2044 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.2041, f32[8,1024,2048]{2,1,0} %convert.2019), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.2098 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %multiply.2088, f32[8,1024,2048]{2,1,0} %multiply.2044), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%negate.2076 = f32[8,1024,1]{2,1,0} negate(f32[8,1024,1]{2,1,0} %multiply.2075), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.111 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %broadcast.164, f32[8,1024]{1,0} %divide.40), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%reshape.451 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %multiply.111), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.2077 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %negate.2076, f32[8,1024,1]{2,1,0} %reshape.451), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%negate.2089 = f32[8,1024,2048]{2,1,0} negate(f32[8,1024,2048]{2,1,0} %multiply.2044), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reduce.2096 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %negate.2089, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__61.2091, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reshape.2097 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %reduce.2096), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.2099 = f32[8,1024,1]{2,1,0} add(f32[8,1024,1]{2,1,0} %multiply.2077, f32[8,1024,1]{2,1,0} %reshape.2097), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%broadcast.2100 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1697), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.2101 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %add.2099, f32[8,1024,1]{2,1,0} %broadcast.2100), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%reshape.288 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %divide.2101), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.2109 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.288), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/broadcast_in_dim[\n broadcast_dimensions=(0, 1)\n shape=(8, 1024, 2048)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%add.2110 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %add.2098, f32[8,1024,2048]{2,1,0} %broadcast.2109), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%convert.2111 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %add.2110), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reshape.90 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %convert.2111), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.46 = f16[8192,8192]{1,0} dot(f16[8192,2048]{1,0} %reshape.90, f16[8192,2048]{1,0} %get-tuple-element.1721), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%convert.17 = f32[8192,8192]{1,0} convert(f16[8192,8192]{1,0} %dot.46), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%broadcast.168 = f32[8192,8192]{1,0} broadcast(f32[] %get-tuple-element.1698), dimensions={}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.45 = f32[8192,8192]{1,0} divide(f32[8192,8192]{1,0} %convert.17, f32[8192,8192]{1,0} %broadcast.168), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%reshape.452 = f32[8,1024,8192]{2,1,0} reshape(f32[8192,8192]{1,0} %divide.45), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.2128 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %reshape.452, f32[8,1024,8192]{2,1,0} %add.1957), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.2129 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %multiply.2128), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.2127 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %convert.1958, f32[8,1024,8192]{2,1,0} %reshape.452), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.2132 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %broadcast.1493, f32[8,1024,8192]{2,1,0} %multiply.2127), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.1953 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %divide.1892, f32[8,1024,8192]{2,1,0} %divide.1892), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=2]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%negate.1954 = f32[8,1024,8192]{2,1,0} negate(f32[8,1024,8192]{2,1,0} %multiply.1953), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%exponential.1955 = f32[8,1024,8192]{2,1,0} exponential(f32[8,1024,8192]{2,1,0} %negate.1954), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="exp" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/exp" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%multiply.2133 = f32[8,1024,8192]{2,1,0} multiply(f32[8,1024,8192]{2,1,0} %multiply.2132, f32[8,1024,8192]{2,1,0} %exponential.1955), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%divide.2135 = f32[8,1024,8192]{2,1,0} divide(f32[8,1024,8192]{2,1,0} %multiply.2133, f32[8,1024,8192]{2,1,0} %broadcast.1891), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%convert.2136 = f16[8,1024,8192]{2,1,0} convert(f32[8,1024,8192]{2,1,0} %divide.2135), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=285}
%add.2137 = f16[8,1024,8192]{2,1,0} add(f16[8,1024,8192]{2,1,0} %convert.2129, f16[8,1024,8192]{2,1,0} %convert.2136), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%reshape.96 = f16[8192,8192]{1,0} reshape(f16[8,1024,8192]{2,1,0} %add.2137), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.48 = f16[8192,2048]{1,0} dot(f16[8192,8192]{1,0} %reshape.96, f16[2048,8192]{1,0} %get-tuple-element.1717), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.98 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.48), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (1,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%add.2150 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.2111, f16[8,1024,2048]{2,1,0} %reshape.98), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%convert.2166 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.2150), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%multiply.2167 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.1844, f32[8,1024,2048]{2,1,0} %convert.2166), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.2168 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.2167), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.2190 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %convert.2168, f16[8,1024,2048]{2,1,0} %broadcast.1876), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.2197 = f16[8,1024]{1,0} reduce(f16[8,1024,2048]{2,1,0} %multiply.2190, f16[] %constant.1015), dimensions={2}, to_apply=%primitive_computation_add__69.2192, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.18 = f32[8,1024]{1,0} convert(f16[8,1024]{1,0} %reduce.2197), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.369 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %convert.18), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%divide.1868 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %rsqrt.1866, f32[8,1024,1]{2,1,0} %reshape.483), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.1871 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %broadcast.1373, f32[8,1024,1]{2,1,0} %divide.1868), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.2200 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %reshape.369, f32[8,1024,1]{2,1,0} %multiply.1871), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%broadcast.2203 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1707), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%divide.2204 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %multiply.2200, f32[8,1024,1]{2,1,0} %broadcast.2203), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.90 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %divide.2204, f32[8,1024,1]{2,1,0} %broadcast.130), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.411 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %multiply.90), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.82 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.411), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%multiply.2213 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.1831, f32[8,1024,2048]{2,1,0} %broadcast.82), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.2169 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.2166, f32[8,1024,2048]{2,1,0} %convert.1878), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.2223 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %multiply.2213, f32[8,1024,2048]{2,1,0} %multiply.2169), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%negate.2201 = f32[8,1024,1]{2,1,0} negate(f32[8,1024,1]{2,1,0} %multiply.2200), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.109 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %broadcast.164, f32[8,1024]{1,0} %divide.38), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%reshape.449 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %multiply.109), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.2202 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %negate.2201, f32[8,1024,1]{2,1,0} %reshape.449), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%negate.2214 = f32[8,1024,2048]{2,1,0} negate(f32[8,1024,2048]{2,1,0} %multiply.2169), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reduce.2221 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %negate.2214, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__71.2216, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reshape.2222 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %reduce.2221), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.2224 = f32[8,1024,1]{2,1,0} add(f32[8,1024,1]{2,1,0} %multiply.2202, f32[8,1024,1]{2,1,0} %reshape.2222), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%broadcast.2225 = f32[8,1024,1]{2,1,0} broadcast(f32[] %get-tuple-element.1701), dimensions={}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%divide.2226 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %add.2224, f32[8,1024,1]{2,1,0} %broadcast.2225), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%reshape.291 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %divide.2226), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.2234 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.291), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/broadcast_in_dim[\n broadcast_dimensions=(0, 1)\n shape=(8, 1024, 2048)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%add.2235 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %add.2223, f32[8,1024,2048]{2,1,0} %broadcast.2234), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%convert.2236 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %add.2235), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%reshape.102 = f16[8192,2048]{1,0} reshape(f16[8,1024,2048]{2,1,0} %convert.2236), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.50 = f16[8192,2048]{1,0} dot(f16[8192,2048]{1,0} %reshape.102, f16[2048,2048]{1,0} %get-tuple-element.1713), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%reshape.2249 = f16[8,1024,32,64]{3,2,1,0} reshape(f16[8192,2048]{1,0} %dot.50), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 32, 64)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=208}
%transpose.8 = f16[8,32,64,1024]{2,1,3,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.2249), dimensions={0,2,3,1}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/transpose[\n permutation=(0, 2, 3, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%dot.13 = f16[8,32,64,1024]{3,2,1,0} dot(f16[8,32,64,1024]{2,1,3,0} %transpose.8, f16[8,32,1024,1024]{3,2,1,0} %divide.1800), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((3,), (2,)), ((0, 1), (0, 1)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%transpose.9 = f16[8,1024,32,64]{1,3,2,0} transpose(f16[8,32,64,1024]{3,2,1,0} %dot.13), dimensions={0,3,1,2}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/transpose[\n permutation=(0, 3, 1, 2)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%reshape.2263 = f16[8,1024,2048,1]{3,2,1,0} reshape(f16[8,1024,32,64]{1,3,2,0} %transpose.9), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=172}
%pad.2265 = f16[8,1024,2048,3]{3,2,1,0} pad(f16[8,1024,2048,1]{3,2,1,0} %reshape.2263, f16[] %constant.1015), padding=0_0x0_0x0_0x1_1, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="pad" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/pad[\n padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0), (1, 1, 0))\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%transpose.112 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.2249), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.66 = f16[8,32,1024,1024]{3,2,1,0} dot(f16[8,32,1024,64]{3,1,2,0} %transpose.112, f16[8,32,64,1024]{2,1,3,0} %transpose.110), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((2,), (3,)), ((0, 1), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=207}
%divide.2281 = f16[8,32,1024,1024]{3,2,1,0} divide(f16[8,32,1024,1024]{3,2,1,0} %dot.66, f16[8,32,1024,1024]{3,2,1,0} %broadcast.1799), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%multiply.107 = f16[8,32,1024]{2,1,0} multiply(f16[8,32,1024]{2,1,0} %convert.15, f16[8,32,1024]{2,1,0} %convert.15), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=-2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%divide.47 = f16[8,32,1024]{2,1,0} divide(f16[8,32,1024]{2,1,0} %broadcast.171, f16[8,32,1024]{2,1,0} %multiply.107), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/integer_pow[y=-2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.2267 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %divide.47), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%multiply.2268 = f16[8,32,1024,1024]{3,2,1,0} multiply(f16[8,32,1024,1024]{3,2,1,0} %dot.66, f16[8,32,1024,1024]{3,2,1,0} %broadcast.2267), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%multiply.2269 = f16[8,32,1024,1024]{3,2,1,0} multiply(f16[8,32,1024,1024]{3,2,1,0} %multiply.2268, f16[8,32,1024,1024]{3,2,1,0} %exponential.1787), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%reduce.2276 = f16[8,32,1024]{2,1,0} reduce(f16[8,32,1024,1024]{3,2,1,0} %multiply.2269, f16[] %constant.1015), dimensions={3}, to_apply=%primitive_computation_add__74.2271, sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[axes=(3,)]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%negate.1 = f16[8,32,1024]{2,1,0} negate(f16[8,32,1024]{2,1,0} %reduce.2276), sharding={devices=[2,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/neg" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%broadcast.83 = f16[8,32,1024,1024]{3,2,1,0} broadcast(f16[8,32,1024]{2,1,0} %negate.1), dimensions={0,1,2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/broadcast_in_dim[\n broadcast_dimensions=(0, 1, 2)\n shape=(8, 32, 1024, 1024)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%add.2292 = f16[8,32,1024,1024]{3,2,1,0} add(f16[8,32,1024,1024]{3,2,1,0} %divide.2281, f16[8,32,1024,1024]{3,2,1,0} %broadcast.83), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%multiply.2293 = f16[8,32,1024,1024]{3,2,1,0} multiply(f16[8,32,1024,1024]{3,2,1,0} %add.2292, f16[8,32,1024,1024]{3,2,1,0} %exponential.1787), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=101}
%dot.68 = f16[8,32,1024,64]{3,2,1,0} dot(f16[8,32,1024,1024]{3,2,1,0} %multiply.2293, f16[8,32,1024,64]{3,1,2,0} %transpose.108), lhs_batch_dims={0,1}, lhs_contracting_dims={2}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((2,), (1,)), ((0, 1), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%transpose.11 = f16[8,1024,32,64]{3,1,2,0} transpose(f16[8,32,1024,64]{3,2,1,0} %dot.68), dimensions={0,2,1,3}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/transpose[\n permutation=(0, 2, 1, 3)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%reshape.2307 = f16[8,1024,2048,1]{3,2,1,0} reshape(f16[8,1024,32,64]{3,1,2,0} %transpose.11), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=175}
%pad.2309 = f16[8,1024,2048,3]{3,2,1,0} pad(f16[8,1024,2048,1]{3,2,1,0} %reshape.2307, f16[] %constant.1015), padding=0_0x0_0x0_0x2_0, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="pad" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/pad[\n padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0), (2, 0, 0))\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%add.2310 = f16[8,1024,2048,3]{3,2,1,0} add(f16[8,1024,2048,3]{3,2,1,0} %pad.2265, f16[8,1024,2048,3]{3,2,1,0} %pad.2309), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%transpose.115 = f16[8,32,1024,64]{3,1,2,0} transpose(f16[8,1024,32,64]{3,2,1,0} %reshape.1740), dimensions={0,2,1,3}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.67 = f16[8,32,1024,64]{3,2,1,0} dot(f16[8,32,1024,1024]{3,2,1,0} %multiply.2293, f16[8,32,1024,64]{3,1,2,0} %transpose.115), lhs_batch_dims={0,1}, lhs_contracting_dims={3}, rhs_batch_dims={0,1}, rhs_contracting_dims={2}, sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/jit(transpose(jvp(_einsum)))/dot_general[\n dimension_numbers=(((3,), (1,)), ((0, 1), (0, 2)))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=90}
%broadcast.108 = f16[8,32,1024,64]{3,2,1,0} broadcast(f16[] %get-tuple-element.1702), dimensions={}, sharding={devices=[2,1,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%divide.42 = f16[8,32,1024,64]{3,2,1,0} divide(f16[8,32,1024,64]{3,2,1,0} %dot.67, f16[8,32,1024,64]{3,2,1,0} %broadcast.108), sharding={devices=[2,8,1,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%transpose.122 = f16[8,1024,32,64]{3,2,1,0} transpose(f16[8,32,1024,64]{3,2,1,0} %divide.42), dimensions={0,2,1,3}, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/attention.py" source_line=87}
%reshape.2313 = f16[8,1024,2048,1]{3,2,1,0} reshape(f16[8,1024,32,64]{3,2,1,0} %transpose.122), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 2048, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=169}
%pad.2315 = f16[8,1024,2048,3]{3,2,1,0} pad(f16[8,1024,2048,1]{3,2,1,0} %reshape.2313, f16[] %constant.1015), padding=0_0x0_0x0_0x0_2, sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="pad" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/pad[\n padding_config=((0, 0, 0), (0, 0, 0), (0, 0, 0), (0, 2, 0))\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=165}
%add.2316 = f16[8,1024,2048,3]{3,2,1,0} add(f16[8,1024,2048,3]{3,2,1,0} %add.2310, f16[8,1024,2048,3]{3,2,1,0} %pad.2315), sharding={devices=[2,1,8,1]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%reshape.108 = f16[8192,6144]{1,0} reshape(f16[8,1024,2048,3]{3,2,1,0} %add.2316), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%dot.52 = f16[8192,2048]{1,0} dot(f16[8192,6144]{1,0} %reshape.108, f16[2048,6144]{1,0} %get-tuple-element.1715), lhs_contracting_dims={1}, rhs_contracting_dims={1}, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.110 = f16[8,1024,2048]{2,1,0} reshape(f16[8192,2048]{1,0} %dot.52), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="dot_general" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/dot_general[\n dimension_numbers=(((2,), (1,)), ((), ()))\n precision=None\n preferred_element_type=None\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%add.2330 = f16[8,1024,2048]{2,1,0} add(f16[8,1024,2048]{2,1,0} %convert.2236, f16[8,1024,2048]{2,1,0} %reshape.110), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/add_any" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%reduce.2337 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %add.2330, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__77.2332, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(0, 1)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%broadcast.110 = f16[2048]{0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.56 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2337, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2470 = f32[2048]{0} convert(f16[2048]{0} %multiply.56), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.33 = f32[2048]{0} parameter(32), sharding={replicated}
%broadcast.2468 = f32[2048]{0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2469 = f32[2048]{0} multiply(f32[2048]{0} %parameter.33, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2471 = f32[2048]{0} add(f32[2048]{0} %convert.2470, f32[2048]{0} %multiply.2469), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2973 = f32[2048]{0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.67 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2337, f16[2048]{0} %reduce.2337), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2699 = f16[2048]{0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2700 = f16[2048]{0} multiply(f16[2048]{0} %multiply.67, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2704 = f32[2048]{0} convert(f16[2048]{0} %multiply.2700), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.62 = f32[2048]{0} parameter(61), sharding={replicated}
%broadcast.2702 = f32[2048]{0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2703 = f32[2048]{0} multiply(f32[2048]{0} %parameter.62, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2705 = f32[2048]{0} add(f32[2048]{0} %convert.2704, f32[2048]{0} %multiply.2703), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3065 = f32[2048]{0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3066 = f32[2048]{0} divide(f32[2048]{0} %add.2705, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3159 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3066), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3161 = f32[2048]{0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3162 = f32[2048]{0} add(f32[2048]{0} %sqrt.3159, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.23 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3162), sharding={replicated}
%divide.3163 = f32[2048]{0} divide(f32[2048]{0} %add.2471, f32[2048]{0} %multiply.23), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3424 = f32[2048]{0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3425 = f32[2048]{0} multiply(f32[2048]{0} %divide.3163, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3508 = f32[2048]{0} add(f32[2048]{0} %parameter.91, f32[2048]{0} %multiply.3425), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3537 = f16[2048]{0} convert(f32[2048]{0} %add.3508), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3570 = f16[2048]{0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3571 = f16[2048]{0} multiply(f16[2048]{0} %parameter.3, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3572 = f16[2048]{0} add(f16[2048]{0} %convert.3537, f16[2048]{0} %multiply.3571), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.92 = f32[2048]{0} parameter(91), sharding={replicated}
%convert.2346 = f32[8,1024,2048]{2,1,0} convert(f16[8,1024,2048]{2,1,0} %add.2330), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=152}
%multiply.2347 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %subtract.213, f32[8,1024,2048]{2,1,0} %convert.2346), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%convert.2348 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %multiply.2347), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%multiply.2352 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.207, f16[8,1024,2048]{2,1,0} %convert.2348), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.2359 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %multiply.2352, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__79.2354, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(0, 1)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.57 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2359, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2478 = f32[2048]{0} convert(f16[2048]{0} %multiply.57), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.34 = f32[2048]{0} parameter(33), sharding={replicated}
%multiply.2477 = f32[2048]{0} multiply(f32[2048]{0} %parameter.34, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2479 = f32[2048]{0} add(f32[2048]{0} %convert.2478, f32[2048]{0} %multiply.2477), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.68 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2359, f16[2048]{0} %reduce.2359), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2709 = f16[2048]{0} multiply(f16[2048]{0} %multiply.68, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2713 = f32[2048]{0} convert(f16[2048]{0} %multiply.2709), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.63 = f32[2048]{0} parameter(62), sharding={replicated}
%multiply.2712 = f32[2048]{0} multiply(f32[2048]{0} %parameter.63, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2714 = f32[2048]{0} add(f32[2048]{0} %convert.2713, f32[2048]{0} %multiply.2712), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3069 = f32[2048]{0} divide(f32[2048]{0} %add.2714, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3167 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3069), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3170 = f32[2048]{0} add(f32[2048]{0} %sqrt.3167, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.24 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3170), sharding={replicated}
%divide.3171 = f32[2048]{0} divide(f32[2048]{0} %add.2479, f32[2048]{0} %multiply.24), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3428 = f32[2048]{0} multiply(f32[2048]{0} %divide.3171, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3509 = f32[2048]{0} add(f32[2048]{0} %parameter.92, f32[2048]{0} %multiply.3428), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3538 = f16[2048]{0} convert(f32[2048]{0} %add.3509), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3575 = f16[2048]{0} multiply(f16[2048]{0} %parameter.4, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3576 = f16[2048]{0} add(f16[2048]{0} %convert.3538, f16[2048]{0} %multiply.3575), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.93 = f32[1024,2048]{1,0} parameter(92), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%iota = s32[1,1,1024]{2,1,0} iota(), iota_dimension=2, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=(2,)\n shape=(1, 1, 1024)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%multiply.2370 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %convert.2348, f16[8,1024,2048]{2,1,0} %broadcast.209), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.2377 = f16[8,1024]{1,0} reduce(f16[8,1024,2048]{2,1,0} %multiply.2370, f16[] %constant.1015), dimensions={2}, to_apply=%primitive_computation_add__81.2372, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%convert.19 = f32[8,1024]{1,0} convert(f16[8,1024]{1,0} %reduce.2377), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%reshape.371 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %convert.19), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=147}
%divide.200 = f32[8,1024,1]{2,1,0} divide(f32[8,1024,1]{2,1,0} %rsqrt.199, f32[8,1024,1]{2,1,0} %reshape.476), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.203 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %broadcast.1373, f32[8,1024,1]{2,1,0} %divide.200), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%multiply.2380 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %reshape.371, f32[8,1024,1]{2,1,0} %multiply.203), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=146}
%constant.65 = f32[] constant(0.0009765625), sharding={replicated}
%broadcast.158 = f32[8,1024,1]{2,1,0} broadcast(f32[] %constant.65), dimensions={}, sharding={replicated}
%multiply.91 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %multiply.2380, f32[8,1024,1]{2,1,0} %broadcast.158), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.412 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %multiply.91), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.87 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.412), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%multiply.2394 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %reshape.428, f32[8,1024,2048]{2,1,0} %broadcast.87), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=144}
%multiply.2349 = f32[8,1024,2048]{2,1,0} multiply(f32[8,1024,2048]{2,1,0} %convert.2346, f32[8,1024,2048]{2,1,0} %convert.214), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.2404 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %multiply.2394, f32[8,1024,2048]{2,1,0} %multiply.2349), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/add_any" source_file="test_export_hlo.py" source_line=83}
%negate.2381 = f32[8,1024,1]{2,1,0} negate(f32[8,1024,1]{2,1,0} %multiply.2380), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.97 = f32[8,1024]{1,0} multiply(f32[8,1024]{1,0} %broadcast.164, f32[8,1024]{1,0} %multiply.77), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%reshape.430 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %multiply.97), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%multiply.2382 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %negate.2381, f32[8,1024,1]{2,1,0} %reshape.430), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=145}
%negate.2395 = f32[8,1024,2048]{2,1,0} negate(f32[8,1024,2048]{2,1,0} %multiply.2349), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="neg" op_name="parallelize(train_step_shard_parallel)/neg" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reduce.2402 = f32[8,1024]{1,0} reduce(f32[8,1024,2048]{2,1,0} %negate.2395, f32[] %constant.165), dimensions={2}, to_apply=%primitive_computation_add__83.2397, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%reshape.2403 = f32[8,1024,1]{2,1,0} reshape(f32[8,1024]{1,0} %reduce.2402), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/reshape[\n dimensions=None\n new_sizes=(8, 1024, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=151}
%add.2405 = f32[8,1024,1]{2,1,0} add(f32[8,1024,1]{2,1,0} %multiply.2382, f32[8,1024,1]{2,1,0} %reshape.2403), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/add_any" source_file="test_export_hlo.py" source_line=83}
%broadcast.85 = f32[8,1024,1]{2,1,0} broadcast(f32[] %constant.35), dimensions={}, sharding={replicated}
%multiply.28 = f32[8,1024,1]{2,1,0} multiply(f32[8,1024,1]{2,1,0} %add.2405, f32[8,1024,1]{2,1,0} %broadcast.85), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%reshape.301 = f32[8,1024]{1,0} reshape(f32[8,1024,1]{2,1,0} %multiply.28), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/reduce_sum[axes=(2,)]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%broadcast.2416 = f32[8,1024,2048]{2,1,0} broadcast(f32[8,1024]{1,0} %reshape.301), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=(0, 1)\n shape=(8, 1024, 2048)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=143}
%add.2417 = f32[8,1024,2048]{2,1,0} add(f32[8,1024,2048]{2,1,0} %add.2404, f32[8,1024,2048]{2,1,0} %broadcast.2416), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/add_any" source_file="test_export_hlo.py" source_line=83}
%convert.2418 = f16[8,1024,2048]{2,1,0} convert(f32[8,1024,2048]{2,1,0} %add.2417), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=141}
%tuple.2419 = (s32[1,1,1024]{2,1,0}, f16[1024,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) tuple(s32[1,1,1024]{2,1,0} %iota, f16[1024,2048]{1,0} %parameter.5, u32[2]{0} %parameter.124, s32[8,1024]{1,0} %parameter.122, f16[8,1024,2048]{2,1,0} %convert.2418), sharding={{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_call" op_name="parallelize(train_step_shard_parallel)/remat_call[\n concrete=True\n differentiated=True\n name=jvp(core_fn)\n policy=None\n prevent_cse=True\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%custom-call.2420 = (s32[1,1,1024]{2,1,0}, f16[1024,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) custom-call((s32[1,1,1024]{2,1,0}, f16[1024,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %tuple.2419), custom_call_target="identity", output_to_operand_aliasing={{0}: (0, {0}), {1}: (0, {1}), {2}: (0, {2}), {3}: (0, {3}), {4}: (0, {4})}, sharding={{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_begin"}
%get-tuple-element.2424 = s32[8,1024]{1,0} get-tuple-element((s32[1,1,1024]{2,1,0}, f16[1024,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %custom-call.2420), index=3, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.2429 = s32[8,1024,1024]{2,1,0} broadcast(s32[8,1024]{1,0} %get-tuple-element.2424), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%get-tuple-element.2421 = s32[1,1,1024]{2,1,0} get-tuple-element((s32[1,1,1024]{2,1,0}, f16[1024,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %custom-call.2420), index=0, sharding={replicated}
%reshape.2430 = s32[1024]{0} reshape(s32[1,1,1024]{2,1,0} %get-tuple-element.2421), sharding={replicated}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%broadcast.2431 = s32[8,1024,1024]{2,1,0} broadcast(s32[1024]{0} %reshape.2430), dimensions={2}, sharding={replicated}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%compare.2432 = pred[8,1024,1024]{2,1,0} compare(s32[8,1024,1024]{2,1,0} %broadcast.2429, s32[8,1024,1024]{2,1,0} %broadcast.2431), direction=EQ, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%convert.2433 = f16[8,1024,1024]{2,1,0} convert(pred[8,1024,1024]{2,1,0} %compare.2432), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%reshape.112 = f16[8192,1024]{1,0} reshape(f16[8,1024,1024]{2,1,0} %convert.2433), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.498 = f16[8192,1024]{1,0} reshape(f16[8192,1024]{1,0} %reshape.112), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%get-tuple-element.2425 = f16[8,1024,2048]{2,1,0} get-tuple-element((s32[1,1,1024]{2,1,0}, f16[1024,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %custom-call.2420), index=4, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%transpose.86 = f16[2048,8,1024]{0,2,1} transpose(f16[8,1024,2048]{2,1,0} %get-tuple-element.2425), dimensions={2,0,1}, sharding={devices=[1,2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.111 = f16[2048,8192]{1,0} reshape(f16[2048,8,1024]{0,2,1} %transpose.86), sharding={devices=[1,2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.69 = f16[1024,2048]{0,1} dot(f16[8192,1024]{1,0} %reshape.498, f16[2048,8192]{1,0} %reshape.111), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=121}
%broadcast.2481 = f16[1024,2048]{1,0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2482 = f16[1024,2048]{1,0} multiply(f16[1024,2048]{0,1} %dot.69, f16[1024,2048]{1,0} %broadcast.2481), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2486 = f32[1024,2048]{1,0} convert(f16[1024,2048]{1,0} %multiply.2482), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.35 = f32[1024,2048]{1,0} parameter(34), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2484 = f32[1024,2048]{1,0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2485 = f32[1024,2048]{1,0} multiply(f32[1024,2048]{1,0} %parameter.35, f32[1024,2048]{1,0} %broadcast.2484), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2487 = f32[1024,2048]{1,0} add(f32[1024,2048]{1,0} %convert.2486, f32[1024,2048]{1,0} %multiply.2485), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2979 = f32[1024,2048]{1,0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2715 = f16[1024,2048]{0,1} multiply(f16[1024,2048]{0,1} %dot.69, f16[1024,2048]{0,1} %dot.69), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2717 = f16[1024,2048]{1,0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2718 = f16[1024,2048]{1,0} multiply(f16[1024,2048]{0,1} %multiply.2715, f16[1024,2048]{1,0} %broadcast.2717), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2722 = f32[1024,2048]{1,0} convert(f16[1024,2048]{1,0} %multiply.2718), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.64 = f32[1024,2048]{1,0} parameter(63), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2720 = f32[1024,2048]{1,0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2721 = f32[1024,2048]{1,0} multiply(f32[1024,2048]{1,0} %parameter.64, f32[1024,2048]{1,0} %broadcast.2720), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2723 = f32[1024,2048]{1,0} add(f32[1024,2048]{1,0} %convert.2722, f32[1024,2048]{1,0} %multiply.2721), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3071 = f32[1024,2048]{1,0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3072 = f32[1024,2048]{1,0} divide(f32[1024,2048]{1,0} %add.2723, f32[1024,2048]{1,0} %broadcast.3071), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3175 = f32[1024,2048]{1,0} sqrt(f32[1024,2048]{1,0} %divide.3072), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3177 = f32[1024,2048]{1,0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3178 = f32[1024,2048]{1,0} add(f32[1024,2048]{1,0} %sqrt.3175, f32[1024,2048]{1,0} %broadcast.3177), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.29 = f32[1024,2048]{1,0} multiply(f32[1024,2048]{1,0} %broadcast.2979, f32[1024,2048]{1,0} %add.3178), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3179 = f32[1024,2048]{1,0} divide(f32[1024,2048]{1,0} %add.2487, f32[1024,2048]{1,0} %multiply.29), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%constant.3380 = f32[] constant(0.0001), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%broadcast.3381 = f32[1024,2048]{1,0} broadcast(f32[] %constant.3380), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3382 = f32[1024,2048]{1,0} multiply(f32[1024,2048]{1,0} %parameter.93, f32[1024,2048]{1,0} %broadcast.3381), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3383 = f32[1024,2048]{1,0} add(f32[1024,2048]{1,0} %divide.3179, f32[1024,2048]{1,0} %multiply.3382), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%broadcast.3430 = f32[1024,2048]{1,0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3431 = f32[1024,2048]{1,0} multiply(f32[1024,2048]{1,0} %add.3383, f32[1024,2048]{1,0} %broadcast.3430), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3510 = f32[1024,2048]{1,0} add(f32[1024,2048]{1,0} %parameter.93, f32[1024,2048]{1,0} %multiply.3431), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3539 = f16[1024,2048]{1,0} convert(f32[1024,2048]{1,0} %add.3510), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3578 = f16[1024,2048]{1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3579 = f16[1024,2048]{1,0} multiply(f16[1024,2048]{1,0} %parameter.5, f16[1024,2048]{1,0} %broadcast.3578), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3580 = f16[1024,2048]{1,0} add(f16[1024,2048]{1,0} %convert.3539, f16[1024,2048]{1,0} %multiply.3579), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.94 = f32[51200,2048]{1,0} parameter(93), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.34 = f16[51200,8,1024]{0,2,1} transpose(f16[8,1024,51200]{2,1,0} %add.1035), dimensions={2,0,1}, sharding={devices=[8,2,1]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%reshape.33 = f16[51200,8192]{1,0} reshape(f16[51200,8,1024]{0,2,1} %transpose.34), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%dot.71 = f16[51200,2048]{1,0} dot(f16[51200,8192]{1,0} %reshape.33, f16[8192,2048]{1,0} %reshape.30), lhs_contracting_dims={1}, rhs_contracting_dims={0}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/transpose[permutation=(1, 0)]" source_file="/home/ubuntu/efs/alpa/alpa/model/gpt_model.py" source_line=70}
%iota.1 = s32[1,1,51200]{2,1,0} iota(), iota_dimension=2, sharding={replicated}, metadata={op_type="broadcast_in_dim" op_name="parallelize(train_step_shard_parallel)/broadcast_in_dim[\n broadcast_dimensions=(2,)\n shape=(1, 1, 51200)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%tuple.2437 = (s32[1,1,51200]{2,1,0}, f16[51200,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) tuple(s32[1,1,51200]{2,1,0} %iota.1, f16[51200,2048]{1,0} %parameter.6, u32[2]{0} %parameter.124, s32[8,1024]{1,0} %parameter.120, f16[8,1024,2048]{2,1,0} %convert.2418), sharding={{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_call" op_name="parallelize(train_step_shard_parallel)/remat_call[\n concrete=True\n differentiated=True\n name=jvp(core_fn)\n policy=None\n prevent_cse=True\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/core/lift.py" source_line=845}
%custom-call.2438 = (s32[1,1,51200]{2,1,0}, f16[51200,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) custom-call((s32[1,1,51200]{2,1,0}, f16[51200,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %tuple.2437), custom_call_target="identity", output_to_operand_aliasing={{0}: (0, {0}), {1}: (0, {1}), {2}: (0, {2}), {3}: (0, {3}), {4}: (0, {4})}, sharding={{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, {devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}}, metadata={op_type="remat_begin"}
%get-tuple-element.2442 = s32[8,1024]{1,0} get-tuple-element((s32[1,1,51200]{2,1,0}, f16[51200,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %custom-call.2438), index=3, sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%broadcast.2447 = s32[8,1024,51200]{2,1,0} broadcast(s32[8,1024]{1,0} %get-tuple-element.2442), dimensions={0,1}, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%get-tuple-element.2439 = s32[1,1,51200]{2,1,0} get-tuple-element((s32[1,1,51200]{2,1,0}, f16[51200,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %custom-call.2438), index=0, sharding={replicated}
%reshape.2448 = s32[51200]{0} reshape(s32[1,1,51200]{2,1,0} %get-tuple-element.2439), sharding={replicated}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%broadcast.2449 = s32[8,1024,51200]{2,1,0} broadcast(s32[51200]{0} %reshape.2448), dimensions={2}, sharding={replicated}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%compare.2450 = pred[8,1024,51200]{2,1,0} compare(s32[8,1024,51200]{2,1,0} %broadcast.2447, s32[8,1024,51200]{2,1,0} %broadcast.2449), direction=EQ, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="eq" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/eq" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%convert.2451 = f16[8,1024,51200]{2,1,0} convert(pred[8,1024,51200]{2,1,0} %compare.2450), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=120}
%reshape.115 = f16[8192,51200]{1,0} reshape(f16[8,1024,51200]{2,1,0} %convert.2451), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.497 = f16[8192,51200]{1,0} reshape(f16[8192,51200]{1,0} %reshape.115), sharding={devices=[2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}
%get-tuple-element.2443 = f16[8,1024,2048]{2,1,0} get-tuple-element((s32[1,1,51200]{2,1,0}, f16[51200,2048]{1,0}, u32[2]{0}, s32[8,1024]{1,0}, f16[8,1024,2048]{2,1,0}) %custom-call.2438), index=4, sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%transpose.88 = f16[2048,8,1024]{0,2,1} transpose(f16[8,1024,2048]{2,1,0} %get-tuple-element.2443), dimensions={2,0,1}, sharding={devices=[1,2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.114 = f16[2048,8192]{1,0} reshape(f16[2048,8,1024]{0,2,1} %transpose.88), sharding={devices=[1,2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.72 = f16[51200,2048]{0,1} dot(f16[8192,51200]{1,0} %reshape.497, f16[2048,8192]{1,0} %reshape.114), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=121}
%add.2455 = f16[51200,2048]{1,0} add(f16[51200,2048]{1,0} %dot.71, f16[51200,2048]{0,1} %dot.72), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add_any" op_name="parallelize(train_step_shard_parallel)/add_any" source_file="test_export_hlo.py" source_line=83}
%broadcast.2489 = f16[51200,2048]{1,0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2490 = f16[51200,2048]{1,0} multiply(f16[51200,2048]{1,0} %add.2455, f16[51200,2048]{1,0} %broadcast.2489), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2494 = f32[51200,2048]{1,0} convert(f16[51200,2048]{1,0} %multiply.2490), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.36 = f32[51200,2048]{1,0} parameter(35), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2492 = f32[51200,2048]{1,0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2493 = f32[51200,2048]{1,0} multiply(f32[51200,2048]{1,0} %parameter.36, f32[51200,2048]{1,0} %broadcast.2492), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2495 = f32[51200,2048]{1,0} add(f32[51200,2048]{1,0} %convert.2494, f32[51200,2048]{1,0} %multiply.2493), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2982 = f32[51200,2048]{1,0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2724 = f16[51200,2048]{1,0} multiply(f16[51200,2048]{1,0} %add.2455, f16[51200,2048]{1,0} %add.2455), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2726 = f16[51200,2048]{1,0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2727 = f16[51200,2048]{1,0} multiply(f16[51200,2048]{1,0} %multiply.2724, f16[51200,2048]{1,0} %broadcast.2726), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2731 = f32[51200,2048]{1,0} convert(f16[51200,2048]{1,0} %multiply.2727), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.65 = f32[51200,2048]{1,0} parameter(64), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2729 = f32[51200,2048]{1,0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2730 = f32[51200,2048]{1,0} multiply(f32[51200,2048]{1,0} %parameter.65, f32[51200,2048]{1,0} %broadcast.2729), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2732 = f32[51200,2048]{1,0} add(f32[51200,2048]{1,0} %convert.2731, f32[51200,2048]{1,0} %multiply.2730), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3074 = f32[51200,2048]{1,0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3075 = f32[51200,2048]{1,0} divide(f32[51200,2048]{1,0} %add.2732, f32[51200,2048]{1,0} %broadcast.3074), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3183 = f32[51200,2048]{1,0} sqrt(f32[51200,2048]{1,0} %divide.3075), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3185 = f32[51200,2048]{1,0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3186 = f32[51200,2048]{1,0} add(f32[51200,2048]{1,0} %sqrt.3183, f32[51200,2048]{1,0} %broadcast.3185), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.30 = f32[51200,2048]{1,0} multiply(f32[51200,2048]{1,0} %broadcast.2982, f32[51200,2048]{1,0} %add.3186), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3187 = f32[51200,2048]{1,0} divide(f32[51200,2048]{1,0} %add.2495, f32[51200,2048]{1,0} %multiply.30), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3385 = f32[51200,2048]{1,0} broadcast(f32[] %constant.3380), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3386 = f32[51200,2048]{1,0} multiply(f32[51200,2048]{1,0} %parameter.94, f32[51200,2048]{1,0} %broadcast.3385), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3387 = f32[51200,2048]{1,0} add(f32[51200,2048]{1,0} %divide.3187, f32[51200,2048]{1,0} %multiply.3386), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%broadcast.3433 = f32[51200,2048]{1,0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3434 = f32[51200,2048]{1,0} multiply(f32[51200,2048]{1,0} %add.3387, f32[51200,2048]{1,0} %broadcast.3433), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3511 = f32[51200,2048]{1,0} add(f32[51200,2048]{1,0} %parameter.94, f32[51200,2048]{1,0} %multiply.3434), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3540 = f16[51200,2048]{1,0} convert(f32[51200,2048]{1,0} %add.3511), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3582 = f16[51200,2048]{1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3583 = f16[51200,2048]{1,0} multiply(f16[51200,2048]{1,0} %parameter.6, f16[51200,2048]{1,0} %broadcast.3582), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3584 = f16[51200,2048]{1,0} add(f16[51200,2048]{1,0} %convert.3540, f16[51200,2048]{1,0} %multiply.3583), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.95 = f32[2048]{0} parameter(94), sharding={replicated}
%reduce.2157 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %add.2150, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__65.2152, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%multiply.58 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2157, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2502 = f32[2048]{0} convert(f16[2048]{0} %multiply.58), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.37 = f32[2048]{0} parameter(36), sharding={replicated}
%multiply.2501 = f32[2048]{0} multiply(f32[2048]{0} %parameter.37, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2503 = f32[2048]{0} add(f32[2048]{0} %convert.2502, f32[2048]{0} %multiply.2501), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.69 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2157, f16[2048]{0} %reduce.2157), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2736 = f16[2048]{0} multiply(f16[2048]{0} %multiply.69, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2740 = f32[2048]{0} convert(f16[2048]{0} %multiply.2736), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.66 = f32[2048]{0} parameter(65), sharding={replicated}
%multiply.2739 = f32[2048]{0} multiply(f32[2048]{0} %parameter.66, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2741 = f32[2048]{0} add(f32[2048]{0} %convert.2740, f32[2048]{0} %multiply.2739), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3078 = f32[2048]{0} divide(f32[2048]{0} %add.2741, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3191 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3078), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3194 = f32[2048]{0} add(f32[2048]{0} %sqrt.3191, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.31 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3194), sharding={replicated}
%divide.3195 = f32[2048]{0} divide(f32[2048]{0} %add.2503, f32[2048]{0} %multiply.31), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3437 = f32[2048]{0} multiply(f32[2048]{0} %divide.3195, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3512 = f32[2048]{0} add(f32[2048]{0} %parameter.95, f32[2048]{0} %multiply.3437), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3541 = f16[2048]{0} convert(f32[2048]{0} %add.3512), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3587 = f16[2048]{0} multiply(f16[2048]{0} %parameter.7, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3588 = f16[2048]{0} add(f16[2048]{0} %convert.3541, f16[2048]{0} %multiply.3587), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.96 = f32[2048]{0} parameter(95), sharding={replicated}
%multiply.2172 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.1874, f16[8,1024,2048]{2,1,0} %convert.2168), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.2179 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %multiply.2172, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__67.2174, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.59 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2179, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2510 = f32[2048]{0} convert(f16[2048]{0} %multiply.59), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.38 = f32[2048]{0} parameter(37), sharding={replicated}
%multiply.2509 = f32[2048]{0} multiply(f32[2048]{0} %parameter.38, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2511 = f32[2048]{0} add(f32[2048]{0} %convert.2510, f32[2048]{0} %multiply.2509), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.70 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2179, f16[2048]{0} %reduce.2179), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2745 = f16[2048]{0} multiply(f16[2048]{0} %multiply.70, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2749 = f32[2048]{0} convert(f16[2048]{0} %multiply.2745), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.67 = f32[2048]{0} parameter(66), sharding={replicated}
%multiply.2748 = f32[2048]{0} multiply(f32[2048]{0} %parameter.67, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2750 = f32[2048]{0} add(f32[2048]{0} %convert.2749, f32[2048]{0} %multiply.2748), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3081 = f32[2048]{0} divide(f32[2048]{0} %add.2750, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3199 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3081), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3202 = f32[2048]{0} add(f32[2048]{0} %sqrt.3199, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.32 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3202), sharding={replicated}
%divide.3203 = f32[2048]{0} divide(f32[2048]{0} %add.2511, f32[2048]{0} %multiply.32), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3440 = f32[2048]{0} multiply(f32[2048]{0} %divide.3203, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3513 = f32[2048]{0} add(f32[2048]{0} %parameter.96, f32[2048]{0} %multiply.3440), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3542 = f16[2048]{0} convert(f32[2048]{0} %add.3513), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3591 = f16[2048]{0} multiply(f16[2048]{0} %parameter.8, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3592 = f16[2048]{0} add(f16[2048]{0} %convert.3542, f16[2048]{0} %multiply.3591), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.97 = f32[2048]{0} parameter(96), sharding={replicated}
%reduce.2243 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %convert.2236, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__73.2238, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%multiply.2514 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2243, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2518 = f32[2048]{0} convert(f16[2048]{0} %multiply.2514), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.39 = f32[2048]{0} parameter(38), sharding={replicated}
%multiply.2517 = f32[2048]{0} multiply(f32[2048]{0} %parameter.39, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2519 = f32[2048]{0} add(f32[2048]{0} %convert.2518, f32[2048]{0} %multiply.2517), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2751 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2243, f16[2048]{0} %reduce.2243), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2754 = f16[2048]{0} multiply(f16[2048]{0} %multiply.2751, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2758 = f32[2048]{0} convert(f16[2048]{0} %multiply.2754), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.68 = f32[2048]{0} parameter(67), sharding={replicated}
%multiply.2757 = f32[2048]{0} multiply(f32[2048]{0} %parameter.68, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2759 = f32[2048]{0} add(f32[2048]{0} %convert.2758, f32[2048]{0} %multiply.2757), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3084 = f32[2048]{0} divide(f32[2048]{0} %add.2759, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3207 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3084), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3210 = f32[2048]{0} add(f32[2048]{0} %sqrt.3207, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.33 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3210), sharding={replicated}
%divide.3211 = f32[2048]{0} divide(f32[2048]{0} %add.2519, f32[2048]{0} %multiply.33), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3443 = f32[2048]{0} multiply(f32[2048]{0} %divide.3211, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3514 = f32[2048]{0} add(f32[2048]{0} %parameter.97, f32[2048]{0} %multiply.3443), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3543 = f16[2048]{0} convert(f32[2048]{0} %add.3514), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3595 = f16[2048]{0} multiply(f16[2048]{0} %parameter.9, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3596 = f16[2048]{0} add(f16[2048]{0} %convert.3543, f16[2048]{0} %multiply.3595), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.98 = f32[2048,2048]{1,0} parameter(97), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.78 = f16[2048,8,1024]{0,2,1} transpose(f16[8,1024,2048]{2,1,0} %convert.2236), dimensions={2,0,1}, sharding={devices=[1,2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.99 = f16[2048,8192]{1,0} reshape(f16[2048,8,1024]{0,2,1} %transpose.78), sharding={devices=[1,2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.73 = f16[2048,2048]{0,1} dot(f16[8192,2048]{1,0} %reshape.78, f16[2048,8192]{1,0} %reshape.99), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%broadcast.2521 = f16[2048,2048]{1,0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2522 = f16[2048,2048]{1,0} multiply(f16[2048,2048]{0,1} %dot.73, f16[2048,2048]{1,0} %broadcast.2521), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2526 = f32[2048,2048]{1,0} convert(f16[2048,2048]{1,0} %multiply.2522), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.40 = f32[2048,2048]{1,0} parameter(39), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2524 = f32[2048,2048]{1,0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2525 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %parameter.40, f32[2048,2048]{1,0} %broadcast.2524), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2527 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %convert.2526, f32[2048,2048]{1,0} %multiply.2525), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2994 = f32[2048,2048]{1,0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2760 = f16[2048,2048]{0,1} multiply(f16[2048,2048]{0,1} %dot.73, f16[2048,2048]{0,1} %dot.73), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2762 = f16[2048,2048]{1,0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2763 = f16[2048,2048]{1,0} multiply(f16[2048,2048]{0,1} %multiply.2760, f16[2048,2048]{1,0} %broadcast.2762), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2767 = f32[2048,2048]{1,0} convert(f16[2048,2048]{1,0} %multiply.2763), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.69 = f32[2048,2048]{1,0} parameter(68), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2765 = f32[2048,2048]{1,0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2766 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %parameter.69, f32[2048,2048]{1,0} %broadcast.2765), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2768 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %convert.2767, f32[2048,2048]{1,0} %multiply.2766), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3086 = f32[2048,2048]{1,0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3087 = f32[2048,2048]{1,0} divide(f32[2048,2048]{1,0} %add.2768, f32[2048,2048]{1,0} %broadcast.3086), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3215 = f32[2048,2048]{1,0} sqrt(f32[2048,2048]{1,0} %divide.3087), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3217 = f32[2048,2048]{1,0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3218 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %sqrt.3215, f32[2048,2048]{1,0} %broadcast.3217), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.34 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %broadcast.2994, f32[2048,2048]{1,0} %add.3218), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3219 = f32[2048,2048]{1,0} divide(f32[2048,2048]{1,0} %add.2527, f32[2048,2048]{1,0} %multiply.34), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3389 = f32[2048,2048]{1,0} broadcast(f32[] %constant.3380), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3390 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %parameter.98, f32[2048,2048]{1,0} %broadcast.3389), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3391 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %divide.3219, f32[2048,2048]{1,0} %multiply.3390), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%broadcast.3445 = f32[2048,2048]{1,0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3446 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %add.3391, f32[2048,2048]{1,0} %broadcast.3445), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3515 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %parameter.98, f32[2048,2048]{1,0} %multiply.3446), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3544 = f16[2048,2048]{1,0} convert(f32[2048,2048]{1,0} %add.3515), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3598 = f16[2048,2048]{1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3599 = f16[2048,2048]{1,0} multiply(f16[2048,2048]{1,0} %parameter.10, f16[2048,2048]{1,0} %broadcast.3598), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3600 = f16[2048,2048]{1,0} add(f16[2048,2048]{1,0} %convert.3544, f16[2048,2048]{1,0} %multiply.3599), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.99 = f32[6144]{0} parameter(98), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%reshape.2317 = f16[8,1024,6144]{2,1,0} reshape(f16[8,1024,2048,3]{3,2,1,0} %add.2316), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 6144)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=162}
%reduce.2324 = f16[6144]{0} reduce(f16[8,1024,6144]{2,1,0} %reshape.2317, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__76.2319, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%broadcast.2529 = f16[6144]{0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2530 = f16[6144]{0} multiply(f16[6144]{0} %reduce.2324, f16[6144]{0} %broadcast.2529), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2534 = f32[6144]{0} convert(f16[6144]{0} %multiply.2530), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.41 = f32[6144]{0} parameter(40), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2532 = f32[6144]{0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2533 = f32[6144]{0} multiply(f32[6144]{0} %parameter.41, f32[6144]{0} %broadcast.2532), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2535 = f32[6144]{0} add(f32[6144]{0} %convert.2534, f32[6144]{0} %multiply.2533), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2997 = f32[6144]{0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2769 = f16[6144]{0} multiply(f16[6144]{0} %reduce.2324, f16[6144]{0} %reduce.2324), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2771 = f16[6144]{0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2772 = f16[6144]{0} multiply(f16[6144]{0} %multiply.2769, f16[6144]{0} %broadcast.2771), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2776 = f32[6144]{0} convert(f16[6144]{0} %multiply.2772), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.70 = f32[6144]{0} parameter(69), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2774 = f32[6144]{0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2775 = f32[6144]{0} multiply(f32[6144]{0} %parameter.70, f32[6144]{0} %broadcast.2774), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2777 = f32[6144]{0} add(f32[6144]{0} %convert.2776, f32[6144]{0} %multiply.2775), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3089 = f32[6144]{0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3090 = f32[6144]{0} divide(f32[6144]{0} %add.2777, f32[6144]{0} %broadcast.3089), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3223 = f32[6144]{0} sqrt(f32[6144]{0} %divide.3090), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3225 = f32[6144]{0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3226 = f32[6144]{0} add(f32[6144]{0} %sqrt.3223, f32[6144]{0} %broadcast.3225), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.35 = f32[6144]{0} multiply(f32[6144]{0} %broadcast.2997, f32[6144]{0} %add.3226), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3227 = f32[6144]{0} divide(f32[6144]{0} %add.2535, f32[6144]{0} %multiply.35), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3448 = f32[6144]{0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3449 = f32[6144]{0} multiply(f32[6144]{0} %divide.3227, f32[6144]{0} %broadcast.3448), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3516 = f32[6144]{0} add(f32[6144]{0} %parameter.99, f32[6144]{0} %multiply.3449), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3545 = f16[6144]{0} convert(f32[6144]{0} %add.3516), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3602 = f16[6144]{0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3603 = f16[6144]{0} multiply(f16[6144]{0} %parameter.11, f16[6144]{0} %broadcast.3602), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3604 = f16[6144]{0} add(f16[6144]{0} %convert.3545, f16[6144]{0} %multiply.3603), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.100 = f32[2048,6144]{1,0} parameter(99), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.82 = f16[6144,8,1024]{0,2,1} transpose(f16[8,1024,6144]{2,1,0} %reshape.2317), dimensions={2,0,1}, sharding={devices=[8,2,1]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%reshape.105 = f16[6144,8192]{1,0} reshape(f16[6144,8,1024]{0,2,1} %transpose.82), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%dot.74 = f16[2048,6144]{0,1} dot(f16[8192,2048]{1,0} %reshape.75, f16[6144,8192]{1,0} %reshape.105), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%broadcast.2537 = f16[2048,6144]{1,0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2538 = f16[2048,6144]{1,0} multiply(f16[2048,6144]{0,1} %dot.74, f16[2048,6144]{1,0} %broadcast.2537), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2542 = f32[2048,6144]{1,0} convert(f16[2048,6144]{1,0} %multiply.2538), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.42 = f32[2048,6144]{1,0} parameter(41), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2540 = f32[2048,6144]{1,0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2541 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %parameter.42, f32[2048,6144]{1,0} %broadcast.2540), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2543 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %convert.2542, f32[2048,6144]{1,0} %multiply.2541), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3000 = f32[2048,6144]{1,0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2778 = f16[2048,6144]{0,1} multiply(f16[2048,6144]{0,1} %dot.74, f16[2048,6144]{0,1} %dot.74), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2780 = f16[2048,6144]{1,0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2781 = f16[2048,6144]{1,0} multiply(f16[2048,6144]{0,1} %multiply.2778, f16[2048,6144]{1,0} %broadcast.2780), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2785 = f32[2048,6144]{1,0} convert(f16[2048,6144]{1,0} %multiply.2781), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.71 = f32[2048,6144]{1,0} parameter(70), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2783 = f32[2048,6144]{1,0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2784 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %parameter.71, f32[2048,6144]{1,0} %broadcast.2783), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2786 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %convert.2785, f32[2048,6144]{1,0} %multiply.2784), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3092 = f32[2048,6144]{1,0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3093 = f32[2048,6144]{1,0} divide(f32[2048,6144]{1,0} %add.2786, f32[2048,6144]{1,0} %broadcast.3092), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3231 = f32[2048,6144]{1,0} sqrt(f32[2048,6144]{1,0} %divide.3093), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3233 = f32[2048,6144]{1,0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3234 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %sqrt.3231, f32[2048,6144]{1,0} %broadcast.3233), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.36 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %broadcast.3000, f32[2048,6144]{1,0} %add.3234), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3235 = f32[2048,6144]{1,0} divide(f32[2048,6144]{1,0} %add.2543, f32[2048,6144]{1,0} %multiply.36), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3393 = f32[2048,6144]{1,0} broadcast(f32[] %constant.3380), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3394 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %parameter.100, f32[2048,6144]{1,0} %broadcast.3393), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3395 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %divide.3235, f32[2048,6144]{1,0} %multiply.3394), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%broadcast.3451 = f32[2048,6144]{1,0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3452 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %add.3395, f32[2048,6144]{1,0} %broadcast.3451), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3517 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %parameter.100, f32[2048,6144]{1,0} %multiply.3452), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3546 = f16[2048,6144]{1,0} convert(f32[2048,6144]{1,0} %add.3517), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3606 = f16[2048,6144]{1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3607 = f16[2048,6144]{1,0} multiply(f16[2048,6144]{1,0} %parameter.12, f16[2048,6144]{1,0} %broadcast.3606), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3608 = f16[2048,6144]{1,0} add(f16[2048,6144]{1,0} %convert.3546, f16[2048,6144]{1,0} %multiply.3607), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.101 = f32[8192]{0} parameter(100), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%reduce.2144 = f16[8192]{0} reduce(f16[8,1024,8192]{2,1,0} %add.2137, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__64.2139, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%broadcast.2545 = f16[8192]{0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2546 = f16[8192]{0} multiply(f16[8192]{0} %reduce.2144, f16[8192]{0} %broadcast.2545), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2550 = f32[8192]{0} convert(f16[8192]{0} %multiply.2546), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.43 = f32[8192]{0} parameter(42), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2548 = f32[8192]{0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2549 = f32[8192]{0} multiply(f32[8192]{0} %parameter.43, f32[8192]{0} %broadcast.2548), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2551 = f32[8192]{0} add(f32[8192]{0} %convert.2550, f32[8192]{0} %multiply.2549), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3003 = f32[8192]{0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2787 = f16[8192]{0} multiply(f16[8192]{0} %reduce.2144, f16[8192]{0} %reduce.2144), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2789 = f16[8192]{0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2790 = f16[8192]{0} multiply(f16[8192]{0} %multiply.2787, f16[8192]{0} %broadcast.2789), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2794 = f32[8192]{0} convert(f16[8192]{0} %multiply.2790), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.72 = f32[8192]{0} parameter(71), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2792 = f32[8192]{0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2793 = f32[8192]{0} multiply(f32[8192]{0} %parameter.72, f32[8192]{0} %broadcast.2792), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2795 = f32[8192]{0} add(f32[8192]{0} %convert.2794, f32[8192]{0} %multiply.2793), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3095 = f32[8192]{0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3096 = f32[8192]{0} divide(f32[8192]{0} %add.2795, f32[8192]{0} %broadcast.3095), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3239 = f32[8192]{0} sqrt(f32[8192]{0} %divide.3096), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3241 = f32[8192]{0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3242 = f32[8192]{0} add(f32[8192]{0} %sqrt.3239, f32[8192]{0} %broadcast.3241), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.37 = f32[8192]{0} multiply(f32[8192]{0} %broadcast.3003, f32[8192]{0} %add.3242), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3243 = f32[8192]{0} divide(f32[8192]{0} %add.2551, f32[8192]{0} %multiply.37), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3454 = f32[8192]{0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3455 = f32[8192]{0} multiply(f32[8192]{0} %divide.3243, f32[8192]{0} %broadcast.3454), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3518 = f32[8192]{0} add(f32[8192]{0} %parameter.101, f32[8192]{0} %multiply.3455), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3547 = f16[8192]{0} convert(f32[8192]{0} %add.3518), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3610 = f16[8192]{0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3611 = f16[8192]{0} multiply(f16[8192]{0} %parameter.13, f16[8192]{0} %broadcast.3610), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3612 = f16[8192]{0} add(f16[8192]{0} %convert.3547, f16[8192]{0} %multiply.3611), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.102 = f32[2048,8192]{1,0} parameter(101), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.74 = f16[8192,8,1024]{0,2,1} transpose(f16[8,1024,8192]{2,1,0} %add.2137), dimensions={2,0,1}, sharding={devices=[8,2,1]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%reshape.93 = f16[8192,8192]{1,0} reshape(f16[8192,8,1024]{0,2,1} %transpose.74), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%dot.75 = f16[2048,8192]{0,1} dot(f16[8192,2048]{1,0} %reshape.81, f16[8192,8192]{1,0} %reshape.93), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%broadcast.2553 = f16[2048,8192]{1,0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2554 = f16[2048,8192]{1,0} multiply(f16[2048,8192]{0,1} %dot.75, f16[2048,8192]{1,0} %broadcast.2553), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2558 = f32[2048,8192]{1,0} convert(f16[2048,8192]{1,0} %multiply.2554), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.44 = f32[2048,8192]{1,0} parameter(43), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2556 = f32[2048,8192]{1,0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2557 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %parameter.44, f32[2048,8192]{1,0} %broadcast.2556), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2559 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %convert.2558, f32[2048,8192]{1,0} %multiply.2557), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3006 = f32[2048,8192]{1,0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2796 = f16[2048,8192]{0,1} multiply(f16[2048,8192]{0,1} %dot.75, f16[2048,8192]{0,1} %dot.75), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2798 = f16[2048,8192]{1,0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2799 = f16[2048,8192]{1,0} multiply(f16[2048,8192]{0,1} %multiply.2796, f16[2048,8192]{1,0} %broadcast.2798), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2803 = f32[2048,8192]{1,0} convert(f16[2048,8192]{1,0} %multiply.2799), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.73 = f32[2048,8192]{1,0} parameter(72), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2801 = f32[2048,8192]{1,0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2802 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %parameter.73, f32[2048,8192]{1,0} %broadcast.2801), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2804 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %convert.2803, f32[2048,8192]{1,0} %multiply.2802), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3098 = f32[2048,8192]{1,0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3099 = f32[2048,8192]{1,0} divide(f32[2048,8192]{1,0} %add.2804, f32[2048,8192]{1,0} %broadcast.3098), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3247 = f32[2048,8192]{1,0} sqrt(f32[2048,8192]{1,0} %divide.3099), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3249 = f32[2048,8192]{1,0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3250 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %sqrt.3247, f32[2048,8192]{1,0} %broadcast.3249), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.38 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %broadcast.3006, f32[2048,8192]{1,0} %add.3250), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3251 = f32[2048,8192]{1,0} divide(f32[2048,8192]{1,0} %add.2559, f32[2048,8192]{1,0} %multiply.38), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3397 = f32[2048,8192]{1,0} broadcast(f32[] %constant.3380), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3398 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %parameter.102, f32[2048,8192]{1,0} %broadcast.3397), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3399 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %divide.3251, f32[2048,8192]{1,0} %multiply.3398), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%broadcast.3457 = f32[2048,8192]{1,0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3458 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %add.3399, f32[2048,8192]{1,0} %broadcast.3457), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3519 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %parameter.102, f32[2048,8192]{1,0} %multiply.3458), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3548 = f16[2048,8192]{1,0} convert(f32[2048,8192]{1,0} %add.3519), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3614 = f16[2048,8192]{1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3615 = f16[2048,8192]{1,0} multiply(f16[2048,8192]{1,0} %parameter.14, f16[2048,8192]{1,0} %broadcast.3614), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3616 = f16[2048,8192]{1,0} add(f16[2048,8192]{1,0} %convert.3548, f16[2048,8192]{1,0} %multiply.3615), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.103 = f32[2048]{0} parameter(102), sharding={replicated}
%reduce.2032 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %get-tuple-element.1727, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__55.2027, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%multiply.60 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2032, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2566 = f32[2048]{0} convert(f16[2048]{0} %multiply.60), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.45 = f32[2048]{0} parameter(44), sharding={replicated}
%multiply.2565 = f32[2048]{0} multiply(f32[2048]{0} %parameter.45, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2567 = f32[2048]{0} add(f32[2048]{0} %convert.2566, f32[2048]{0} %multiply.2565), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.71 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2032, f16[2048]{0} %reduce.2032), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2808 = f16[2048]{0} multiply(f16[2048]{0} %multiply.71, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2812 = f32[2048]{0} convert(f16[2048]{0} %multiply.2808), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.74 = f32[2048]{0} parameter(73), sharding={replicated}
%multiply.2811 = f32[2048]{0} multiply(f32[2048]{0} %parameter.74, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2813 = f32[2048]{0} add(f32[2048]{0} %convert.2812, f32[2048]{0} %multiply.2811), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3102 = f32[2048]{0} divide(f32[2048]{0} %add.2813, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3255 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3102), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3258 = f32[2048]{0} add(f32[2048]{0} %sqrt.3255, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.39 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3258), sharding={replicated}
%divide.3259 = f32[2048]{0} divide(f32[2048]{0} %add.2567, f32[2048]{0} %multiply.39), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3461 = f32[2048]{0} multiply(f32[2048]{0} %divide.3259, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3520 = f32[2048]{0} add(f32[2048]{0} %parameter.103, f32[2048]{0} %multiply.3461), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3549 = f16[2048]{0} convert(f32[2048]{0} %add.3520), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3619 = f16[2048]{0} multiply(f16[2048]{0} %parameter.15, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3620 = f16[2048]{0} add(f16[2048]{0} %convert.3549, f16[2048]{0} %multiply.3619), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.104 = f32[2048]{0} parameter(103), sharding={replicated}
%multiply.2047 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.2015, f16[8,1024,2048]{2,1,0} %convert.2043), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.2054 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %multiply.2047, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__57.2049, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.61 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2054, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2574 = f32[2048]{0} convert(f16[2048]{0} %multiply.61), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.46 = f32[2048]{0} parameter(45), sharding={replicated}
%multiply.2573 = f32[2048]{0} multiply(f32[2048]{0} %parameter.46, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2575 = f32[2048]{0} add(f32[2048]{0} %convert.2574, f32[2048]{0} %multiply.2573), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.72 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2054, f16[2048]{0} %reduce.2054), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2817 = f16[2048]{0} multiply(f16[2048]{0} %multiply.72, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2821 = f32[2048]{0} convert(f16[2048]{0} %multiply.2817), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.75 = f32[2048]{0} parameter(74), sharding={replicated}
%multiply.2820 = f32[2048]{0} multiply(f32[2048]{0} %parameter.75, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2822 = f32[2048]{0} add(f32[2048]{0} %convert.2821, f32[2048]{0} %multiply.2820), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3105 = f32[2048]{0} divide(f32[2048]{0} %add.2822, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3263 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3105), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3266 = f32[2048]{0} add(f32[2048]{0} %sqrt.3263, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.40 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3266), sharding={replicated}
%divide.3267 = f32[2048]{0} divide(f32[2048]{0} %add.2575, f32[2048]{0} %multiply.40), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3464 = f32[2048]{0} multiply(f32[2048]{0} %divide.3267, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3521 = f32[2048]{0} add(f32[2048]{0} %parameter.104, f32[2048]{0} %multiply.3464), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3550 = f16[2048]{0} convert(f32[2048]{0} %add.3521), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3623 = f16[2048]{0} multiply(f16[2048]{0} %parameter.16, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3624 = f16[2048]{0} add(f16[2048]{0} %convert.3550, f16[2048]{0} %multiply.3623), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.105 = f32[2048]{0} parameter(104), sharding={replicated}
%reduce.2118 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %convert.2111, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__63.2113, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%multiply.2578 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2118, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2582 = f32[2048]{0} convert(f16[2048]{0} %multiply.2578), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.47 = f32[2048]{0} parameter(46), sharding={replicated}
%multiply.2581 = f32[2048]{0} multiply(f32[2048]{0} %parameter.47, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2583 = f32[2048]{0} add(f32[2048]{0} %convert.2582, f32[2048]{0} %multiply.2581), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2823 = f16[2048]{0} multiply(f16[2048]{0} %reduce.2118, f16[2048]{0} %reduce.2118), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2826 = f16[2048]{0} multiply(f16[2048]{0} %multiply.2823, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2830 = f32[2048]{0} convert(f16[2048]{0} %multiply.2826), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.76 = f32[2048]{0} parameter(75), sharding={replicated}
%multiply.2829 = f32[2048]{0} multiply(f32[2048]{0} %parameter.76, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2831 = f32[2048]{0} add(f32[2048]{0} %convert.2830, f32[2048]{0} %multiply.2829), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3108 = f32[2048]{0} divide(f32[2048]{0} %add.2831, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3271 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3108), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3274 = f32[2048]{0} add(f32[2048]{0} %sqrt.3271, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.41 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3274), sharding={replicated}
%divide.3275 = f32[2048]{0} divide(f32[2048]{0} %add.2583, f32[2048]{0} %multiply.41), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3467 = f32[2048]{0} multiply(f32[2048]{0} %divide.3275, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3522 = f32[2048]{0} add(f32[2048]{0} %parameter.105, f32[2048]{0} %multiply.3467), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3551 = f16[2048]{0} convert(f32[2048]{0} %add.3522), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3627 = f16[2048]{0} multiply(f16[2048]{0} %parameter.17, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3628 = f16[2048]{0} add(f16[2048]{0} %convert.3551, f16[2048]{0} %multiply.3627), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.106 = f32[8192,2048]{1,0} parameter(105), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.70 = f16[2048,8,1024]{0,2,1} transpose(f16[8,1024,2048]{2,1,0} %convert.2111), dimensions={2,0,1}, sharding={devices=[1,2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.87 = f16[2048,8192]{1,0} reshape(f16[2048,8,1024]{0,2,1} %transpose.70), sharding={devices=[1,2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.76 = f16[8192,2048]{0,1} dot(f16[8192,8192]{1,0} %reshape.84, f16[2048,8192]{1,0} %reshape.87), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%broadcast.2585 = f16[8192,2048]{1,0} broadcast(f16[] %constant.2456), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2586 = f16[8192,2048]{1,0} multiply(f16[8192,2048]{0,1} %dot.76, f16[8192,2048]{1,0} %broadcast.2585), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2590 = f32[8192,2048]{1,0} convert(f16[8192,2048]{1,0} %multiply.2586), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.48 = f32[8192,2048]{1,0} parameter(47), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2588 = f32[8192,2048]{1,0} broadcast(f32[] %constant.2459), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2589 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %parameter.48, f32[8192,2048]{1,0} %broadcast.2588), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2591 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %convert.2590, f32[8192,2048]{1,0} %multiply.2589), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3018 = f32[8192,2048]{1,0} broadcast(f32[] %subtract.2968), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%multiply.2832 = f16[8192,2048]{0,1} multiply(f16[8192,2048]{0,1} %dot.76, f16[8192,2048]{0,1} %dot.76), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.2834 = f16[8192,2048]{1,0} broadcast(f16[] %constant.2689), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2835 = f16[8192,2048]{1,0} multiply(f16[8192,2048]{0,1} %multiply.2832, f16[8192,2048]{1,0} %broadcast.2834), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2839 = f32[8192,2048]{1,0} convert(f16[8192,2048]{1,0} %multiply.2835), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.77 = f32[8192,2048]{1,0} parameter(76), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%broadcast.2837 = f32[8192,2048]{1,0} broadcast(f32[] %constant.2692), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2838 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %parameter.77, f32[8192,2048]{1,0} %broadcast.2837), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2840 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %convert.2839, f32[8192,2048]{1,0} %multiply.2838), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%broadcast.3110 = f32[8192,2048]{1,0} broadcast(f32[] %subtract.3060), dimensions={}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%divide.3111 = f32[8192,2048]{1,0} divide(f32[8192,2048]{1,0} %add.2840, f32[8192,2048]{1,0} %broadcast.3110), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3279 = f32[8192,2048]{1,0} sqrt(f32[8192,2048]{1,0} %divide.3111), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3281 = f32[8192,2048]{1,0} broadcast(f32[] %constant.3152), dimensions={}, sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3282 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %sqrt.3279, f32[8192,2048]{1,0} %broadcast.3281), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.42 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %broadcast.3018, f32[8192,2048]{1,0} %add.3282), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3283 = f32[8192,2048]{1,0} divide(f32[8192,2048]{1,0} %add.2591, f32[8192,2048]{1,0} %multiply.42), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%broadcast.3401 = f32[8192,2048]{1,0} broadcast(f32[] %constant.3380), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3402 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %parameter.106, f32[8192,2048]{1,0} %broadcast.3401), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3403 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %divide.3283, f32[8192,2048]{1,0} %multiply.3402), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%broadcast.3469 = f32[8192,2048]{1,0} broadcast(f32[] %constant.3420), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%multiply.3470 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %add.3403, f32[8192,2048]{1,0} %broadcast.3469), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3523 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %parameter.106, f32[8192,2048]{1,0} %multiply.3470), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3552 = f16[8192,2048]{1,0} convert(f32[8192,2048]{1,0} %add.3523), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%broadcast.3630 = f16[8192,2048]{1,0} broadcast(f16[] %constant.1015), dimensions={}, sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%multiply.3631 = f16[8192,2048]{1,0} multiply(f16[8192,2048]{1,0} %parameter.18, f16[8192,2048]{1,0} %broadcast.3630), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3632 = f16[8192,2048]{1,0} add(f16[8192,2048]{1,0} %convert.3552, f16[8192,2048]{1,0} %multiply.3631), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.107 = f32[2048]{0} parameter(106), sharding={replicated}
%reduce.1519 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %add.1512, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__37.1514, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%multiply.62 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1519, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2598 = f32[2048]{0} convert(f16[2048]{0} %multiply.62), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.49 = f32[2048]{0} parameter(48), sharding={replicated}
%multiply.2597 = f32[2048]{0} multiply(f32[2048]{0} %parameter.49, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2599 = f32[2048]{0} add(f32[2048]{0} %convert.2598, f32[2048]{0} %multiply.2597), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.73 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1519, f16[2048]{0} %reduce.1519), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2844 = f16[2048]{0} multiply(f16[2048]{0} %multiply.73, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2848 = f32[2048]{0} convert(f16[2048]{0} %multiply.2844), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.78 = f32[2048]{0} parameter(77), sharding={replicated}
%multiply.2847 = f32[2048]{0} multiply(f32[2048]{0} %parameter.78, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2849 = f32[2048]{0} add(f32[2048]{0} %convert.2848, f32[2048]{0} %multiply.2847), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3114 = f32[2048]{0} divide(f32[2048]{0} %add.2849, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3287 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3114), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3290 = f32[2048]{0} add(f32[2048]{0} %sqrt.3287, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.43 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3290), sharding={replicated}
%divide.3291 = f32[2048]{0} divide(f32[2048]{0} %add.2599, f32[2048]{0} %multiply.43), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3473 = f32[2048]{0} multiply(f32[2048]{0} %divide.3291, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3524 = f32[2048]{0} add(f32[2048]{0} %parameter.107, f32[2048]{0} %multiply.3473), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3553 = f16[2048]{0} convert(f32[2048]{0} %add.3524), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3635 = f16[2048]{0} multiply(f16[2048]{0} %parameter.19, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3636 = f16[2048]{0} add(f16[2048]{0} %convert.3553, f16[2048]{0} %multiply.3635), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.108 = f32[2048]{0} parameter(107), sharding={replicated}
%multiply.1534 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.1236, f16[8,1024,2048]{2,1,0} %convert.1530), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.1541 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %multiply.1534, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__39.1536, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.63 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1541, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2606 = f32[2048]{0} convert(f16[2048]{0} %multiply.63), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.50 = f32[2048]{0} parameter(49), sharding={replicated}
%multiply.2605 = f32[2048]{0} multiply(f32[2048]{0} %parameter.50, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2607 = f32[2048]{0} add(f32[2048]{0} %convert.2606, f32[2048]{0} %multiply.2605), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.74 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1541, f16[2048]{0} %reduce.1541), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2853 = f16[2048]{0} multiply(f16[2048]{0} %multiply.74, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2857 = f32[2048]{0} convert(f16[2048]{0} %multiply.2853), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.79 = f32[2048]{0} parameter(78), sharding={replicated}
%multiply.2856 = f32[2048]{0} multiply(f32[2048]{0} %parameter.79, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2858 = f32[2048]{0} add(f32[2048]{0} %convert.2857, f32[2048]{0} %multiply.2856), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3117 = f32[2048]{0} divide(f32[2048]{0} %add.2858, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3295 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3117), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3298 = f32[2048]{0} add(f32[2048]{0} %sqrt.3295, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.44 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3298), sharding={replicated}
%divide.3299 = f32[2048]{0} divide(f32[2048]{0} %add.2607, f32[2048]{0} %multiply.44), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3476 = f32[2048]{0} multiply(f32[2048]{0} %divide.3299, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3525 = f32[2048]{0} add(f32[2048]{0} %parameter.108, f32[2048]{0} %multiply.3476), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3554 = f16[2048]{0} convert(f32[2048]{0} %add.3525), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3639 = f16[2048]{0} multiply(f16[2048]{0} %parameter.20, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3640 = f16[2048]{0} add(f16[2048]{0} %convert.3554, f16[2048]{0} %multiply.3639), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.109 = f32[2048]{0} parameter(108), sharding={replicated}
%reduce.1605 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %convert.1598, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__45.1600, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%multiply.2610 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1605, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2614 = f32[2048]{0} convert(f16[2048]{0} %multiply.2610), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.51 = f32[2048]{0} parameter(50), sharding={replicated}
%multiply.2613 = f32[2048]{0} multiply(f32[2048]{0} %parameter.51, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2615 = f32[2048]{0} add(f32[2048]{0} %convert.2614, f32[2048]{0} %multiply.2613), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2859 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1605, f16[2048]{0} %reduce.1605), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2862 = f16[2048]{0} multiply(f16[2048]{0} %multiply.2859, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2866 = f32[2048]{0} convert(f16[2048]{0} %multiply.2862), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.80 = f32[2048]{0} parameter(79), sharding={replicated}
%multiply.2865 = f32[2048]{0} multiply(f32[2048]{0} %parameter.80, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2867 = f32[2048]{0} add(f32[2048]{0} %convert.2866, f32[2048]{0} %multiply.2865), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3120 = f32[2048]{0} divide(f32[2048]{0} %add.2867, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3303 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3120), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3306 = f32[2048]{0} add(f32[2048]{0} %sqrt.3303, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.45 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3306), sharding={replicated}
%divide.3307 = f32[2048]{0} divide(f32[2048]{0} %add.2615, f32[2048]{0} %multiply.45), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3479 = f32[2048]{0} multiply(f32[2048]{0} %divide.3307, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3526 = f32[2048]{0} add(f32[2048]{0} %parameter.109, f32[2048]{0} %multiply.3479), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3555 = f16[2048]{0} convert(f32[2048]{0} %add.3526), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3643 = f16[2048]{0} multiply(f16[2048]{0} %parameter.21, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3644 = f16[2048]{0} add(f16[2048]{0} %convert.3555, f16[2048]{0} %multiply.3643), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.110 = f32[2048,2048]{1,0} parameter(109), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.54 = f16[2048,8,1024]{0,2,1} transpose(f16[8,1024,2048]{2,1,0} %convert.1598), dimensions={2,0,1}, sharding={devices=[1,2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.63 = f16[2048,8192]{1,0} reshape(f16[2048,8,1024]{0,2,1} %transpose.54), sharding={devices=[1,2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.77 = f16[2048,2048]{0,1} dot(f16[8192,2048]{1,0} %reshape.42, f16[2048,8192]{1,0} %reshape.63), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%multiply.2618 = f16[2048,2048]{1,0} multiply(f16[2048,2048]{0,1} %dot.77, f16[2048,2048]{1,0} %broadcast.2521), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2622 = f32[2048,2048]{1,0} convert(f16[2048,2048]{1,0} %multiply.2618), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.52 = f32[2048,2048]{1,0} parameter(51), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2621 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %parameter.52, f32[2048,2048]{1,0} %broadcast.2524), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2623 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %convert.2622, f32[2048,2048]{1,0} %multiply.2621), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2868 = f16[2048,2048]{0,1} multiply(f16[2048,2048]{0,1} %dot.77, f16[2048,2048]{0,1} %dot.77), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2871 = f16[2048,2048]{1,0} multiply(f16[2048,2048]{0,1} %multiply.2868, f16[2048,2048]{1,0} %broadcast.2762), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2875 = f32[2048,2048]{1,0} convert(f16[2048,2048]{1,0} %multiply.2871), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.81 = f32[2048,2048]{1,0} parameter(80), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2874 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %parameter.81, f32[2048,2048]{1,0} %broadcast.2765), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2876 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %convert.2875, f32[2048,2048]{1,0} %multiply.2874), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3123 = f32[2048,2048]{1,0} divide(f32[2048,2048]{1,0} %add.2876, f32[2048,2048]{1,0} %broadcast.3086), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3311 = f32[2048,2048]{1,0} sqrt(f32[2048,2048]{1,0} %divide.3123), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3314 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %sqrt.3311, f32[2048,2048]{1,0} %broadcast.3217), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.46 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %broadcast.2994, f32[2048,2048]{1,0} %add.3314), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3315 = f32[2048,2048]{1,0} divide(f32[2048,2048]{1,0} %add.2623, f32[2048,2048]{1,0} %multiply.46), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3406 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %parameter.110, f32[2048,2048]{1,0} %broadcast.3389), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3407 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %divide.3315, f32[2048,2048]{1,0} %multiply.3406), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3482 = f32[2048,2048]{1,0} multiply(f32[2048,2048]{1,0} %add.3407, f32[2048,2048]{1,0} %broadcast.3445), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3527 = f32[2048,2048]{1,0} add(f32[2048,2048]{1,0} %parameter.110, f32[2048,2048]{1,0} %multiply.3482), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3556 = f16[2048,2048]{1,0} convert(f32[2048,2048]{1,0} %add.3527), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3647 = f16[2048,2048]{1,0} multiply(f16[2048,2048]{1,0} %parameter.22, f16[2048,2048]{1,0} %broadcast.3598), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3648 = f16[2048,2048]{1,0} add(f16[2048,2048]{1,0} %convert.3556, f16[2048,2048]{1,0} %multiply.3647), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.111 = f32[6144]{0} parameter(110), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%reshape.1679 = f16[8,1024,6144]{2,1,0} reshape(f16[8,1024,2048,3]{3,2,1,0} %add.1678), sharding={devices=[2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15}, metadata={op_type="reshape" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reshape[\n dimensions=None\n new_sizes=(8, 1024, 6144)\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/bert_model.py" source_line=162}
%reduce.1686 = f16[6144]{0} reduce(f16[8,1024,6144]{2,1,0} %reshape.1679, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__48.1681, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%multiply.2626 = f16[6144]{0} multiply(f16[6144]{0} %reduce.1686, f16[6144]{0} %broadcast.2529), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2630 = f32[6144]{0} convert(f16[6144]{0} %multiply.2626), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.53 = f32[6144]{0} parameter(52), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2629 = f32[6144]{0} multiply(f32[6144]{0} %parameter.53, f32[6144]{0} %broadcast.2532), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2631 = f32[6144]{0} add(f32[6144]{0} %convert.2630, f32[6144]{0} %multiply.2629), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2877 = f16[6144]{0} multiply(f16[6144]{0} %reduce.1686, f16[6144]{0} %reduce.1686), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2880 = f16[6144]{0} multiply(f16[6144]{0} %multiply.2877, f16[6144]{0} %broadcast.2771), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2884 = f32[6144]{0} convert(f16[6144]{0} %multiply.2880), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.82 = f32[6144]{0} parameter(81), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2883 = f32[6144]{0} multiply(f32[6144]{0} %parameter.82, f32[6144]{0} %broadcast.2774), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2885 = f32[6144]{0} add(f32[6144]{0} %convert.2884, f32[6144]{0} %multiply.2883), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3126 = f32[6144]{0} divide(f32[6144]{0} %add.2885, f32[6144]{0} %broadcast.3089), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3319 = f32[6144]{0} sqrt(f32[6144]{0} %divide.3126), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3322 = f32[6144]{0} add(f32[6144]{0} %sqrt.3319, f32[6144]{0} %broadcast.3225), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.47 = f32[6144]{0} multiply(f32[6144]{0} %broadcast.2997, f32[6144]{0} %add.3322), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3323 = f32[6144]{0} divide(f32[6144]{0} %add.2631, f32[6144]{0} %multiply.47), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3485 = f32[6144]{0} multiply(f32[6144]{0} %divide.3323, f32[6144]{0} %broadcast.3448), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3528 = f32[6144]{0} add(f32[6144]{0} %parameter.111, f32[6144]{0} %multiply.3485), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3557 = f16[6144]{0} convert(f32[6144]{0} %add.3528), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3651 = f16[6144]{0} multiply(f16[6144]{0} %parameter.23, f16[6144]{0} %broadcast.3602), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3652 = f16[6144]{0} add(f16[6144]{0} %convert.3557, f16[6144]{0} %multiply.3651), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.112 = f32[2048,6144]{1,0} parameter(111), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.58 = f16[6144,8,1024]{0,2,1} transpose(f16[8,1024,6144]{2,1,0} %reshape.1679), dimensions={2,0,1}, sharding={devices=[8,2,1]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%reshape.69 = f16[6144,8192]{1,0} reshape(f16[6144,8,1024]{0,2,1} %transpose.58), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%dot.78 = f16[2048,6144]{0,1} dot(f16[8192,2048]{1,0} %reshape.39, f16[6144,8192]{1,0} %reshape.69), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%multiply.2634 = f16[2048,6144]{1,0} multiply(f16[2048,6144]{0,1} %dot.78, f16[2048,6144]{1,0} %broadcast.2537), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2638 = f32[2048,6144]{1,0} convert(f16[2048,6144]{1,0} %multiply.2634), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.54 = f32[2048,6144]{1,0} parameter(53), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2637 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %parameter.54, f32[2048,6144]{1,0} %broadcast.2540), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2639 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %convert.2638, f32[2048,6144]{1,0} %multiply.2637), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2886 = f16[2048,6144]{0,1} multiply(f16[2048,6144]{0,1} %dot.78, f16[2048,6144]{0,1} %dot.78), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2889 = f16[2048,6144]{1,0} multiply(f16[2048,6144]{0,1} %multiply.2886, f16[2048,6144]{1,0} %broadcast.2780), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2893 = f32[2048,6144]{1,0} convert(f16[2048,6144]{1,0} %multiply.2889), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.83 = f32[2048,6144]{1,0} parameter(82), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2892 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %parameter.83, f32[2048,6144]{1,0} %broadcast.2783), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2894 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %convert.2893, f32[2048,6144]{1,0} %multiply.2892), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3129 = f32[2048,6144]{1,0} divide(f32[2048,6144]{1,0} %add.2894, f32[2048,6144]{1,0} %broadcast.3092), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3327 = f32[2048,6144]{1,0} sqrt(f32[2048,6144]{1,0} %divide.3129), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3330 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %sqrt.3327, f32[2048,6144]{1,0} %broadcast.3233), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.48 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %broadcast.3000, f32[2048,6144]{1,0} %add.3330), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3331 = f32[2048,6144]{1,0} divide(f32[2048,6144]{1,0} %add.2639, f32[2048,6144]{1,0} %multiply.48), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3410 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %parameter.112, f32[2048,6144]{1,0} %broadcast.3393), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3411 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %divide.3331, f32[2048,6144]{1,0} %multiply.3410), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3488 = f32[2048,6144]{1,0} multiply(f32[2048,6144]{1,0} %add.3411, f32[2048,6144]{1,0} %broadcast.3451), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3529 = f32[2048,6144]{1,0} add(f32[2048,6144]{1,0} %parameter.112, f32[2048,6144]{1,0} %multiply.3488), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3558 = f16[2048,6144]{1,0} convert(f32[2048,6144]{1,0} %add.3529), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3655 = f16[2048,6144]{1,0} multiply(f16[2048,6144]{1,0} %parameter.24, f16[2048,6144]{1,0} %broadcast.3606), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3656 = f16[2048,6144]{1,0} add(f16[2048,6144]{1,0} %convert.3558, f16[2048,6144]{1,0} %multiply.3655), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.113 = f32[8192]{0} parameter(112), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%reduce.1506 = f16[8192]{0} reduce(f16[8,1024,8192]{2,1,0} %add.1499, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__36.1501, sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%multiply.2642 = f16[8192]{0} multiply(f16[8192]{0} %reduce.1506, f16[8192]{0} %broadcast.2545), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2646 = f32[8192]{0} convert(f16[8192]{0} %multiply.2642), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.55 = f32[8192]{0} parameter(54), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2645 = f32[8192]{0} multiply(f32[8192]{0} %parameter.55, f32[8192]{0} %broadcast.2548), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2647 = f32[8192]{0} add(f32[8192]{0} %convert.2646, f32[8192]{0} %multiply.2645), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2895 = f16[8192]{0} multiply(f16[8192]{0} %reduce.1506, f16[8192]{0} %reduce.1506), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2898 = f16[8192]{0} multiply(f16[8192]{0} %multiply.2895, f16[8192]{0} %broadcast.2789), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2902 = f32[8192]{0} convert(f16[8192]{0} %multiply.2898), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.84 = f32[8192]{0} parameter(83), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2901 = f32[8192]{0} multiply(f32[8192]{0} %parameter.84, f32[8192]{0} %broadcast.2792), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2903 = f32[8192]{0} add(f32[8192]{0} %convert.2902, f32[8192]{0} %multiply.2901), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3132 = f32[8192]{0} divide(f32[8192]{0} %add.2903, f32[8192]{0} %broadcast.3095), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3335 = f32[8192]{0} sqrt(f32[8192]{0} %divide.3132), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3338 = f32[8192]{0} add(f32[8192]{0} %sqrt.3335, f32[8192]{0} %broadcast.3241), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.49 = f32[8192]{0} multiply(f32[8192]{0} %broadcast.3003, f32[8192]{0} %add.3338), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3339 = f32[8192]{0} divide(f32[8192]{0} %add.2647, f32[8192]{0} %multiply.49), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3491 = f32[8192]{0} multiply(f32[8192]{0} %divide.3339, f32[8192]{0} %broadcast.3454), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3530 = f32[8192]{0} add(f32[8192]{0} %parameter.113, f32[8192]{0} %multiply.3491), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3559 = f16[8192]{0} convert(f32[8192]{0} %add.3530), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3659 = f16[8192]{0} multiply(f16[8192]{0} %parameter.25, f16[8192]{0} %broadcast.3610), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3660 = f16[8192]{0} add(f16[8192]{0} %convert.3559, f16[8192]{0} %multiply.3659), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.114 = f32[2048,8192]{1,0} parameter(113), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.50 = f16[8192,8,1024]{0,2,1} transpose(f16[8,1024,8192]{2,1,0} %add.1499), dimensions={2,0,1}, sharding={devices=[8,2,1]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%reshape.57 = f16[8192,8192]{1,0} reshape(f16[8192,8,1024]{0,2,1} %transpose.50), sharding={devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15}
%dot.79 = f16[2048,8192]{0,1} dot(f16[8192,2048]{1,0} %reshape.45, f16[8192,8192]{1,0} %reshape.57), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%multiply.2650 = f16[2048,8192]{1,0} multiply(f16[2048,8192]{0,1} %dot.79, f16[2048,8192]{1,0} %broadcast.2553), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2654 = f32[2048,8192]{1,0} convert(f16[2048,8192]{1,0} %multiply.2650), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.56 = f32[2048,8192]{1,0} parameter(55), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2653 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %parameter.56, f32[2048,8192]{1,0} %broadcast.2556), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2655 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %convert.2654, f32[2048,8192]{1,0} %multiply.2653), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2904 = f16[2048,8192]{0,1} multiply(f16[2048,8192]{0,1} %dot.79, f16[2048,8192]{0,1} %dot.79), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2907 = f16[2048,8192]{1,0} multiply(f16[2048,8192]{0,1} %multiply.2904, f16[2048,8192]{1,0} %broadcast.2798), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2911 = f32[2048,8192]{1,0} convert(f16[2048,8192]{1,0} %multiply.2907), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.85 = f32[2048,8192]{1,0} parameter(84), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2910 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %parameter.85, f32[2048,8192]{1,0} %broadcast.2801), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2912 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %convert.2911, f32[2048,8192]{1,0} %multiply.2910), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3135 = f32[2048,8192]{1,0} divide(f32[2048,8192]{1,0} %add.2912, f32[2048,8192]{1,0} %broadcast.3098), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3343 = f32[2048,8192]{1,0} sqrt(f32[2048,8192]{1,0} %divide.3135), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3346 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %sqrt.3343, f32[2048,8192]{1,0} %broadcast.3249), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.50 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %broadcast.3006, f32[2048,8192]{1,0} %add.3346), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3347 = f32[2048,8192]{1,0} divide(f32[2048,8192]{1,0} %add.2655, f32[2048,8192]{1,0} %multiply.50), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3414 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %parameter.114, f32[2048,8192]{1,0} %broadcast.3397), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3415 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %divide.3347, f32[2048,8192]{1,0} %multiply.3414), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3494 = f32[2048,8192]{1,0} multiply(f32[2048,8192]{1,0} %add.3415, f32[2048,8192]{1,0} %broadcast.3457), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3531 = f32[2048,8192]{1,0} add(f32[2048,8192]{1,0} %parameter.114, f32[2048,8192]{1,0} %multiply.3494), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3560 = f16[2048,8192]{1,0} convert(f32[2048,8192]{1,0} %add.3531), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3663 = f16[2048,8192]{1,0} multiply(f16[2048,8192]{1,0} %parameter.26, f16[2048,8192]{1,0} %broadcast.3614), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3664 = f16[2048,8192]{1,0} add(f16[2048,8192]{1,0} %convert.3560, f16[2048,8192]{1,0} %multiply.3663), sharding={devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.115 = f32[2048]{0} parameter(114), sharding={replicated}
%reduce.1394 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %get-tuple-element.1089, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__27.1389, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=155}
%multiply.64 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1394, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2662 = f32[2048]{0} convert(f16[2048]{0} %multiply.64), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.57 = f32[2048]{0} parameter(56), sharding={replicated}
%multiply.2661 = f32[2048]{0} multiply(f32[2048]{0} %parameter.57, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2663 = f32[2048]{0} add(f32[2048]{0} %convert.2662, f32[2048]{0} %multiply.2661), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.75 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1394, f16[2048]{0} %reduce.1394), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2916 = f16[2048]{0} multiply(f16[2048]{0} %multiply.75, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2920 = f32[2048]{0} convert(f16[2048]{0} %multiply.2916), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.86 = f32[2048]{0} parameter(85), sharding={replicated}
%multiply.2919 = f32[2048]{0} multiply(f32[2048]{0} %parameter.86, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2921 = f32[2048]{0} add(f32[2048]{0} %convert.2920, f32[2048]{0} %multiply.2919), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3138 = f32[2048]{0} divide(f32[2048]{0} %add.2921, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3351 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3138), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3354 = f32[2048]{0} add(f32[2048]{0} %sqrt.3351, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.51 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3354), sharding={replicated}
%divide.3355 = f32[2048]{0} divide(f32[2048]{0} %add.2663, f32[2048]{0} %multiply.51), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3497 = f32[2048]{0} multiply(f32[2048]{0} %divide.3355, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3532 = f32[2048]{0} add(f32[2048]{0} %parameter.115, f32[2048]{0} %multiply.3497), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3561 = f16[2048]{0} convert(f32[2048]{0} %add.3532), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3667 = f16[2048]{0} multiply(f16[2048]{0} %parameter.27, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3668 = f16[2048]{0} add(f16[2048]{0} %convert.3561, f16[2048]{0} %multiply.3667), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.116 = f32[2048]{0} parameter(115), sharding={replicated}
%multiply.1409 = f16[8,1024,2048]{2,1,0} multiply(f16[8,1024,2048]{2,1,0} %broadcast.1377, f16[8,1024,2048]{2,1,0} %convert.1405), sharding={devices=[2,1,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/mul" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%reduce.1416 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %multiply.1409, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__29.1411, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/efs/alpa/alpa/monkey_patch.py" source_line=150}
%multiply.65 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1416, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2670 = f32[2048]{0} convert(f16[2048]{0} %multiply.65), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.58 = f32[2048]{0} parameter(57), sharding={replicated}
%multiply.2669 = f32[2048]{0} multiply(f32[2048]{0} %parameter.58, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2671 = f32[2048]{0} add(f32[2048]{0} %convert.2670, f32[2048]{0} %multiply.2669), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.76 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1416, f16[2048]{0} %reduce.1416), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2925 = f16[2048]{0} multiply(f16[2048]{0} %multiply.76, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2929 = f32[2048]{0} convert(f16[2048]{0} %multiply.2925), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.87 = f32[2048]{0} parameter(86), sharding={replicated}
%multiply.2928 = f32[2048]{0} multiply(f32[2048]{0} %parameter.87, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2930 = f32[2048]{0} add(f32[2048]{0} %convert.2929, f32[2048]{0} %multiply.2928), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3141 = f32[2048]{0} divide(f32[2048]{0} %add.2930, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3359 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3141), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3362 = f32[2048]{0} add(f32[2048]{0} %sqrt.3359, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.52 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3362), sharding={replicated}
%divide.3363 = f32[2048]{0} divide(f32[2048]{0} %add.2671, f32[2048]{0} %multiply.52), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3500 = f32[2048]{0} multiply(f32[2048]{0} %divide.3363, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3533 = f32[2048]{0} add(f32[2048]{0} %parameter.116, f32[2048]{0} %multiply.3500), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3562 = f16[2048]{0} convert(f32[2048]{0} %add.3533), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3671 = f16[2048]{0} multiply(f16[2048]{0} %parameter.28, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3672 = f16[2048]{0} add(f16[2048]{0} %convert.3562, f16[2048]{0} %multiply.3671), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.117 = f32[2048]{0} parameter(116), sharding={replicated}
%reduce.1480 = f16[2048]{0} reduce(f16[8,1024,2048]{2,1,0} %convert.1473, f16[] %constant.1015), dimensions={0,1}, to_apply=%primitive_computation_add__35.1475, sharding={replicated}, metadata={op_type="reduce_sum" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/reduce_sum[\n axes=(0, 1)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=181}
%multiply.2674 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1480, f16[2048]{0} %broadcast.110), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2678 = f32[2048]{0} convert(f16[2048]{0} %multiply.2674), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.59 = f32[2048]{0} parameter(58), sharding={replicated}
%multiply.2677 = f32[2048]{0} multiply(f32[2048]{0} %parameter.59, f32[2048]{0} %broadcast.2468), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2679 = f32[2048]{0} add(f32[2048]{0} %convert.2678, f32[2048]{0} %multiply.2677), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2931 = f16[2048]{0} multiply(f16[2048]{0} %reduce.1480, f16[2048]{0} %reduce.1480), sharding={replicated}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2934 = f16[2048]{0} multiply(f16[2048]{0} %multiply.2931, f16[2048]{0} %broadcast.2699), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2938 = f32[2048]{0} convert(f16[2048]{0} %multiply.2934), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.88 = f32[2048]{0} parameter(87), sharding={replicated}
%multiply.2937 = f32[2048]{0} multiply(f32[2048]{0} %parameter.88, f32[2048]{0} %broadcast.2702), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2939 = f32[2048]{0} add(f32[2048]{0} %convert.2938, f32[2048]{0} %multiply.2937), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3144 = f32[2048]{0} divide(f32[2048]{0} %add.2939, f32[2048]{0} %broadcast.3065), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3367 = f32[2048]{0} sqrt(f32[2048]{0} %divide.3144), sharding={replicated}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3370 = f32[2048]{0} add(f32[2048]{0} %sqrt.3367, f32[2048]{0} %broadcast.3161), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.53 = f32[2048]{0} multiply(f32[2048]{0} %broadcast.2973, f32[2048]{0} %add.3370), sharding={replicated}
%divide.3371 = f32[2048]{0} divide(f32[2048]{0} %add.2679, f32[2048]{0} %multiply.53), sharding={replicated}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3503 = f32[2048]{0} multiply(f32[2048]{0} %divide.3371, f32[2048]{0} %broadcast.3424), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3534 = f32[2048]{0} add(f32[2048]{0} %parameter.117, f32[2048]{0} %multiply.3503), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3563 = f16[2048]{0} convert(f32[2048]{0} %add.3534), sharding={replicated}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3675 = f16[2048]{0} multiply(f16[2048]{0} %parameter.29, f16[2048]{0} %broadcast.3570), sharding={replicated}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3676 = f16[2048]{0} add(f16[2048]{0} %convert.3563, f16[2048]{0} %multiply.3675), sharding={replicated}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%parameter.118 = f32[8192,2048]{1,0} parameter(117), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%transpose.46 = f16[2048,8,1024]{0,2,1} transpose(f16[8,1024,2048]{2,1,0} %convert.1473), dimensions={2,0,1}, sharding={devices=[1,2,1,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%reshape.51 = f16[2048,8192]{1,0} reshape(f16[2048,8,1024]{0,2,1} %transpose.46), sharding={devices=[1,2,8]0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15 last_tile_dim_replicate}
%dot.80 = f16[8192,2048]{0,1} dot(f16[8192,8192]{1,0} %reshape.48, f16[2048,8192]{1,0} %reshape.51), lhs_contracting_dims={0}, rhs_contracting_dims={1}, sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="transpose" op_name="parallelize(train_step_shard_parallel)/remat(jvp(core_fn))/transpose[\n permutation=(1, 0)\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/flax/linen/linear.py" source_line=177}
%multiply.2682 = f16[8192,2048]{1,0} multiply(f16[8192,2048]{0,1} %dot.80, f16[8192,2048]{1,0} %broadcast.2585), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2686 = f32[8192,2048]{1,0} convert(f16[8192,2048]{1,0} %multiply.2682), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.60 = f32[8192,2048]{1,0} parameter(59), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2685 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %parameter.60, f32[8192,2048]{1,0} %broadcast.2588), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2687 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %convert.2686, f32[8192,2048]{1,0} %multiply.2685), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2940 = f16[8192,2048]{0,1} multiply(f16[8192,2048]{0,1} %dot.80, f16[8192,2048]{0,1} %dot.80), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="integer_pow" op_name="parallelize(train_step_shard_parallel)/integer_pow[y=2]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%multiply.2943 = f16[8192,2048]{1,0} multiply(f16[8192,2048]{0,1} %multiply.2940, f16[8192,2048]{1,0} %broadcast.2834), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%convert.2947 = f32[8192,2048]{1,0} convert(f16[8192,2048]{1,0} %multiply.2943), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float32\n weak_type=False\n]" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%parameter.89 = f32[8192,2048]{1,0} parameter(88), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%multiply.2946 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %parameter.89, f32[8192,2048]{1,0} %broadcast.2837), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%add.2948 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %convert.2947, f32[8192,2048]{1,0} %multiply.2946), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=81}
%divide.3147 = f32[8192,2048]{1,0} divide(f32[8192,2048]{1,0} %add.2948, f32[8192,2048]{1,0} %broadcast.3110), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=87}
%sqrt.3375 = f32[8192,2048]{1,0} sqrt(f32[8192,2048]{1,0} %divide.3147), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="sqrt" op_name="parallelize(train_step_shard_parallel)/sqrt" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%add.3378 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %sqrt.3375, f32[8192,2048]{1,0} %broadcast.3281), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.54 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %broadcast.3018, f32[8192,2048]{1,0} %add.3378), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}
%divide.3379 = f32[8192,2048]{1,0} divide(f32[8192,2048]{1,0} %add.2687, f32[8192,2048]{1,0} %multiply.54), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="div" op_name="parallelize(train_step_shard_parallel)/div" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=302}
%multiply.3418 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %parameter.118, f32[8192,2048]{1,0} %broadcast.3401), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%add.3419 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %divide.3379, f32[8192,2048]{1,0} %multiply.3418), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=571}
%multiply.3506 = f32[8192,2048]{1,0} multiply(f32[8192,2048]{1,0} %add.3419, f32[8192,2048]{1,0} %broadcast.3469), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/transform.py" source_line=328}
%add.3535 = f32[8192,2048]{1,0} add(f32[8192,2048]{1,0} %parameter.118, f32[8192,2048]{1,0} %multiply.3506), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/.local/lib/python3.7/site-packages/optax/_src/update.py" source_line=43}
%convert.3564 = f16[8192,2048]{1,0} convert(f32[8192,2048]{1,0} %add.3535), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="convert_element_type" op_name="parallelize(train_step_shard_parallel)/convert_element_type[\n new_dtype=float16\n weak_type=False\n]" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=327}
%multiply.3679 = f16[8192,2048]{1,0} multiply(f16[8192,2048]{1,0} %parameter.30, f16[8192,2048]{1,0} %broadcast.3630), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="mul" op_name="parallelize(train_step_shard_parallel)/mul" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
%add.3680 = f16[8192,2048]{1,0} add(f16[8192,2048]{1,0} %convert.3564, f16[8192,2048]{1,0} %multiply.3679), sharding={devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, metadata={op_type="add" op_name="parallelize(train_step_shard_parallel)/add" source_file="/home/ubuntu/efs/alpa/alpa/model/model_util.py" source_line=334}
ROOT %tuple.3683 = (s32[], f16[51200]{0}, f16[2048]{0}, f16[2048]{0}, f16[1024,2048]{1,0}, /*index=5*/f16[51200,2048]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[2048,2048]{1,0}, /*index=10*/f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, f16[2048,8192]{1,0}, f16[2048]{0}, /*index=15*/f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, f16[2048]{0}, f16[2048]{0}, /*index=20*/f16[2048]{0}, f16[2048,2048]{1,0}, f16[6144]{0}, f16[2048,6144]{1,0}, f16[8192]{0}, /*index=25*/f16[2048,8192]{1,0}, f16[2048]{0}, f16[2048]{0}, f16[2048]{0}, f16[8192,2048]{1,0}, /*index=30*/s32[], f32[51200]{0}, f32[2048]{0}, f32[2048]{0}, f32[1024,2048]{1,0}, /*index=35*/f32[51200,2048]{1,0}, f32[2048]{0}, f32[2048]{0}, f32[2048]{0}, f32[2048,2048]{1,0}, /*index=40*/f32[6144]{0}, f32[2048,6144]{1,0}, f32[8192]{0}, f32[2048,8192]{1,0}, f32[2048]{0}, /*index=45*/f32[2048]{0}, f32[2048]{0}, f32[8192,2048]{1,0}, f32[2048]{0}, f32[2048]{0}, /*index=50*/f32[2048]{0}, f32[2048,2048]{1,0}, f32[6144]{0}, f32[2048,6144]{1,0}, f32[8192]{0}, /*index=55*/f32[2048,8192]{1,0}, f32[2048]{0}, f32[2048]{0}, f32[2048]{0}, f32[8192,2048]{1,0}, /*index=60*/f32[51200]{0}, f32[2048]{0}, f32[2048]{0}, f32[1024,2048]{1,0}, f32[51200,2048]{1,0}, /*index=65*/f32[2048]{0}, f32[2048]{0}, f32[2048]{0}, f32[2048,2048]{1,0}, f32[6144]{0}, /*index=70*/f32[2048,6144]{1,0}, f32[8192]{0}, f32[2048,8192]{1,0}, f32[2048]{0}, f32[2048]{0}, /*index=75*/f32[2048]{0}, f32[8192,2048]{1,0}, f32[2048]{0}, f32[2048]{0}, f32[2048]{0}, /*index=80*/f32[2048,2048]{1,0}, f32[6144]{0}, f32[2048,6144]{1,0}, f32[8192]{0}, f32[2048,8192]{1,0}, /*index=85*/f32[2048]{0}, f32[2048]{0}, f32[2048]{0}, f32[8192,2048]{1,0}, f32[51200]{0}, /*index=90*/f32[2048]{0}, f32[2048]{0}, f32[1024,2048]{1,0}, f32[51200,2048]{1,0}, f32[2048]{0}, /*index=95*/f32[2048]{0}, f32[2048]{0}, f32[2048,2048]{1,0}, f32[6144]{0}, f32[2048,6144]{1,0}, /*index=100*/f32[8192]{0}, f32[2048,8192]{1,0}, f32[2048]{0}, f32[2048]{0}, f32[2048]{0}, /*index=105*/f32[8192,2048]{1,0}, f32[2048]{0}, f32[2048]{0}, f32[2048]{0}, f32[2048,2048]{1,0}, /*index=110*/f32[6144]{0}, f32[2048,6144]{1,0}, f32[8192]{0}, f32[2048,8192]{1,0}, f32[2048]{0}, /*index=115*/f32[2048]{0}, f32[2048]{0}, f32[8192,2048]{1,0}) tuple(s32[] %add.3682, f16[51200]{0} %add.3568, f16[2048]{0} %add.3572, f16[2048]{0} %add.3576, f16[1024,2048]{1,0} %add.3580, /*index=5*/f16[51200,2048]{1,0} %add.3584, f16[2048]{0} %add.3588, f16[2048]{0} %add.3592, f16[2048]{0} %add.3596, f16[2048,2048]{1,0} %add.3600, /*index=10*/f16[6144]{0} %add.3604, f16[2048,6144]{1,0} %add.3608, f16[8192]{0} %add.3612, f16[2048,8192]{1,0} %add.3616, f16[2048]{0} %add.3620, /*index=15*/f16[2048]{0} %add.3624, f16[2048]{0} %add.3628, f16[8192,2048]{1,0} %add.3632, f16[2048]{0} %add.3636, f16[2048]{0} %add.3640, /*index=20*/f16[2048]{0} %add.3644, f16[2048,2048]{1,0} %add.3648, f16[6144]{0} %add.3652, f16[2048,6144]{1,0} %add.3656, f16[8192]{0} %add.3660, /*index=25*/f16[2048,8192]{1,0} %add.3664, f16[2048]{0} %add.3668, f16[2048]{0} %add.3672, f16[2048]{0} %add.3676, f16[8192,2048]{1,0} %add.3680, /*index=30*/s32[] %select.1, f32[51200]{0} %add.2463, f32[2048]{0} %add.2471, f32[2048]{0} %add.2479, f32[1024,2048]{1,0} %add.2487, /*index=35*/f32[51200,2048]{1,0} %add.2495, f32[2048]{0} %add.2503, f32[2048]{0} %add.2511, f32[2048]{0} %add.2519, f32[2048,2048]{1,0} %add.2527, /*index=40*/f32[6144]{0} %add.2535, f32[2048,6144]{1,0} %add.2543, f32[8192]{0} %add.2551, f32[2048,8192]{1,0} %add.2559, f32[2048]{0} %add.2567, /*index=45*/f32[2048]{0} %add.2575, f32[2048]{0} %add.2583, f32[8192,2048]{1,0} %add.2591, f32[2048]{0} %add.2599, f32[2048]{0} %add.2607, /*index=50*/f32[2048]{0} %add.2615, f32[2048,2048]{1,0} %add.2623, f32[6144]{0} %add.2631, f32[2048,6144]{1,0} %add.2639, f32[8192]{0} %add.2647, /*index=55*/f32[2048,8192]{1,0} %add.2655, f32[2048]{0} %add.2663, f32[2048]{0} %add.2671, f32[2048]{0} %add.2679, f32[8192,2048]{1,0} %add.2687, /*index=60*/f32[51200]{0} %add.2696, f32[2048]{0} %add.2705, f32[2048]{0} %add.2714, f32[1024,2048]{1,0} %add.2723, f32[51200,2048]{1,0} %add.2732, /*index=65*/f32[2048]{0} %add.2741, f32[2048]{0} %add.2750, f32[2048]{0} %add.2759, f32[2048,2048]{1,0} %add.2768, f32[6144]{0} %add.2777, /*index=70*/f32[2048,6144]{1,0} %add.2786, f32[8192]{0} %add.2795, f32[2048,8192]{1,0} %add.2804, f32[2048]{0} %add.2813, f32[2048]{0} %add.2822, /*index=75*/f32[2048]{0} %add.2831, f32[8192,2048]{1,0} %add.2840, f32[2048]{0} %add.2849, f32[2048]{0} %add.2858, f32[2048]{0} %add.2867, /*index=80*/f32[2048,2048]{1,0} %add.2876, f32[6144]{0} %add.2885, f32[2048,6144]{1,0} %add.2894, f32[8192]{0} %add.2903, f32[2048,8192]{1,0} %add.2912, /*index=85*/f32[2048]{0} %add.2921, f32[2048]{0} %add.2930, f32[2048]{0} %add.2939, f32[8192,2048]{1,0} %add.2948, f32[51200]{0} %add.3507, /*index=90*/f32[2048]{0} %add.3508, f32[2048]{0} %add.3509, f32[1024,2048]{1,0} %add.3510, f32[51200,2048]{1,0} %add.3511, f32[2048]{0} %add.3512, /*index=95*/f32[2048]{0} %add.3513, f32[2048]{0} %add.3514, f32[2048,2048]{1,0} %add.3515, f32[6144]{0} %add.3516, f32[2048,6144]{1,0} %add.3517, /*index=100*/f32[8192]{0} %add.3518, f32[2048,8192]{1,0} %add.3519, f32[2048]{0} %add.3520, f32[2048]{0} %add.3521, f32[2048]{0} %add.3522, /*index=105*/f32[8192,2048]{1,0} %add.3523, f32[2048]{0} %add.3524, f32[2048]{0} %add.3525, f32[2048]{0} %add.3526, f32[2048,2048]{1,0} %add.3527, /*index=110*/f32[6144]{0} %add.3528, f32[2048,6144]{1,0} %add.3529, f32[8192]{0} %add.3530, f32[2048,8192]{1,0} %add.3531, f32[2048]{0} %add.3532, /*index=115*/f32[2048]{0} %add.3533, f32[2048]{0} %add.3534, f32[8192,2048]{1,0} %add.3535), sharding={{replicated}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=5*/{devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=10*/{devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, /*index=15*/{replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, /*index=20*/{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=25*/{devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=30*/{replicated}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=35*/{devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=40*/{devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, /*index=45*/{replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, /*index=50*/{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=55*/{devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=60*/{devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=65*/{replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=70*/{devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, /*index=75*/{replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, /*index=80*/{devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=85*/{replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=90*/{replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, /*index=95*/{replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=100*/{devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, /*index=105*/{devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, {replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, /*index=110*/{devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {devices=[1,8,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}, {replicated}, /*index=115*/{replicated}, {replicated}, {devices=[8,1,2]0,8,1,9,2,10,3,11,4,12,5,13,6,14,7,15 last_tile_dim_replicate}}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment