Skip to content

Instantly share code, notes, and snippets.

@zklhp
Created April 11, 2023 01:40
Show Gist options
  • Save zklhp/a60c4501060383d1cb99b4b6e24109d1 to your computer and use it in GitHub Desktop.
Save zklhp/a60c4501060383d1cb99b4b6e24109d1 to your computer and use it in GitHub Desktop.
Use rwkv.cpp to talk with Raven model
# python -i chat_with_Raven.py rwkv.cpp-14B-Q4_1_O.bin
import os
import argparse
import pathlib
import sampling
import tokenizers
import rwkv_cpp_model
import rwkv_cpp_shared_library
bot_message_prefix: str = '# Response: '
max_tokens_per_generation: int = 4096
def generate_prompt(instruction, input=None):
if input:
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
# Instruction:
{instruction}
# Input:
{input}
# Response:
"""
else:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
# Instruction:
{instruction}
# Response:
"""
def generate_next(instruction, input=None):
if input:
return f"""# Instruction:
{instruction}
# Input:
{input}
# Response:
"""
else:
return f"""# Instruction:
{instruction}
# Response:
"""
logits, state = None, None
def evaluate(
prompt,
temperature=1.0,
top_p=0.7,
):
global logits, state
new_tokens = tokenizer.encode(prompt).ids
for token in new_tokens:
logits, state = model.eval(token, state, state, logits)
print(bot_message_prefix, end='', flush=True)
decoded = ''
for i in range(max_tokens_per_generation):
token = sampling.sample_logits(logits, temperature, top_p)
if token == 0:
break
decoded = tokenizer.decode([token])
print(decoded, end='', flush=True)
logits, state = model.eval(token, state, state, logits)
print()
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model')
parser.add_argument('model_path', help='Path to RWKV model in ggml format')
args = parser.parse_args()
print('Loading 20B tokenizer')
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json'
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path))
library = rwkv_cpp_shared_library.load_rwkv_shared_library()
print(f'System info: {library.rwkv_get_system_info_string()}')
print('Loading RWKV model')
model = rwkv_cpp_model.RWKVModel(library, args.model_path)
print('\nChat initialized! Write something and press Enter.')
print('- Use \'+\' to start a new dialog.')
print('- To fill the input, use \'\\\' at the end of line.\n')
user_instruction = input('> ')
assert user_instruction != '', 'Prompt must not be empty'
if '\\' in user_instruction[-1:]:
user_input = input('> ')
assert user_input != '', 'Prompt must not be empty'
evaluate(generate_prompt(user_instruction, user_input))
else:
evaluate(generate_prompt(user_instruction))
while True:
user_instruction = input('> ')
assert user_instruction != '', 'Prompt must not be empty'
if '\\' in user_instruction[-1:]:
user_input = input('> ')
assert user_input != '', 'Prompt must not be empty'
evaluate(generate_next(user_instruction, user_input))
elif user_instruction[0] == '+':
logits, state = None, None
print('Open a new dialog.', flush=True)
evaluate(generate_prompt(user_instruction[1:]))
else:
evaluate(generate_next(user_instruction))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment