Skip to content

Instantly share code, notes, and snippets.

@edraizen
Last active April 1, 2021 21:41
Show Gist options
  • Save edraizen/c3a1898e7cf8f98e0218f45483b899b6 to your computer and use it in GitHub Desktop.
Save edraizen/c3a1898e7cf8f98e0218f45483b899b6 to your computer and use it in GitHub Desktop.
# Copyright (c) Chris Choy (chrischoy@ai.stanford.edu).
#
# Permission is hereby granted, free of charge, to any person obtaining a copy of
# this software and associated documentation files (the "Software"), to deal in
# the Software without restriction, including without limitation the rights to
# use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies
# of the Software, and to permit persons to whom the Software is furnished to do
# so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Please cite "4D Spatio-Temporal ConvNets: Minkowski Convolutional Neural
# Networks", CVPR'19 (https://arxiv.org/abs/1904.08755) if you use any part
# of the code.
import os
import sys
import subprocess
import argparse
import logging
import glob
import numpy as np
from time import time
import urllib
# Must be imported before large libs
try:
import open3d as o3d
except ImportError:
raise ImportError('Please install open3d with `pip install open3d`.')
import torch
import torch.nn as nn
import torch.utils.data
import torch.optim as optim
import MinkowskiEngine as ME
from examples.reconstruction import InfSampler, resample_mesh
M = np.array([[0.80656762, -0.5868724, -0.07091862],
[0.3770505, 0.418344, 0.82632997],
[-0.45528188, -0.6932309, 0.55870326]])
assert int(
o3d.__version__.split('.')[1]
) >= 8, f'Requires open3d version >= 0.8, the current version is {o3d.__version__}'
if not os.path.exists('ModelNet40'):
logging.info('Downloading the fixed ModelNet40 dataset...')
subprocess.run(["sh", "./examples/download_modelnet40.sh"])
###############################################################################
# Utility functions
###############################################################################
def PointCloud(points, colors=None):
pcd = o3d.geometry.PointCloud()
pcd.points = o3d.utility.Vector3dVector(points)
if colors is not None:
pcd.colors = o3d.utility.Vector3dVector(colors)
return pcd
def collate_pointcloud_fn(list_data):
coords, feats, labels = list(zip(*list_data))
# Concatenate all lists
return {
'coords': ME.utils.batched_coordinates(coords),
'xyzs': [torch.from_numpy(feat).float() for feat in feats],
'labels': torch.LongTensor(labels),
}
class ModelNet40Dataset(torch.utils.data.Dataset):
def __init__(self, phase, transform=None, config=None):
self.phase = phase
self.files = []
self.cache = {}
self.data_objects = []
self.transform = transform
self.resolution = config.resolution
self.last_cache_percent = 0
self.root = './ModelNet40'
fnames = glob.glob(os.path.join(self.root, f'chair/{phase}/*.off'))
fnames = sorted([os.path.relpath(fname, self.root) for fname in fnames])
self.files = fnames
assert len(self.files) > 0, "No file loaded"
logging.info(
f"Loading the subset {phase} from {self.root} with {len(self.files)} files"
)
self.density = 30000
# Ignore warnings in obj loader
o3d.utility.set_verbosity_level(o3d.utility.VerbosityLevel.Error)
def __len__(self):
return len(self.files)
def __getitem__(self, idx):
mesh_file = os.path.join(self.root, self.files[idx])
if idx in self.cache:
xyz = self.cache[idx]
else:
# Load a mesh, over sample, copy, rotate, voxelization
assert os.path.exists(mesh_file)
pcd = o3d.io.read_triangle_mesh(mesh_file)
# Normalize to fit the mesh inside a unit cube while preserving aspect ratio
vertices = np.asarray(pcd.vertices)
vmax = vertices.max(0, keepdims=True)
vmin = vertices.min(0, keepdims=True)
pcd.vertices = o3d.utility.Vector3dVector(
(vertices - vmin) / (vmax - vmin).max())
# Oversample points and copy
xyz = resample_mesh(pcd, density=self.density)
self.cache[idx] = xyz
cache_percent = int((len(self.cache) / len(self)) * 100)
if cache_percent > 0 and cache_percent % 10 == 0 and cache_percent != self.last_cache_percent:
logging.info(
f"Cached {self.phase}: {len(self.cache)} / {len(self)}: {cache_percent}%"
)
self.last_cache_percent = cache_percent
# Use color or other features if available
feats = np.ones((len(xyz), 1))
if len(xyz) < 1000:
logging.info(
f"Skipping {mesh_file}: does not have sufficient CAD sampling density after resampling: {len(xyz)}."
)
return None
if self.transform:
xyz, feats = self.transform(xyz, feats)
# Get coords
xyz = xyz * self.resolution
coords = np.floor(xyz)
inds = ME.utils.sparse_quantize(coords, return_index=True, return_maps_only=True)
return (coords[inds], xyz[inds], idx)
def make_data_loader(phase, augment_data, batch_size, shuffle, num_workers,
repeat, config):
dset = ModelNet40Dataset(phase, config=config)
args = {
'batch_size': batch_size,
'num_workers': num_workers,
'collate_fn': collate_pointcloud_fn,
'pin_memory': False,
'drop_last': False
}
if repeat:
args['sampler'] = InfSampler(dset, shuffle)
else:
args['shuffle'] = shuffle
loader = torch.utils.data.DataLoader(dset, **args)
return loader
ch = logging.StreamHandler(sys.stdout)
logging.getLogger().setLevel(logging.INFO)
logging.basicConfig(
format=os.uname()[1].split('.')[0] + ' %(asctime)s %(message)s',
datefmt='%m/%d %H:%M:%S',
handlers=[ch])
parser = argparse.ArgumentParser()
parser.add_argument('--resolution', type=int, default=128)
parser.add_argument('--max_iter', type=int, default=30000)
parser.add_argument('--val_freq', type=int, default=1000)
parser.add_argument('--batch_size', default=16, type=int)
parser.add_argument('--lr', default=1e-2, type=float)
parser.add_argument('--momentum', type=float, default=0.9)
parser.add_argument('--weight_decay', type=float, default=1e-4)
parser.add_argument('--num_workers', type=int, default=1)
parser.add_argument('--stat_freq', type=int, default=50)
parser.add_argument('--weights', type=str, default='modelnet_vae.pth')
parser.add_argument('--resume', type=str, default=None)
parser.add_argument('--load_optimizer', type=str, default='true')
parser.add_argument('--train', action='store_true')
parser.add_argument('--max_visualization', type=int, default=4)
###############################################################################
# End of utility functions
###############################################################################
class Encoder(nn.Module):
CHANNELS = [16, 32, 64, 128, 256, 512, 1024]
def __init__(self):
nn.Module.__init__(self)
# Input sparse tensor must have tensor stride 128.
ch = self.CHANNELS
# Block 1
self.block1 = nn.Sequential(
ME.MinkowskiConvolution(
1, ch[0], kernel_size=3, stride=2, dimension=3),
ME.MinkowskiBatchNorm(ch[0]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[0]),
ME.MinkowskiELU(),
)
self.block2 = nn.Sequential(
ME.MinkowskiConvolution(
ch[0], ch[1], kernel_size=3, stride=2, dimension=3),
ME.MinkowskiBatchNorm(ch[1]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[1]),
ME.MinkowskiELU(),
)
self.block3 = nn.Sequential(
ME.MinkowskiConvolution(
ch[1], ch[2], kernel_size=3, stride=2, dimension=3),
ME.MinkowskiBatchNorm(ch[2]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[2]),
ME.MinkowskiELU(),
)
self.block4 = nn.Sequential(
ME.MinkowskiConvolution(
ch[2], ch[3], kernel_size=3, stride=2, dimension=3),
ME.MinkowskiBatchNorm(ch[3]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[3]),
ME.MinkowskiELU(),
)
self.block5 = nn.Sequential(
ME.MinkowskiConvolution(
ch[3], ch[4], kernel_size=3, stride=2, dimension=3),
ME.MinkowskiBatchNorm(ch[4]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[4]),
ME.MinkowskiELU(),
)
# Block 5
self.block6 = nn.Sequential(
ME.MinkowskiConvolution(
ch[4], ch[5], kernel_size=3, stride=2, dimension=3),
ME.MinkowskiBatchNorm(ch[5]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[5]),
ME.MinkowskiELU(),
)
# Block 6
self.block7 = nn.Sequential(
ME.MinkowskiConvolution(
ch[5], ch[6], kernel_size=3, stride=2, dimension=3),
ME.MinkowskiBatchNorm(ch[6]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[6]),
ME.MinkowskiELU(),
)
self.global_pool = ME.MinkowskiGlobalPooling()
self.linear_mean = ME.MinkowskiLinear(ch[6], ch[6], bias=True)
self.linear_log_var = ME.MinkowskiLinear(ch[6], ch[6], bias=True)
self.weight_initialization()
def weight_initialization(self):
for m in self.modules():
if isinstance(m, ME.MinkowskiConvolution):
ME.utils.kaiming_normal_(
m.kernel, mode='fan_out', nonlinearity='relu')
if isinstance(m, ME.MinkowskiBatchNorm):
nn.init.constant_(m.bn.weight, 1)
nn.init.constant_(m.bn.bias, 0)
def forward(self, sinput):
out = self.block1(sinput)
out = self.block2(out)
out = self.block3(out)
out = self.block4(out)
out = self.block5(out)
out = self.block6(out)
out = self.block7(out)
print("In pool", out.C.size(), out.C, out.F.cpu())
out = self.global_pool(out)
print("Out pool", out.C.size(), out.C, out.F.cpu())
mean = self.linear_mean(out)
log_var = self.linear_log_var(out)
return mean, log_var
class Decoder(nn.Module):
CHANNELS = [1024, 512, 256, 128, 64, 32, 16]
resolution = 128
def __init__(self):
nn.Module.__init__(self)
# Input sparse tensor must have tensor stride 128.
ch = self.CHANNELS
# Block 1
self.block1 = nn.Sequential(
ME.MinkowskiGenerativeConvolutionTranspose(
ch[0],
ch[0],
kernel_size=2,
stride=2,
dimension=3),
ME.MinkowskiBatchNorm(ch[0]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[0], ch[0], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[0]),
ME.MinkowskiELU(),
ME.MinkowskiGenerativeConvolutionTranspose(
ch[0],
ch[1],
kernel_size=2,
stride=2,
dimension=3),
ME.MinkowskiBatchNorm(ch[1]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[1], ch[1], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[1]),
ME.MinkowskiELU(),
)
self.block1_cls = ME.MinkowskiConvolution(
ch[1], 1, kernel_size=1, bias=True, dimension=3)
# Block 2
self.block2 = nn.Sequential(
ME.MinkowskiGenerativeConvolutionTranspose(
ch[1],
ch[2],
kernel_size=2,
stride=2,
dimension=3),
ME.MinkowskiBatchNorm(ch[2]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[2], ch[2], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[2]),
ME.MinkowskiELU(),
)
self.block2_cls = ME.MinkowskiConvolution(
ch[2], 1, kernel_size=1, bias=True, dimension=3)
# Block 3
self.block3 = nn.Sequential(
ME.MinkowskiGenerativeConvolutionTranspose(
ch[2],
ch[3],
kernel_size=2,
stride=2,
dimension=3),
ME.MinkowskiBatchNorm(ch[3]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[3], ch[3], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[3]),
ME.MinkowskiELU(),
)
self.block3_cls = ME.MinkowskiConvolution(
ch[3], 1, kernel_size=1, bias=True, dimension=3)
# Block 4
self.block4 = nn.Sequential(
ME.MinkowskiGenerativeConvolutionTranspose(
ch[3],
ch[4],
kernel_size=2,
stride=2,
dimension=3),
ME.MinkowskiBatchNorm(ch[4]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[4], ch[4], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[4]),
ME.MinkowskiELU(),
)
self.block4_cls = ME.MinkowskiConvolution(
ch[4], 1, kernel_size=1, bias=True, dimension=3)
# Block 5
self.block5 = nn.Sequential(
ME.MinkowskiGenerativeConvolutionTranspose(
ch[4],
ch[5],
kernel_size=2,
stride=2,
dimension=3),
ME.MinkowskiBatchNorm(ch[5]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[5], ch[5], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[5]),
ME.MinkowskiELU(),
)
self.block5_cls = ME.MinkowskiConvolution(
ch[5], 1, kernel_size=1, bias=True, dimension=3)
# Block 6
self.block6 = nn.Sequential(
ME.MinkowskiGenerativeConvolutionTranspose(
ch[5],
ch[6],
kernel_size=2,
stride=2,
dimension=3),
ME.MinkowskiBatchNorm(ch[6]),
ME.MinkowskiELU(),
ME.MinkowskiConvolution(ch[6], ch[6], kernel_size=3, dimension=3),
ME.MinkowskiBatchNorm(ch[6]),
ME.MinkowskiELU(),
)
self.block6_cls = ME.MinkowskiConvolution(
ch[6], 1, kernel_size=1, bias=True, dimension=3)
# pruning
self.pruning = ME.MinkowskiPruning()
def get_batch_indices(self, out):
return out.coords_man.get_row_indices_per_batch(out.coords_key)
def get_target(self, out, target_key, kernel_size=1):
with torch.no_grad():
target = torch.zeros(len(out), dtype=torch.bool)
cm = out.coords_man
strided_target_key = cm.stride(
target_key, out.tensor_stride[0], force_creation=True)
ins, outs = cm.kernel_map(
out.coords_key,
strided_target_key,
kernel_size=kernel_size,
region_type=1)
for curr_in in ins:
target[curr_in] = 1
return target
def valid_batch_map(self, batch_map):
for b in batch_map:
if len(b) == 0:
return False
return True
def forward(self, z, target_key):
out_cls, targets = [], []
#z.set_tensor_stride(self.resolution)
z1 = ME.SparseTensor(
features=z.F,
coordinates=z.C,
tensor_stride=self.resolution,
coordinate_manager=z.coordinate_manager)
# Block1
out1 = self.block1(z1)
out1_cls = self.block1_cls(out1)
target = self.get_target(out1, target_key)
targets.append(target)
out_cls.append(out1_cls)
keep1 = (out1_cls.F > 0).cpu().squeeze()
# If training, force target shape generation, use net.eval() to disable
if self.training:
keep1 += target
# Remove voxels 32
out1 = self.pruning(out1, keep1.cpu())
# Block 2
out2 = self.block2(out1)
out2_cls = self.block2_cls(out2)
target = self.get_target(out2, target_key)
targets.append(target)
out_cls.append(out2_cls)
keep2 = (out2_cls.F > 0).cpu().squeeze()
if self.training:
keep2 += target
# Remove voxels 16
out2 = self.pruning(out2, keep2.cpu())
# Block 3
out3 = self.block3(out2)
out3_cls = self.block3_cls(out3)
target = self.get_target(out3, target_key)
targets.append(target)
out_cls.append(out3_cls)
keep3 = (out3_cls.F > 0).cpu().squeeze()
if self.training:
keep3 += target
# Remove voxels 8
out3 = self.pruning(out3, keep3.cpu())
# Block 4
out4 = self.block4(out3)
out4_cls = self.block4_cls(out4)
target = self.get_target(out4, target_key)
targets.append(target)
out_cls.append(out4_cls)
keep4 = (out4_cls.F > 0).cpu().squeeze()
if self.training:
keep4 += target
# Remove voxels 4
out4 = self.pruning(out4, keep4.cpu())
# Block 5
out5 = self.block5(out4)
out5_cls = self.block5_cls(out5)
target = self.get_target(out5, target_key)
targets.append(target)
out_cls.append(out5_cls)
keep5 = (out5_cls.F > 0).cpu().squeeze()
if self.training:
keep5 += target
# Remove voxels 2
out5 = self.pruning(out5, keep5.cpu())
# Block 5
out6 = self.block6(out5)
out6_cls = self.block6_cls(out6)
target = self.get_target(out6, target_key)
targets.append(target)
out_cls.append(out6_cls)
keep6 = (out6_cls.F > 0).cpu().squeeze()
# Last layer does not require keep
# if self.training:
# keep6 += target
# Remove voxels 1
if keep6.sum() > 0:
out6 = self.pruning(out6, keep6.cpu())
return out_cls, targets, out6
class VAE(nn.Module):
def __init__(self):
nn.Module.__init__(self)
self.encoder = Encoder()
self.decoder = Decoder()
def forward(self, sinput, gt_target):
means, log_vars = self.encoder(sinput)
zs = means
if self.training:
zs = zs + torch.exp(0.5 * log_vars.F) * torch.randn_like(log_vars.F)
out_cls, targets, sout = self.decoder(zs, gt_target)
return out_cls, targets, sout, means, log_vars, zs
def train(net, dataloader, device, config):
optimizer = optim.SGD(
net.parameters(),
lr=config.lr,
momentum=config.momentum,
weight_decay=config.weight_decay)
scheduler = optim.lr_scheduler.ExponentialLR(optimizer, 0.95)
crit = nn.BCEWithLogitsLoss()
start_iter = 0
if config.resume is not None:
checkpoint = torch.load(config.resume)
print('Resuming weights')
net.load_state_dict(checkpoint['state_dict'])
optimizer.load_state_dict(checkpoint['optimizer'])
scheduler.load_state_dict(checkpoint['scheduler'])
start_iter = checkpoint['curr_iter']
net.train()
train_iter = iter(dataloader)
# val_iter = iter(val_dataloader)
logging.info(f'LR: {scheduler.get_lr()}')
for i in range(start_iter, config.max_iter):
s = time()
data_dict = train_iter.next()
d = time() - s
optimizer.zero_grad()
sin = ME.SparseTensor(
features=torch.ones(len(data_dict['coords']), 1),
coordinates=data_dict['coords'].int(),
device=device
#allow_duplicate_coords=True, # for classification, it doesn't matter
) #.to(device)
# Generate target sparse tensor
target_key = sin.coordinate_map_key #coords_key
out_cls, targets, sout, means, log_vars, zs = net(sin, target_key)
num_layers, BCE = len(out_cls), 0
losses = []
for out_cl, target in zip(out_cls, targets):
curr_loss = crit(out_cl.F.squeeze(),
target.type(out_cl.F.dtype).to(device))
losses.append(curr_loss.item())
BCE += curr_loss / num_layers
KLD = -0.5 * torch.mean(
torch.mean(1 + log_vars.F - means.F.pow(2) - log_vars.F.exp(), 1))
loss = KLD + BCE
loss.backward()
optimizer.step()
t = time() - s
if i % config.stat_freq == 0:
logging.info(
f'Iter: {i}, Loss: {loss.item():.3e}, Depths: {len(out_cls)} Data Loading Time: {d:.3e}, Tot Time: {t:.3e}'
)
if i % config.val_freq == 0 and i > 0:
torch.save(
{
'state_dict': net.state_dict(),
'optimizer': optimizer.state_dict(),
'scheduler': scheduler.state_dict(),
'curr_iter': i,
}, config.weights)
scheduler.step()
logging.info(f'LR: {scheduler.get_lr()}')
net.train()
def visualize(net, dataloader, device, config):
net.eval()
crit = nn.BCEWithLogitsLoss()
n_vis = 0
for data_dict in dataloader:
sin = ME.SparseTensor(
torch.ones(len(data_dict['coords']), 1),
data_dict['coords'].int(),
allow_duplicate_coords=True, # for classification, it doesn't matter
).to(device)
# Generate target sparse tensor
target_key = sin.coords_key
out_cls, targets, sout, means, log_vars, zs = net(sin, target_key)
num_layers, BCE = len(out_cls), 0
losses = []
for out_cl, target in zip(out_cls, targets):
curr_loss = crit(out_cl.F.squeeze(),
target.type(out_cl.F.dtype).to(device))
losses.append(curr_loss.item())
BCE += curr_loss / num_layers
KLD = -0.5 * torch.mean(
torch.sum(1 + log_vars.F - means.F.pow(2) - log_vars.F.exp(), 1))
loss = KLD + BCE
print(loss)
batch_coords, batch_feats = sout.decomposed_coordinates_and_features
for b, (coords, feats) in enumerate(zip(batch_coords, batch_feats)):
pcd = PointCloud(coords)
pcd.estimate_normals()
pcd.translate([0.6 * config.resolution, 0, 0])
pcd.rotate(M)
opcd = PointCloud(data_dict['xyzs'][b])
opcd.translate([-0.6 * config.resolution, 0, 0])
opcd.estimate_normals()
opcd.rotate(M)
o3d.visualization.draw_geometries([pcd, opcd])
n_vis += 1
if n_vis > config.max_visualization:
return
if __name__ == '__main__':
config = parser.parse_args()
logging.info(config)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
net = VAE()
net.to(device)
logging.info(net)
if config.train:
dataloader = make_data_loader(
'train',
augment_data=True,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
repeat=True,
config=config)
train(net, dataloader, device, config)
else:
if not os.path.exists(config.weights):
logging.info(
f'Downloaing pretrained weights. This might take a while...')
urllib.request.urlretrieve(
"https://bit.ly/39TvWys", filename=config.weights)
logging.info(f'Loading weights from {config.weights}')
checkpoint = torch.load(config.weights)
net.load_state_dict(checkpoint['state_dict'])
dataloader = make_data_loader(
'test',
augment_data=True,
batch_size=config.batch_size,
shuffle=True,
num_workers=config.num_workers,
repeat=True,
config=config)
with torch.no_grad():
visualize(net, dataloader, device, config)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment