Forked from seanchatmangpt/gen_python_primitive.py
Created
February 13, 2024 15:39
-
-
Save inayet/de5ead91cdca1e6c9dd6ba28e14b35bb to your computer and use it in GitHub Desktop.
Generate a python primitive
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import ast | |
from dspy import Assert | |
from rdddy.generators.gen_module import GenModule | |
def is_primitive_type(data_type): | |
primitive_types = { | |
int, float, str, bool, list, tuple, dict, set | |
} | |
return data_type in primitive_types | |
class GenPythonPrimitive(GenModule): | |
def __init__(self, primitive_type, lm=None): | |
if not is_primitive_type(primitive_type): | |
raise ValueError(f'primitive type {primitive_type.__name__} must be a Python primitive type') | |
super().__init__(f"{primitive_type.__name__}_python_primitive_string", lm) | |
self.primitive_type = primitive_type | |
def validate_primitive(self, output) -> bool: | |
try: | |
return isinstance(ast.literal_eval(output), self.primitive_type) | |
except SyntaxError as error: | |
return False | |
def validate_output(self, output): | |
Assert( | |
self.validate_primitive(output), | |
f"You need to create a valid python {self.primitive_type.__name__} " | |
f"primitive type for \n{self.output_key}\n" | |
f"You will be penalized for not returning only a {self.primitive_type.__name__} for " | |
f"{self.output_key}", | |
) | |
data = ast.literal_eval(output) | |
if self.primitive_type is set: | |
data = set(data) | |
return data | |
def __call__(self, prompt): | |
return self.forward(prompt=prompt) | |
class GenDict(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=dict) | |
class GenList(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=list) | |
class GenBool(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=bool) | |
class GenInt(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=int) | |
class GenFloat(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=float) | |
class GenTuple(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=tuple) | |
class GenSet(GenPythonPrimitive): | |
def __init__(self): | |
super().__init__(primitive_type=set) | |
def main(): | |
result = GenTuple()("Create a list of planets in our solar system sorted by largest to smallest") | |
assert result == ('Jupiter', 'Saturn', 'Uranus', 'Neptune', 'Earth', 'Venus', 'Mars', 'Mercury') | |
print(f"The planets of the solar system are {result}") | |
if __name__ == '__main__': | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment