Skip to content

Instantly share code, notes, and snippets.

@psinger
Last active November 5, 2019 14:09
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save psinger/f6cb3b99dd9837606717339382c2dbb5 to your computer and use it in GitHub Desktop.
Save psinger/f6cb3b99dd9837606717339382c2dbb5 to your computer and use it in GitHub Desktop.
tta segmentation #ml #pytorch #python
#https://github.com/qubvel/ttach
import ttach as tta
transforms = tta.Compose(
[
tta.HorizontalFlip(),
tta.VerticalFlip(),
# tta.Rotate90(angles=[0, 180]),
# tta.Scale(scales=[1, 2, 4]),
# tta.Multiply(factors=[0.9, 1, 1.1]),
]
)
preds = []
for transformer in transforms:
#print(1)
augmented_images = transformer.augment_image(images)
model_output = model(augmented_images)
deaug_mask = transformer.deaugment_mask(model_output)
preds.append(deaug_mask)
preds = torch.mean(torch.stack(preds), dim=0)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment