Skip to content

Instantly share code, notes, and snippets.

@zhuhaozhe
Created May 6, 2024 06:33
Show Gist options
  • Save zhuhaozhe/b4c6998a509dcea1796dd05b3005c969 to your computer and use it in GitHub Desktop.
Save zhuhaozhe/b4c6998a509dcea1796dd05b3005c969 to your computer and use it in GitHub Desktop.
import torch
from torch.optim.adagrad import _single_tensor_adagrad, _fused_adagrad
import copy
device='cpu'
dtype=torch.float
import os
TENSOR_SIZE = (int(os.getenv('TENSOR_SIZE', 512 * 512)), )
NPARAM = int(os.getenv("NPARAM", 4))
kwargs = {}
kwargs['params'] = [torch.randn(TENSOR_SIZE, device=device, dtype=dtype) for _ in range(NPARAM)]
kwargs['grads'] = [torch.randn(TENSOR_SIZE, device=device, dtype=dtype) for _ in range(NPARAM)]
kwargs['state_sums'] = [torch.randn(TENSOR_SIZE, device=device, dtype=dtype) for _ in range(NPARAM)]
kwargs['state_steps'] = [torch.tensor([10], device=device, dtype=torch.float64) for _ in range(NPARAM)]
kwargs['grad_scale'] = None
kwargs['found_inf'] = None
kwargs['lr_decay'] = 0.1
kwargs['lr'] = 0.1
kwargs['eps'] = 0.1
kwargs['has_sparse_grad'] = False
kwargs['has_complex'] = False
kwargs['maximize'] = False
kwargs['differentiable'] = False
kwargs['weight_decay'] = 0.01
kwargs_a = copy.deepcopy(kwargs)
kwargs_b = copy.deepcopy(kwargs)
a = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float)
b = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float)
def cache_flush():
# We assume the cache size is <= 512MB here.
# a = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float)
# b = torch.ones(256 * 1024 * 1024 // 4, dtype=torch.float)
# a, b are initialized out of this function to avoid allocate memory every time
global a, b
a += b
import time
def bench(fn, kwargs, warmup=100, bench_iters=100):
for _ in range(warmup):
cache_flush()
fn(**kwargs)
end_time = 0
for _ in range(bench_iters):
cache_flush()
start_time = time.time()
fn(**kwargs)
end_time += (time.time() - start_time)
print(f"{fn.__name__} time: {end_time:.4f} seconds")
bench(_single_tensor_adagrad, kwargs_a)
bench(_fused_adagrad, kwargs_b)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment