Skip to content

Instantly share code, notes, and snippets.

@shuklak13
Created October 7, 2021 16:09
Show Gist options
  • Save shuklak13/3239454ad0cc8fcaf9dc47e4f19801e8 to your computer and use it in GitHub Desktop.
Save shuklak13/3239454ad0cc8fcaf9dc47e4f19801e8 to your computer and use it in GitHub Desktop.
model_card_component.py
class ModelCardArtifact(Artifact):
TYPE_NAME = 'ModelCard'
class ModelCardComponentSpec(component_spec.ComponentSpec):
PARAMETERS = {
'model_card_proto':
component_spec.ExecutionParameter(type=Message),
'model_card_schema': component_spec.ExecutionParameter(type=GeneratedProtocolMessageType),
'model_card_template':
component_spec.ExecutionParameter(type=GeneratedProtocolMessageType)
}
INPUTS = {
'example_statistics_artifact':
component_spec.ChannelParameter(
type=standard_artifacts.ExampleStatistics),
'model_artifact':
component_spec.ChannelParameter(type=standard_artifacts.Model),
'model_evaluation_artifact':
component_spec.ChannelParameter(
type=standard_artifacts.ModelEvaluation),
}
OUTPUTS = {
'model_card_artifact':
component_spec.ChannelParameter(type=ModelCardArtifact),
}
class ModelCardExecutor(BaseBeamExecutor):
def Do(self, input_dict: Dict[str, List[Artifact]],
output_dict: Dict[str, List[Artifact]],
exec_properties: Dict[str, Any]) -> None:
# Initialize ModelCardToolkit with input and output artifacts
mct = TfxModelCardToolkit(
example_statistics_artifact=input_dict['dataset'],
model_artifact=input_dict['model_artifact'],
model_evaluation_artifact=input_dict['model_evaluation_artifact'],
output_dir=output_dict['model_card_artifact'],
proto_schema=exec_properties['model_card_schema'])
# Create model card from input artifacts and proto
model_card = mct.scaffold_assets()
model_card.copy_from_proto(exec_properties['model_card_proto'])
# Write model card as output artifact
mct.export_format(template_path=exec_properties['model_card_template'])
class TfxModelCardToolkit(ModelCardToolkit):
def __init__(self,
example_statistics_artifact: Optional[
standard_artifacts.ExampleStatistics] = None,
model_artifact: Optional[standard_artifacts.Model] = None,
model_evaluation_artifact: Optional[
standard_artifacts.ModelEvaluation] = None,
output_dir: Optional[str] = None,
proto_schema: Optional[GeneratedProtocolMessageType] = None):
self._example_statistics_artifact = example_statistics_artifact
self._model_artifact = model_artifact
self._model_evaluation_artifact = model_evaluation_artifact
super().__init__(output_dir=output_dir, proto_schema=proto_schema)
def _scaffold_model_card(self) -> ModelCard:
# TODO(karanshukla): populate ModelCard using input artifacts
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment