Skip to content

Instantly share code, notes, and snippets.

@moyix
Created August 29, 2022 03:37
Show Gist options
  • Star 2 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save moyix/93ed1042f91550a2dc7ed47be69a741f to your computer and use it in GitHub Desktop.
Save moyix/93ed1042f91550a2dc7ed47be69a741f to your computer and use it in GitHub Desktop.
#!/usr/bin/env python
import torch
from transformers import CodeGenConfig, CodeGenForCausalLM, CodeGenTokenizer
from transformers.utils.hub import cached_file
NEW_SIZE = 4096
cg_config = CodeGenConfig.from_pretrained('Salesforce/codegen-350M-mono')
cg_config.n_ctx = NEW_SIZE
cg_config.n_positions = NEW_SIZE
weights_file = cached_file('Salesforce/codegen-350M-mono', 'pytorch_model.bin')
state_dict = torch.load(weights_file)
# Remove the causal mask from the state dict
for k in list(state_dict.keys()):
if k.endswith('causal_mask'): del state_dict[k]
model = CodeGenForCausalLM(cg_config)
model.load_state_dict(state_dict, strict=False)
model.cuda()
model.eval()
# Try to generate something
prompt = 'def hello_world(name):\n print('
tokenizer = CodeGenTokenizer.from_pretrained('Salesforce/codegen-350M-mono')
enc = tokenizer.encode(prompt, return_tensors='pt')
enc = enc.to(torch.device('cuda'))
out = model.generate(enc,
do_sample=True,
max_length=4096-(len(prompt)+1),
min_length=4096-(len(prompt)+1),
num_return_sequences=1
)
print(tokenizer.decode(out[0]))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment