Skip to content

Instantly share code, notes, and snippets.

View tchaton's full-sized avatar
👻
Always up for it !

thomas chaton tchaton

👻
Always up for it !
View GitHub Profile
from typing import List, Optional, NamedTuple
# use to make model jittable
OptTensor = Optional[Tensor]
ListTensor = List[Tensor]
class TensorBatch(NamedTuple):
x: Tensor
edge_index: ListTensor
edge_attr: OptTensor
def get_single_batch(datamodule):
for batch in datamodule.test_dataloader():
return datamodule.gather_data(batch, 0)
def run(args):
datamodule: LightningDataModule = instantiate_datamodule(args)
model: LightningModule = instantiate_model(args, datamodule)
print(model)
model.jittable()
print(model)
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: test_params_groups_and_state_are_accessible",
"type": "python",
"request": "launch",
ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABgQDcodbGh2c/27Qgbls/yQvrfydHy/RnGgldPk8CyNbpsYTFX5MyMZ9gs988iaZ9DQSZVyWZni+8dXd4/aiCyN6OvlrLCQ+vxmi4IXfopqVWkxPRAsfzNFil6dgreepG0r4aZUkGu28JpKecNynPlkCFH0o0BWp5kEloy63l+4+cfvMDk81vbt4khtRDv5o/Eo2X1dJ7MLEvSEjQ8eB0i5Gxjzha3Zgb7SoWQxS/Qmbkp1poaeGgQCbAVauyCbqtjpR3OH+Ea5GXiCPTIOGlDR2NsvjmNmsJOhJtOPIwe957YxBnR4PUb7yFOJ0KMjWzD1I8nEZitYvo111lUMTECX1AzzzV4/TkAv/vakN9SeJgG0rHshCxrySnkgOY8KmRgxVv6nQwwOmTnWOaJ1iw2Qi9AcqESva1FXW+7Pt7YgUcIzhNyFkJfG7rZEFxZfp3pOsF8z8HWnbY0BijUPspUtrEYkoNaFL3EQFrFSENymJ26dWnxjmHNZAEHuLkmCsx7Ok= thomas@thomass-MacBook-Pro.local
model = VideoClassifier(backbone="x3d_xs", num_classes=datamodule.num_classes, serializer=Labels())
@tchaton
tchaton / ptv_2.py
Last active May 11, 2021 10:23
VideoClassifier
model = VideoClassifier(
backbone="x3d_xs",
num_classes=datamodule.num_classes,
serializer=Labels()
)
@tchaton
tchaton / ptv_3.py
Created May 11, 2021 10:31
Flash prediction
predictions = model.predict(os.path.join(flash.PROJECT_ROOT, "data/kinetics/predict"))
print(predictions)
# ['archery', 'bowling', 'flying_kite', 'high_jump', 'marching']
@tchaton
tchaton / imports.py
Created May 11, 2021 11:11
imports
import os
from typing import Callable, List
import kornia.augmentation as K
import torch
from pytorchvideo.transforms import ApplyTransformToKey, RandomShortSideScale, UniformTemporalSubsample
from torch.utils.data.sampler import RandomSampler
from torchvision.transforms import CenterCrop, Compose, RandomCrop, RandomHorizontalFlip
import flash
# 1. Download a video clip dataset.
download_data("https://pl-flash-data.s3.amazonaws.com/kinetics.zip")
@tchaton
tchaton / ptv_transform.py
Created May 11, 2021 11:14
ptv_transform
# 2. [Optional] Specify transforms to be used during training.
post_tensor_transform = [UniformTemporalSubsample(8), RandomShortSideScale(min_size=256, max_size=320)]
per_batch_transform_on_device = [K.Normalize(torch.tensor([0.45, 0.45, 0.45]), torch.tensor([0.225, 0.225, 0.225]))]
train_post_tensor_transform = post_tensor_transform + [RandomCrop(244), RandomHorizontalFlip(p=0.5)]
val_post_tensor_transform = post_tensor_transform + [CenterCrop(244)]
train_per_batch_transform_on_device = per_batch_transform_on_device