Created
May 15, 2020 17:10
-
-
Save pdxjohnny/14ffd450dea40753fa352d4bbae53a5b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/dffml/feature/feature.py b/dffml/feature/feature.py | |
index 9e9c1341..e6d92cca 100644 | |
--- a/dffml/feature/feature.py | |
+++ b/dffml/feature/feature.py | |
@@ -14,6 +14,7 @@ from typing import List, Dict, Type, Any | |
from .log import LOGGER | |
from ..util.entrypoint import Entrypoint | |
+from ..util.data import parser_helper | |
class Feature(abc.ABC): | |
@@ -66,8 +67,15 @@ class Feature(abc.ABC): | |
# FREQUENCY: Type[Frequency] = Quarterly | |
ENTRYPOINT = "dffml.feature" | |
- def __init__(self, name: str, dtype: Type = int, length: int = 1)-> Any: | |
+ def __init__(self, name: str, dtype: Type = int, length: int = 1) -> Any: | |
super().__init__() | |
+ if name.count(":") == 2: | |
+ tempvar = name.split(":") | |
+ name = tempvar[0] | |
+ dtype = tempvar[1] | |
+ length = parser_helper(tempvar[2]) | |
+ if isinstance(dtype, str): | |
+ dtype = self.convert_dtype(tempvar[1]) | |
self._dtype = dtype | |
self._length = length | |
self.name = name | |
@@ -85,7 +93,7 @@ class Feature(abc.ABC): | |
return "%s(%s)" % (self.name, self.__class__.__qualname__) | |
def __repr__(self): | |
- return "%s[%r, %d]" % (self.__str__(), self.dtype(), self.length()) | |
+ return "%s[%r, %r]" % (self.__str__(), self.dtype(), self.length()) | |
def export(self): | |
return { | |
@@ -95,7 +103,7 @@ class Feature(abc.ABC): | |
} | |
@classmethod | |
- def _fromdict(cls,**kwargs): | |
+ def _fromdict(cls, **kwargs): | |
return Feature(**kwargs) | |
def dtype(self) -> Type: | |
@@ -131,18 +139,6 @@ class Feature(abc.ABC): | |
self.LOGGER.warning("%s length unimplemented", self) | |
return self._length | |
- @classmethod | |
- def load(cls, loading=None): | |
- # CLI or dict compatibility | |
- # TODO Consolidate this | |
- if loading is not None: | |
- if isinstance(loading, dict): | |
- return Feature(loading["name"], loading["dtype"], loading["length"]) | |
- elif loading.count(":") == 2: | |
- tempvar = loading.split(":") | |
- return Feature(tempvar[0], cls.convert_dtype(tempvar[1]), int(tempvar[2])) | |
- return super().load(loading) | |
- | |
# @classmethod | |
# def load_def(cls, name: str, dtype: str, length: str): | |
# return DefFeature(name, cls.convert_dtype(dtype), int(length)) | |
@@ -185,7 +181,6 @@ class Feature(abc.ABC): | |
# return self._length | |
- | |
# def DefFeature(name, dtype, length): | |
# return type("Feature" + name, (Feature,), {})(name=name, | |
diff --git a/tests/test_base.py b/tests/test_base.py | |
index bd1791cc..de9e80f5 100644 | |
--- a/tests/test_base.py | |
+++ b/tests/test_base.py | |
@@ -62,7 +62,7 @@ class TestAutoArgsConfig(unittest.TestCase): | |
}, | |
"features": { | |
"plugin": Arg( | |
- type=Feature.load, | |
+ type=Feature, | |
nargs="+", | |
action=list_action(Features), | |
), | |
@@ -131,9 +131,7 @@ class TestAutoArgsConfig(unittest.TestCase): | |
) | |
self.assertEqual( | |
config.features, | |
- Features( | |
- Feature("Year", int, 1), Feature("Commits", int, 10) | |
- ), | |
+ Features(Feature("Year", int, 1), Feature("Commits", int, 10)), | |
) | |
def test_config_set(self): | |
@@ -170,9 +168,7 @@ class TestAutoArgsConfig(unittest.TestCase): | |
) | |
self.assertEqual( | |
config.features, | |
- Features( | |
- Feature("Year", int, 1), Feature("Commits", int, 10) | |
- ), | |
+ Features(Feature("Year", int, 1), Feature("Commits", int, 10)), | |
) | |
diff --git a/tests/test_cli.py b/tests/test_cli.py | |
index 7c9d121c..12b34b07 100644 | |
--- a/tests/test_cli.py | |
+++ b/tests/test_cli.py | |
@@ -59,9 +59,6 @@ class RecordsTestCase(AsyncExitStackTestCase): | |
new=ModelCMD.arg_model.modify(type=model_load), | |
) | |
) | |
- self._stack.enter_context( | |
- patch("dffml.feature.feature.Feature.load", new=feature_load) | |
- ) | |
self._stack.enter_context( | |
patch("dffml.df.base.OperationImplementation.load", new=opimp_load) | |
) | |
@@ -82,9 +79,9 @@ class FakeConfig: | |
class FakeFeature(Feature): | |
# NAME: str = "fake" | |
- | |
- def __init__(self,name="fake",dt=float,length=1): | |
- super().__init__(name,dt,length) | |
+ | |
+ def __init__(self, name="fake", dt=float, length=1): | |
+ super().__init__(name, dt, length) | |
# def dtype(self): | |
# return float # pragma: no cov | |
@@ -116,12 +113,6 @@ class FakeModel(Model): | |
CONFIG = FakeConfig | |
-def feature_load(loading=None): | |
- if loading == "fake": | |
- return FakeFeature() | |
- return [FakeFeature()] | |
- | |
- | |
def model_load(loading): | |
if loading == "fake": | |
return FakeModel | |
@@ -409,7 +400,7 @@ class TestPredict(RecordsTestCase): | |
"-model", | |
"fake", | |
"-model-features", | |
- "fake", | |
+ "fake:float:[10,10]", | |
"-model-predict", | |
"fake", | |
) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment