Created
November 7, 2022 16:26
-
-
Save grahamannett/5c46135ca36001ae95c519a2426a7603 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from concurrent.futures import ThreadPoolExecutor | |
import cohere | |
import os | |
co = cohere.Client(os.environ.get("COHERE_KEY"), check_api_key=False) | |
import torch | |
import time | |
zipped_args = torch.load("zipped_args.pt") | |
zipped_args = zipped_args[0:16] | |
option, prompt, return_likelihood = zipped_args[0] | |
t1 = time.perf_counter() | |
response = co.generate(prompt=prompt, max_tokens=0, model="xlarge", return_likelihoods=return_likelihood) | |
print(response.generations[0].likelihood) | |
t2 = time.perf_counter() | |
print(f"Time elapsed: {t2-t1}") | |
def _func(args): | |
option, prompt, return_likelihood = args | |
out = ( | |
co.generate( | |
prompt=prompt, | |
max_tokens=0, | |
model="xlarge", | |
return_likelihoods=return_likelihood, | |
) | |
.generations[0] | |
.likelihood | |
) | |
return out | |
t3 = time.perf_counter() | |
with ThreadPoolExecutor(32) as thread_pool: | |
_lh = thread_pool.map(_func, zipped_args) | |
t4 = time.perf_counter() | |
print(f"Time elapsed with ThreadPool Exec: {t4-t3}") | |
# throw this in or cohere seems to 400 | |
time.sleep(3) | |
import json | |
import requests | |
from urllib.parse import urljoin | |
from typing import List, Dict | |
import sys | |
GENERATE_URL = "generate" | |
COHERE_API_URL = "https://api.cohere.ai" | |
api_url = COHERE_API_URL | |
url = urljoin(api_url, GENERATE_URL) | |
def __request(json_body, endpoint): | |
headers = { | |
"Authorization": "BEARER {}".format(co.api_key), | |
"Content-Type": "application/json", | |
"Request-Source": "python-sdk", | |
} | |
if co.cohere_version != "": | |
headers["Cohere-Version"] = co.cohere_version | |
url = urljoin(api_url, endpoint) | |
response = requests.request("POST", url, headers=headers, data=json_body, **co.request_dict) | |
breakpoint() | |
try: | |
res = json.loads(response.text) | |
except Exception: | |
raise Exception("Invalid response from server: {}".format(response.text)) | |
# if 'message' in res.keys(): # has errors | |
# raise Exception("message in res") | |
# if 'X-API-Warning' in response.headers: | |
# print("\033[93mWarning: {}\n\033[0m".format(response.headers['X-API-Warning']), file=sys.stderr) | |
return res | |
async def a_request(json_body, endpoint): | |
headers = { | |
"Authorization": "BEARER {}".format(co.api_key), | |
"Content-Type": "application/json", | |
"Request-Source": "python-sdk", | |
} | |
if co.cohere_version != "": | |
headers["Cohere-Version"] = co.cohere_version | |
url = urljoin(api_url, endpoint) | |
def get_body( | |
prompt: str = None, | |
prompt_vars: object = {}, | |
model: str = None, | |
preset: str = None, | |
num_generations: int = 1, | |
max_tokens: int = None, | |
temperature: float = 1.0, | |
k: int = 0, | |
p: float = 0.75, | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
stop_sequences: List[str] = None, | |
return_likelihoods: str = "NONE", | |
truncate: str = None, | |
logit_bias: Dict[int, float] = {}, | |
): | |
body = { | |
"model": model, | |
"prompt": prompt, | |
"prompt_vars": prompt_vars, | |
"preset": preset, | |
"num_generations": num_generations, | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"k": k, | |
"p": p, | |
"frequency_penalty": frequency_penalty, | |
"presence_penalty": presence_penalty, | |
"stop_sequences": stop_sequences, | |
"return_likelihoods": return_likelihoods, | |
"truncate": truncate, | |
"logit_bias": logit_bias, | |
} | |
return body | |
def generate( | |
prompt: str = None, | |
prompt_vars: object = {}, | |
model: str = None, | |
preset: str = None, | |
num_generations: int = 1, | |
max_tokens: int = None, | |
temperature: float = 1.0, | |
k: int = 0, | |
p: float = 0.75, | |
frequency_penalty: float = 0.0, | |
presence_penalty: float = 0.0, | |
stop_sequences: List[str] = None, | |
return_likelihoods: str = "NONE", | |
truncate: str = None, | |
logit_bias: Dict[int, float] = {}, | |
): | |
json_body = json.dumps( | |
{ | |
"model": model, | |
"prompt": prompt, | |
"prompt_vars": prompt_vars, | |
"preset": preset, | |
"num_generations": num_generations, | |
"max_tokens": max_tokens, | |
"temperature": temperature, | |
"k": k, | |
"p": p, | |
"frequency_penalty": frequency_penalty, | |
"presence_penalty": presence_penalty, | |
"stop_sequences": stop_sequences, | |
"return_likelihoods": return_likelihoods, | |
"truncate": truncate, | |
"logit_bias": logit_bias, | |
} | |
) | |
return __request(json_body, GENERATE_URL) | |
import httpx | |
import asyncio | |
async def main(): | |
headers = { | |
"Authorization": "BEARER {}".format(co.api_key), | |
"Content-Type": "application/json", | |
"Request-Source": "python-sdk", | |
} | |
if co.cohere_version != "": | |
headers["Cohere-Version"] = co.cohere_version | |
# body = get_body(prompt=prompt, max_tokens=0, model="xlarge", return_likelihoods=return_likelihood) | |
client = httpx.AsyncClient() | |
bodys = [ | |
get_body( | |
prompt=prompt, | |
max_tokens=0, | |
model="xlarge", | |
return_likelihoods=return_likelihood, | |
) | |
for (option, prompt, return_likelihood) in zipped_args | |
] | |
t3 = time.perf_counter() | |
async with httpx.AsyncClient() as c: | |
out = await asyncio.gather(*[c.post(url, headers=headers, json=body) for body in bodys]) | |
t4 = time.perf_counter() | |
print(f"Time elapsed for async: {t4-t3}") | |
asyncio.run(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment