Skip to content

Instantly share code, notes, and snippets.

@htahir1
Created October 9, 2023 17:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save htahir1/f2cf08f4f2c7bfc3987bcb3ffb73e31c to your computer and use it in GitHub Desktop.
Save htahir1/f2cf08f4f2c7bfc3987bcb3ffb73e31c to your computer and use it in GitHub Desktop.
"""
To be used with the ZenML SagemakerOrchestrator
https://github.com/zenml-io/zenml/blob/main/src/zenml/integrations/aws/orchestrators/sagemaker_orchestrator.py
as a small duck typed juke for the sagemaker network config object:
https://github.com/aws/sagemaker-python-sdk/blob/master/src/sagemaker/network.py#L24
"""
from __future__ import absolute_import
from typing import Union, Optional, List
from pydantic import BaseModel
class NetworkConfig(BaseModel):
"""Accepts network configuration parameters for conversion to request dict."""
# We use pydantic fields instead of the constructor
enable_network_isolation: Optional[bool] = None
security_group_ids: Optional[List[str]] = None
subnets: Optional[List[str]] = None
encrypt_inter_container_traffic: Optional[bool] = None
# This stays the same :-)
def _to_request_dict(self):
"""Generates a request dictionary using the parameters provided to the class."""
# Enable Network Isolation should default to False if it is not provided.
enable_network_isolation = (
False if self.enable_network_isolation is None else self.enable_network_isolation
)
network_config_request = {"EnableNetworkIsolation": enable_network_isolation}
if self.encrypt_inter_container_traffic is not None:
network_config_request[
"EnableInterContainerTrafficEncryption"
] = self.encrypt_inter_container_traffic
if self.security_group_ids is not None or self.subnets is not None:
network_config_request["VpcConfig"] = {}
if self.security_group_ids is not None:
network_config_request["VpcConfig"]["SecurityGroupIds"] = self.security_group_ids
if self.subnets is not None:
network_config_request["VpcConfig"]["Subnets"] = self.subnets
return network_config_request
# You can now use this as follows:
"""
sagemaker_orchestrator_settings = SagemakerOrchestratorSettings(
processor_args={
"network_config": NetworkConfig(
security_group_ids: ["sg-"],
subnets: ["subnet-"]
)
}
)
"""
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment