Created
July 8, 2023 22:21
-
-
Save simonw/f3686efc447678d5eb5e98331e5f18e6 to your computer and use it in GitHub Desktop.
Example plugin, refs https://github.com/simonw/llm/issues/53
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import llm | |
import random | |
import time | |
from typing import Optional | |
from pydantic import field_validator | |
@llm.hookimpl | |
def register_models(register): | |
register(Markov()) | |
def build_markov_table(text): | |
words = text.split() | |
transitions = {} | |
# Loop through all but the last word | |
for i in range(len(words) - 1): | |
word = words[i] | |
next_word = words[i + 1] | |
transitions.setdefault(word, []).append(next_word) | |
return transitions | |
def generate(transitions, length, start_word=None): | |
all_words = list(transitions.keys()) | |
next_word = start_word or random.choice(all_words) | |
for i in range(length): | |
yield next_word | |
options = transitions.get(next_word) or all_words | |
next_word = random.choice(options) | |
class Markov(llm.Model): | |
model_id = "markov" | |
can_stream = True | |
class Options(llm.Options): | |
length: Optional[int] = None | |
delay: Optional[float] = None | |
@field_validator("length") | |
def validate_length(cls, length): | |
if length is None: | |
return None | |
if length < 2: | |
raise ValueError("length must be >= 2") | |
return length | |
@field_validator("delay") | |
def validate_delay(cls, delay): | |
if delay is None: | |
return None | |
if not 0 <= delay <= 10: | |
raise ValueError("delay must be between 0 and 10") | |
return delay | |
class Response(llm.Response): | |
def iter_prompt(self, prompt): | |
text = prompt.prompt | |
transitions = build_markov_table(text) | |
length = prompt.options.length or 10 | |
for word in generate(transitions, length): | |
yield word + ' ' | |
if prompt.options.delay: | |
time.sleep(prompt.options.delay) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[project] | |
name = "llm-markov" | |
version = "0.1" | |
[project.entry-points.llm] | |
markov = "llm_markov" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment