""" This script provides functions to train and test a CNN model with/without the Snorkel framework. """ import torch from data import get_data_loader_snorkel, get_data_loader from inference import test_model from model import CNNModel from train import train device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') def Train(): """ Train a CNN model with/without the Snorkel framework. """ model = CNNModel() dataloaders = { 'original': get_data_loader('data_split_manual', splits=['train', 'val']), 'snorkel_hard': get_data_loader_snorkel('data_snorkel', splits=['train', 'val'], label_type='hard'), 'snorkel_soft': get_data_loader_snorkel('data_snorkel', splits=['train', 'val'], label_type='soft') } criterion = torch.nn.CrossEntropyLoss() optimizer = torch.optim.SGD(model.parameters(), lr=0.0001, momentum=0.9, weight_decay=1e-4) lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, factor=0.9, patience=5) for run_name, dataloader in dataloaders.items(): train( model=model, dataloaders=dataloader, optimizer=optimizer, criterion=criterion, scheduler=lr_scheduler, device=device, run_name=run_name, num_epochs=20 ) def Test(): """ Test a trained CNN model using the test dataset. """ test_model(weight_folder='weights', data_root='data_split_manual', device=device) if __name__ == '__main__': Train() Test()