Skip to content

Instantly share code, notes, and snippets.

@grahamannett
Created November 7, 2022 16:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save grahamannett/5c46135ca36001ae95c519a2426a7603 to your computer and use it in GitHub Desktop.
Save grahamannett/5c46135ca36001ae95c519a2426a7603 to your computer and use it in GitHub Desktop.
from concurrent.futures import ThreadPoolExecutor
import cohere
import os
co = cohere.Client(os.environ.get("COHERE_KEY"), check_api_key=False)
import torch
import time
zipped_args = torch.load("zipped_args.pt")
zipped_args = zipped_args[0:16]
option, prompt, return_likelihood = zipped_args[0]
t1 = time.perf_counter()
response = co.generate(prompt=prompt, max_tokens=0, model="xlarge", return_likelihoods=return_likelihood)
print(response.generations[0].likelihood)
t2 = time.perf_counter()
print(f"Time elapsed: {t2-t1}")
def _func(args):
option, prompt, return_likelihood = args
out = (
co.generate(
prompt=prompt,
max_tokens=0,
model="xlarge",
return_likelihoods=return_likelihood,
)
.generations[0]
.likelihood
)
return out
t3 = time.perf_counter()
with ThreadPoolExecutor(32) as thread_pool:
_lh = thread_pool.map(_func, zipped_args)
t4 = time.perf_counter()
print(f"Time elapsed with ThreadPool Exec: {t4-t3}")
# throw this in or cohere seems to 400
time.sleep(3)
import json
import requests
from urllib.parse import urljoin
from typing import List, Dict
import sys
GENERATE_URL = "generate"
COHERE_API_URL = "https://api.cohere.ai"
api_url = COHERE_API_URL
url = urljoin(api_url, GENERATE_URL)
def __request(json_body, endpoint):
headers = {
"Authorization": "BEARER {}".format(co.api_key),
"Content-Type": "application/json",
"Request-Source": "python-sdk",
}
if co.cohere_version != "":
headers["Cohere-Version"] = co.cohere_version
url = urljoin(api_url, endpoint)
response = requests.request("POST", url, headers=headers, data=json_body, **co.request_dict)
breakpoint()
try:
res = json.loads(response.text)
except Exception:
raise Exception("Invalid response from server: {}".format(response.text))
# if 'message' in res.keys(): # has errors
# raise Exception("message in res")
# if 'X-API-Warning' in response.headers:
# print("\033[93mWarning: {}\n\033[0m".format(response.headers['X-API-Warning']), file=sys.stderr)
return res
async def a_request(json_body, endpoint):
headers = {
"Authorization": "BEARER {}".format(co.api_key),
"Content-Type": "application/json",
"Request-Source": "python-sdk",
}
if co.cohere_version != "":
headers["Cohere-Version"] = co.cohere_version
url = urljoin(api_url, endpoint)
def get_body(
prompt: str = None,
prompt_vars: object = {},
model: str = None,
preset: str = None,
num_generations: int = 1,
max_tokens: int = None,
temperature: float = 1.0,
k: int = 0,
p: float = 0.75,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stop_sequences: List[str] = None,
return_likelihoods: str = "NONE",
truncate: str = None,
logit_bias: Dict[int, float] = {},
):
body = {
"model": model,
"prompt": prompt,
"prompt_vars": prompt_vars,
"preset": preset,
"num_generations": num_generations,
"max_tokens": max_tokens,
"temperature": temperature,
"k": k,
"p": p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stop_sequences": stop_sequences,
"return_likelihoods": return_likelihoods,
"truncate": truncate,
"logit_bias": logit_bias,
}
return body
def generate(
prompt: str = None,
prompt_vars: object = {},
model: str = None,
preset: str = None,
num_generations: int = 1,
max_tokens: int = None,
temperature: float = 1.0,
k: int = 0,
p: float = 0.75,
frequency_penalty: float = 0.0,
presence_penalty: float = 0.0,
stop_sequences: List[str] = None,
return_likelihoods: str = "NONE",
truncate: str = None,
logit_bias: Dict[int, float] = {},
):
json_body = json.dumps(
{
"model": model,
"prompt": prompt,
"prompt_vars": prompt_vars,
"preset": preset,
"num_generations": num_generations,
"max_tokens": max_tokens,
"temperature": temperature,
"k": k,
"p": p,
"frequency_penalty": frequency_penalty,
"presence_penalty": presence_penalty,
"stop_sequences": stop_sequences,
"return_likelihoods": return_likelihoods,
"truncate": truncate,
"logit_bias": logit_bias,
}
)
return __request(json_body, GENERATE_URL)
import httpx
import asyncio
async def main():
headers = {
"Authorization": "BEARER {}".format(co.api_key),
"Content-Type": "application/json",
"Request-Source": "python-sdk",
}
if co.cohere_version != "":
headers["Cohere-Version"] = co.cohere_version
# body = get_body(prompt=prompt, max_tokens=0, model="xlarge", return_likelihoods=return_likelihood)
client = httpx.AsyncClient()
bodys = [
get_body(
prompt=prompt,
max_tokens=0,
model="xlarge",
return_likelihoods=return_likelihood,
)
for (option, prompt, return_likelihood) in zipped_args
]
t3 = time.perf_counter()
async with httpx.AsyncClient() as c:
out = await asyncio.gather(*[c.post(url, headers=headers, json=body) for body in bodys])
t4 = time.perf_counter()
print(f"Time elapsed for async: {t4-t3}")
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment