Skip to content

Instantly share code, notes, and snippets.

Created June 25, 2023 16:06
Show Gist options
  • Save conradlz/8cd6d9d939390f8726d32d303fc9cb4f to your computer and use it in GitHub Desktop.
Save conradlz/8cd6d9d939390f8726d32d303fc9cb4f to your computer and use it in GitHub Desktop.
HuggingFace Inference API Handler for RWKV-4-World-7B-v1-OnlyForTest_84%_trained-20230618-ctx4096.pth
# The RWKV Language Model -
# import os
# import copy
import os, gc
# os.environ["CUDA_VISIBLE_DEVICES"] = "0,1,2,3"
os.environ["RWKV_JIT_ON"] = '1'
os.environ["RWKV_CUDA_ON"] = '1' # if '1' then use CUDA kernel for seq mode (much faster)
from typing import Dict, List, Any
import torch
from rwkv.model import RWKV
from rwkv.utils import PIPELINE, PIPELINE_ARGS
from pynvml import nvmlInit, nvmlDeviceGetHandleByIndex, nvmlDeviceGetMemoryInfo
gpu_h1 = nvmlDeviceGetHandleByIndex(0)
# gpu_h2 = nvmlDeviceGetHandleByIndex(1)
# gpu_h3 = nvmlDeviceGetHandleByIndex(2)
# gpu_h4 = nvmlDeviceGetHandleByIndex(3)
ctx_limit = 4096
# MODEL_NAME = '/repository/RWKV-4-World-0.1B-v1-20230520-ctx4096.pth'
MODEL_NAME = '/repository/RWKV-4-World-7B-v1-OnlyForTest_84%_trained-20230618-ctx4096.pth'
torch.backends.cudnn.benchmark = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cuda.matmul.allow_tf32 = True
class EndpointHandler():
def __init__(self, path=""):
# load the model
self.model = RWKV(model=MODEL_NAME, strategy='cuda fp16')
# create inference pipeline
self.pipeline = PIPELINE(self.model, "rwkv_vocab_v20230424")
def __call__(self, data: Dict[str, Any]) -> List[Dict[str, Any]]:
predictions = []
inputs = data.pop("inputs", [])
for inputValue in inputs:
context = self.generate_prompt(
self.sanitize_user_input(inputValue.get('instruction', 'Suggest an instruction here.')),
self.sanitize_user_input(inputValue.get('input', None)),
'response': self.evaluate(context)
return predictions
def evaluate(
presencePenalty = 0.1,
countPenalty = 0.1
temperature = max(0.2, float(temperature)),
top_p = float(top_p),
alpha_frequency = countPenalty,
alpha_presence = presencePenalty,
token_ban = [], # ban the generation of some tokens
token_stop = [0] # stop generation whenever you see any token here
ctx = inputContext
all_tokens = []
out_last = 0
out_str = ''
occurrence = {}
state = None
for i in range(int(token_count)):
out, state = self.model.forward(self.pipeline.encode(ctx)[-ctx_limit:] if i == 0 else [token], state)
for n in occurrence:
out[n] -= (args.alpha_presence + occurrence[n] * args.alpha_frequency)
token = self.pipeline.sample_logits(out, temperature=args.temperature, top_p=args.top_p)
if token in args.token_stop:
all_tokens += [token]
if token not in occurrence:
occurrence[token] = 1
occurrence[token] += 1
tmp = self.pipeline.decode(all_tokens[out_last:])
if '\ufffd' not in tmp:
out_str += tmp
out_last = i + 1
gpu_info1 = nvmlDeviceGetMemoryInfo(gpu_h1)
print(f'vram {} used {gpu_info1.used} free {}')
# gpu_info2 = nvmlDeviceGetMemoryInfo(gpu_h2)
# print(f'vram {} used {gpu_info2.used} free {}')
# gpu_info3 = nvmlDeviceGetMemoryInfo(gpu_h3)
# print(f'vram {} used {gpu_info3.used} free {}')
# gpu_info4 = nvmlDeviceGetMemoryInfo(gpu_h4)
# print(f'vram {} used {gpu_info4.used} free {}')
del out
del state
return out_str.strip()
def sanitize_user_input(self, user_input: str) -> str:
if user_input and isinstance(user_input, str):
return user_input.strip().replace('\r\n','\n').replace('\n\n','\n')
def generate_prompt(self, instruction, inputValue=None) -> str:
instruction = instruction.strip().replace('\r\n','\n').replace('\n\n','\n')
if inputValue and isinstance(inputValue, str):
inputValue = inputValue.strip().replace('\r\n','\n').replace('\n\n','\n')
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
# Instruction:
# Input:
# Response:
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
# Instruction:
# Response:
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment