Skip to content

Instantly share code, notes, and snippets.

Created March 12, 2020 17:14
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Eric-Arellano/9a1cb63f9db0c43e61c82cbf809342ad to your computer and use it in GitHub Desktop.
Save Eric-Arellano/9a1cb63f9db0c43e61c82cbf809342ad to your computer and use it in GitHub Desktop.
Target API - AsyncField
# Copyright 2020 Pants project contributors (see
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from abc import ABC, ABCMeta, abstractmethod
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, Iterable, List, Optional, Tuple, Type, TypeVar, Union, cast
from pants.engine.objects import union
from pants.engine.rules import UnionMembership
from pants.util.collections import ensure_str_list
from pants.util.memo import memoized_property
from pants.util.meta import frozen_after_init
class Field(ABC):
alias: ClassVar[str]
raw_value: Optional[Any] # None indicates that the field was not explicitly defined
def __repr__(self) -> str:
return f"{self.__class__}(alias={repr(self.alias)}, raw_value={self.raw_value})"
class PrimitiveField(Field, metaclass=ABCMeta):
"""A Field that does not need the engine in order to be hydrated.
This should be subclassed by the majority of fields.
def __str__(self) -> str:
return f"{self.alias}={self.value}"
def value(self) -> Any:
"""Get the field's value.
The value will possibly be first hydrated and/or validated, such as using a default value
if the field was not defined or ensuring that an int value is positive.
This property is memoized because hydration and validation can often be costly. This
hydration is lazy, i.e. it will only happen when a downstream rule explicitly requests this
class AsyncField(Field, metaclass=ABCMeta):
"""A field that needs the engine in order to be hydrated.
You should create a corresponding Result class and define a rule to go from this AsyncField to
the Result. For example:
class Sources(AsyncField):
alias: ClassVar = "sources"
raw_value: Optional[List[str]]
class SourcesResult:
snapshot: Snapshot
def hydrate_sources(sources: Sources) -> SourcesResult:
result = await Get[Snapshot](PathGlobs(sources.raw_value))
return SourcesResult(result)
def rules():
return [hydrate_sources]
Then, call sites can `await Get` if they need to hydrate the field:
sources = await Get[SourcesResult](Sources, my_tgt.get(Sources))
def validate_pre_hydration(self) -> None:
"""Any validation that can be done on the original `raw_value`.
It is cheaper to do any possible validation here, rather than in `validate_post_hydration`,
because we can short-circuit if an invariant is violated before the expense of hydration.
def validate_post_hydration(self, result: Any) -> None:
"""Any validation that must be done after hydration by inspecting the result of that
For example, a `PythonSources` field may validate that all hydrated source files end in
class PluginField:
"""Allows plugin authors to add new fields to pre-existing target types via UnionRules.
When defining a Target, authors should create a corresponding PluginField class marked with
`@union`. Then, plugin authors simply need to create whatever new `Field` they want and in a
``'s `rules()` function, call `UnionRule`. For example, to add a
`TypeChecked` field to `python_library`, register `UnionRule(PythonLibraryField, TypeChecked)`.
class PythonLibraryField(PluginField):
class PythonLibrary(Target):
core_fields = (Compatibility, PythonSources, ...)
plugin_field_type = PythonLibraryField
class TypeChecked(PrimitiveField):
def rules():
return [UnionRule(PythonLibraryField, TypeChecked)]
_F = TypeVar("_F", bound=Field)
class Target(ABC):
# Subclasses must define these
alias: ClassVar[str]
core_fields: ClassVar[Tuple[Type[Field], ...]]
plugin_field_type: ClassVar[Type[PluginField]]
# These get calculated in the constructor
plugin_fields: Tuple[Type[Field], ...]
field_values: Dict[Type[Field], Any]
def __init__(
unhydrated_values: Dict[str, Any],
union_membership: Optional[UnionMembership] = None,
) -> None:
self.plugin_fields = cast(
Tuple[Type[Field], ...],
if union_membership is None
else tuple(union_membership.union_rules.get(self.plugin_field_type, ()))
self.field_values = {}
aliases_to_fields = {field.alias: field for field in self.field_types}
for alias, value in unhydrated_values.items():
if alias not in aliases_to_fields:
raise ValueError(
f"Unrecognized field `{alias}={value}` for target type `{self.alias}`."
field = aliases_to_fields[alias]
self.field_values[field] = field(value)
# For undefined fields, mark the raw value as None.
for field in set(self.field_types) - set(self.field_values.keys()):
self.field_values[field] = field(raw_value=None)
def field_types(self) -> Tuple[Type[Field], ...]:
return (*self.core_fields, *self.plugin_fields)
def __repr__(self) -> str:
return (
f"alias={repr(self.alias)}, "
f"plugin_field_type={self.plugin_field_type}, "
f"core_fields={list(self.core_fields)}, "
f"plugin_fields={list(self.plugin_fields)}, "
def __str__(self) -> str:
fields = ", ".join(str(field) for field in self.field_values.values())
return f"{self.alias}({fields})"
def get(self, field: Type[_F]) -> _F:
return cast(_F, self.field_values[field])
def has_fields(self, fields: Iterable[Type[Field]]) -> bool:
# TODO: consider if this should support subclasses. For example, if a target has a
# field PythonSources(Sources), then .has_fields(Sources) should still return True. Why?
# This allows overriding how fields behave for custom target types, e.g. a `python3_library`
# subclassing the Field Compatibility with its own custom implementation. When adding
# this, be sure to update `.get()` to allow looking up by subclass, too. (Is it possible
# to do that in a performant way?)
return all(field in self.field_types for field in fields)
class Sources(AsyncField):
alias: ClassVar = "sources"
raw_value: Optional[Iterable[str]]
class BinarySources(Sources):
def value_request(self):
if self.raw_value is not None and len(list(self.raw_value)) not in [0, 1]:
raise ValueError("Binary targets must have only 0 or 1 source files.")
return super().value_request
class Compatibility(PrimitiveField):
alias: ClassVar = "compatibility"
raw_value: Optional[Union[str, Iterable[str]]]
def value(self) -> Optional[List[str]]:
if self.raw_value is None:
return None
return ensure_str_list(self.raw_value)
class Coverage(PrimitiveField):
alias: ClassVar = "coverage"
raw_value: Optional[Union[str, Iterable[str]]]
def value(self) -> Optional[List[str]]:
if self.raw_value is None:
return None
return ensure_str_list(self.raw_value)
class Timeout(PrimitiveField):
alias: ClassVar = "timeout"
raw_value: Optional[int]
def value(self) -> Optional[int]:
if self.raw_value is None:
return None
if not isinstance(self.raw_value, int):
raise ValueError(
f"The `timeout` field must be an `int`. Was {type(self.raw_value)} "
if self.raw_value <= 0:
raise ValueError(f"The `timeout` field must be > 1. Was {self.raw_value}.")
return self.raw_value
class EntryPoint(PrimitiveField):
alias: ClassVar = "entry_point"
raw_value: Optional[str]
def value(self) -> Optional[str]:
return self.raw_value
class ZipSafe(PrimitiveField):
alias: ClassVar = "zip_safe"
raw_value: Optional[bool]
def value(self) -> bool:
if self.raw_value is None:
return True
return self.raw_value
class AlwaysWriteCache(PrimitiveField):
alias: ClassVar = "always_write_cache"
raw_value: Optional[bool]
def value(self) -> bool:
if self.raw_value is None:
return False
return self.raw_value
class PythonBinaryField(PluginField):
class PythonLibraryField(PluginField):
class PythonTestsField(PluginField):
PYTHON_TARGET_FIELDS = (Compatibility,)
class PythonBinary(Target):
alias: ClassVar = "python_binary"
core_fields: ClassVar = (*PYTHON_TARGET_FIELDS, EntryPoint, ZipSafe, AlwaysWriteCache)
plugin_field_type: ClassVar = PythonBinaryField
class PythonLibrary(Target):
alias: ClassVar = "python_library"
core_fields: ClassVar = PYTHON_TARGET_FIELDS
plugin_field_type: ClassVar = PythonLibraryField
class PythonTests(Target):
alias: ClassVar = "python_tests"
core_fields: ClassVar = (*PYTHON_TARGET_FIELDS, Coverage, Timeout)
plugin_field_type: ClassVar = PythonTestsField
# Copyright 2020 Pants project contributors (see
# Licensed under the Apache License, Version 2.0 (see LICENSE).
from collections import OrderedDict
from dataclasses import dataclass
from pathlib import PurePath
from typing import ClassVar, List, Optional, Tuple
import pytest
from pants.engine.fs import EMPTY_DIRECTORY_DIGEST, PathGlobs, Snapshot
from pants.engine.rules import UnionMembership, rule
from pants.engine.selectors import Get
from import AsyncField, PluginField, PrimitiveField, Target
from pants.testutil.engine.util import MockGet, run_rule
from pants.util.collections import ensure_str_list
from pants.util.memo import memoized_property
from pants.util.ordered_set import OrderedSet
class HaskellGhcExtensions(PrimitiveField):
alias: ClassVar = "ghc_extensions"
raw_value: Optional[List[str]]
def value(self) -> List[str]:
if self.raw_value is None:
return []
# Add some arbitrary validation to test that hydration works properly.
bad_extensions = [
extension for extension in self.raw_value if not extension.startswith("Ghc")
if bad_extensions:
raise ValueError(
f"All elements of `{self.alias}` must be prefixed by `Ghc`. Received "
return self.raw_value
class HaskellSources(AsyncField):
alias: ClassVar = "sources"
raw_value: Optional[List[str]]
def validate_pre_hydration(self) -> None:
def validate_post_hydration(self, result: Snapshot) -> None:
non_haskell_sources = [fp for fp in result.files if PurePath(fp).suffix != ".hs"]
if non_haskell_sources:
raise ValueError(
f"Received non-Haskell sources in {self.alias}: {non_haskell_sources}."
class HaskellSourcesResult:
snapshot: Snapshot
async def hydrate_haskell_sources(sources: HaskellSources) -> HaskellSourcesResult:
result = await Get[Snapshot](PathGlobs("*.hs"))
return HaskellSourcesResult(result)
class HaskellField(PluginField):
class HaskellTarget(Target):
alias: ClassVar = "haskell"
core_fields: ClassVar = (HaskellGhcExtensions, HaskellSources)
plugin_field_type: ClassVar = HaskellField
def test_invalid_fields_rejected() -> None:
with pytest.raises(ValueError) as exc:
HaskellTarget({"invalid_field": True})
assert "Unrecognized field `invalid_field=True` for target type `haskell`." in str(exc)
def test_get_primitive_field() -> None:
extensions = ["GhcExistentialQuantification"]
extensions_field = HaskellTarget({HaskellGhcExtensions.alias: extensions}).get(
assert extensions_field.raw_value == extensions
assert extensions_field.value == extensions
default_extensions_field = HaskellTarget({}).get(HaskellGhcExtensions)
assert default_extensions_field.raw_value is None
assert default_extensions_field.value == []
def test_get_async_field() -> None:
def hydrate_field(
*, raw_source_files: List[str], hydrated_source_files: Tuple[str, ...]
) -> HaskellSourcesResult:
sources_field = HaskellTarget({HaskellSources.alias: raw_source_files}).get(HaskellSources)
assert sources_field.raw_value == raw_source_files
result: HaskellSourcesResult = run_rule(
mock=lambda _: Snapshot(
return result
# Normal field
expected_files = ("monad.hs", "abstract_art.hs", "abstract_algebra.hs")
assert (
raw_source_files=["monad.hs", "abstract_*.hs"], hydrated_source_files=expected_files
== expected_files
# Test pre-hydration validation
with pytest.raises(ValueError) as exc:
hydrate_field(raw_source_files=[0, 1, 2], hydrated_source_files=()) # type: ignore[call-arg]
assert "Not all elements of the iterable have type" in str(exc)
# Test post-hydration validation
with pytest.raises(ValueError) as exc:
hydrate_field(raw_source_files=["*.js"], hydrated_source_files=("not_haskell.js",))
assert "Received non-Haskell sources" in str(exc)
def test_has_fields() -> None:
class UnrelatedField(PrimitiveField):
alias: ClassVar = "unrelated"
raw_value: Optional[bool]
def value(self) -> bool:
if self.raw_value is None:
return False
return self.raw_value
tgt = HaskellTarget({})
assert tgt.has_fields([]) is True
assert tgt.has_fields([HaskellGhcExtensions]) is True
assert tgt.has_fields([UnrelatedField]) is False
assert tgt.has_fields([HaskellGhcExtensions, UnrelatedField]) is False
def test_field_hydration_is_lazy() -> None:
bad_extension = "DoesNotStartWithGhc"
# No error upon creating the Target because validation does not happen until a call site
# hydrates the specific field.
tgt = HaskellTarget(
{HaskellGhcExtensions.alias: ["GhcExistentialQuantification", bad_extension]}
# When hydrating, we expect a failure.
with pytest.raises(ValueError) as exc:
assert "must be prefixed by `Ghc`" in str(exc)
def test_add_custom_fields() -> None:
class CustomField(PrimitiveField):
alias: ClassVar = "custom_field"
raw_value: Optional[bool]
def value(self) -> bool:
if self.raw_value is None:
return False
return self.raw_value
union_membership = UnionMembership(OrderedDict({HaskellField: OrderedSet([CustomField])}))
tgt = HaskellTarget({CustomField.alias: True}, union_membership=union_membership)
assert tgt.field_types == (HaskellGhcExtensions, HaskellSources, CustomField)
assert tgt.core_fields == (HaskellGhcExtensions, HaskellSources)
assert tgt.plugin_fields == (CustomField,)
assert tgt.get(CustomField).value is True
default_tgt = HaskellTarget({}, union_membership=union_membership)
assert default_tgt.get(CustomField).value is False
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment