Created
December 15, 2020 20:57
-
-
Save caleb-kaiser/f6fc18dc83a269aa982009ba88d5bc9d to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
from metaflow import FlowSpec, current, step | |
import cortex | |
def select_model(a, b): | |
if a > b: | |
return a | |
return b | |
class TrainingFlow(FlowSpec): | |
""" | |
The flow performs the following steps: | |
1) Loads the new data from NewDataFlow and initializes a new model. | |
2) In parallel branches: | |
- A) Fits the model with approach A. | |
- B) Fits the model with approach B. | |
3) Evaluates both models and exports the best. | |
""" | |
@step | |
def start(self): | |
""" | |
Use the Metaflow client to retrieve the latest successful run from our | |
NewDataFlow, then initialize new model. | |
""" | |
self.next(self.train_approach_a, self.train_approach_b) | |
@step | |
def train_approach_a(self): | |
""" | |
Train model A | |
""" | |
self.model = 42 # Store resulting trained model | |
self.next(self.join) | |
@step | |
def train_approach_b(self): | |
""" | |
Train model B | |
""" | |
self.model = 43 # Store resulting trained model | |
self.next(self.join) | |
@step | |
def join(self, inputs): | |
""" | |
Pick the best of both models | |
""" | |
self.model = select_model(inputs.train_approach_a.model, | |
inputs.train_approach_b.model) # Evaluate best model | |
self.model_artifact = '/'.join([current.pathspec, 'model']) | |
self.next(self.deploy) | |
@step | |
def deploy(self): | |
""" | |
Deploy model to Cortex cluster | |
""" | |
cortex_client = cortex.client("aws") | |
api_config = { | |
"name": "api-classifier", | |
"kind": "RealtimeAPI", | |
"predictor": { | |
"type": "python", | |
"path": "predictor.py", | |
"config": { | |
"model_artifact": self.model_artifact | |
}, | |
"env": { | |
"USERNAME": "cortex", | |
"METAFLOW_DATASTORE_SYSROOT_S3": "XXX", | |
"METAFLOW_DEFAULT_DATASTORE": "s3", | |
"METAFLOW_DEFAULT_METADATA": "service", | |
"METAFLOW_SERVICE_URL": "XXX", | |
} | |
} | |
} | |
self.deployments = cortex_client.deploy(api_config, project_dir=".") | |
self.next(self.end) | |
@step | |
def end(self): | |
print("All done!") | |
if __name__ == '__main__': | |
TrainingFlow() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment