Skip to content

Instantly share code, notes, and snippets.

@bgobbi
Last active August 19, 2018 16:31
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save bgobbi/042bbee9b0bedf70995f88c6ce3a9208 to your computer and use it in GitHub Desktop.
Save bgobbi/042bbee9b0bedf70995f88c6ce3a9208 to your computer and use it in GitHub Desktop.
import torch
import sys
tin = [ 0, 0, 4, 6, 6, 8, 11, 12, 13, 14, 14, 14, 15, 15, 15, 19, 24, 25, 25,
725, 726, 726, 727, 727, 727, 728, 728, 730, 731, 732, 733, 734, 735, 737, 737, 739, 739, 739, 740, 740, 742, 743, 744, 747, 748, 748, 750, 752, 753, 757, 757, 758, 758, 759, 761, 761, 762, 762, 764, 765, 766, 766, 766, 767, 767, 772, 773, 773, 774, 775, 775, 775, 776, 778, 779, 779, 779, 780, 781, 781, 781, 784, 785, 785, 786, 790, 790, 791, 791, 793, 793, 793, 795, 796, 796, 797, 798, 798, 800, 800, 801, 803, 804, 806, 807, 807, 808, 809, 809, 810, 811, 811, 812, 813, 813, 814, 814, 814, 816, 816, 821, 823, 824, 828, 828, 828, 829, 830, 830, 831, 831, 833, 836, 837, 838, 839, 839, 840, 840, 841, 843, 843, 843, 843, 844, 845, 846, 846, 846, 847, 847, 848, 849, 853, 854, 854, 857, 859, 862, 864, 864, 867, 869, 869, 872, 874, 875, 876, 876, 877, 880, 880, 881, 881, 882, 882, 883, 884, 884, 884, 885, 886, 887, 889, 890, 890, 891, 893, 896, 897, 898, 898, 898, 899, 900, 900, 901, 902, 904, 905, 906, 908, 908, 909, 910, 910, 910, 911, 913, 914, 915, 915, 916, 916, 917, 917, 918, 918, 921, 921, 922, 923, 924, 925,
926, 928, 928, 928, 929, 929, 930, 931, 931, 931, 933, 934, 934, 935, 935, 936, 937, 938, 939, 940, 941, 941, 942, 943, 943, 944, 944, 945, 945, 948, 950, 950, 950, 951, 952, 952, 954, 955, 956, 957, 957, 959, 959, 960, 961, 962, 963, 964, 964, 966, 966, 966, 967, 967, 968, 969, 969, 971, 971, 973, 974, 975, 975, 975, 976, 978, 978, 979, 980, 983, 983, 984, 985, 986, 986, 987, 988, 988, 988, 991, 991, 992, 993, 993, 994, 994, 996, 997, 997, 998, 999, 999, 1001, 1001, 1002, 1002, 1003, 1003, 1004, 1005, 1006, 1007, 1008, 1010, 1011, 1013, 1013, 1016, 1016, 1018, 1018, 1021, 1023, 1023, 1024, 1024, 1025,
1026, 1026, 1027, 1027, 1028, 1030, 1031, 1032, 1033, 1033, 1033, 1034, 1034, 1036, 1037, 1037, 1037, 1038, 1038, 1039, 1039, 1039, 1040, 1041, 1042, 1044, 1045, 1047, 1048, 1049, 1049, 1050, 1050, 1051, 1053, 1054, 1055, 1057, 1057, 1058, 1059, 1059, 1062, 1064, 1064, 1068, 1070, 1071, 1071, 1073, 1073, 1074, 1074, 1075, 1075, 1076, 1077, 1080, 1081, 1082, 1083, 1084, 1086, 1088, 1088, 1089, 1089, 1089, 1091, 1091, 1092, 1093, 1094, 1094, 1095, 1095, 1096, 1097, 1098, 1098, 1099, 1099, 1100, 1103, 1103, 1104, 1104, 1105, 1106, 1107, 1108, 1108, 1109, 1111, 1111, 1112, 1113, 1114, 1114, 1114, 1115, 1117, 1117, 1118, 1119, 1119, 1120, 1120, 1120, 1121, 1123, 1123, 1123, 1123, 1124, 1124, 1124, 1124, 1126, 1127, 1129, 1129, 1132, 1134, 1137, 1138, 1139, 1140, 1141, 1143, 1144, 1147, 1147, 1149, 1149, 1151, 1154, 1156, 1158, 1159, 1159, 1161, 1162, 1162, 1163, 1164, 1164, 1166, 1166, 1167, 1169, 1169, 1170, 1170, 1171, 1171, 1172, 1175, 1176, 1177, 1178, 1179, 1179, 1182, 1184, 1190, 1190, 1192, 1193, 1193, 1194, 1195, 1196, 1197, 1197, 1198, 1199, 1201, 1201, 1202, 1203, 1204, 1206, 1207, 1208, 1209, 1210, 1210, 1212, 1213, 1213, 1216, 1218, 1219, 1220, 1221, 1222, 1223, 1223, 1224, 1227, 1228, 1228, 1230, 1230, 1234, 1234, 1236, 1238, 1238, 1240, 1241, 1242, 1242, 1244, 1246, 1247, 1249, 1250, 1251, 1251, 1252, 1253, 1254, 1256, 1257, 1257, 1258, 1258, 1259, 1262, 1262, 1262, 1263, 1265, 1265, 1266, 1266, 1267, 1267, 1269, 1272, 1273, 1274, 1274, 1275, 1275, 1276, 1277, 1277, 1278, 1278, 1279, 1279, 1280, 1280, 1283, 1284, 1284, 1286, 1286, 1287, 1287, 1288, 1288, 1288, 1289, 1289, 1291, 1292, 1292, 1294, 1295, 1296, 1297, 1298, 1299, 1299, 1301, 1304, 1304, 1305, 1307, 1308, 1311, 1312, 1312, 1313, 1317, 1318, 1319, 1319, 1320, 1321, 1322, 1323, 1324, 1324, 1325,
1325, 1325, 1326, 1326, 1327, 1329, 1330, 1331, 1332, 1332, 1334, 1334, 1336, 1336, 1338, 1338, 1338, 1339, 1341, 1341, 1341, 1342, 1343, 1343, 1343, 1343, 1344, 1346, 1346, 1349, 1349, 1352, 1352, 1355, 1355, 1356, 1357, 1358, 1361, 1362, 1362, 1364, 1364, 1364, 1366, 1366, 1369, 1370, 1370, 1371, 1374, 1375, 1375, 1376, 1376, 1377, 1377, 1377, 1378, 1380, 1381, 1382, 1383, 1384, 1385, 1385, 1385, 1387, 1388, 1388, 1390, 1391, 1391, 1393, 1394, 1396, 1398, 1399, 1400, 1401, 1404, 1405, 1407, 1410, 1410, 1411, 1414, 1414, 1415, 1416, 1416, 1417, 1417, 1418, 1418, 1421, 1422, 1423, 1427, 1427, 1428, 1431, 1433, 1433, 1434, 1434, 1436, 1436, 1437, 1437, 1438, 1439, 1439, 1441, 1441, 1442, 1442, 1443, 1446, 1447, 1449, 1450, 1451, 1452, 1454, 1455, 1456, 1457, 1457, 1458, 1459, 1463, 1464, 1465, 1465, 1466, 1467, 1467, 1469, 1470, 1471, 1472, 1472, 1473, 1475, 1478, 1478, 1480, 1481, 1483, 1485, 1487, 1487, 1488, 1488, 1489, 1490, 1492, 1492, 1494, 1495, 1495, 1496, 1497, 1498, 1499, 1501, 1501, 1502, 1502, 1503, 1503, 1503, 1504, 1504, 1505, 1505, 1505, 1506, 1506, 1507, 1508, 1509, 1510, 1511, 1515, 1516, 1516, 1517, 1518, 1519, 1520, 1520, 1522, 1523, 1524, 1524, 1524, 1525,
1527, 1527, 1530, 1531, 1536, 1537, 1538, 1538, 1538, 1539, 1539, 1540, 1541, 1541, 1543, 1545, 1545, 1546, 1548, 1549, 1550, 1552, 1552, 1553, 1553, 1556, 1557, 1558, 1559, 1561, 1561, 1562, 1563, 1564, 1565, 1565, 1566, 1568, 1569, 1569, 1570, 1570, 1571, 1572, 1572, 1573, 1574, 1574, 1577, 1578, 1578, 1579, 1580, 1581, 1581, 1582, 1583, 1583, 1584, 1585, 1586, 1587, 1589, 1589, 1590, 1590, 1592, 1592, 1595, 1596, 1596, 1597, 1598, 1600, 1601, 1602, 1602, 1603, 1604, 1605, 1609, 1609, 1610, 1613, 1614, 1615, 1616, 1617, 1618, 1620, 1620, 1621, 1622, 1623, 1625,
1628, 1628, 1630, 1631, 1632, 1633, 1633, 1635, 1638, 1639, 1640, 1642, 1643, 1644, 1646, 1647, 1647, 1649, 1649, 1651, 1652, 1652, 1652, 1654, 1655, 1656, 1656, 1657, 1658, 1658, 1661, 1661, 1662, 1663, 1663, 1664, 1664, 1665, 1665, 1665, 1666, 1667, 1668, 1669, 1670, 1671, 1671, 1672, 1673, 1676, 1677, 1678, 1679, 1680, 1681, 1683, 1685, 1685, 1687, 1687, 1688, 1691, 1691, 1692, 1693, 1694, 1694, 1696, 1696, 1698, 1698, 1700, 1700, 1701, 1701, 1701, 1703, 1704, 1704, 1706, 1707, 1708, 1709, 1711, 1712, 1712, 1715, 1716, 1716, 1717, 1718, 1718, 1719, 1719, 1720, 1722, 1724, 1725,
1726, 1727, 1728, 1730, 1735, 1740, 1741, 1744, 1744, 1745, 1745, 1747, 1748, 1750, 1751, 1752, 1752, 1754, 1756, 1756, 1759, 1759, 1760, 1760, 1763, 1764, 1765, 1765, 1766, 1767, 1769, 1769, 1772, 1773, 1773, 1774, 1775, 1776, 1776, 1777, 1779, 1780, 1780, 1781, 1781, 1783, 1784, 1784, 1785, 1785, 1786, 1788, 1788, 1789, 1790, 1790, 1790, 1791, 1794, 1794, 1796, 1796, 1796, 1797, 1797, 1798, 1798, 1799, 1800, 1800, 1800, 1802, 1802, 1804, 1805, 1805, 1806, 1808, 1809, 1809, 1812, 1813, 1814, 1816, 1816, 1819, 1819, 1820, 1823, 1824, 1826, 1827, 1827, 1828, 1828, 1829, 1830, 1831, 1831, 1831, 1833, 1833, 1834, 1836, 1838, 1839, 1839, 1839, 1840, 1842, 1844, 1847, 1848, 1853, 1853, 1854, 1861, 1862, 1865, 1866, 1869, 1869, 1870, 1871, 1872, 1874, 1876, 1877, 1878, 1879, 1880, 1882, 1882, 1882, 1883, 1883, 1884, 1884, 1885, 1885, 1886, 1888, 1890, 1891, 1893, 1893, 1894, 1894, 1894, 1895, 1895, 1897, 1898, 1900, 1900, 1901, 1901, 1902, 1903, 1906, 1908, 1910, 1911, 1912, 1912, 1917, 1917, 1917, 1918, 1919, 1919, 1920, 1920, 1921, 1921, 1922, 1922, 1924, 1925,
1925, 1926, 1928, 1932, 1933, 1934, 1934, 1935, 1935, 1936, 1937, 1938, 1938, 1939, 1940, 1941, 1941, 1943, 1944, 1947, 1948, 1948, 1948, 1949, 1951, 1954, 1955, 1957, 1959, 1960, 1967, 1970, 1971, 1973, 1974, 1984, 1985, 1985, 1986, 1987, 1987, 1988, 1989, 1990, 1991, 1992, 1992, 1993, 1994, 1994, 1997, 1999, 1999, 2001, 2002, 2003, 2003, 2004, 2005, 2005, 2007, 2008, 2008, 2009, 2011, 2011, 2013, 2015, 2015, 2016, 2017, 2018, 2020, 2020, 2021, 2022, 2024, 2027, 2029, 2031, 2032, 2034, 2034, 2035, 2036, 2037, 2038, 2038, 2039, 2044, 2045, 2045, 2046, 2048, 2048, 2053, 2053, 2055, 2058, 2059, 2059, 2060, 2062, 2066, 2069, 2070, 2071, 2075, 2075, 2076, 2077, 2078, 2085, 2086, 2089, 2091, 2092, 2096, 2099, 2102, 2102, 2105, 2107, 2108, 2110, 2110, 2112, 2113, 2114, 2114, 2118, 2124, 2125,
2127, 2128, 2129, 2134, 2134, 2135, 2136, 2143, 2144, 2145, 2149, 2150, 2154, 2156, 2158, 2159, 2160, 2161, 2163, 2164, 2164, 2165, 2166, 2167, 2169, 2170, 2173, 2174, 2174, 2175, 2176, 2181, 2182, 2184, 2190, 2191, 2194, 2196, 2197, 2197, 2198, 2198, 2199, 2200, 2201, 2202, 2205, 2208, 2212, 2215, 2215, 2216, 2217, 2217, 2219, 2219, 2220, 2220, 2221, 2221, 2225,
2226, 2227, 2227, 2228, 2230, 2231, 2233, 2234, 2238, 2238, 2240, 2246, 2246, 2247, 2250, 2251, 2253, 2255, 2256, 2259, 2261, 2266, 2267, 2268, 2269, 2269, 2270, 2270, 2274, 2274, 2281, 2283, 2286, 2288, 2290, 2291, 2292, 2293, 2295, 2295, 2297, 2300, 2301, 2301, 2303, 2303, 2304, 2308, 2309, 2310, 2311, 2313, 2315, 2315, 2316, 2317, 2320, 2321, 2325,
2325, 2327, 2328, 2329, 2330, 2330, 2335, 2339, 2341, 2341, 2342, 2342, 2342, 2344, 2344, 2345, 2345, 2346, 2347, 2348, 2351, 2353, 2355, 2357, 2359, 2360, 2368, 2370, 2372, 2373, 2375, 2379, 2382, 2385, 2386, 2387, 2390, 2390, 2391, 2393, 2397, 2402, 2403, 2403, 2404, 2405, 2408, 2409, 2410, 2414, 2414, 2415, 2416, 2418, 2420, 2421, 2424, 2426, 2428, 2429, 2430, 2431, 2433, 2434, 2435, 2436, 2440, 2443, 2443, 2446, 2446, 2447, 2448, 2449, 2450, 2450, 2453, 2454, 2456, 2458, 2460, 2462, 2463, 2474, 2474, 2480, 2481, 2484, 2486, 2487, 2490, 2493, 2494, 2494, 2497, 2501, 2502, 2503, 2503, 2505, 2508, 2508, 2513, 2513, 2514, 2518, 2524, 2525,
2527, 2529, 2531, 2543, 2554]
def firstLast(tnsr,cFi, cLa):
tnsrShiftedDown = torch.empty_like(tnsr)
tnsrShiftedDown[1:] = tnsr[:-1]
tnsrShiftedDown[0] = -1 # shifted
isFirst = tnsrShiftedDown != tnsr # isFirst(!=)
#######################################################################
# commenting out this line makes the coda return the correct result
cs = isFirst.cumsum(0)
isLast = isFirst.clone();
assert isLast[0] == 1, "First element of first group shoul always be [0]"
isLast[:-1] = isLast[1:];
isLast[-1] = 1 # isLast
# isFirst is index into first Element of group and isLast into last
# they should have equal number of ones
if isLast.sum().cpu().item() != isFirst.sum().cpu().item():
print("countOnes isFirst=%s isLast=%s" %(isFirst.sum().cpu().item(),isLast.sum().cpu().item()))
if isinstance(cFi, type(isFirst)):
diff = isFirst-cFi
print("No mismatches in isFirst %s %s %s" % ( diff.argmax(), diff.argmin(), diff.sum()))
diff = isLast-cLa
print("pos of mismatch in isLast %s %s %s" % ( diff.argmax(), diff.argmin(), diff.sum()))
print("")
elif not isinstance(cFi, type(isFirst)):
cFi = isFirst.clone()
cLa = isLast.clone()
return isFirst, isLast, cFi, cLa
def test(device,tin):
tin = torch.tensor(tin, device=device, dtype=torch.long)
cFi = None
cLa = None
cntEq = 0
cntNEq = 0
for i in range(0,20000):
fi, la, cFi, cLa = firstLast(tin, cFi, cLa)
if fi.sum().cpu().item() == la.sum().cpu().item():
cntEq += 1
else:
cntNEq += 1
return cntEq, cntNEq
devices = [torch.device("cpu")]
if torch.cuda.is_available(): devices.append(torch.device("cuda"))
for dev in devices:
print("%s %s" % (dev, test(dev, tin)))
@bgobbi
Copy link
Author

bgobbi commented Aug 19, 2018

The code in this gist does some indexing and simple arithmetic with the exact same inputs 20 000 times.
It is supposed to find the index of the first and last elements in groups of consecutive elements.
When run on a cpu it always returns the same correct result.
However when run on cuda it returns a wrong result 4-20 times out of 20 000 iterations.

I paired the code down as much as I could, however here are a few points:

  • The error is data depended, if I delete any more data than I already have the error stops occurring
  • The cumsum() in line 27 seems to be part of the problem. Commenting it out stops the error from occurring
  • By keeping a copy of the correct result of isfirst and isLast in cFi and cLa I can see that the error is in isLast
  • By subtracting the correct from the incorrect result in lines 39++ I can see that there is a single mismatch in position 767

Could this be related to me moving elements onto overlapping regions of the tensor (lines 20 and 31)?

Any other ideas?

A typical result I see is below.
The first row shows that the CPU version correctly 20000/20000 times.
The last row shows that the GPU version gave incorrect results 6 out of 20000 times

cpu (20000, 0)
countOnes isFirst=1127 isLast=1128
No mismatches in isFirst tensor(0, device='cuda:0') tensor(0, device='cuda:0') tensor(0, device='cuda:0')
pos of mismatch in isLast tensor(767, device='cuda:0') tensor(0, device='cuda:0') tensor(1, device='cuda:0')

