Created
March 12, 2020 17:14
-
-
Save Eric-Arellano/9a1cb63f9db0c43e61c82cbf809342ad to your computer and use it in GitHub Desktop.
Target API - AsyncField
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). | |
# Licensed under the Apache License, Version 2.0 (see LICENSE). | |
from abc import ABC, ABCMeta, abstractmethod | |
from dataclasses import dataclass | |
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, cast | |
from pants.engine.objects import union | |
from pants.engine.rules import UnionMembership | |
from pants.util.collections import ensure_str_list | |
from pants.util.memo import memoized_property | |
from pants.util.meta import frozen_after_init | |
@dataclass(frozen=True) | |
class Field(ABC): | |
alias: ClassVar[str] | |
raw_value: Optional[Any] # None indicates that the field was not explicitly defined | |
def __repr__(self) -> str: | |
return f"{self.__class__}(alias={repr(self.alias)}, raw_value={self.raw_value})" | |
class PrimitiveField(Field, metaclass=ABCMeta): | |
"""A Field that does not need the engine in order to be hydrated. | |
This should be subclassed by the majority of fields. | |
""" | |
def __str__(self) -> str: | |
return f"{self.alias}={self.value}" | |
@memoized_property | |
@abstractmethod | |
def value(self) -> Any: | |
"""Get the field's value. | |
The value will possibly be first hydrated and/or validated, such as using a default value | |
if the field was not defined or ensuring that an int value is positive. | |
This property is memoized because hydration and validation can often be costly. This | |
hydration is lazy, i.e. it will only happen when a downstream rule explicitly requests this | |
field. | |
""" | |
class AsyncField(Field, metaclass=ABCMeta): | |
"""A field that needs the engine in order to be hydrated. | |
You should create a corresponding Result class and define a rule to go from this AsyncField to | |
the Result. For example: | |
class Sources(AsyncField): | |
alias: ClassVar = "sources" | |
raw_value: Optional[List[str]] | |
@dataclass(frozen=True) | |
class SourcesResult: | |
snapshot: Snapshot | |
@rule | |
def hydrate_sources(sources: Sources) -> SourcesResult: | |
sources.validate_pre_hydration() | |
result = await Get[Snapshot](PathGlobs(sources.raw_value)) | |
sources.validate_post_hydration() | |
return SourcesResult(result) | |
def rules(): | |
return [hydrate_sources] | |
Then, call sites can `await Get` if they need to hydrate the field: | |
sources = await Get[SourcesResult](Sources, my_tgt.get(Sources)) | |
""" | |
def validate_pre_hydration(self) -> None: | |
"""Any validation that can be done on the original `raw_value`. | |
It is cheaper to do any possible validation here, rather than in `validate_post_hydration`, | |
because we can short-circuit if an invariant is violated before the expense of hydration. | |
""" | |
def validate_post_hydration(self, result: Any) -> None: | |
"""Any validation that must be done after hydration by inspecting the result of that | |
hydration. | |
For example, a `PythonSources` field may validate that all hydrated source files end in | |
`.py`. | |
""" | |
class PluginField: | |
"""Allows plugin authors to add new fields to pre-existing target types via UnionRules. | |
When defining a Target, authors should create a corresponding PluginField class marked with | |
`@union`. Then, plugin authors simply need to create whatever new `Field` they want and in a | |
`register.py`'s `rules()` function, call `UnionRule`. For example, to add a | |
`TypeChecked` field to `python_library`, register `UnionRule(PythonLibraryField, TypeChecked)`. | |
@union | |
class PythonLibraryField(PluginField): | |
pass | |
class PythonLibrary(Target): | |
core_fields = (Compatibility, PythonSources, ...) | |
plugin_field_type = PythonLibraryField | |
class TypeChecked(PrimitiveField): | |
... | |
def rules(): | |
return [UnionRule(PythonLibraryField, TypeChecked)] | |
""" | |
_F = TypeVar("_F", bound=Field) | |
@frozen_after_init | |
@dataclass(unsafe_hash=True) | |
class Target(ABC): | |
# Subclasses must define these | |
alias: ClassVar[str] | |
core_fields: ClassVar[Tuple[Type[Field], ...]] | |
plugin_field_type: ClassVar[Type[PluginField]] | |
# These get calculated in the constructor | |
plugin_fields: Tuple[Type[Field], ...] | |
field_values: Dict[Type[Field], Any] | |
def __init__( | |
self, | |
unhydrated_values: Dict[str, Any], | |
*, | |
union_membership: Optional[UnionMembership] = None, | |
) -> None: | |
self.plugin_fields = cast( | |
Tuple[Type[Field], ...], | |
( | |
() | |
if union_membership is None | |
else tuple(union_membership.union_rules.get(self.plugin_field_type, ())) | |
), | |
) | |
self.field_values = {} | |
aliases_to_fields = {field.alias: field for field in self.field_types} | |
for alias, value in unhydrated_values.items(): | |
if alias not in aliases_to_fields: | |
raise ValueError( | |
f"Unrecognized field `{alias}={value}` for target type `{self.alias}`." | |
) | |
field = aliases_to_fields[alias] | |
self.field_values[field] = field(value) | |
# For undefined fields, mark the raw value as None. | |
for field in set(self.field_types) - set(self.field_values.keys()): | |
self.field_values[field] = field(raw_value=None) | |
@property | |
def field_types(self) -> Tuple[Type[Field], ...]: | |
return (*self.core_fields, *self.plugin_fields) | |
def __repr__(self) -> str: | |
return ( | |
f"{self.__class__}(" | |
f"alias={repr(self.alias)}, " | |
f"plugin_field_type={self.plugin_field_type}, " | |
f"core_fields={list(self.core_fields)}, " | |
f"plugin_fields={list(self.plugin_fields)}, " | |
f"raw_field_values={list(self.field_values.values())}" | |
f")" | |
) | |
def __str__(self) -> str: | |
fields = ", ".join(str(field) for field in self.field_values.values()) | |
return f"{self.alias}({fields})" | |
def get(self, field: Type[_F]) -> _F: | |
return cast(_F, self.field_values[field]) | |
def has_fields(self, fields: Iterable[Type[Field]]) -> bool: | |
# TODO: consider if this should support subclasses. For example, if a target has a | |
# field PythonSources(Sources), then .has_fields(Sources) should still return True. Why? | |
# This allows overriding how fields behave for custom target types, e.g. a `python3_library` | |
# subclassing the Field Compatibility with its own custom implementation. When adding | |
# this, be sure to update `.get()` to allow looking up by subclass, too. (Is it possible | |
# to do that in a performant way?) | |
return all(field in self.field_types for field in fields) | |
class Sources(AsyncField): | |
alias: ClassVar = "sources" | |
raw_value: Optional[Iterable[str]] | |
class BinarySources(Sources): | |
@memoized_property | |
def value_request(self): | |
if self.raw_value is not None and len(list(self.raw_value)) not in [0, 1]: | |
raise ValueError("Binary targets must have only 0 or 1 source files.") | |
return super().value_request | |
class Compatibility(PrimitiveField): | |
alias: ClassVar = "compatibility" | |
raw_value: Optional[Union[str, Iterable[str]]] | |
@memoized_property | |
def value(self) -> Optional[List[str]]: | |
if self.raw_value is None: | |
return None | |
return ensure_str_list(self.raw_value) | |
class Coverage(PrimitiveField): | |
alias: ClassVar = "coverage" | |
raw_value: Optional[Union[str, Iterable[str]]] | |
@memoized_property | |
def value(self) -> Optional[List[str]]: | |
if self.raw_value is None: | |
return None | |
return ensure_str_list(self.raw_value) | |
class Timeout(PrimitiveField): | |
alias: ClassVar = "timeout" | |
raw_value: Optional[int] | |
@memoized_property | |
def value(self) -> Optional[int]: | |
if self.raw_value is None: | |
return None | |
if not isinstance(self.raw_value, int): | |
raise ValueError( | |
f"The `timeout` field must be an `int`. Was {type(self.raw_value)} " | |
f"({self.raw_value})." | |
) | |
if self.raw_value <= 0: | |
raise ValueError(f"The `timeout` field must be > 1. Was {self.raw_value}.") | |
return self.raw_value | |
class EntryPoint(PrimitiveField): | |
alias: ClassVar = "entry_point" | |
raw_value: Optional[str] | |
@memoized_property | |
def value(self) -> Optional[str]: | |
return self.raw_value | |
class ZipSafe(PrimitiveField): | |
alias: ClassVar = "zip_safe" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return True | |
return self.raw_value | |
class AlwaysWriteCache(PrimitiveField): | |
alias: ClassVar = "always_write_cache" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return False | |
return self.raw_value | |
@union | |
class PythonBinaryField(PluginField): | |
pass | |
@union | |
class PythonLibraryField(PluginField): | |
pass | |
@union | |
class PythonTestsField(PluginField): | |
pass | |
PYTHON_TARGET_FIELDS = (Compatibility,) | |
class PythonBinary(Target): | |
alias: ClassVar = "python_binary" | |
core_fields: ClassVar = (*PYTHON_TARGET_FIELDS, EntryPoint, ZipSafe, AlwaysWriteCache) | |
plugin_field_type: ClassVar = PythonBinaryField | |
class PythonLibrary(Target): | |
alias: ClassVar = "python_library" | |
core_fields: ClassVar = PYTHON_TARGET_FIELDS | |
plugin_field_type: ClassVar = PythonLibraryField | |
class PythonTests(Target): | |
alias: ClassVar = "python_tests" | |
core_fields: ClassVar = (*PYTHON_TARGET_FIELDS, Coverage, Timeout) | |
plugin_field_type: ClassVar = PythonTestsField |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Copyright 2020 Pants project contributors (see CONTRIBUTORS.md). | |
# Licensed under the Apache License, Version 2.0 (see LICENSE). | |
from collections import OrderedDict | |
from dataclasses import dataclass | |
from pathlib import PurePath | |
from typing import ClassVar, List, Optional, Tuple | |
import pytest | |
from pants.engine.fs import EMPTY_DIRECTORY_DIGEST, PathGlobs, Snapshot | |
from pants.engine.rules import UnionMembership, rule | |
from pants.engine.selectors import Get | |
from pants.engine.target import AsyncField, PluginField, PrimitiveField, Target | |
from pants.testutil.engine.util import MockGet, run_rule | |
from pants.util.collections import ensure_str_list | |
from pants.util.memo import memoized_property | |
from pants.util.ordered_set import OrderedSet | |
class HaskellGhcExtensions(PrimitiveField): | |
alias: ClassVar = "ghc_extensions" | |
raw_value: Optional[List[str]] | |
@memoized_property | |
def value(self) -> List[str]: | |
if self.raw_value is None: | |
return [] | |
# Add some arbitrary validation to test that hydration works properly. | |
bad_extensions = [ | |
extension for extension in self.raw_value if not extension.startswith("Ghc") | |
] | |
if bad_extensions: | |
raise ValueError( | |
f"All elements of `{self.alias}` must be prefixed by `Ghc`. Received " | |
f"{bad_extensions}." | |
) | |
return self.raw_value | |
class HaskellSources(AsyncField): | |
alias: ClassVar = "sources" | |
raw_value: Optional[List[str]] | |
def validate_pre_hydration(self) -> None: | |
ensure_str_list(self.raw_value) | |
def validate_post_hydration(self, result: Snapshot) -> None: | |
non_haskell_sources = [fp for fp in result.files if PurePath(fp).suffix != ".hs"] | |
if non_haskell_sources: | |
raise ValueError( | |
f"Received non-Haskell sources in {self.alias}: {non_haskell_sources}." | |
) | |
@dataclass(frozen=True) | |
class HaskellSourcesResult: | |
snapshot: Snapshot | |
@rule | |
async def hydrate_haskell_sources(sources: HaskellSources) -> HaskellSourcesResult: | |
sources.validate_pre_hydration() | |
result = await Get[Snapshot](PathGlobs("*.hs")) | |
sources.validate_post_hydration(result) | |
return HaskellSourcesResult(result) | |
class HaskellField(PluginField): | |
pass | |
class HaskellTarget(Target): | |
alias: ClassVar = "haskell" | |
core_fields: ClassVar = (HaskellGhcExtensions, HaskellSources) | |
plugin_field_type: ClassVar = HaskellField | |
def test_invalid_fields_rejected() -> None: | |
with pytest.raises(ValueError) as exc: | |
HaskellTarget({"invalid_field": True}) | |
assert "Unrecognized field `invalid_field=True` for target type `haskell`." in str(exc) | |
def test_get_primitive_field() -> None: | |
extensions = ["GhcExistentialQuantification"] | |
extensions_field = HaskellTarget({HaskellGhcExtensions.alias: extensions}).get( | |
HaskellGhcExtensions | |
) | |
assert extensions_field.raw_value == extensions | |
assert extensions_field.value == extensions | |
default_extensions_field = HaskellTarget({}).get(HaskellGhcExtensions) | |
assert default_extensions_field.raw_value is None | |
assert default_extensions_field.value == [] | |
def test_get_async_field() -> None: | |
def hydrate_field( | |
*, raw_source_files: List[str], hydrated_source_files: Tuple[str, ...] | |
) -> HaskellSourcesResult: | |
sources_field = HaskellTarget({HaskellSources.alias: raw_source_files}).get(HaskellSources) | |
assert sources_field.raw_value == raw_source_files | |
result: HaskellSourcesResult = run_rule( | |
hydrate_haskell_sources, | |
rule_args=[sources_field], | |
mock_gets=[ | |
MockGet( | |
product_type=Snapshot, | |
subject_type=PathGlobs, | |
mock=lambda _: Snapshot( | |
directory_digest=EMPTY_DIRECTORY_DIGEST, | |
files=hydrated_source_files, | |
dirs=(), | |
), | |
) | |
], | |
) | |
return result | |
# Normal field | |
expected_files = ("monad.hs", "abstract_art.hs", "abstract_algebra.hs") | |
assert ( | |
hydrate_field( | |
raw_source_files=["monad.hs", "abstract_*.hs"], hydrated_source_files=expected_files | |
).snapshot.files | |
== expected_files | |
) | |
# Test pre-hydration validation | |
with pytest.raises(ValueError) as exc: | |
hydrate_field(raw_source_files=[0, 1, 2], hydrated_source_files=()) # type: ignore[call-arg] | |
assert "Not all elements of the iterable have type" in str(exc) | |
# Test post-hydration validation | |
with pytest.raises(ValueError) as exc: | |
hydrate_field(raw_source_files=["*.js"], hydrated_source_files=("not_haskell.js",)) | |
assert "Received non-Haskell sources" in str(exc) | |
def test_has_fields() -> None: | |
class UnrelatedField(PrimitiveField): | |
alias: ClassVar = "unrelated" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return False | |
return self.raw_value | |
tgt = HaskellTarget({}) | |
assert tgt.has_fields([]) is True | |
assert tgt.has_fields([HaskellGhcExtensions]) is True | |
assert tgt.has_fields([UnrelatedField]) is False | |
assert tgt.has_fields([HaskellGhcExtensions, UnrelatedField]) is False | |
def test_field_hydration_is_lazy() -> None: | |
bad_extension = "DoesNotStartWithGhc" | |
# No error upon creating the Target because validation does not happen until a call site | |
# hydrates the specific field. | |
tgt = HaskellTarget( | |
{HaskellGhcExtensions.alias: ["GhcExistentialQuantification", bad_extension]} | |
) | |
# When hydrating, we expect a failure. | |
with pytest.raises(ValueError) as exc: | |
tgt.get(HaskellGhcExtensions).value | |
assert "must be prefixed by `Ghc`" in str(exc) | |
def test_add_custom_fields() -> None: | |
class CustomField(PrimitiveField): | |
alias: ClassVar = "custom_field" | |
raw_value: Optional[bool] | |
@memoized_property | |
def value(self) -> bool: | |
if self.raw_value is None: | |
return False | |
return self.raw_value | |
union_membership = UnionMembership(OrderedDict({HaskellField: OrderedSet([CustomField])})) | |
tgt = HaskellTarget({CustomField.alias: True}, union_membership=union_membership) | |
assert tgt.field_types == (HaskellGhcExtensions, HaskellSources, CustomField) | |
assert tgt.core_fields == (HaskellGhcExtensions, HaskellSources) | |
assert tgt.plugin_fields == (CustomField,) | |
assert tgt.get(CustomField).value is True | |
default_tgt = HaskellTarget({}, union_membership=union_membership) | |
assert default_tgt.get(CustomField).value is False |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment