Created
April 11, 2023 01:40
-
-
Save zklhp/a60c4501060383d1cb99b4b6e24109d1 to your computer and use it in GitHub Desktop.
Use rwkv.cpp to talk with Raven model
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# python -i chat_with_Raven.py rwkv.cpp-14B-Q4_1_O.bin | |
import os | |
import argparse | |
import pathlib | |
import sampling | |
import tokenizers | |
import rwkv_cpp_model | |
import rwkv_cpp_shared_library | |
bot_message_prefix: str = '# Response: ' | |
max_tokens_per_generation: int = 4096 | |
def generate_prompt(instruction, input=None): | |
if input: | |
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
# Instruction: | |
{instruction} | |
# Input: | |
{input} | |
# Response: | |
""" | |
else: | |
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
# Instruction: | |
{instruction} | |
# Response: | |
""" | |
def generate_next(instruction, input=None): | |
if input: | |
return f"""# Instruction: | |
{instruction} | |
# Input: | |
{input} | |
# Response: | |
""" | |
else: | |
return f"""# Instruction: | |
{instruction} | |
# Response: | |
""" | |
logits, state = None, None | |
def evaluate( | |
prompt, | |
temperature=1.0, | |
top_p=0.7, | |
): | |
global logits, state | |
new_tokens = tokenizer.encode(prompt).ids | |
for token in new_tokens: | |
logits, state = model.eval(token, state, state, logits) | |
print(bot_message_prefix, end='', flush=True) | |
decoded = '' | |
for i in range(max_tokens_per_generation): | |
token = sampling.sample_logits(logits, temperature, top_p) | |
if token == 0: | |
break | |
decoded = tokenizer.decode([token]) | |
print(decoded, end='', flush=True) | |
logits, state = model.eval(token, state, state, logits) | |
print() | |
parser = argparse.ArgumentParser(description='Provide terminal-based chat interface for RWKV model') | |
parser.add_argument('model_path', help='Path to RWKV model in ggml format') | |
args = parser.parse_args() | |
print('Loading 20B tokenizer') | |
tokenizer_path = pathlib.Path(os.path.abspath(__file__)).parent / '20B_tokenizer.json' | |
tokenizer = tokenizers.Tokenizer.from_file(str(tokenizer_path)) | |
library = rwkv_cpp_shared_library.load_rwkv_shared_library() | |
print(f'System info: {library.rwkv_get_system_info_string()}') | |
print('Loading RWKV model') | |
model = rwkv_cpp_model.RWKVModel(library, args.model_path) | |
print('\nChat initialized! Write something and press Enter.') | |
print('- Use \'+\' to start a new dialog.') | |
print('- To fill the input, use \'\\\' at the end of line.\n') | |
user_instruction = input('> ') | |
assert user_instruction != '', 'Prompt must not be empty' | |
if '\\' in user_instruction[-1:]: | |
user_input = input('> ') | |
assert user_input != '', 'Prompt must not be empty' | |
evaluate(generate_prompt(user_instruction, user_input)) | |
else: | |
evaluate(generate_prompt(user_instruction)) | |
while True: | |
user_instruction = input('> ') | |
assert user_instruction != '', 'Prompt must not be empty' | |
if '\\' in user_instruction[-1:]: | |
user_input = input('> ') | |
assert user_input != '', 'Prompt must not be empty' | |
evaluate(generate_next(user_instruction, user_input)) | |
elif user_instruction[0] == '+': | |
logits, state = None, None | |
print('Open a new dialog.', flush=True) | |
evaluate(generate_prompt(user_instruction[1:])) | |
else: | |
evaluate(generate_next(user_instruction)) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment