Skip to content

Instantly share code, notes, and snippets.

@sourabh2k15
Created November 11, 2022 18:58
Show Gist options
  • Save sourabh2k15/7f363358801ec1e8ac770c61c7de5203 to your computer and use it in GitHub Desktop.
Save sourabh2k15/7f363358801ec1e8ac770c61c7de5203 to your computer and use it in GitHub Desktop.
JAX training steps progress :
0) loss = 32.684505462646484 grad_norm = 46.189605712890625
1) loss = 32.684505462646484 grad_norm = 46.189605712890625
2) loss = 32.2853889465332 grad_norm = 47.51691818237305
3) loss = 31.46957778930664 grad_norm = 49.439979553222656
4) loss = 30.260791778564453 grad_norm = 53.63009262084961
5) loss = 28.776248931884766 grad_norm = 58.055015563964844
6) loss = 27.165119171142578 grad_norm = 73.13194274902344
7) loss = 25.350372314453125 grad_norm = 83.24612426757812
8) loss = 23.403196334838867 grad_norm = 82.15689086914062
9) loss = 21.36658477783203 grad_norm = 66.725830078125
10) loss = 19.72756576538086 grad_norm = 52.60395812988281
11) loss = 18.716609954833984 grad_norm = 13.802923202514648
12) loss = 19.0362491607666 grad_norm = 39.36472702026367
13) loss = 19.368528366088867 grad_norm = 49.91731262207031
14) loss = 19.246591567993164 grad_norm = 49.75506591796875
15) loss = 18.68264389038086 grad_norm = 42.25536346435547
16) loss = 17.978076934814453 grad_norm = 27.26990509033203
17) loss = 17.526891708374023 grad_norm = 17.5865421295166
18) loss = 17.449459075927734 grad_norm = 32.81880187988281
19) loss = 17.382862091064453 grad_norm = 43.94654083251953
20) loss = 17.09641456604004 grad_norm = 47.34761428833008
21) loss = 16.57782745361328 grad_norm = 45.367431640625
22) loss = 15.897388458251953 grad_norm = 38.91270065307617
23) loss = 15.132328033447266 grad_norm = 30.798185348510742
24) loss = 14.477874755859375 grad_norm = 21.245559692382812
25) loss = 13.980663299560547 grad_norm = 21.04851531982422
26) loss = 13.48892593383789 grad_norm = 22.92694091796875
27) loss = 12.980422973632812 grad_norm = 23.357032775878906
28) loss = 12.4077730178833 grad_norm = 23.73438262939453
29) loss = 11.470919609069824 grad_norm = 44.06322479248047
30) loss = 9.974434852600098 grad_norm = 62.56569290161133
31) loss = 8.121500015258789 grad_norm = 32.62065505981445
32) loss = 7.7279953956604 grad_norm = 12.60580062866211
33) loss = 8.078917503356934 grad_norm = 20.27943229675293
34) loss = 8.428027153015137 grad_norm = 22.718721389770508
35) loss = 8.606433868408203 grad_norm = 23.325977325439453
36) loss = 8.611522674560547 grad_norm = 23.37453269958496
37) loss = 8.46786880493164 grad_norm = 23.161762237548828
38) loss = 8.200492858886719 grad_norm = 22.680795669555664
39) loss = 7.833049774169922 grad_norm = 21.734981536865234
40) loss = 7.39454984664917 grad_norm = 19.83635711669922
41) loss = 6.936933994293213 grad_norm = 15.858280181884766
42) loss = 6.5785231590271 grad_norm = 7.353372097015381
43) loss = 6.590823650360107 grad_norm = 10.804391860961914
44) loss = 6.976533889770508 grad_norm = 28.068567276000977
45) loss = 7.364764213562012 grad_norm = 39.22932052612305
46) loss = 7.52872896194458 grad_norm = 43.379493713378906
47) loss = 7.444374084472656 grad_norm = 42.099246978759766
48) loss = 7.193683624267578 grad_norm = 37.185791015625
49) loss = 6.878286361694336 grad_norm = 30.11634635925293
50) loss = 6.578378677368164 grad_norm = 22.0616512298584
51) loss = 6.342897891998291 grad_norm = 13.930447578430176
52) loss = 6.193991184234619 grad_norm = 6.407179355621338
53) loss = 6.135108470916748 grad_norm = 1.4418846368789673
54) loss = 6.147031784057617 grad_norm = 5.338115692138672
55) loss = 6.180387496948242 grad_norm = 8.108418464660645
56) loss = 6.207060813903809 grad_norm = 9.701889038085938
57) loss = 6.215665817260742 grad_norm = 10.471610069274902
58) loss = 6.202969074249268 grad_norm = 10.607603073120117
59) loss = 6.170167922973633 grad_norm = 10.189746856689453
60) loss = 6.121431827545166 grad_norm = 9.220542907714844
61) loss = 6.063607692718506 grad_norm = 7.641906261444092
62) loss = 6.006581783294678 grad_norm = 5.349120140075684
63) loss = 5.963957786560059 grad_norm = 2.2978594303131104
64) loss = 5.951851844787598 grad_norm = 2.2282962799072266
65) loss = 5.971628665924072 grad_norm = 5.987220764160156
66) loss = 6.005858421325684 grad_norm = 9.195704460144043
67) loss = 6.038926124572754 grad_norm = 11.548829078674316
68) loss = 6.059482097625732 grad_norm = 12.938092231750488
69) loss = 6.061285972595215 grad_norm = 13.334378242492676
70) loss = 6.043198108673096 grad_norm = 12.766958236694336
71) loss = 6.008614540100098 grad_norm = 11.311990737915039
72) loss = 5.964560508728027 grad_norm = 9.086621284484863
73) loss = 5.920508861541748 grad_norm = 6.246153354644775
74) loss = 5.887033462524414 grad_norm = 2.993448495864868
75) loss = 5.874838352203369 grad_norm = 0.8009199500083923
76) loss = 5.884912014007568 grad_norm = 3.272224187850952
77) loss = 5.907811641693115 grad_norm = 5.344062328338623
78) loss = 5.9315948486328125 grad_norm = 6.702033996582031
79) loss = 5.949487686157227 grad_norm = 7.477400779724121
80) loss = 5.958114147186279 grad_norm = 7.759702205657959
81) loss = 5.956571578979492 grad_norm = 7.596379280090332
82) loss = 5.945903301239014 grad_norm = 6.9987993240356445
83) loss = 5.928875923156738 grad_norm = 5.9481329917907715
84) loss = 5.909940242767334 grad_norm = 4.404060363769531
85) loss = 5.895603179931641 grad_norm = 2.398494243621826
86) loss = 5.892603874206543 grad_norm = 0.8986305594444275
87) loss = 5.902677536010742 grad_norm = 2.8863940238952637
88) loss = 5.920483589172363 grad_norm = 5.06048059463501
89) loss = 5.935702800750732 grad_norm = 6.575645446777344
90) loss = 5.942634105682373 grad_norm = 7.361220836639404
91) loss = 5.9386491775512695 grad_norm = 7.407079219818115
92) loss = 5.924259185791016 grad_norm = 6.745035171508789
93) loss = 5.902726173400879 grad_norm = 5.442744731903076
94) loss = 5.87941837310791 grad_norm = 3.6016411781311035
95) loss = 5.86182165145874 grad_norm = 1.5386890172958374
96) loss = 5.854459762573242 grad_norm = 0.8638936281204224
97) loss = 5.856882095336914 grad_norm = 2.5303218364715576
98) loss = 5.86497688293457 grad_norm = 3.910663366317749
99) loss = 5.873371601104736 grad_norm = 4.834494590759277
JAX program execution took 140.81151580810547 seconds
PyTorch training steps progress :
I1111 18:56:44.580436 140143604311872 torch_e2e.py:143] 0) loss = 32.799583435058594, grad_norm = 5.480198383331299
I1111 18:56:44.602942 140143604311872 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:44.602947 140194652907328 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:44.602951 140007166039872 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:44.604685 140110968313664 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:44.604720 140128537937728 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:44.604736 139625133123392 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:44.604791 140694398179136 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:44.604785 140680726771520 distributed.py:995] Reducer buckets have been rebuilt in this iteration.
I1111 18:56:45.760965 140143604311872 torch_e2e.py:143] 1) loss = 32.81392288208008, grad_norm = 5.564063549041748
I1111 18:56:46.859724 140143604311872 torch_e2e.py:143] 2) loss = 32.78697204589844, grad_norm = 5.491509914398193
I1111 18:56:47.799092 140143604311872 torch_e2e.py:143] 3) loss = 32.7858772277832, grad_norm = 5.60224723815918
I1111 18:56:48.779518 140143604311872 torch_e2e.py:143] 4) loss = 32.76042175292969, grad_norm = 5.729595184326172
I1111 18:56:49.569550 140143604311872 torch_e2e.py:143] 5) loss = 32.7255973815918, grad_norm = 5.890851020812988
I1111 18:56:50.360571 140143604311872 torch_e2e.py:143] 6) loss = 32.686431884765625, grad_norm = 6.0222039222717285
I1111 18:56:51.153121 140143604311872 torch_e2e.py:143] 7) loss = 32.63671875, grad_norm = 6.230602264404297
I1111 18:56:51.946784 140143604311872 torch_e2e.py:143] 8) loss = 32.585205078125, grad_norm = 6.409563064575195
I1111 18:56:52.741659 140143604311872 torch_e2e.py:143] 9) loss = 32.50416564941406, grad_norm = 6.716286659240723
I1111 18:56:53.535360 140143604311872 torch_e2e.py:143] 10) loss = 32.42424774169922, grad_norm = 7.101415634155273
I1111 18:56:54.330379 140143604311872 torch_e2e.py:143] 11) loss = 32.317264556884766, grad_norm = 7.50632381439209
I1111 18:56:55.125588 140143604311872 torch_e2e.py:143] 12) loss = 32.208160400390625, grad_norm = 7.92339563369751
I1111 18:56:55.921339 140143604311872 torch_e2e.py:143] 13) loss = 32.077720642089844, grad_norm = 8.465384483337402
I1111 18:56:56.714532 140143604311872 torch_e2e.py:143] 14) loss = 31.917701721191406, grad_norm = 9.081777572631836
I1111 18:56:57.510020 140143604311872 torch_e2e.py:143] 15) loss = 31.73249053955078, grad_norm = 9.849095344543457
I1111 18:56:58.305579 140143604311872 torch_e2e.py:143] 16) loss = 31.553165435791016, grad_norm = 10.648540496826172
I1111 18:56:59.098185 140143604311872 torch_e2e.py:143] 17) loss = 31.305742263793945, grad_norm = 11.513188362121582
I1111 18:56:59.893549 140143604311872 torch_e2e.py:143] 18) loss = 31.056217193603516, grad_norm = 12.512486457824707
I1111 18:57:00.690072 140143604311872 torch_e2e.py:143] 19) loss = 30.75790786743164, grad_norm = 13.616098403930664
I1111 18:57:01.483093 140143604311872 torch_e2e.py:143] 20) loss = 30.423843383789062, grad_norm = 14.62498950958252
I1111 18:57:02.277024 140143604311872 torch_e2e.py:143] 21) loss = 30.038217544555664, grad_norm = 15.507715225219727
I1111 18:57:03.072989 140143604311872 torch_e2e.py:143] 22) loss = 29.629596710205078, grad_norm = 16.433927536010742
I1111 18:57:03.873271 140143604311872 torch_e2e.py:143] 23) loss = 29.170137405395508, grad_norm = 17.48212432861328
I1111 18:57:04.667958 140143604311872 torch_e2e.py:143] 24) loss = 28.667102813720703, grad_norm = 18.283262252807617
I1111 18:57:05.461400 140143604311872 torch_e2e.py:143] 25) loss = 28.125272750854492, grad_norm = 19.037086486816406
I1111 18:57:06.256874 140143604311872 torch_e2e.py:143] 26) loss = 27.546295166015625, grad_norm = 19.721418380737305
I1111 18:57:07.051243 140143604311872 torch_e2e.py:143] 27) loss = 26.921031951904297, grad_norm = 20.364442825317383
I1111 18:57:07.846579 140143604311872 torch_e2e.py:143] 28) loss = 26.258169174194336, grad_norm = 20.912199020385742
I1111 18:57:08.642325 140143604311872 torch_e2e.py:143] 29) loss = 25.564598083496094, grad_norm = 21.407268524169922
I1111 18:57:09.437666 140143604311872 torch_e2e.py:143] 30) loss = 24.84899139404297, grad_norm = 21.567232131958008
I1111 18:57:10.234629 140143604311872 torch_e2e.py:143] 31) loss = 24.078826904296875, grad_norm = 21.630910873413086
I1111 18:57:11.031373 140143604311872 torch_e2e.py:143] 32) loss = 23.287574768066406, grad_norm = 21.588590621948242
I1111 18:57:11.826267 140143604311872 torch_e2e.py:143] 33) loss = 22.48159408569336, grad_norm = 21.48728370666504
I1111 18:57:12.623108 140143604311872 torch_e2e.py:143] 34) loss = 21.630659103393555, grad_norm = 21.31056022644043
I1111 18:57:13.418298 140143604311872 torch_e2e.py:143] 35) loss = 20.80036163330078, grad_norm = 21.08628273010254
I1111 18:57:14.212611 140143604311872 torch_e2e.py:143] 36) loss = 19.9207820892334, grad_norm = 20.772138595581055
I1111 18:57:15.009569 140143604311872 torch_e2e.py:143] 37) loss = 19.082284927368164, grad_norm = 20.419795989990234
I1111 18:57:15.806542 140143604311872 torch_e2e.py:143] 38) loss = 18.18767738342285, grad_norm = 19.978591918945312
I1111 18:57:16.604197 140143604311872 torch_e2e.py:143] 39) loss = 17.313161849975586, grad_norm = 19.488441467285156
I1111 18:57:17.398409 140143604311872 torch_e2e.py:143] 40) loss = 16.46230125427246, grad_norm = 18.933177947998047
I1111 18:57:18.195473 140143604311872 torch_e2e.py:143] 41) loss = 15.606142044067383, grad_norm = 18.297391891479492
I1111 18:57:18.992078 140143604311872 torch_e2e.py:143] 42) loss = 14.758169174194336, grad_norm = 17.583349227905273
I1111 18:57:19.790554 140143604311872 torch_e2e.py:143] 43) loss = 13.962091445922852, grad_norm = 16.807153701782227
I1111 18:57:20.587406 140143604311872 torch_e2e.py:143] 44) loss = 13.139028549194336, grad_norm = 15.904189109802246
I1111 18:57:21.384323 140143604311872 torch_e2e.py:143] 45) loss = 12.383018493652344, grad_norm = 14.943195343017578
I1111 18:57:22.180801 140143604311872 torch_e2e.py:143] 46) loss = 11.665107727050781, grad_norm = 13.905163764953613
I1111 18:57:22.979599 140143604311872 torch_e2e.py:143] 47) loss = 10.987071990966797, grad_norm = 12.793498039245605
I1111 18:57:23.774337 140143604311872 torch_e2e.py:143] 48) loss = 10.358452796936035, grad_norm = 11.621541976928711
I1111 18:57:24.569179 140143604311872 torch_e2e.py:143] 49) loss = 9.786215782165527, grad_norm = 10.415675163269043
I1111 18:57:25.367716 140143604311872 torch_e2e.py:143] 50) loss = 9.264466285705566, grad_norm = 9.165142059326172
I1111 18:57:26.164860 140143604311872 torch_e2e.py:143] 51) loss = 8.827374458312988, grad_norm = 7.97666072845459
I1111 18:57:26.960957 140143604311872 torch_e2e.py:143] 52) loss = 8.432060241699219, grad_norm = 6.76027250289917
I1111 18:57:27.761851 140143604311872 torch_e2e.py:143] 53) loss = 8.101480484008789, grad_norm = 5.599133491516113
I1111 18:57:28.555984 140143604311872 torch_e2e.py:143] 54) loss = 7.838068008422852, grad_norm = 4.5281829833984375
I1111 18:57:29.352050 140143604311872 torch_e2e.py:143] 55) loss = 7.619318008422852, grad_norm = 3.4904284477233887
I1111 18:57:30.147860 140143604311872 torch_e2e.py:143] 56) loss = 7.455406665802002, grad_norm = 2.54789662361145
I1111 18:57:30.943805 140143604311872 torch_e2e.py:143] 57) loss = 7.346614360809326, grad_norm = 1.7550326585769653
I1111 18:57:31.742405 140143604311872 torch_e2e.py:143] 58) loss = 7.276298999786377, grad_norm = 1.0543296337127686
I1111 18:57:32.537595 140143604311872 torch_e2e.py:143] 59) loss = 7.240423679351807, grad_norm = 0.5265750885009766
I1111 18:57:33.333628 140143604311872 torch_e2e.py:143] 60) loss = 7.236795425415039, grad_norm = 0.480295866727829
I1111 18:57:34.131216 140143604311872 torch_e2e.py:143] 61) loss = 7.256988048553467, grad_norm = 0.8030412793159485
I1111 18:57:34.932935 140143604311872 torch_e2e.py:143] 62) loss = 7.296460151672363, grad_norm = 1.1590831279754639
I1111 18:57:35.725716 140143604311872 torch_e2e.py:143] 63) loss = 7.349276065826416, grad_norm = 1.4545279741287231
I1111 18:57:36.523608 140143604311872 torch_e2e.py:143] 64) loss = 7.411477565765381, grad_norm = 1.7063922882080078
I1111 18:57:37.323120 140143604311872 torch_e2e.py:143] 65) loss = 7.481031894683838, grad_norm = 1.9173399209976196
I1111 18:57:38.120862 140143604311872 torch_e2e.py:143] 66) loss = 7.552356719970703, grad_norm = 2.088557004928589
I1111 18:57:38.916439 140143604311872 torch_e2e.py:143] 67) loss = 7.622851848602295, grad_norm = 2.2259576320648193
I1111 18:57:39.713130 140143604311872 torch_e2e.py:143] 68) loss = 7.695225715637207, grad_norm = 2.343404769897461
I1111 18:57:40.512876 140143604311872 torch_e2e.py:143] 69) loss = 7.762754440307617, grad_norm = 2.4399759769439697
I1111 18:57:41.310517 140143604311872 torch_e2e.py:143] 70) loss = 7.825161457061768, grad_norm = 2.5152499675750732
I1111 18:57:42.105275 140143604311872 torch_e2e.py:143] 71) loss = 7.885710716247559, grad_norm = 2.581749677658081
I1111 18:57:42.904766 140143604311872 torch_e2e.py:143] 72) loss = 7.9366374015808105, grad_norm = 2.6322216987609863
I1111 18:57:43.701193 140143604311872 torch_e2e.py:143] 73) loss = 7.984954357147217, grad_norm = 2.6764440536499023
I1111 18:57:44.499818 140143604311872 torch_e2e.py:143] 74) loss = 8.027819633483887, grad_norm = 2.7141380310058594
I1111 18:57:45.294919 140143604311872 torch_e2e.py:143] 75) loss = 8.060647964477539, grad_norm = 2.740922689437866
I1111 18:57:46.094676 140143604311872 torch_e2e.py:143] 76) loss = 8.086572647094727, grad_norm = 2.7629334926605225
I1111 18:57:46.896962 140143604311872 torch_e2e.py:143] 77) loss = 8.109407424926758, grad_norm = 2.783541679382324
I1111 18:57:47.694501 140143604311872 torch_e2e.py:143] 78) loss = 8.125061988830566, grad_norm = 2.797971248626709
I1111 18:57:48.491031 140143604311872 torch_e2e.py:143] 79) loss = 8.133414268493652, grad_norm = 2.808274745941162
I1111 18:57:49.288383 140143604311872 torch_e2e.py:143] 80) loss = 8.135396957397461, grad_norm = 2.8145511150360107
I1111 18:57:50.087886 140143604311872 torch_e2e.py:143] 81) loss = 8.132173538208008, grad_norm = 2.819377899169922
I1111 18:57:50.885110 140143604311872 torch_e2e.py:143] 82) loss = 8.122648239135742, grad_norm = 2.819749116897583
I1111 18:57:51.681151 140143604311872 torch_e2e.py:143] 83) loss = 8.107420921325684, grad_norm = 2.8163645267486572
I1111 18:57:52.479457 140143604311872 torch_e2e.py:143] 84) loss = 8.086889266967773, grad_norm = 2.8111491203308105
I1111 18:57:53.276332 140143604311872 torch_e2e.py:143] 85) loss = 8.05954647064209, grad_norm = 2.799006462097168
I1111 18:57:54.073843 140143604311872 torch_e2e.py:143] 86) loss = 8.027091026306152, grad_norm = 2.784994125366211
I1111 18:57:54.869232 140143604311872 torch_e2e.py:143] 87) loss = 7.9884209632873535, grad_norm = 2.764768123626709
I1111 18:57:55.669992 140143604311872 torch_e2e.py:143] 88) loss = 7.947974681854248, grad_norm = 2.742222785949707
I1111 18:57:56.470902 140143604311872 torch_e2e.py:143] 89) loss = 7.900345802307129, grad_norm = 2.7122137546539307
I1111 18:57:57.267645 140143604311872 torch_e2e.py:143] 90) loss = 7.849843978881836, grad_norm = 2.6777143478393555
I1111 18:57:58.063938 140143604311872 torch_e2e.py:143] 91) loss = 7.796028137207031, grad_norm = 2.635615825653076
I1111 18:57:58.861551 140143604311872 torch_e2e.py:143] 92) loss = 7.7386155128479, grad_norm = 2.585268497467041
I1111 18:57:59.659476 140143604311872 torch_e2e.py:143] 93) loss = 7.676787376403809, grad_norm = 2.5245890617370605
I1111 18:58:00.456698 140143604311872 torch_e2e.py:143] 94) loss = 7.611974716186523, grad_norm = 2.4550864696502686
I1111 18:58:01.255294 140143604311872 torch_e2e.py:143] 95) loss = 7.5449748039245605, grad_norm = 2.3730602264404297
I1111 18:58:02.051496 140143604311872 torch_e2e.py:143] 96) loss = 7.480135440826416, grad_norm = 2.2816708087921143
I1111 18:58:02.850461 140143604311872 torch_e2e.py:143] 97) loss = 7.412574768066406, grad_norm = 2.1708152294158936
I1111 18:58:03.647884 140143604311872 torch_e2e.py:143] 98) loss = 7.346048355102539, grad_norm = 2.0456466674804688
PyTorch program execution took 83.153489112854 seconds
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment