Created
June 25, 2023 16:06
-
-
Save conradlz/8cd6d9d939390f8726d32d303fc9cb4f to your computer and use it in GitHub Desktop.
HuggingFace Inference API Handler for RWKV-4-World-7B-v1-OnlyForTest_84%_trained-20230618-ctx4096.pth
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
######################################################################################################## | |
# The RWKV Language Model - https://github.com/BlinkDL/RWKV-LM | |
######################################################################################################## | |
# import os | |
# import copy | |
import os, gc | |
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3" | |
os.environ["RWKV_JIT_ON"] = '1' | |
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster) | |
from typing import Dict, List, Any | |
import torch | |
from rwkv.model import RWKV | |
from rwkv.utils import PIPELINE, PIPELINE_ARGS | |
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo | |
nvmlInit() | |
gpu_h1 = nvmlDeviceGetHandleByIndex(0) | |
# gpu_h2 = nvmlDeviceGetHandleByIndex(1) | |
# gpu_h3 = nvmlDeviceGetHandleByIndex(2) | |
# gpu_h4 = nvmlDeviceGetHandleByIndex(3) | |
ctx_limit = 4096 | |
# MODEL_NAME = '/repository/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth' | |
MODEL_NAME = '/repository/RWKV-4-World-7B-v1-OnlyForTest_84%_trained-20230618-ctx4096.pth' | |
torch.backends.cudnn.benchmark = True | |
torch.backends.cudnn.allow_tf32 = True | |
torch.backends.cuda.matmul.allow_tf32 = True | |
class EndpointHandler(): | |
def __init__(self, path=""): | |
# load the model | |
self.model = RWKV(model=MODEL_NAME, strategy='cuda fp16') | |
# create inference pipeline | |
self.pipeline = PIPELINE(self.model, "rwkv_vocab_v20230424") | |
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]: | |
predictions = [] | |
inputs = data.pop("inputs", []) | |
for inputValue in inputs: | |
context = self.generate_prompt( | |
self.sanitize_user_input(inputValue.get('instruction', 'Suggest an instruction here.')), | |
self.sanitize_user_input(inputValue.get('input', None)), | |
) | |
predictions.append({ | |
'response': self.evaluate(context) | |
}) | |
return predictions | |
def evaluate( | |
self, | |
inputContext, | |
token_count=800, | |
temperature=1.0, | |
top_p=0.7, | |
presencePenalty = 0.1, | |
countPenalty = 0.1 | |
): | |
args = PIPELINE_ARGS( | |
temperature = max(0.2, float(temperature)), | |
top_p = float(top_p), | |
alpha_frequency = countPenalty, | |
alpha_presence = presencePenalty, | |
token_ban = [], # ban the generation of some tokens | |
token_stop = [0] # stop generation whenever you see any token here | |
) | |
ctx = inputContext | |
all_tokens = [] | |
out_last = 0 | |
out_str = '' | |
occurrence = {} | |
state = None | |
for i in range(int(token_count)): | |
out, state = self.model.forward(self.pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state) | |
for n in occurrence: | |
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency) | |
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p) | |
if token in args.token_stop: | |
break | |
all_tokens += [token] | |
if token not in occurrence: | |
occurrence[token] = 1 | |
else: | |
occurrence[token] += 1 | |
tmp = self.pipeline.decode(all_tokens[out_last:]) | |
if '\ufffd' not in tmp: | |
out_str += tmp | |
out_last = i + 1 | |
gpu_info1 = nvmlDeviceGetMemoryInfo(gpu_h1) | |
print(f'vram {gpu_info1.total} used {gpu_info1.used} free {gpu_info1.free}') | |
# gpu_info2 = nvmlDeviceGetMemoryInfo(gpu_h2) | |
# print(f'vram {gpu_info2.total} used {gpu_info2.used} free {gpu_info2.free}') | |
# gpu_info3 = nvmlDeviceGetMemoryInfo(gpu_h3) | |
# print(f'vram {gpu_info3.total} used {gpu_info3.used} free {gpu_info3.free}') | |
# gpu_info4 = nvmlDeviceGetMemoryInfo(gpu_h4) | |
# print(f'vram {gpu_info4.total} used {gpu_info4.used} free {gpu_info4.free}') | |
del out | |
del state | |
gc.collect() | |
torch.cuda.empty_cache() | |
return out_str.strip() | |
def sanitize_user_input(self, user_input: str) -> str: | |
if user_input and isinstance(user_input, str): | |
return user_input.strip().replace('\r\n','\n').replace('\n\n','\n') | |
def generate_prompt(self, instruction, inputValue=None) -> str: | |
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n') | |
if inputValue and isinstance(inputValue, str): | |
inputValue = inputValue.strip().replace('\r\n','\n').replace('\n\n','\n') | |
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. | |
# Instruction: | |
{instruction} | |
# Input: | |
{inputValue} | |
# Response: | |
""" | |
else: | |
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request. | |
# Instruction: | |
{instruction} | |
# Response: | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment