Created
January 22, 2023 01:35
-
-
Save dfyz/32317ac2582a6c9a1104ac5a504f1216 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/sample.py b/sample.py | |
index 6ff0ea2..00daada 100644 | |
--- a/sample.py | |
+++ b/sample.py | |
@@ -7,10 +7,11 @@ from contextlib import nullcontext | |
import torch | |
import tiktoken | |
from model import GPTConfig, GPT | |
+import time | |
# ----------------------------------------------------------------------------- | |
out_dir = 'out' | |
-start = "\n" # or "<|endoftext|>" or whatever you like | |
+start = "How long does it take to travel to Mars?" # or "<|endoftext|>" or whatever you like | |
num_samples = 10 # number of samples to draw | |
max_new_tokens = 500 # number of tokens generated in each sample | |
temperature = 0.8 # higher temperature (up to 1) is more random, lower (down to 0) means more greedy | |
@@ -31,16 +32,16 @@ ptdtype = {'float32': torch.float32, 'bfloat16': torch.bfloat16, 'float16': torc | |
ctx = nullcontext() if device_type == 'cpu' else torch.amp.autocast(device_type=device_type, dtype=ptdtype) | |
# model | |
-ckpt_path = os.path.join(out_dir, 'ckpt.pt') | |
-checkpoint = torch.load(ckpt_path, map_location=device) | |
-gptconf = GPTConfig(**checkpoint['model_args']) | |
-model = GPT(gptconf) | |
-state_dict = checkpoint['model'] | |
-unwanted_prefix = '_orig_mod.' | |
-for k,v in list(state_dict.items()): | |
- if k.startswith(unwanted_prefix): | |
- state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
-model.load_state_dict(state_dict) | |
+#ckpt_path = os.path.join(out_dir, 'ckpt.pt') | |
+#checkpoint = torch.load(ckpt_path, map_location=device) | |
+#gptconf = GPTConfig(**checkpoint['model_args']) | |
+model = GPT.from_pretrained('gpt2') | |
+# state_dict = checkpoint['model'] | |
+# unwanted_prefix = '_orig_mod.' | |
+# for k,v in list(state_dict.items()): | |
+# if k.startswith(unwanted_prefix): | |
+# state_dict[k[len(unwanted_prefix):]] = state_dict.pop(k) | |
+# model.load_state_dict(state_dict) | |
model.eval() | |
model.to(device) | |
if compile: | |
@@ -48,9 +49,9 @@ if compile: | |
# look for the meta pickle in case it is available in the dataset folder | |
load_meta = False | |
-if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... | |
- meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') | |
- load_meta = os.path.exists(meta_path) | |
+# if 'config' in checkpoint and 'dataset' in checkpoint['config']: # older checkpoints might not have these... | |
+# meta_path = os.path.join('data', checkpoint['config']['dataset'], 'meta.pkl') | |
+# load_meta = os.path.exists(meta_path) | |
if load_meta: | |
print(f"Loading meta from {meta_path}...") | |
with open(meta_path, 'rb') as f: | |
@@ -70,10 +71,16 @@ else: | |
start_ids = encode(start) | |
x = (torch.tensor(start_ids, dtype=torch.long, device=device)[None, ...]) | |
+start_time = time.time() | |
# run generation | |
with torch.no_grad(): | |
with ctx: | |
for k in range(num_samples): | |
y = model.generate(x, max_new_tokens, temperature=temperature, top_k=top_k) | |
- print(decode(y[0].tolist())) | |
+ generated = y[0].tolist() | |
+ print(decode(generated)) | |
+ print('Raw tokens', generated) | |
print('---------------') | |
+end_time = time.time() | |
+ | |
+print(f'Generation time: {end_time - start_time:.3f} seconds') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment