Last active
August 19, 2018 16:31
-
-
Save bgobbi/042bbee9b0bedf70995f88c6ce3a9208 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import torch | |
import sys | |
tin = [ 0, 0, 4, 6, 6, 8, 11, 12, 13, 14, 14, 14, 15, 15, 15, 19, 24, 25, 25, | |
725, 726, 726, 727, 727, 727, 728, 728, 730, 731, 732, 733, 734, 735, 737, 737, 739, 739, 739, 740, 740, 742, 743, 744, 747, 748, 748, 750, 752, 753, 757, 757, 758, 758, 759, 761, 761, 762, 762, 764, 765, 766, 766, 766, 767, 767, 772, 773, 773, 774, 775, 775, 775, 776, 778, 779, 779, 779, 780, 781, 781, 781, 784, 785, 785, 786, 790, 790, 791, 791, 793, 793, 793, 795, 796, 796, 797, 798, 798, 800, 800, 801, 803, 804, 806, 807, 807, 808, 809, 809, 810, 811, 811, 812, 813, 813, 814, 814, 814, 816, 816, 821, 823, 824, 828, 828, 828, 829, 830, 830, 831, 831, 833, 836, 837, 838, 839, 839, 840, 840, 841, 843, 843, 843, 843, 844, 845, 846, 846, 846, 847, 847, 848, 849, 853, 854, 854, 857, 859, 862, 864, 864, 867, 869, 869, 872, 874, 875, 876, 876, 877, 880, 880, 881, 881, 882, 882, 883, 884, 884, 884, 885, 886, 887, 889, 890, 890, 891, 893, 896, 897, 898, 898, 898, 899, 900, 900, 901, 902, 904, 905, 906, 908, 908, 909, 910, 910, 910, 911, 913, 914, 915, 915, 916, 916, 917, 917, 918, 918, 921, 921, 922, 923, 924, 925, | |
926, 928, 928, 928, 929, 929, 930, 931, 931, 931, 933, 934, 934, 935, 935, 936, 937, 938, 939, 940, 941, 941, 942, 943, 943, 944, 944, 945, 945, 948, 950, 950, 950, 951, 952, 952, 954, 955, 956, 957, 957, 959, 959, 960, 961, 962, 963, 964, 964, 966, 966, 966, 967, 967, 968, 969, 969, 971, 971, 973, 974, 975, 975, 975, 976, 978, 978, 979, 980, 983, 983, 984, 985, 986, 986, 987, 988, 988, 988, 991, 991, 992, 993, 993, 994, 994, 996, 997, 997, 998, 999, 999, 1001, 1001, 1002, 1002, 1003, 1003, 1004, 1005, 1006, 1007, 1008, 1010, 1011, 1013, 1013, 1016, 1016, 1018, 1018, 1021, 1023, 1023, 1024, 1024, 1025, | |
1026, 1026, 1027, 1027, 1028, 1030, 1031, 1032, 1033, 1033, 1033, 1034, 1034, 1036, 1037, 1037, 1037, 1038, 1038, 1039, 1039, 1039, 1040, 1041, 1042, 1044, 1045, 1047, 1048, 1049, 1049, 1050, 1050, 1051, 1053, 1054, 1055, 1057, 1057, 1058, 1059, 1059, 1062, 1064, 1064, 1068, 1070, 1071, 1071, 1073, 1073, 1074, 1074, 1075, 1075, 1076, 1077, 1080, 1081, 1082, 1083, 1084, 1086, 1088, 1088, 1089, 1089, 1089, 1091, 1091, 1092, 1093, 1094, 1094, 1095, 1095, 1096, 1097, 1098, 1098, 1099, 1099, 1100, 1103, 1103, 1104, 1104, 1105, 1106, 1107, 1108, 1108, 1109, 1111, 1111, 1112, 1113, 1114, 1114, 1114, 1115, 1117, 1117, 1118, 1119, 1119, 1120, 1120, 1120, 1121, 1123, 1123, 1123, 1123, 1124, 1124, 1124, 1124, 1126, 1127, 1129, 1129, 1132, 1134, 1137, 1138, 1139, 1140, 1141, 1143, 1144, 1147, 1147, 1149, 1149, 1151, 1154, 1156, 1158, 1159, 1159, 1161, 1162, 1162, 1163, 1164, 1164, 1166, 1166, 1167, 1169, 1169, 1170, 1170, 1171, 1171, 1172, 1175, 1176, 1177, 1178, 1179, 1179, 1182, 1184, 1190, 1190, 1192, 1193, 1193, 1194, 1195, 1196, 1197, 1197, 1198, 1199, 1201, 1201, 1202, 1203, 1204, 1206, 1207, 1208, 1209, 1210, 1210, 1212, 1213, 1213, 1216, 1218, 1219, 1220, 1221, 1222, 1223, 1223, 1224, 1227, 1228, 1228, 1230, 1230, 1234, 1234, 1236, 1238, 1238, 1240, 1241, 1242, 1242, 1244, 1246, 1247, 1249, 1250, 1251, 1251, 1252, 1253, 1254, 1256, 1257, 1257, 1258, 1258, 1259, 1262, 1262, 1262, 1263, 1265, 1265, 1266, 1266, 1267, 1267, 1269, 1272, 1273, 1274, 1274, 1275, 1275, 1276, 1277, 1277, 1278, 1278, 1279, 1279, 1280, 1280, 1283, 1284, 1284, 1286, 1286, 1287, 1287, 1288, 1288, 1288, 1289, 1289, 1291, 1292, 1292, 1294, 1295, 1296, 1297, 1298, 1299, 1299, 1301, 1304, 1304, 1305, 1307, 1308, 1311, 1312, 1312, 1313, 1317, 1318, 1319, 1319, 1320, 1321, 1322, 1323, 1324, 1324, 1325, | |
1325, 1325, 1326, 1326, 1327, 1329, 1330, 1331, 1332, 1332, 1334, 1334, 1336, 1336, 1338, 1338, 1338, 1339, 1341, 1341, 1341, 1342, 1343, 1343, 1343, 1343, 1344, 1346, 1346, 1349, 1349, 1352, 1352, 1355, 1355, 1356, 1357, 1358, 1361, 1362, 1362, 1364, 1364, 1364, 1366, 1366, 1369, 1370, 1370, 1371, 1374, 1375, 1375, 1376, 1376, 1377, 1377, 1377, 1378, 1380, 1381, 1382, 1383, 1384, 1385, 1385, 1385, 1387, 1388, 1388, 1390, 1391, 1391, 1393, 1394, 1396, 1398, 1399, 1400, 1401, 1404, 1405, 1407, 1410, 1410, 1411, 1414, 1414, 1415, 1416, 1416, 1417, 1417, 1418, 1418, 1421, 1422, 1423, 1427, 1427, 1428, 1431, 1433, 1433, 1434, 1434, 1436, 1436, 1437, 1437, 1438, 1439, 1439, 1441, 1441, 1442, 1442, 1443, 1446, 1447, 1449, 1450, 1451, 1452, 1454, 1455, 1456, 1457, 1457, 1458, 1459, 1463, 1464, 1465, 1465, 1466, 1467, 1467, 1469, 1470, 1471, 1472, 1472, 1473, 1475, 1478, 1478, 1480, 1481, 1483, 1485, 1487, 1487, 1488, 1488, 1489, 1490, 1492, 1492, 1494, 1495, 1495, 1496, 1497, 1498, 1499, 1501, 1501, 1502, 1502, 1503, 1503, 1503, 1504, 1504, 1505, 1505, 1505, 1506, 1506, 1507, 1508, 1509, 1510, 1511, 1515, 1516, 1516, 1517, 1518, 1519, 1520, 1520, 1522, 1523, 1524, 1524, 1524, 1525, | |
1527, 1527, 1530, 1531, 1536, 1537, 1538, 1538, 1538, 1539, 1539, 1540, 1541, 1541, 1543, 1545, 1545, 1546, 1548, 1549, 1550, 1552, 1552, 1553, 1553, 1556, 1557, 1558, 1559, 1561, 1561, 1562, 1563, 1564, 1565, 1565, 1566, 1568, 1569, 1569, 1570, 1570, 1571, 1572, 1572, 1573, 1574, 1574, 1577, 1578, 1578, 1579, 1580, 1581, 1581, 1582, 1583, 1583, 1584, 1585, 1586, 1587, 1589, 1589, 1590, 1590, 1592, 1592, 1595, 1596, 1596, 1597, 1598, 1600, 1601, 1602, 1602, 1603, 1604, 1605, 1609, 1609, 1610, 1613, 1614, 1615, 1616, 1617, 1618, 1620, 1620, 1621, 1622, 1623, 1625, | |
1628, 1628, 1630, 1631, 1632, 1633, 1633, 1635, 1638, 1639, 1640, 1642, 1643, 1644, 1646, 1647, 1647, 1649, 1649, 1651, 1652, 1652, 1652, 1654, 1655, 1656, 1656, 1657, 1658, 1658, 1661, 1661, 1662, 1663, 1663, 1664, 1664, 1665, 1665, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1671, 1672, 1673, 1676, 1677, 1678, 1679, 1680, 1681, 1683, 1685, 1685, 1687, 1687, 1688, 1691, 1691, 1692, 1693, 1694, 1694, 1696, 1696, 1698, 1698, 1700, 1700, 1701, 1701, 1701, 1703, 1704, 1704, 1706, 1707, 1708, 1709, 1711, 1712, 1712, 1715, 1716, 1716, 1717, 1718, 1718, 1719, 1719, 1720, 1722, 1724, 1725, | |
1726, 1727, 1728, 1730, 1735, 1740, 1741, 1744, 1744, 1745, 1745, 1747, 1748, 1750, 1751, 1752, 1752, 1754, 1756, 1756, 1759, 1759, 1760, 1760, 1763, 1764, 1765, 1765, 1766, 1767, 1769, 1769, 1772, 1773, 1773, 1774, 1775, 1776, 1776, 1777, 1779, 1780, 1780, 1781, 1781, 1783, 1784, 1784, 1785, 1785, 1786, 1788, 1788, 1789, 1790, 1790, 1790, 1791, 1794, 1794, 1796, 1796, 1796, 1797, 1797, 1798, 1798, 1799, 1800, 1800, 1800, 1802, 1802, 1804, 1805, 1805, 1806, 1808, 1809, 1809, 1812, 1813, 1814, 1816, 1816, 1819, 1819, 1820, 1823, 1824, 1826, 1827, 1827, 1828, 1828, 1829, 1830, 1831, 1831, 1831, 1833, 1833, 1834, 1836, 1838, 1839, 1839, 1839, 1840, 1842, 1844, 1847, 1848, 1853, 1853, 1854, 1861, 1862, 1865, 1866, 1869, 1869, 1870, 1871, 1872, 1874, 1876, 1877, 1878, 1879, 1880, 1882, 1882, 1882, 1883, 1883, 1884, 1884, 1885, 1885, 1886, 1888, 1890, 1891, 1893, 1893, 1894, 1894, 1894, 1895, 1895, 1897, 1898, 1900, 1900, 1901, 1901, 1902, 1903, 1906, 1908, 1910, 1911, 1912, 1912, 1917, 1917, 1917, 1918, 1919, 1919, 1920, 1920, 1921, 1921, 1922, 1922, 1924, 1925, | |
1925, 1926, 1928, 1932, 1933, 1934, 1934, 1935, 1935, 1936, 1937, 1938, 1938, 1939, 1940, 1941, 1941, 1943, 1944, 1947, 1948, 1948, 1948, 1949, 1951, 1954, 1955, 1957, 1959, 1960, 1967, 1970, 1971, 1973, 1974, 1984, 1985, 1985, 1986, 1987, 1987, 1988, 1989, 1990, 1991, 1992, 1992, 1993, 1994, 1994, 1997, 1999, 1999, 2001, 2002, 2003, 2003, 2004, 2005, 2005, 2007, 2008, 2008, 2009, 2011, 2011, 2013, 2015, 2015, 2016, 2017, 2018, 2020, 2020, 2021, 2022, 2024, 2027, 2029, 2031, 2032, 2034, 2034, 2035, 2036, 2037, 2038, 2038, 2039, 2044, 2045, 2045, 2046, 2048, 2048, 2053, 2053, 2055, 2058, 2059, 2059, 2060, 2062, 2066, 2069, 2070, 2071, 2075, 2075, 2076, 2077, 2078, 2085, 2086, 2089, 2091, 2092, 2096, 2099, 2102, 2102, 2105, 2107, 2108, 2110, 2110, 2112, 2113, 2114, 2114, 2118, 2124, 2125, | |
2127, 2128, 2129, 2134, 2134, 2135, 2136, 2143, 2144, 2145, 2149, 2150, 2154, 2156, 2158, 2159, 2160, 2161, 2163, 2164, 2164, 2165, 2166, 2167, 2169, 2170, 2173, 2174, 2174, 2175, 2176, 2181, 2182, 2184, 2190, 2191, 2194, 2196, 2197, 2197, 2198, 2198, 2199, 2200, 2201, 2202, 2205, 2208, 2212, 2215, 2215, 2216, 2217, 2217, 2219, 2219, 2220, 2220, 2221, 2221, 2225, | |
2226, 2227, 2227, 2228, 2230, 2231, 2233, 2234, 2238, 2238, 2240, 2246, 2246, 2247, 2250, 2251, 2253, 2255, 2256, 2259, 2261, 2266, 2267, 2268, 2269, 2269, 2270, 2270, 2274, 2274, 2281, 2283, 2286, 2288, 2290, 2291, 2292, 2293, 2295, 2295, 2297, 2300, 2301, 2301, 2303, 2303, 2304, 2308, 2309, 2310, 2311, 2313, 2315, 2315, 2316, 2317, 2320, 2321, 2325, | |
2325, 2327, 2328, 2329, 2330, 2330, 2335, 2339, 2341, 2341, 2342, 2342, 2342, 2344, 2344, 2345, 2345, 2346, 2347, 2348, 2351, 2353, 2355, 2357, 2359, 2360, 2368, 2370, 2372, 2373, 2375, 2379, 2382, 2385, 2386, 2387, 2390, 2390, 2391, 2393, 2397, 2402, 2403, 2403, 2404, 2405, 2408, 2409, 2410, 2414, 2414, 2415, 2416, 2418, 2420, 2421, 2424, 2426, 2428, 2429, 2430, 2431, 2433, 2434, 2435, 2436, 2440, 2443, 2443, 2446, 2446, 2447, 2448, 2449, 2450, 2450, 2453, 2454, 2456, 2458, 2460, 2462, 2463, 2474, 2474, 2480, 2481, 2484, 2486, 2487, 2490, 2493, 2494, 2494, 2497, 2501, 2502, 2503, 2503, 2505, 2508, 2508, 2513, 2513, 2514, 2518, 2524, 2525, | |
2527, 2529, 2531, 2543, 2554] | |
def firstLast(tnsr,cFi, cLa): | |
tnsrShiftedDown = torch.empty_like(tnsr) | |
tnsrShiftedDown[1:] = tnsr[:-1] | |
tnsrShiftedDown[0] = -1 # shifted | |
isFirst = tnsrShiftedDown != tnsr # isFirst(!=) | |
####################################################################### | |
# commenting out this line makes the coda return the correct result | |
cs = isFirst.cumsum(0) | |
isLast = isFirst.clone(); | |
assert isLast[0] == 1, "First element of first group shoul always be [0]" | |
isLast[:-1] = isLast[1:]; | |
isLast[-1] = 1 # isLast | |
# isFirst is index into first Element of group and isLast into last | |
# they should have equal number of ones | |
if isLast.sum().cpu().item() != isFirst.sum().cpu().item(): | |
print("countOnes isFirst=%s isLast=%s" %(isFirst.sum().cpu().item(),isLast.sum().cpu().item())) | |
if isinstance(cFi, type(isFirst)): | |
diff = isFirst-cFi | |
print("No mismatches in isFirst %s %s %s" % ( diff.argmax(), diff.argmin(), diff.sum())) | |
diff = isLast-cLa | |
print("pos of mismatch in isLast %s %s %s" % ( diff.argmax(), diff.argmin(), diff.sum())) | |
print("") | |
elif not isinstance(cFi, type(isFirst)): | |
cFi = isFirst.clone() | |
cLa = isLast.clone() | |
return isFirst, isLast, cFi, cLa | |
def test(device,tin): | |
tin = torch.tensor(tin, device=device, dtype=torch.long) | |
cFi = None | |
cLa = None | |
cntEq = 0 | |
cntNEq = 0 | |
for i in range(0,20000): | |
fi, la, cFi, cLa = firstLast(tin, cFi, cLa) | |
if fi.sum().cpu().item() == la.sum().cpu().item(): | |
cntEq += 1 | |
else: | |
cntNEq += 1 | |
return cntEq, cntNEq | |
devices = [torch.device("cpu")] | |
if torch.cuda.is_available(): devices.append(torch.device("cuda")) | |
for dev in devices: | |
print("%s %s" % (dev, test(dev, tin))) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
The code in this gist does some indexing and simple arithmetic with the exact same inputs 20 000 times.
It is supposed to find the index of the first and last elements in groups of consecutive elements.
When run on a cpu it always returns the same correct result.
However when run on cuda it returns a wrong result 4-20 times out of 20 000 iterations.
I paired the code down as much as I could, however here are a few points:
Could this be related to me moving elements onto overlapping regions of the tensor (lines 20 and 31)?
Any other ideas?
A typical result I see is below.
The first row shows that the CPU version correctly 20000/20000 times.
The last row shows that the GPU version gave incorrect results 6 out of 20000 times