countOnes isFirst=1127 isLast=1128
No mismatches in isFirst tensor(0, device='cuda:0') tensor(0, device='cuda:0') tensor(0, device='cuda:0')
pos of mismatch in isLast tensor(767, device='cuda:0') tensor(0, device='cuda:0') tensor(1, device='cuda:0')

countOnes isFirst=1127 isLast=1128
No mismatches in isFirst tensor(0, device='cuda:0') tensor(0, device='cuda:0') tensor(0, device='cuda:0')
pos of mismatch in isLast tensor(767, device='cuda:0') tensor(0, device='cuda:0') tensor(1, device='cuda:0')

countOnes isFirst=1127 isLast=1128
No mismatches in isFirst tensor(0, device='cuda:0') tensor(0, device='cuda:0') tensor(0, device='cuda:0')
pos of mismatch in isLast tensor(767, device='cuda:0') tensor(0, device='cuda:0') tensor(1, device='cuda:0')

countOnes isFirst=1127 isLast=1128
No mismatches in isFirst tensor(0, device='cuda:0') tensor(0, device='cuda:0') tensor(0, device='cuda:0')
pos of mismatch in isLast tensor(767, device='cuda:0') tensor(0, device='cuda:0') tensor(1, device='cuda:0')

countOnes isFirst=1127 isLast=1128
No mismatches in isFirst tensor(0, device='cuda:0') tensor(0, device='cuda:0') tensor(0, device='cuda:0')
pos of mismatch in isLast tensor(767, device='cuda:0') tensor(0, device='cuda:0') tensor(1, device='cuda:0')

cuda (19994, 6)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment