Skip to content

Instantly share code, notes, and snippets.

@samhita-alla
Created October 28, 2022 10:15
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save samhita-alla/013601cd8b3a86a4277cdda54fa51a00 to your computer and use it in GitHub Desktop.
Save samhita-alla/013601cd8b3a86a4277cdda54fa51a00 to your computer and use it in GitHub Desktop.
def to_literal(
self,
ctx: FlyteContext,
python_val: T,
python_type: Type[T],
expected: LiteralType,
) -> Literal:
meta = BlobMetadata(
type=_core_types.BlobType(
format=self.TENSORFLOW_FORMAT,
dimensionality=_core_types.BlobType.BlobDimensionality.SINGLE,
)
)
local_path = ctx.file_access.get_random_local_path()
pathlib.Path(local_path).parent.mkdir(parents=True, exist_ok=True)
# Save the `tf.tensor` as a file on disk
filename = "tensor_data"
tf.io.write_file(filename, tf.io.serialize_tensor(python_val))
tensor_dtype = python_val.dtype.name
remote_path = ctx.file_access.get_random_remote_path(local_path)
ctx.file_access.put_data(local_path, remote_path, is_multipart=False)
return Literal(
collection=LiteralCollection(
literals=[
Literal(scalar=Scalar(blob=Blob(metadata=meta, uri=remote_path))),
Literal(scalar=Scalar(primitive=Primitive(string=tensor_dtype))),
]
)
)
def to_python_value(
self, ctx: FlyteContext, lv: Literal, python_val: T, expected_python_type: Type[T]
) -> T:
try:
uri = lv.collection.literals.scalar.blob.uri
except AttributeError:
TypeTransformerFailedError(
f"Cannot convert from {lv} to {expected_python_type}"
)
local_path = ctx.file_access.get_random_local_path()
ctx.file_access.get_data(uri, local_path, is_multipart=False)
tensor_dtype = tf.dtypes.as_dtype(lv.collection.scalar.primitive.string)
serialized_tensor = tf.io.read_file(local_path)
actual_tensor = tf.io.parse_tensor(read_serial, out_type=tensor_dtype)
return actual_tensor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment