This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CNN(nn.Module): | |
def __init__(self, im_size, hidden_dim,hidden_dim2,hidden_dim3, kernel_size, n_classes): | |
''' | |
Create components of a CNN classifier and initialize their weights. | |
Arguments: | |
im_size (tuple): A tuple of ints with (channels, height, width) | |
hidden_dim (int): Number of hidden activations to use | |
kernel_size (int): Width and height of (square) convolution filters | |
n_classes (int): Number of classes to score | |
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class CNN(nn.Module): | |
def __init__(self, im_size, hidden_dim,hidden_dim2,hidden_dim3, kernel_size, n_classes): | |
super(CNN, self).__init__() | |
self.conv1 = nn.Conv2d(3,16,3,padding=1) | |
self.bn1 = nn.BatchNorm2d(16) | |
self.conv2 = nn.Conv2d(16,32,3,padding=1) | |
self.bn2 = nn.BatchNorm2d(32) | |
self.conv3 = nn.Conv2d(32,64,3,padding=1) | |
self.bn3 = nn.BatchNorm2d(64) | |
self.conv4 = nn.Conv2d(64,128,3,padding=1) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
with warnings.catch_warnings(): | |
warnings.filterwarnings("ignore", category=FutureWarning) | |
import tensorflow as tf | |
# WandB – Login to your wandb account so you can log all your metrics | |
wandb.login() | |
wandb.init(project="hierarchical_cma", sync_tensorboard=True) | |
wb_config = wandb.config |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_epoch(self, diter, length, batch_size, epoch, writer, train_steps): | |
loss, action_loss, aux_loss = 0, 0, 0 | |
step_id = 0 | |
# high_level_losses=[] | |
# low_level_action_losses =[] | |
# low_level_stop_losses =[] | |
# low_level_total_losses=[] | |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys | |
import argparse | |
import cv2 | |
import numpy as np | |
import rosbag | |
from sensor_msgs.msg import Image | |
from cv_bridge import CvBridge | |
import numpy as np |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import os | |
from .kitti360_utils import * | |
from .ray_utils import * | |
from PIL import Image | |
from torchvision import transforms as T | |
import random | |
def read_files(file_path): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import numpy as np | |
import torch | |
import collections | |
from collections import namedtuple | |
from abc import ABCMeta | |
from matplotlib import cm | |
import xml.etree.ElementTree as ET | |
import os | |
from collections import defaultdict |