Skip to content

Instantly share code, notes, and snippets.

@tchaton
Created May 11, 2021 11:14
Show Gist options
  • Save tchaton/b627d56e6d3da89eb456aede48c3c171 to your computer and use it in GitHub Desktop.
Save tchaton/b627d56e6d3da89eb456aede48c3c171 to your computer and use it in GitHub Desktop.
ptv_transform
# 2. [Optional] Specify transforms to be used during training.
post_tensor_transform = [UniformTemporalSubsample(8), RandomShortSideScale(min_size=256, max_size=320)]
per_batch_transform_on_device = [K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225]))]
train_post_tensor_transform = post_tensor_transform + [RandomCrop(244), RandomHorizontalFlip(p=0.5)]
val_post_tensor_transform = post_tensor_transform + [CenterCrop(244)]
train_per_batch_transform_on_device = per_batch_transform_on_device
def make_transform(
post_tensor_transform: List[Callable] = post_tensor_transform,
per_batch_transform_on_device: List[Callable] = per_batch_transform_on_device
):
return {
"post_tensor_transform": Compose([
ApplyTransformToKey(
key="video",
transform=Compose(post_tensor_transform),
),
]),
"per_batch_transform_on_device": Compose([
ApplyTransformToKey(
key="video",
transform=K.VideoSequential(
*per_batch_transform_on_device, data_format="BCTHW", same_on_frame=False)
),
]),
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment