Skip to content

Instantly share code, notes, and snippets.

@scottlogic-alex
Created November 3, 2023 17:35
Show Gist options
  • Save scottlogic-alex/9796d5583e5fe7672d7d829734c84c81 to your computer and use it in GitHub Desktop.
Save scottlogic-alex/9796d5583e5fe7672d7d829734c84c81 to your computer and use it in GitHub Desktop.
Measure memory allocated by CUDA for an FFN network
import torch
from torch import FloatTensor, Tensor
from torch.nn import Linear, MSELoss, Module, Sequential, GELU
from torch.cuda.amp import autocast
from torch.optim import AdamW, SGD
from typing import List, Optional, Tuple
from contextlib import nullcontext
def mib_str(bytes: int) -> str:
return f'{bytes/1024**2:.2f}MiB'
def pretty_memory_snapshot() -> str:
return '\n'.join([f"""{f"{m['address']:02x}"[3:-5]}: {mib_str(m['allocated_size']).rjust(11)} alloc, {mib_str(m['total_size']).rjust(11)} total""" for m in torch.cuda.memory_snapshot()])
def pretty_mem(
preamble: str,
context: str,
device_ix=0
):
alloc: int = torch.cuda.memory_allocated(device_ix)
total: int = torch.cuda.memory_reserved(device_ix)
reserved: int = total-alloc
return f'{preamble}{context.rjust(20)} {mib_str(alloc).rjust(11)} alloc {mib_str(reserved).rjust(11)} reserved {mib_str(total).rjust(11)} total'
device=torch.device('cuda')
layer_count = 7
in_dim = 4096
hidden_dim = 16384
out_dim = 8192
batch_size = 1024
print(f'batch={batch_size}')
use_mixed = True
print(f'precision: {"mixed" if use_mixed else "uniform"}')
cache_enabled = True
if use_mixed:
print(f'cache_enabled: {cache_enabled}')
realloc_each_microstep = True
print(f'realloc_each_microstep: {realloc_each_microstep}')
optim_set_to_none=True
print(f'optim_set_to_none: {optim_set_to_none}')
class LoggingSequential(Sequential):
def forward(self, input: FloatTensor, step_and_micro_indicator = '') -> FloatTensor:
for ix, module in enumerate(self):
input: FloatTensor = module(input)
layer_label = 'G' if isinstance(module, GELU) else 'L'
dense_ix = int(ix / 2)
print(pretty_mem(step_and_micro_indicator, f'after {layer_label}{dense_ix}.forward:'))
print(pretty_memory_snapshot())
return input
class Model(Module):
layers: LoggingSequential
def __init__(self, layer_count: int, in_dim: int, hidden_dim: int, out_dim: int, bias: bool, device=None, dtype=None) -> None:
super().__init__()
assert layer_count > 0
layers: List[Linear] = []
for layer_ix in range(layer_count):
in_features: int = in_dim if layer_ix == 0 else hidden_dim
out_features: int = out_dim if layer_ix == layer_count-1 else hidden_dim
layer = Linear(in_features=in_features, out_features=out_features, bias=bias, device=device, dtype=dtype)
layers.append(layer)
if layer_ix != layer_count - 1:
gate = GELU()
layers.append(gate)
self.layers = LoggingSequential(*layers)
def forward(self, x: FloatTensor, step_and_micro_indicator = '') -> FloatTensor:
x: FloatTensor = self.layers(x, step_and_micro_indicator=step_and_micro_indicator)
return x
model = Model(
layer_count=layer_count,
in_dim=in_dim,
hidden_dim=hidden_dim,
out_dim=out_dim,
device=device,
bias=False,
)
print(model)
print(pretty_mem('', 'after declare model:'))
print(pretty_memory_snapshot())
# optim = AdamW(model.parameters(), lr=2e-5)
momentum=0.
optim = SGD(model.parameters(), lr=2e-5, momentum=momentum)
optim_extra_desc = f', mom={momentum}' if isinstance(optim, SGD) else ''
print(pretty_mem('', f'after declare optim ({type(optim).__name__}{optim_extra_desc})'))
loss_fn = MSELoss()
precision_ctx = autocast(dtype=torch.bfloat16, cache_enabled=cache_enabled) if use_mixed else nullcontext()
def hook_fn(m: Module, i: Tuple[Tensor, ...], o: Tuple[Tensor, ...]) -> Optional[Tensor]:
print(pretty_mem('', f"after bwd {m.__class__.__name__}:"))
print(pretty_memory_snapshot())
def add_bwd_hook(mod: Module) -> None:
match(mod):
case Linear() | GELU():
mod.register_full_backward_hook(hook_fn)
model.apply(add_bwd_hook)
def pre_hook_fn(m: Module, o: Tuple[Tensor, ...]) -> Optional[Tensor]:
torch.cuda.synchronize()
print(pretty_mem('', f"after b_pre {m.__class__.__name__} {o[0].shape}:"))
print(pretty_memory_snapshot())
model.layers[-1].register_full_backward_pre_hook(pre_hook_fn)
steps = 1
microsteps = 1
for step in range(steps):
step_indicator = f'[step {step}] ' if steps > 1 else ''
for microstep in range(microsteps):
microstep_indicator = f'[microstep {microstep}] ' if microsteps > 1 else ''
step_and_micro_indicator = f'{step_indicator}{microstep_indicator}'
if realloc_each_microstep or step == 0 and microstep == 0:
x = torch.randn(batch_size, in_dim, device=device, requires_grad=False)
y_true = torch.randn(batch_size, out_dim, device=device, requires_grad=False)
print(pretty_mem(step_and_micro_indicator, f'after declare x/y:'))
print(pretty_memory_snapshot())
with precision_ctx:
y_pred = model.forward(x)
# y_pred.retain_grad()
# print(pretty_mem(step_and_micro_indicator, f'after model.forward:'))
y_pred2 = y_pred.float()
del y_pred
print(pretty_mem(step_and_micro_indicator, 'after y_pred32:'))
print(pretty_memory_snapshot())
loss = loss_fn.forward(y_pred2, y_true)
del y_pred2
# loss.retain_grad()
print(pretty_mem(step_and_micro_indicator, f'after loss:'))
if microsteps > 1:
loss /= microsteps
torch.cuda.synchronize()
print(pretty_memory_snapshot())
torch.cuda.synchronize()
loss.backward()
print(pretty_mem(step_and_micro_indicator, f'after backward:'))
print(pretty_memory_snapshot())
del loss
print(pretty_mem(step_indicator, 'after del loss'))
optim.step()
print(pretty_mem(step_indicator, 'after optim.step'))
optim.zero_grad(set_to_none=optim_set_to_none)
print(pretty_mem(step_indicator, f'after optim.zero_grad ({optim_set_to_none})'))
print(f'model (f32): {mib_str(sum([p.numel() for p in model.parameters()])*4)}')
print(f'model.in (f32): {mib_str(in_dim*hidden_dim*4)}')
if layer_count > 2:
print(f'model.mid (f32): {mib_str(hidden_dim**2*4)}')
print(f'activ.mid (f32): {mib_str(batch_size*hidden_dim*4)}')
print(f'model.out (f32): {mib_str(hidden_dim*out_dim*4)}')
if use_mixed:
print(f'model (f16): {mib_str(sum([p.numel() for p in model.parameters()])*2)}')
print(f'model.in (f16): {mib_str(in_dim*hidden_dim*2)}')
if layer_count > 2:
print(f'model.mid (f16): {mib_str(hidden_dim**2*2)}')
print(f'activ.mid (f16): {mib_str(batch_size*hidden_dim*2)}')
print(f'model.out (f16): {mib_str(hidden_dim*out_dim*2)}')
print(f'x (f32): {mib_str(batch_size*in_dim*4)}')
print(f'y_true (f32): {mib_str(batch_size*out_dim*4)}')
if use_mixed:
print(f'y_pred (f16): {mib_str(batch_size*out_dim*2)}')
else:
print(f'y_pred (f32): {mib_str(batch_size*out_dim*4)}')
print(pretty_memory_snapshot())
@scottlogic-alex
Copy link
Author

Memory snapshots at each point in time (i.e. after every layer's forward, and hooked into every module in the backward pass). Annotated with my best guesses about what's being allocated/deallocated each time the snapshot changes.

batch=1024
precision: mixed
cache_enabled: True
realloc_each_microstep: True
optim_set_to_none: True
Model(
  (layers): LoggingSequential(
    (0): Linear(in_features=4096, out_features=16384, bias=False)
    (1): GELU(approximate='none')
    (2): Linear(in_features=16384, out_features=16384, bias=False)
    (3): GELU(approximate='none')
    (4): Linear(in_features=16384, out_features=16384, bias=False)
    (5): GELU(approximate='none')
    (6): Linear(in_features=16384, out_features=16384, bias=False)
    (7): GELU(approximate='none')
    (8): Linear(in_features=16384, out_features=16384, bias=False)
    (9): GELU(approximate='none')
    (10): Linear(in_features=16384, out_features=16384, bias=False)
    (11): GELU(approximate='none')
    (12): Linear(in_features=16384, out_features=8192, bias=False)
  )
)
after declare model: 5888.00MiB alloc, 0.00MiB reserved, 5888.00MiB total
8800:   512.00MiB alloc,   512.00MiB total L6 f32 model.out
8a00:  1024.00MiB alloc,  1024.00MiB total L5 f32 model.mid
8e00:  1024.00MiB alloc,  1024.00MiB total L4 f32 model.mid
9200:  1024.00MiB alloc,  1024.00MiB total L3 f32 model.mid
9600:  1024.00MiB alloc,  1024.00MiB total L2 f32 model.mid
9a00:  1024.00MiB alloc,  1024.00MiB total L1 f32 model.mid
9e00:   256.00MiB alloc,   256.00MiB total L0 f32 model.in
after declare optim (SGD, mom=0.0): 5888.00MiB alloc, 0.00MiB reserved, 5888.00MiB total
  after declare x/y:  5936.00MiB alloc     0.00MiB reserved  5936.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total new f32 x
a0e0:    32.00MiB alloc,    32.00MiB total new f32 y_true
   after L0.forward:  6112.12MiB alloc     3.88MiB reserved  6116.00MiB total
8780:   128.00MiB alloc,   128.00MiB total new f16 model.in
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total new L0 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after G0.forward:  6144.12MiB alloc     3.88MiB reserved  6148.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total new G0 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after L1.forward:  6688.12MiB alloc     3.88MiB reserved  6692.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total new L1 f16 model.mid
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total new L1 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after G1.forward:  6720.12MiB alloc     3.88MiB reserved  6724.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total new G1 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after L2.forward:  7264.12MiB alloc     3.88MiB reserved  7268.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total new L2 f16 model.mid
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total new L2 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after G2.forward:  7296.12MiB alloc     3.88MiB reserved  7300.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total new G2 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after L3.forward:  7840.12MiB alloc     3.88MiB reserved  7844.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total new L3 f16 model.mid
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total new L3 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after G3.forward:  7872.12MiB alloc     3.88MiB reserved  7876.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total new G3 f16 activ.mid
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after L4.forward:  8416.12MiB alloc     3.88MiB reserved  8420.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total new L4 f16 activ.mid
7de0:   512.00MiB alloc,   512.00MiB total new L4 f16 model.mid
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after G4.forward:  8448.12MiB alloc     3.88MiB reserved  8452.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total new G4 f16 activ.mid
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after L5.forward:  8992.12MiB alloc     3.88MiB reserved  8996.00MiB total
7b80:    32.00MiB alloc,    32.00MiB total new f16 L5 active.mid
7ba0:   512.00MiB alloc,   512.00MiB total new f16 model.mid
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after G5.forward:  9024.12MiB alloc     3.88MiB reserved  9028.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total new f16 G5 active.mid
7b80:    32.00MiB alloc,    32.00MiB total
7ba0:   512.00MiB alloc,   512.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after L6.forward:  9296.12MiB alloc     3.88MiB reserved  9300.00MiB total
7a60:   256.00MiB alloc,   256.00MiB total new f16 model.out
7b60:    32.00MiB alloc,    32.00MiB total
7b80:    32.00MiB alloc,    32.00MiB total
7ba0:   512.00MiB alloc,   512.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:    16.00MiB alloc,    16.00MiB total new f16 y_pred
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after y_pred32:  9312.12MiB alloc    19.88MiB reserved  9332.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total new f32 y_pred
7a60:   256.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:    32.00MiB alloc,    32.00MiB total
7ba0:   512.00MiB alloc,   512.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:   128.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total dealloc f16 y_pred
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
         after loss:  9344.12MiB alloc    53.88MiB reserved  9398.00MiB total
7a00:     0.00MiB alloc,    32.00MiB total new dealloc f32 y_pred-sized (squares?)
7a20:    32.00MiB alloc,    32.00MiB total new f32 y_pred-sized differences
7a40:    32.00MiB alloc,    32.00MiB total
7a60:   256.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:    32.00MiB alloc,    32.00MiB total
7ba0:   512.00MiB alloc,   512.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total dealloc L0 f16 model.in
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total alloc loss + small page padding
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
after b_pre L6 Linear torch.Size([1024, 8192]):  9200.13MiB alloc   197.87MiB reserved  9398.00MiB total
7a00:     0.00MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total no change? was f32 differences; repurposed as f32 y_pred grad?
7a40:     0.00MiB alloc,    32.00MiB total dealloc f32 y_pred
7a60:   256.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:    32.00MiB alloc,    32.00MiB total
7ba0:   512.00MiB alloc,   512.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:    16.00MiB alloc,    16.00MiB total repurpose as f16 y_pred grad
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after bwd L6 Linear:  9464.25MiB alloc   701.75MiB reserved 10166.00MiB total
7680:   512.00MiB alloc,   512.00MiB total new L6 f32 model.out grad
7880:     0.00MiB alloc,   256.00MiB total new dealloc L6 f16 model.out grad
7a00:     8.12MiB alloc,    32.00MiB total repurpose as cuBLAS workspace (first matmul of bwd)
7a20:    32.00MiB alloc,    32.00MiB total hmmm if this was y_pred grad then we don't need it. something else now?
7a40:    32.00MiB alloc,    32.00MiB total repurposed… f16 G5 f16 activ.mid grad?
7a60:     0.00MiB alloc,   256.00MiB total dealloc L6 f16 model.out
7b60:     0.00MiB alloc,    32.00MiB total dealloc f16 G5 active.mid
7b80:    32.00MiB alloc,    32.00MiB total
7ba0:   512.00MiB alloc,   512.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:    16.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after bwd G5 GELU:  9448.25MiB alloc   717.75MiB reserved 10166.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total repurpose f16 L5 activ.mid grad
7b80:     0.00MiB alloc,    32.00MiB total dealloc f16 L5 activ.mid
7ba0:   512.00MiB alloc,   512.00MiB total
7da0:    32.00MiB alloc,    32.00MiB total
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total dealloc f16 y_pred grad
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after bwd L5 Linear:  9928.25MiB alloc  1773.75MiB reserved 11702.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total alloc L5 f32 model.mid grad
7480:     0.00MiB alloc,   512.00MiB total new dealloc f16 model.mid grad
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total dealloc L5 f16 model.mid
7da0:     0.00MiB alloc,    32.00MiB total dealloc G4 f16 activ.mid
7dc0:    32.00MiB alloc,    32.00MiB total
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after bwd G4 GELU:  9896.25MiB alloc  1805.75MiB reserved 11702.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total dealloc L4 f16 activ.mid
7de0:   512.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:    32.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after bwd L4 Linear: 10376.25MiB alloc  2349.75MiB reserved 12726.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total alloc L4 f32 model.mid grad
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total dealloc L4 f16 model.mid
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:    32.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total dealloc f16 G3 activ.mid
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after bwd G3 GELU: 10344.25MiB alloc  2381.75MiB reserved 12726.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:   512.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:    32.00MiB alloc,    32.00MiB total
9fd4:     0.00MiB alloc,    32.00MiB total dealloc f16 L3 activ.mid
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after bwd L3 Linear: 10824.25MiB alloc  2925.75MiB reserved 13750.00MiB total
6880:  1024.00MiB alloc,  1024.00MiB total new L3 f32 model.mid.grad
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:     0.00MiB alloc,   512.00MiB total dealloc L3 f16 model.mid
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:    32.00MiB alloc,    32.00MiB total
9fb4:     0.00MiB alloc,    32.00MiB total dealloc G2 f16 activ.mid
9fd4:     0.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after bwd G2 GELU: 10792.25MiB alloc  2957.75MiB reserved 13750.00MiB total
6880:  1024.00MiB alloc,  1024.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:     0.00MiB alloc,   512.00MiB total
81e0:   512.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:    32.00MiB alloc,    32.00MiB total
9f94:     0.00MiB alloc,    32.00MiB total dealloc L2 f16 activ.mid
9fb4:     0.00MiB alloc,    32.00MiB total
9fd4:     0.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after bwd L2 Linear: 11272.25MiB alloc  3501.75MiB reserved 14774.00MiB total
6480:  1024.00MiB alloc,  1024.00MiB total alloc L2 f32 model.mid
6880:  1024.00MiB alloc,  1024.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:     0.00MiB alloc,   512.00MiB total
81e0:     0.00MiB alloc,   512.00MiB total dealloc L2 f16 model.mid
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:    32.00MiB alloc,    32.00MiB total
9f74:     0.00MiB alloc,    32.00MiB total dealloc G1 f16 activ.mid
9f94:     0.00MiB alloc,    32.00MiB total
9fb4:     0.00MiB alloc,    32.00MiB total
9fd4:     0.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after bwd G1 GELU: 11240.25MiB alloc  3533.75MiB reserved 14774.00MiB total
6480:  1024.00MiB alloc,  1024.00MiB total
6880:  1024.00MiB alloc,  1024.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:     0.00MiB alloc,   512.00MiB total
81e0:     0.00MiB alloc,   512.00MiB total
83e0:   512.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:    32.00MiB alloc,    32.00MiB total
9f54:     0.00MiB alloc,    32.00MiB total dealloc L1 f16 activ.mid
9f74:     0.00MiB alloc,    32.00MiB total
9f94:     0.00MiB alloc,    32.00MiB total
9fb4:     0.00MiB alloc,    32.00MiB total
9fd4:     0.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after bwd L1 Linear: 11720.25MiB alloc  4077.75MiB reserved 15798.00MiB total
6080:  1024.00MiB alloc,  1024.00MiB total new L1 f32 model.mid
6480:  1024.00MiB alloc,  1024.00MiB total
6880:  1024.00MiB alloc,  1024.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:     0.00MiB alloc,   512.00MiB total
81e0:     0.00MiB alloc,   512.00MiB total
83e0:     0.00MiB alloc,   512.00MiB total dealloc L1 f16 model.mid
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:    32.00MiB alloc,    32.00MiB total
9f34:     0.00MiB alloc,    32.00MiB total dealloc G0 f16 activ.mid
9f54:     0.00MiB alloc,    32.00MiB total
9f74:     0.00MiB alloc,    32.00MiB total
9f94:     0.00MiB alloc,    32.00MiB total
9fb4:     0.00MiB alloc,    32.00MiB total
9fd4:     0.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after bwd G0 GELU: 11688.25MiB alloc  4109.75MiB reserved 15798.00MiB total
6080:  1024.00MiB alloc,  1024.00MiB total
6480:  1024.00MiB alloc,  1024.00MiB total
6880:  1024.00MiB alloc,  1024.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:    32.00MiB alloc,    32.00MiB total
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:     0.00MiB alloc,   512.00MiB total
81e0:     0.00MiB alloc,   512.00MiB total
83e0:     0.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:     0.00MiB alloc,    32.00MiB total dealloc L0 f16 activ.mid
9f34:     0.00MiB alloc,    32.00MiB total
9f54:     0.00MiB alloc,    32.00MiB total
9f74:     0.00MiB alloc,    32.00MiB total
9f94:     0.00MiB alloc,    32.00MiB total
9fb4:     0.00MiB alloc,    32.00MiB total
9fd4:     0.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
   after bwd L0 Linear: 11656.25MiB alloc  4141.75MiB reserved 15798.00MiB total
6080:  1024.00MiB alloc,  1024.00MiB total
6480:  1024.00MiB alloc,  1024.00MiB total
6880:  1024.00MiB alloc,  1024.00MiB total
6c80:  1024.00MiB alloc,  1024.00MiB total
7080:  1024.00MiB alloc,  1024.00MiB total
7480:     0.00MiB alloc,   512.00MiB total
7680:   512.00MiB alloc,   512.00MiB total
7880:     0.00MiB alloc,   256.00MiB total
7a00:     8.12MiB alloc,    32.00MiB total
7a20:    32.00MiB alloc,    32.00MiB total
7a40:     0.00MiB alloc,    32.00MiB total dealloc unknown
7a60:     0.00MiB alloc,   256.00MiB total
7b60:    32.00MiB alloc,    32.00MiB total
7b80:     0.00MiB alloc,    32.00MiB total
7ba0:     0.00MiB alloc,   512.00MiB total
7da0:     0.00MiB alloc,    32.00MiB total
7dc0:     0.00MiB alloc,    32.00MiB total
7de0:     0.00MiB alloc,   512.00MiB total
7fe0:     0.00MiB alloc,   512.00MiB total
81e0:     0.00MiB alloc,   512.00MiB total
83e0:     0.00MiB alloc,   512.00MiB total
8780:     0.00MiB alloc,   128.00MiB total
8800:   512.00MiB alloc,   512.00MiB total
8a00:  1024.00MiB alloc,  1024.00MiB total
8e00:  1024.00MiB alloc,  1024.00MiB total
9200:  1024.00MiB alloc,  1024.00MiB total
9600:  1024.00MiB alloc,  1024.00MiB total
9a00:  1024.00MiB alloc,  1024.00MiB total
9e00:   256.00MiB alloc,   256.00MiB total
9f00:    16.12MiB alloc,    20.00MiB total
9f14:     0.00MiB alloc,    32.00MiB total
9f34:     0.00MiB alloc,    32.00MiB total
9f54:     0.00MiB alloc,    32.00MiB total
9f74:     0.00MiB alloc,    32.00MiB total
9f94:     0.00MiB alloc,    32.00MiB total
9fb4:     0.00MiB alloc,    32.00MiB total
9fd4:     0.00MiB alloc,    32.00MiB total
9ff4:     0.00MiB alloc,    32.00MiB total
a014:     0.00MiB alloc,    16.00MiB total
a024:     0.00MiB alloc,     2.00MiB total
a02e:    16.00MiB alloc,    16.00MiB total
a0e0:    32.00MiB alloc,    32.00MiB total
     after backward: 11872.25MiB alloc  3925.75MiB reserved 15798.00MiB total
      after del loss 11840.25MiB alloc  3957.75MiB reserved 15798.00MiB total
    after optim.step 11840.25MiB alloc  3957.75MiB reserved 15798.00MiB total
after optim.zero_grad (True)  5952.25MiB alloc  9845.75MiB reserved 15798.00MiB total
model     (f32): 5888.00MiB
model.in  (f32): 256.00MiB
model.mid (f32): 1024.00MiB
activ.mid (f32): 64.00MiB
model.out (f32): 512.00MiB
model     (f16): 2944.00MiB
model.in  (f16): 128.00MiB
model.mid (f16): 512.00MiB
activ.mid (f16): 32.00MiB
model.out (f16): 256.00MiB
x         (f32): 16.00MiB
y_true    (f32): 32.00MiB
y_pred    (f16): 16.00MiB

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment