Skip to content

Instantly share code, notes, and snippets.

@tamanobi
Last active October 21, 2022 10:55
Show Gist options
  • Save tamanobi/91b243405bb1282e0cb1b82a1ebb3724 to your computer and use it in GitHub Desktop.
Save tamanobi/91b243405bb1282e0cb1b82a1ebb3724 to your computer and use it in GitHub Desktop.
DTO や Entity などのクラス間をいい感じに変換する Decoder の構想
# coding: utf-8
from dataclasses import dataclass
class Result:
def __str__(self):
return f"<{self.__class__.__name__}: {self._value}>"
def is_ok(self) -> bool:
...
def ok(self):
...
def is_error(self) -> bool:
...
def error(self):
...
def unwrap(self):
return self._value
class Ok(Result):
_value: any
def __init__(self, value):
self._value = value
def __str__(self):
return f"<{self.__class__.__name__}: {self._value}>"
def is_ok(self):
return True
def ok(self):
return self._value
def is_error(self):
return self.is_ok()
def error(self):
return None
class Err(Result):
_value: any
def __init__(self, value=None):
self._value = value
def __str__(self):
if isinstance(self._value, Exception):
return f"<{self.__class__.__name__}: {self._value.__class__.__name__}>"
if hasattr(self._value, "__iter__"):
names = []
for v in self._value:
names.append(v.__class__.__name__)
names_string = ", ".join(names)
return f"<{self.__class__.__name__}: [{names_string}]>"
return f"<{self.__class__.__name__}: ???>"
def is_ok(self):
return False
def ok():
return None
def is_error(self):
return not self.is_ok()
def error():
return self._value
# ---------------------------------------------------------------------
class Transformable:
def transform(self, func):
return func(self)
@dataclass
class Person(Transformable):
name: str
age: int
address: str
@dataclass
class PersonData:
name: str
age: str
address: str
@classmethod
def from_dict(cls, data: dict):
return cls(**data)
def to_entity(self):
return Person(name=self.name, age=int(self.age), address=self.address)
@classmethod
def to_dto(cls, source: Person):
return cls(source.name, str(source.age), source.address)
data = PersonData.from_dict({"name": "yasutaka", "age": "30", "address": "niigata"})
print(data)
print(data.to_entity())
print(data.to_entity().transform(PersonData.to_dto))
class D:
callables: list
def __init__(self):
self.callables = []
@classmethod
def choice(self, error: Exception, choices: list):
def inner(value):
if value in choices:
return Ok(value)
else:
return Err(error())
return inner
@classmethod
def int(self, error: Exception):
def inner(value):
try:
return Ok(int(value))
except ValueError:
return Err(error())
except TypeError:
return Err(error())
return inner
@classmethod
def min_length(self, error: Exception, length: int):
def inner(value):
if len(value) >= length:
return Ok(value)
else:
return Err(error())
return inner
@classmethod
def max_length(self, error: Exception, length: int):
def inner(value):
if len(value) <= length:
return Ok(value)
else:
return Err(error())
return inner
@classmethod
def max_bound(self, error: Exception, max_: int):
def inner(value):
if value <= max_:
return Ok(value)
else:
return Err(error())
return inner
def __rshift__(self, func):
self.callables.append(func)
return self
def __call__(self, value):
prev = None
for func in self.callables:
if prev is None:
result = prev = func(value)
else:
result = prev = func(prev)
if result.is_ok():
prev = prev.unwrap()
else:
return result
return result
class DecoderT:
class_ = None
decoders = list()
def __init__(self, to=class_):
self.class_ = to
def __rshift__(self, value):
self.decoders.append(value)
return self
@classmethod
def field(cls, name, decoder):
return (name, decoder)
def __call__(self, value):
args = []
errors = []
for name, decoder in self.decoders:
result = decoder(getattr(value, name))
if result.is_ok():
args.append(result.unwrap())
else:
errors.append(result.unwrap())
if errors:
return Err(errors)
return Ok(self.class_(*args))
# ---------------------------------------------------------------------
class MaxExceeded(Exception):
pass
class EmptyError(Exception):
pass
class MaxLengthError(Exception):
pass
class InvalidAddress(Exception):
pass
age_decoder = (
D()
>> D.min_length(EmptyError, 1)
>> D.int(ValueError)
>> D.max_bound(MaxExceeded, 99)
)
assert age_decoder("10").is_ok()
name_decoder = (
D()
>> D.min_length(EmptyError, 1)
>> D.max_length(MaxLengthError, 10)
)
assert name_decoder("こんにちは。").is_ok()
address_decoder = (
D()
>> D.min_length(EmptyError, 1)
>> D.choice(InvalidAddress, ["新潟", "東京"])
)
assert address_decoder("新潟").is_ok()
decoder = (
DecoderT(to=Person)
>> DecoderT.field("name", name_decoder)
>> DecoderT.field("age", age_decoder)
>> DecoderT.field("address", address_decoder)
)
@dataclass
class PersonInputData:
name: str
age: str
address: str
invalid_data = PersonInputData("", "100", "神奈川")
valid_data = PersonInputData("kohki", "30", "東京")
print(decoder(invalid_data))
print(decoder(valid_data))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment