Created
October 9, 2023 17:33
-
-
Save htahir1/f2cf08f4f2c7bfc3987bcb3ffb73e31c to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
To be used with the ZenML SagemakerOrchestrator | |
https://github.com/zenml-io/zenml/blob/main/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py | |
as a small duck typed juke for the sagemaker network config object: | |
https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/network.py#L24 | |
""" | |
from __future__ import absolute_import | |
from typing import Union, Optional, List | |
from pydantic import BaseModel | |
class NetworkConfig(BaseModel): | |
"""Accepts network configuration parameters for conversion to request dict.""" | |
# We use pydantic fields instead of the constructor | |
enable_network_isolation: Optional[bool] = None | |
security_group_ids: Optional[List[str]] = None | |
subnets: Optional[List[str]] = None | |
encrypt_inter_container_traffic: Optional[bool] = None | |
# This stays the same :-) | |
def _to_request_dict(self): | |
"""Generates a request dictionary using the parameters provided to the class.""" | |
# Enable Network Isolation should default to False if it is not provided. | |
enable_network_isolation = ( | |
False if self.enable_network_isolation is None else self.enable_network_isolation | |
) | |
network_config_request = {"EnableNetworkIsolation": enable_network_isolation} | |
if self.encrypt_inter_container_traffic is not None: | |
network_config_request[ | |
"EnableInterContainerTrafficEncryption" | |
] = self.encrypt_inter_container_traffic | |
if self.security_group_ids is not None or self.subnets is not None: | |
network_config_request["VpcConfig"] = {} | |
if self.security_group_ids is not None: | |
network_config_request["VpcConfig"]["SecurityGroupIds"] = self.security_group_ids | |
if self.subnets is not None: | |
network_config_request["VpcConfig"]["Subnets"] = self.subnets | |
return network_config_request | |
# You can now use this as follows: | |
""" | |
sagemaker_orchestrator_settings = SagemakerOrchestratorSettings( | |
processor_args={ | |
"network_config": NetworkConfig( | |
security_group_ids: ["sg-"], | |
subnets: ["subnet-"] | |
) | |
} | |
) | |
""" |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment