Skip to content

Instantly share code, notes, and snippets.

@pdxjohnny
Created May 15, 2020 17:10
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save pdxjohnny/14ffd450dea40753fa352d4bbae53a5b to your computer and use it in GitHub Desktop.
Save pdxjohnny/14ffd450dea40753fa352d4bbae53a5b to your computer and use it in GitHub Desktop.
diff --git a/dffml/feature/feature.py b/dffml/feature/feature.py
index 9e9c1341..e6d92cca 100644
--- a/dffml/feature/feature.py
+++ b/dffml/feature/feature.py
@@ -14,6 +14,7 @@ from typing import List, Dict, Type, Any
from .log import LOGGER
from ..util.entrypoint import Entrypoint
+from ..util.data import parser_helper
class Feature(abc.ABC):
@@ -66,8 +67,15 @@ class Feature(abc.ABC):
# FREQUENCY: Type[Frequency] = Quarterly
ENTRYPOINT = "dffml.feature"
- def __init__(self, name: str, dtype: Type = int, length: int = 1)-> Any:
+ def __init__(self, name: str, dtype: Type = int, length: int = 1) -> Any:
super().__init__()
+ if name.count(":") == 2:
+ tempvar = name.split(":")
+ name = tempvar[0]
+ dtype = tempvar[1]
+ length = parser_helper(tempvar[2])
+ if isinstance(dtype, str):
+ dtype = self.convert_dtype(tempvar[1])
self._dtype = dtype
self._length = length
self.name = name
@@ -85,7 +93,7 @@ class Feature(abc.ABC):
return "%s(%s)" % (self.name, self.__class__.__qualname__)
def __repr__(self):
- return "%s[%r, %d]" % (self.__str__(), self.dtype(), self.length())
+ return "%s[%r, %r]" % (self.__str__(), self.dtype(), self.length())
def export(self):
return {
@@ -95,7 +103,7 @@ class Feature(abc.ABC):
}
@classmethod
- def _fromdict(cls,**kwargs):
+ def _fromdict(cls, **kwargs):
return Feature(**kwargs)
def dtype(self) -> Type:
@@ -131,18 +139,6 @@ class Feature(abc.ABC):
self.LOGGER.warning("%s length unimplemented", self)
return self._length
- @classmethod
- def load(cls, loading=None):
- # CLI or dict compatibility
- # TODO Consolidate this
- if loading is not None:
- if isinstance(loading, dict):
- return Feature(loading["name"], loading["dtype"], loading["length"])
- elif loading.count(":") == 2:
- tempvar = loading.split(":")
- return Feature(tempvar[0], cls.convert_dtype(tempvar[1]), int(tempvar[2]))
- return super().load(loading)
-
# @classmethod
# def load_def(cls, name: str, dtype: str, length: str):
# return DefFeature(name, cls.convert_dtype(dtype), int(length))
@@ -185,7 +181,6 @@ class Feature(abc.ABC):
# return self._length
-
# def DefFeature(name, dtype, length):
# return type("Feature" + name, (Feature,), {})(name=name,
diff --git a/tests/test_base.py b/tests/test_base.py
index bd1791cc..de9e80f5 100644
--- a/tests/test_base.py
+++ b/tests/test_base.py
@@ -62,7 +62,7 @@ class TestAutoArgsConfig(unittest.TestCase):
},
"features": {
"plugin": Arg(
- type=Feature.load,
+ type=Feature,
nargs="+",
action=list_action(Features),
),
@@ -131,9 +131,7 @@ class TestAutoArgsConfig(unittest.TestCase):
)
self.assertEqual(
config.features,
- Features(
- Feature("Year", int, 1), Feature("Commits", int, 10)
- ),
+ Features(Feature("Year", int, 1), Feature("Commits", int, 10)),
)
def test_config_set(self):
@@ -170,9 +168,7 @@ class TestAutoArgsConfig(unittest.TestCase):
)
self.assertEqual(
config.features,
- Features(
- Feature("Year", int, 1), Feature("Commits", int, 10)
- ),
+ Features(Feature("Year", int, 1), Feature("Commits", int, 10)),
)
diff --git a/tests/test_cli.py b/tests/test_cli.py
index 7c9d121c..12b34b07 100644
--- a/tests/test_cli.py
+++ b/tests/test_cli.py
@@ -59,9 +59,6 @@ class RecordsTestCase(AsyncExitStackTestCase):
new=ModelCMD.arg_model.modify(type=model_load),
)
)
- self._stack.enter_context(
- patch("dffml.feature.feature.Feature.load", new=feature_load)
- )
self._stack.enter_context(
patch("dffml.df.base.OperationImplementation.load", new=opimp_load)
)
@@ -82,9 +79,9 @@ class FakeConfig:
class FakeFeature(Feature):
# NAME: str = "fake"
-
- def __init__(self,name="fake",dt=float,length=1):
- super().__init__(name,dt,length)
+
+ def __init__(self, name="fake", dt=float, length=1):
+ super().__init__(name, dt, length)
# def dtype(self):
# return float # pragma: no cov
@@ -116,12 +113,6 @@ class FakeModel(Model):
CONFIG = FakeConfig
-def feature_load(loading=None):
- if loading == "fake":
- return FakeFeature()
- return [FakeFeature()]
-
-
def model_load(loading):
if loading == "fake":
return FakeModel
@@ -409,7 +400,7 @@ class TestPredict(RecordsTestCase):
"-model",
"fake",
"-model-features",
- "fake",
+ "fake:float:[10,10]",
"-model-predict",
"fake",
)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment