Skip to content

Instantly share code, notes, and snippets.

@aribornstein
Created May 20, 2021 16:05
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save aribornstein/8d6ec46957c367437c744a6c678f184d to your computer and use it in GitHub Desktop.
Save aribornstein/8d6ec46957c367437c744a6c678f184d to your computer and use it in GitHub Desktop.
class SemanticSegmentationPostprocess(Postprocess):
def per_sample_transform(self, sample: Any) -> Any:
image_original_shape = sample[DefaultDataKeys.METADATA][-2:]
resize = K.geometry.Resize(image_original_shape, interpolation='nearest')
sample[DefaultDataKeys.PREDS] = resize(torch.stack(sample[DefaultDataKeys.PREDS]))
return super().per_sample_transform(sample)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment