Skip to content

Instantly share code, notes, and snippets.

@dfyz
Created January 22, 2023 01:35
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save dfyz/32317ac2582a6c9a1104ac5a504f1216 to your computer and use it in GitHub Desktop.
Save dfyz/32317ac2582a6c9a1104ac5a504f1216 to your computer and use it in GitHub Desktop.
diff --git a/sample.py b/sample.py
index 6ff0ea2..00daada 100644
--- a/sample.py
+++ b/sample.py
@@ -7,10 +7,11 @@ from contextlib import nullcontext
import torch
import tiktoken
from model import GPTConfig, GPT
+import time
# -----------------------------------------------------------------------------
out_dir = 'out'
-start = "\n" # or "<|endoftext|>" or whatever you like
+start = "How long does it take to travel to Mars?" # or "<|endoftext|>" or whatever you like
num_samples = 10 # number of samples to draw
max_new_tokens = 500 # number of tokens generated in each sample
temperature = 0.8 # higher temperature (up to 1) is more random, lower (down to 0) means more greedy
@@ -31,16 +32,16 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype)
# model
-ckpt_path = os.path.join(out_dir, 'ckpt.pt')
-checkpoint = torch.load(ckpt_path, map_location=device)
-gptconf = GPTConfig(**checkpoint['model_args'])
-model = GPT(gptconf)
-state_dict = checkpoint['model']
-unwanted_prefix = '_orig_mod.'
-for k,v in list(state_dict.items()):
- if k.startswith(unwanted_prefix):
- state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
-model.load_state_dict(state_dict)
+#ckpt_path = os.path.join(out_dir, 'ckpt.pt')
+#checkpoint = torch.load(ckpt_path, map_location=device)
+#gptconf = GPTConfig(**checkpoint['model_args'])
+model = GPT.from_pretrained('gpt2')
+# state_dict = checkpoint['model']
+# unwanted_prefix = '_orig_mod.'
+# for k,v in list(state_dict.items()):
+# if k.startswith(unwanted_prefix):
+# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k)
+# model.load_state_dict(state_dict)
model.eval()
model.to(device)
if compile:
@@ -48,9 +49,9 @@ if compile:
# look for the meta pickle in case it is available in the dataset folder
load_meta = False
-if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
- meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
- load_meta = os.path.exists(meta_path)
+# if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these...
+# meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl')
+# load_meta = os.path.exists(meta_path)
if load_meta:
print(f"Loading meta from {meta_path}...")
with open(meta_path, 'rb') as f:
@@ -70,10 +71,16 @@ else:
start_ids = encode(start)
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...])
+start_time = time.time()
# run generation
with torch.no_grad():
with ctx:
for k in range(num_samples):
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k)
- print(decode(y[0].tolist()))
+ generated = y[0].tolist()
+ print(decode(generated))
+ print('Raw tokens', generated)
print('---------------')
+end_time = time.time()
+
+print(f'Generation time: {end_time - start_time:.3f} seconds')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment