This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo.utils import set_log_path | |
set_log_path('log.txt') #Setup logger | |
data_dir = 'datasets/hmdb6/' #Directory containing dataset (hmdb-6) | |
train_table_path = os.path.join(data_dir, 'train.csv') | |
train_media_dir = os.path.join(data_dir, 'media') | |
target_index = 2 #Index of column containing label information | |
# Read the CSV file |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo.utils import set_log_path | |
import pandas as pd | |
import os | |
set_log_path('log.txt') #Setup logger | |
data_dir = 'datasets/hmdb6/' #Directory containing dataset (hmdb-6) | |
train_table_path = os.path.join(data_dir, 'train.csv') | |
train_media_dir = os.path.join(data_dir, 'media') | |
target_index = 2 #Index of column containing label information |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo import extract_frames | |
# Extract frames from the video | |
video_ext = train_dataset.iloc[0, 1].split('.')[-1] | |
extract_frames(train_media_dir, video_ext) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Build pipeline based on configs | |
# Here we can specify the hyperparameters defined in each primitive | |
# The default hyperparameters will be used if not specified | |
from autovideo import build_pipeline | |
config = { | |
"algorithm": 'tsn', #Specify the Action Recognition algorithm to use. In this example we use TSN | |
"load_pretrained": True, #To fine-tune from pretrained weights | |
"learning_rate": 0.001, | |
"epochs": 20, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Hyperparams(SupervisedHyperparamsBase): | |
num_workers = hyperparams.Hyperparameter[int]( | |
semantic_types=['https://metadata.datadrivendiscovery.org/types/ResourcesUseParameter'], | |
default=2, | |
description='The number of subprocesses to use for data loading. 0 means that the data will be loaded in the ' | |
'main process.' | |
) | |
batch_size = hyperparams.Hyperparameter[int]( | |
default=2, | |
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'], |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo import fit | |
import torch | |
# Fit | |
_, fitted_pipeline = fit(train_dataset=train_dataset, | |
train_media_dir=train_media_dir, | |
target_index=target_index, | |
pipeline=pipeline) | |
# Save the fitted pipeline |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo.utils import set_log_path, logger | |
set_log_path('log.txt') | |
# Load fitted pipeline | |
import torch | |
if torch.cuda.is_available(): | |
fitted_pipeline = torch.load('fitted_pipeline', map_location="cuda:0") | |
else: | |
fitted_pipeline = torch.load('fitted_pipeline', map_location="cpu") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pandas as pd | |
data_dir = 'datasets/hmdb6' | |
train_table_path = os.path.join(data_dir, 'train.csv') | |
valid_table_path = os.path.join(data_dir, 'test.csv') | |
train_media_dir = os.path.join(data_dir, 'media') | |
valid_media_dir = train_media_dir | |
train_dataset = pd.read_csv(train_table_path) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo.searcher import RaySearcher | |
import ray | |
from ray import tune | |
#Initialise the searcher | |
searcher = RaySearcher( | |
train_dataset=train_dataset, | |
train_media_dir=train_media_dir, | |
valid_dataset=valid_dataset, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
best_config = searcher.search( | |
search_space=search_space, | |
config=config | |
) | |
print("Best config: ", best_config) |
OlderNewer