Created
October 7, 2021 16:09
-
-
Save shuklak13/3239454ad0cc8fcaf9dc47e4f19801e8 to your computer and use it in GitHub Desktop.
model_card_component.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class ModelCardArtifact(Artifact): | |
TYPE_NAME = 'ModelCard' | |
class ModelCardComponentSpec(component_spec.ComponentSpec): | |
PARAMETERS = { | |
'model_card_proto': | |
component_spec.ExecutionParameter(type=Message), | |
'model_card_schema': component_spec.ExecutionParameter(type=GeneratedProtocolMessageType), | |
'model_card_template': | |
component_spec.ExecutionParameter(type=GeneratedProtocolMessageType) | |
} | |
INPUTS = { | |
'example_statistics_artifact': | |
component_spec.ChannelParameter( | |
type=standard_artifacts.ExampleStatistics), | |
'model_artifact': | |
component_spec.ChannelParameter(type=standard_artifacts.Model), | |
'model_evaluation_artifact': | |
component_spec.ChannelParameter( | |
type=standard_artifacts.ModelEvaluation), | |
} | |
OUTPUTS = { | |
'model_card_artifact': | |
component_spec.ChannelParameter(type=ModelCardArtifact), | |
} | |
class ModelCardExecutor(BaseBeamExecutor): | |
def Do(self, input_dict: Dict[str, List[Artifact]], | |
output_dict: Dict[str, List[Artifact]], | |
exec_properties: Dict[str, Any]) -> None: | |
# Initialize ModelCardToolkit with input and output artifacts | |
mct = TfxModelCardToolkit( | |
example_statistics_artifact=input_dict['dataset'], | |
model_artifact=input_dict['model_artifact'], | |
model_evaluation_artifact=input_dict['model_evaluation_artifact'], | |
output_dir=output_dict['model_card_artifact'], | |
proto_schema=exec_properties['model_card_schema']) | |
# Create model card from input artifacts and proto | |
model_card = mct.scaffold_assets() | |
model_card.copy_from_proto(exec_properties['model_card_proto']) | |
# Write model card as output artifact | |
mct.export_format(template_path=exec_properties['model_card_template']) | |
class TfxModelCardToolkit(ModelCardToolkit): | |
def __init__(self, | |
example_statistics_artifact: Optional[ | |
standard_artifacts.ExampleStatistics] = None, | |
model_artifact: Optional[standard_artifacts.Model] = None, | |
model_evaluation_artifact: Optional[ | |
standard_artifacts.ModelEvaluation] = None, | |
output_dir: Optional[str] = None, | |
proto_schema: Optional[GeneratedProtocolMessageType] = None): | |
self._example_statistics_artifact = example_statistics_artifact | |
self._model_artifact = model_artifact | |
self._model_evaluation_artifact = model_evaluation_artifact | |
super().__init__(output_dir=output_dir, proto_schema=proto_schema) | |
def _scaffold_model_card(self) -> ModelCard: | |
# TODO(karanshukla): populate ModelCard using input artifacts |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment