Skip to content

Instantly share code, notes, and snippets.

@brownan
Last active May 21, 2024 18:50
Show Gist options
  • Save brownan/29108eb73f547dc6f6f55ebe4f2c5e94 to your computer and use it in GitHub Desktop.
Save brownan/29108eb73f547dc6f6f55ebe4f2c5e94 to your computer and use it in GitHub Desktop.
Ollama command line interface with Markdown rendering
#!/bin/env python3
import argparse
import io
import json
import os
import sys
import urllib.parse
import urllib.request
from typing import Optional, NamedTuple
import rich.console
import rich.live
import rich.markdown
import rich.panel
import rich.spinner
import rich.text
import rich.table
import rich.pretty
query_url = urllib.parse.urljoin(
os.environ.get("OLLAMA_HOST", "http://localhost:11434"), "/api/generate"
)
class ExecutionParams(NamedTuple):
raw: bool
stats: bool
seed: Optional[int] = None
temperature: Optional[float] = None
def main():
parser = argparse.ArgumentParser()
parser.add_argument(
"--models",
"--model",
default="llama3",
help="The name of the models to run. Multiple models can be specified separated by commas",
)
parser.add_argument(
"--raw",
action="store_true",
help="Do not format the results with markdown. Output the results to stdout",
)
parser.add_argument(
"--stats",
action="store_true",
help="Output some stats about the generation at the end",
)
parser.add_argument(
"--seed", type=int, help="The seed to use in the model evaluation. "
"for deterministic output you must also set temperature=0."
)
parser.add_argument(
"--temperature",
type=float,
help="The temperature to use in the model evaluation. A value of 0 means the model"
" will always pick the highest probability token. Values above 0 add some"
" randomness to token selection and generally make the model more creative."
" For deterministic output, set this to 0 and set a seed to a non-zero constant."
" Default in ollama is 0.8.",
)
parser.add_argument("prompt", nargs="*")
args = vars(parser.parse_args())
raw = args["raw"]
if args["prompt"]:
prompt = " ".join(args["prompt"])
else:
prompt = sys.stdin.read()
console = rich.console.Console(stderr=raw)
if not raw:
console.print(rich.panel.Panel.fit("[bold]Prompt"))
console.print(rich.text.Text(prompt))
models = args["models"].split(",")
for model in models:
run_ollama(
model,
prompt,
console,
ExecutionParams(
raw=raw,
stats=args["stats"],
seed=args["seed"],
temperature=args["temperature"],
),
)
def run_ollama(model, prompt, console: rich.console.Console, params: ExecutionParams):
raw = params.raw
if not raw:
console.print()
console.print(rich.panel.Panel.fit("[bold]" + model))
req_data = {
"model": model,
"prompt": prompt,
"options": {},
}
if params.temperature is not None:
req_data["options"]["temperature"] = params.temperature
if params.seed is not None:
req_data["options"]["seed"] = params.seed
request = urllib.request.Request(
query_url,
data=json.dumps(
req_data
).encode("utf-8"),
headers={
"Content-Type": "application/json; charset=utf-8",
},
)
output = io.StringIO()
spinner = rich.spinner.Spinner(
"dots", text="Loading model...", style="status.spinner", speed=1.0
)
stats_data: Optional[dict] = None
with rich.live.Live(
spinner,
console=console,
vertical_overflow="ellipsis",
refresh_per_second=12.5,
transient=raw,
) as live:
response = urllib.request.urlopen(request)
with response:
response_buf = io.TextIOWrapper(response, encoding="utf-8")
while True:
buf = response_buf.readline()
if not buf:
break
data = json.loads(buf)
if data["done"]:
stats_data = data
if raw:
sys.stdout.write("\n")
break
if raw:
live.stop()
sys.stdout.write(data["response"])
else:
output.write(data["response"])
live.update(rich.markdown.Markdown(output.getvalue()))
if params.stats and stats_data:
console.print()
table = rich.table.Table(f"Generation Stats", box=None)
table.add_column()
table.add_column()
def format_int(val):
return format(val, "n")
def format_float(val):
return format(val, ".2f")
table.add_row("[bold]Model", stats_data.get("model"))
table.add_row("[bold]Created At", stats_data.get("created_at"))
table.add_row(
"[bold]Total Duration (s)",
format_float(int(stats_data.get("total_duration")) / 10**9),
)
table.add_row(
"[bold]Load Duration (s)",
format_float(int(stats_data.get("load_duration")) / 10**9),
)
if "prompt_eval_count" in stats_data:
table.add_row(
"[bold]Input Tokens", format_int(int(stats_data.get("prompt_eval_count")))
)
table.add_row(
"[bold]Prompt Evaluation Duration (s)",
format_float(int(stats_data.get("prompt_eval_duration")) / 10**9),
)
table.add_row(
"[bold]Response Tokens", format_int(int(stats_data.get("eval_count")))
)
table.add_row(
"[bold]Response Evaluation Duration (s)",
format_float(int(stats_data.get("eval_duration")) / 10**9),
)
table.add_row(
"[bold]Tokens per second",
format_float(
int(stats_data.get("eval_count"))
/ int(stats_data.get("eval_duration"))
* 10**9
),
)
console.print(table)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment