Last active
December 6, 2023 22:18
-
-
Save omargfh/13b4128c106417ecc949af5d2ef1d24d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import random | |
class ConversationBuilder(): | |
def __init__(self): | |
self.signal_keys = dict() # key -> signal | |
self.signal_responses = dict() # signal -> list of responses | |
self.initial_response = None | |
self.termination_signal = None | |
self.parser = None | |
self.user_signal_function = None | |
def add_signal_response(self, signal, response): | |
return self.add_signal_responses(signal, [response]) | |
def add_signal_responses(self, signal, responses): | |
if signal in self.signal_responses: | |
self.signal_responses[signal].extend(responses) | |
else: | |
self.signal_responses[signal] = responses.copy() | |
return self | |
def add_signal_key(self, signal, key): | |
return self.add_signal_keys(signal, [key]) | |
def add_signal_keys(self, signal, keys): | |
for key in keys: | |
self.signal_keys[self.parse_input(key)] = signal | |
return self | |
def parse_input(self, user_input): | |
if self.parser is not None: | |
return self.parser(user_input) | |
else: | |
# Default parser: lowercase, remove punctuation | |
return user_input.lower().strip(".,!?") | |
def get_signal(self, user_input, user_signal_function = None): | |
# Parse the input and return the signal | |
parsed_input = self.parse_input(user_input) | |
# If a custom signal function is provided, use it | |
if user_signal_function is not None: | |
return user_signal_function(parsed_input, self.signal_keys, self.get_signal) | |
# Otherwise, use the default signal function | |
if parsed_input in self.signal_keys: | |
return self.signal_keys[parsed_input] | |
else: | |
return None | |
def set_initial_response(self, response): | |
# responds to the null signal | |
self.initial_response = response | |
return self | |
def set_termination_signal(self, signal): | |
self.termination_signal = signal | |
return self | |
def is_termination_signal(self, signal): | |
return signal == self.termination_signal | |
def get_response(self, signal): | |
if signal in self.signal_responses: | |
return random.choice(self.signal_responses[signal]) | |
else: | |
return None | |
def set_parser(self, parser): | |
self.parser = parser | |
return self | |
def set_signal_function(self, signal_function): | |
self.signal_function = signal_function | |
return self | |
def build(self): | |
return Conversation(self) | |
class Conversation(): | |
def __init__(self, scaffold): | |
self.history = [] | |
self.scaffold = scaffold | |
def start(self): | |
runtime = self.scaffold | |
if runtime.initial_response is not None: | |
print(runtime.initial_response) | |
while True: | |
user_input = input("> ") | |
if runtime.user_signal_function: | |
curried_signal_function = lambda x, y, z: runtime.signal_function(x, y, self.history, z) | |
else: | |
curried_signal_function = None | |
signal = runtime.get_signal(user_input, curried_signal_function) | |
response = runtime.get_response(signal) | |
if response is not None: | |
print(response) | |
else: | |
print("I don't understand.") | |
self.history.append((user_input, signal, response)) | |
if runtime.is_termination_signal(signal): | |
break | |
def main(): | |
def signal_function(user_input, signal_keys, history, get_signal): | |
# Override the default signal function | |
# Default signal function: return random response from signal_responses[signal] | |
# Same behavior as get_signal | |
return "SIG_ONE" if len(history) == 0 else get_signal(user_input, None) | |
def parser(user_input): | |
# Override the default parser | |
# Default parser: lambda x: x.lower().strip(".,!?") | |
return user_input | |
conversation_builder = ConversationBuilder() | |
conversation_builder.set_initial_response("2ol 1 or 2") | |
conversation_builder \ | |
.add_signal_keys("SIG_ONE", ["1", "one", "wa7ed", "wa7d", "wahed", "wahd"]) \ | |
.add_signal_response("SIG_ONE", "sba7 el fol\ntgrb tany") \ | |
.add_signal_response("SIG_ONE", "tb 2ol 2") | |
conversation_builder \ | |
.add_signal_keys("SIG_TWO", ["2", "two", "etnen", "etnyn", "itnen", "itnyn"]) \ | |
.add_signal_response("SIG_TWO", "msa2 el kheir\ntgrb tany") \ | |
.add_signal_response("SIG_TWO", "tb 2ol 1") \ | |
.set_termination_signal("SIG_TWO") | |
conversation_builder \ | |
.add_signal_keys("SIG_NO", ["no", "la", "laa", "la2a", "la", "nope", "noo", 'la2', "la'a"]) \ | |
.add_signal_response("SIG_NO", "tshrb shai") \ | |
.set_termination_signal("SIG_NO") | |
conversation_builder \ | |
.set_parser(parser) \ | |
.set_signal_function(signal_function) | |
conversation = conversation_builder.build() | |
conversation.start() | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment