This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo import produce, extract_frames, compute_accuracy_with_preds | |
import torch | |
import os | |
import pandas as pd | |
from autovideo.utils import set_log_path, logger | |
set_log_path('log.txt') | |
test_table_path = os.path.join('datasets/hmdb6', 'test.csv') | |
test_media_dir = os.path.join('datasets/hmdb6', 'media') | |
target_index = 2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
best_config = searcher.search( | |
search_space=search_space, | |
config=config | |
) | |
print("Best config: ", best_config) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo.searcher import RaySearcher | |
import ray | |
from ray import tune | |
#Initialise the searcher | |
searcher = RaySearcher( | |
train_dataset=train_dataset, | |
train_media_dir=train_media_dir, | |
valid_dataset=valid_dataset, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import os | |
import pandas as pd | |
data_dir = 'datasets/hmdb6' | |
train_table_path = os.path.join(data_dir, 'train.csv') | |
valid_table_path = os.path.join(data_dir, 'test.csv') | |
train_media_dir = os.path.join(data_dir, 'media') | |
valid_media_dir = train_media_dir | |
train_dataset = pd.read_csv(train_table_path) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo.utils import set_log_path, logger | |
set_log_path('log.txt') | |
# Load fitted pipeline | |
import torch | |
if torch.cuda.is_available(): | |
fitted_pipeline = torch.load('fitted_pipeline', map_location="cuda:0") | |
else: | |
fitted_pipeline = torch.load('fitted_pipeline', map_location="cpu") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo import fit | |
import torch | |
# Fit | |
_, fitted_pipeline = fit(train_dataset=train_dataset, | |
train_media_dir=train_media_dir, | |
target_index=target_index, | |
pipeline=pipeline) | |
# Save the fitted pipeline |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Hyperparams(SupervisedHyperparamsBase): | |
num_workers = hyperparams.Hyperparameter[int]( | |
semantic_types=['https://metadata.datadrivendiscovery.org/types/ResourcesUseParameter'], | |
default=2, | |
description='The number of subprocesses to use for data loading. 0 means that the data will be loaded in the ' | |
'main process.' | |
) | |
batch_size = hyperparams.Hyperparameter[int]( | |
default=2, | |
semantic_types=['https://metadata.datadrivendiscovery.org/types/TuningParameter'], |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Build pipeline based on configs | |
# Here we can specify the hyperparameters defined in each primitive | |
# The default hyperparameters will be used if not specified | |
from autovideo import build_pipeline | |
config = { | |
"algorithm": 'tsn', #Specify the Action Recognition algorithm to use. In this example we use TSN | |
"load_pretrained": True, #To fine-tune from pretrained weights | |
"learning_rate": 0.001, | |
"epochs": 20, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo import extract_frames | |
# Extract frames from the video | |
video_ext = train_dataset.iloc[0, 1].split('.')[-1] | |
extract_frames(train_media_dir, video_ext) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from autovideo.utils import set_log_path | |
import pandas as pd | |
import os | |
set_log_path('log.txt') #Setup logger | |
data_dir = 'datasets/hmdb6/' #Directory containing dataset (hmdb-6) | |
train_table_path = os.path.join(data_dir, 'train.csv') | |
train_media_dir = os.path.join(data_dir, 'media') | |
target_index = 2 #Index of column containing label information |
NewerOlder