Last active
March 12, 2020 07:22
-
-
Save Eric-Arellano/7377ac5f5a183172bf4cf0d2462476c6 to your computer and use it in GitHub Desktop.
Target API V2
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). | |
# Licensed under the Apache License, Version 2.0 (see LICENSE). | |
from abc import ABC, ABCMeta, abstractmethod | |
from dataclasses import dataclass | |
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, cast | |
from pants.engine.objects import union | |
from pants.engine.rules import UnionMembership | |
from pants.util.collections import ensure_str_list | |
from pants.util.memo import memoized_property | |
from pants.util.meta import frozen_after_init | |
@dataclass(frozen=True) | |
class Field(ABC): | |
alias: ClassVar[str] | |
raw_value: Optional[Any] # None indicates that the field was not explicitly defined | |
def __repr__(self) -> str: | |
return f"{self.__class__}(alias={repr(self.alias)}, raw_value={self.raw_value})" | |
class PrimitiveField(Field, metaclass=ABCMeta): | |
"""A Field that does not need the engine in order to be hydrated. | |
This should be subclassed by the majority of fields. | |
""" | |
def __str__(self) -> str: | |
return f"{self.alias}={self.value}" | |
@memoized_property | |
@abstractmethod | |
def value(self) -> Any: | |
"""Get the field's value. | |
The value will possibly be first hydrated and/or validated, such as using a default value | |
if the field was not defined or ensuring that an int value is positive. | |
This property is memoized because hydration and validation can often be costly. This | |
hydration is lazy, i.e. it will only happen when a downstream rule explicitly requests this | |
field. | |
""" | |
class AsyncField(Field, metaclass=ABCMeta): | |
"""A field that needs the engine in order to be hydrated.""" | |
# TODO: what should this be? | |
@memoized_property | |
@abstractmethod | |
def value_request(self) -> Any: | |
pass | |
class PluginField: | |
"""Allows plugin authors to add new fields to pre-existing target types via UnionRules. | |
When defining a Target, authors should create a corresponding PluginField class marked with | |
`@union`. Then, plugin authors simply need to create whatever new `Field` they want and in a | |
`register.py`'s `rules()` function call. For example, to add a | |
`TypeChecked` field to `python_library`, register `UnionRule(PythonLibraryField, TypeChecked)`. | |
@union | |
class PythonLibraryField(PluginField): | |
pass | |
class PythonLibrary(Target): | |
core_fields = (Compatibility, PythonSources, ...) | |
plugin_field_type = PythonLibraryField | |
class TypeChecked(PrimitiveField): | |
... | |
def rules(): | |
return [ | |
UnionRule(PythonLibraryField, TypeChecked), | |
] | |
""" | |
_F = TypeVar("_F", bound=Field) | |
@frozen_after_init | |
@dataclass(unsafe_hash=True) | |
class Target(ABC): | |
# Subclasses must define these | |
alias: ClassVar[str] | |
core_fields: ClassVar[Tuple[Type[Field], ...]] | |
plugin_field_type: ClassVar[Type[PluginField]] | |
# These get calculated in the constructor | |
plugin_fields: Tuple[Type[Field], ...] | |
field_values: Dict[Type[Field], Any] | |
def __init__( | |
self, | |
unhydrated_values: Dict[str, Any], | |
*, | |
union_membership: Optional[UnionMembership] = None, | |
) -> None: | |
self.plugin_fields = cast( | |
Tuple[Type[Field], ...], | |
( | |
() | |
if union_membership is None | |
else tuple(union_membership.union_rules.get(self.plugin_field_type, ())) | |
), | |
) | |
self.field_values = {} | |
aliases_to_fields = {field.alias: field for field in self.field_types} | |
for alias, value in unhydrated_values.items(): | |
if alias not in aliases_to_fields: | |
raise ValueError( | |
f"Unrecognized field `{alias}={value}` for target type `{self.alias}`." | |
) | |
field = aliases_to_fields[alias] | |
self.field_values[field] = field(value) | |
# For undefined fields, mark the raw value as None. | |
for field in set(self.field_types) - set(self.field_values.keys()): | |
self.field_values[field] = field(raw_value=None) | |
@property | |
def field_types(self) -> Tuple[Type[Field], ...]: | |
return (*self.core_fields, *self.plugin_fields) | |
def __repr__(self) -> str: | |
return ( | |
f"{self.__class__}(" | |
f"alias={repr(self.alias)}, " | |
f"plugin_field_type={self.plugin_field_type}, " | |
f"core_fields={list(self.core_fields)}, " | |
f"plugin_fields={list(self.plugin_fields)}, " | |
f"raw_field_values={list(self.field_values.values())}" | |
f")" | |
) | |
def __str__(self) -> str: | |
fields = ", ".join(str(field) for field in self.field_values.values()) | |
return f"{self.alias}({fields})" | |
def get(self, field: Type[_F]) -> _F: | |
return cast(_F, self.field_values[field]) | |
def has_fields(self, fields: Iterable[Type[Field]]) -> bool: | |
# TODO: consider if this should support subclasses. For example, if a target has a | |
# field PythonSources(Sources), then .has_fields(Sources) should still return True. Why? | |
# This allows overriding how fields behave for custom target types, e.g. a `python3_library` | |
# subclassing the Field Compatibility with its own custom implementation. When adding | |
# this, be sure to update `.get()` to allow looking up by subclass, too. (Is it possible | |
# to do that in a performant way?) | |
return all(field in self.field_types for field in fields) | |
class Sources(AsyncField): | |
alias: ClassVar = "sources" | |
raw_value: Optional[Iterable[str]] | |
@memoized_property | |
def value_request(self) -> Any: | |
return self.raw_value | |
class BinarySources(Sources): | |
@memoized_property | |
def value_request(self): | |
if self.raw_value is not None and len(list(self.raw_value)) not in [0, 1]: | |
raise ValueError("Binary targets must have only 0 or 1 source files.") | |
return super().value_request | |
class Compatibility(PrimitiveField): | |
alias: ClassVar = "compatibility" | |
raw_value: Optional[Union[str, Iterable[str]]] | |
@memoized_property | |
def value(self) -> Optional[List[str]]: | |
if self.raw_value is None: | |
return None | |
return ensure_str_list(self.raw_value) | |
class Coverage(PrimitiveField): | |
alias: ClassVar = "coverage" | |
raw_value: Optional[Union[str, Iterable[str]]] | |
@memoized_property | |
def value(self) -> Optional[List[str]]: | |
if self.raw_value is None: | |
return None | |
return ensure_str_list(self.raw_value) | |
class Timeout(PrimitiveField): | |
alias: ClassVar = "timeout" | |
raw_value: Optional[int] | |
@memoized_property | |
def value(self) -> Optional[int]: | |
if self.raw_value is None: | |
return None | |
if not isinstance(self.raw_value, int): | |
raise ValueError( | |
f"The `timeout` field must be an `int`. Was {type(self.raw_value)} " | |
f"({self.raw_value})." | |
) | |
if self.raw_value <= 0: | |
raise ValueError(f"The `timeout` field must be > 1. Was {self.raw_value}.") | |
return self.raw_value | |
class EntryPoint(PrimitiveField): | |
alias: ClassVar = "entry_point" | |
raw_value: Optional[str] | |
@memoized_property | |
def value(self) -> Optional[str]: | |
return self.raw_value | |
class ZipSafe(PrimitiveField): | |
alias: ClassVar = "zip_safe" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return True | |
return self.raw_value | |
class AlwaysWriteCache(PrimitiveField): | |
alias: ClassVar = "always_write_cache" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return False | |
return self.raw_value | |
@union | |
class PythonBinaryField(PluginField): | |
pass | |
@union | |
class PythonLibraryField(PluginField): | |
pass | |
@union | |
class PythonTestsField(PluginField): | |
pass | |
PYTHON_TARGET_FIELDS = (Compatibility,) | |
class PythonBinary(Target): | |
alias: ClassVar = "python_binary" | |
core_fields: ClassVar = (*PYTHON_TARGET_FIELDS, EntryPoint, ZipSafe, AlwaysWriteCache) | |
plugin_field_type: ClassVar = PythonBinaryField | |
class PythonLibrary(Target): | |
alias: ClassVar = "python_library" | |
core_fields: ClassVar = PYTHON_TARGET_FIELDS | |
plugin_field_type: ClassVar = PythonLibraryField | |
class PythonTests(Target): | |
alias: ClassVar = "python_tests" | |
core_fields: ClassVar = (*PYTHON_TARGET_FIELDS, Coverage, Timeout) | |
plugin_field_type: ClassVar = PythonTestsField |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). | |
# Licensed under the Apache License, Version 2.0 (see LICENSE). | |
from collections import OrderedDict | |
from typing import ClassVar, List, Optional | |
import pytest | |
from pants.engine.rules import UnionMembership | |
from pants.engine.target import PluginField, PrimitiveField, Target | |
from pants.util.memo import memoized_property | |
from pants.util.ordered_set import OrderedSet | |
class HaskellGhcExtensions(PrimitiveField): | |
alias: ClassVar = "ghc_extensions" | |
raw_value: Optional[List[str]] | |
@memoized_property | |
def value(self) -> List[str]: | |
# Add some arbitrary validation to test that hydration works properly. | |
if self.raw_value is None: | |
return [] | |
for extension in self.raw_value: | |
if not extension.startswith("Ghc"): | |
raise ValueError( | |
f"All elements of `ghc_extensions` must be prefixed by `Ghc`. Received " | |
f"{extension}" | |
) | |
return self.raw_value | |
class HaskellField(PluginField): | |
pass | |
class HaskellTarget(Target): | |
alias: ClassVar = "haskell" | |
core_fields: ClassVar = (HaskellGhcExtensions,) | |
plugin_field_type: ClassVar = HaskellField | |
def test_invalid_fields_rejected() -> None: | |
with pytest.raises(ValueError) as exc: | |
HaskellTarget({"invalid_field": True}) | |
assert "Unrecognized field `invalid_field=True` for target type `haskell`." in str(exc) | |
def test_get_field() -> None: | |
extensions = ["GhcExistentialQuantification"] | |
extensions_field = HaskellTarget({HaskellGhcExtensions.alias: extensions}).get( | |
HaskellGhcExtensions | |
) | |
assert extensions_field.raw_value == extensions | |
assert extensions_field.value == extensions | |
default_extensions_field = HaskellTarget({}).get(HaskellGhcExtensions) | |
assert default_extensions_field.raw_value is None | |
assert default_extensions_field.value == [] | |
def test_has_fields() -> None: | |
class UnrelatedField(PrimitiveField): | |
alias: ClassVar = "unrelated" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return False | |
return self.raw_value | |
tgt = HaskellTarget({}) | |
assert tgt.has_fields([]) is True | |
assert tgt.has_fields([HaskellGhcExtensions]) is True | |
assert tgt.has_fields([UnrelatedField]) is False | |
assert tgt.has_fields([HaskellGhcExtensions, UnrelatedField]) is False | |
def test_field_hydration_is_lazy() -> None: | |
bad_extension = "DoesNotStartWithGhc" | |
# No error upon creating the Target because validation does not happen until a call site | |
# hydrates the specific field. | |
tgt = HaskellTarget( | |
{HaskellGhcExtensions.alias: ["GhcExistentialQuantification", bad_extension]} | |
) | |
# When hydrating, we expect a failure. | |
with pytest.raises(ValueError) as exc: | |
tgt.get(HaskellGhcExtensions).value | |
assert "must be prefixed by `Ghc`" in str(exc) | |
def test_add_custom_fields() -> None: | |
class CustomField(PrimitiveField): | |
alias: ClassVar = "custom_field" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return False | |
return self.raw_value | |
union_membership = UnionMembership(OrderedDict({HaskellField: OrderedSet([CustomField])})) | |
tgt = HaskellTarget({CustomField.alias: True}, union_membership=union_membership) | |
assert tgt.field_types == (HaskellGhcExtensions, CustomField) | |
assert tgt.core_fields == (HaskellGhcExtensions,) | |
assert tgt.plugin_fields == (CustomField,) | |
assert tgt.get(CustomField).value is True | |
default_tgt = HaskellTarget({}, union_membership=union_membership) | |
assert default_tgt.get(CustomField).value is False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment