Skip to content

Instantly share code, notes, and snippets.

@tejank10
Last active March 3, 2024 12:15
Show Gist options
  • Save tejank10/a116654e644be371fa7296dfc185f73a to your computer and use it in GitHub Desktop.
Save tejank10/a116654e644be371fa7296dfc185f73a to your computer and use it in GitHub Desktop.
def multiview_color_loss(predicts, targets_a, targets_b):
l1loss = nn.L1Loss()
loss = (l1loss(predicts[0][:, :3], targets_a[:, :3]) + \
l1loss(predicts[1][:, :3], targets_a[:, :3]) + \
l1loss(predicts[2][:, :3], targets_b[:, :3]) + \
l1loss(predicts[3][:, :3], targets_b[:, :3])) / 4
return loss
# To be placed in model.py
class ColorGen(nn.Module):
def __init__(self, filename_obj, dim_in=512, im_size=64, Nd=15):
super(ColorGen, self).__init__()
# Nd=15 (palette size) should be ~10-20
# Nc=1280 (sampling points)
self.template_mesh = sr.Mesh.from_obj(filename_obj)
self.Nc = self.template_mesh.faces.shape[1] # face shape: B x num_faces x 3
self.Nd = Nd # color Palette size
self.fc1 = nn.Linear(dim_in, 1024)
self.fc_sampling = nn.Linear(1024, im_size**2 * self.Nd)
self.fc_selection = nn.Linear(1024, self.Nd * self.Nc)
def forward(self, x):
x = F.relu(self.fc1(x), inplace=True)
col_selection = self.fc_selection(x)
col_sampling = self.fc_sampling(x)
return col_sampling, col_selection
def reconstruct(self, images):
z = self.encoder(images)
vertices, faces = self.decoder(z)
if self.tex_gen:
Nd, Nc = self.col_gen.Nd, self.col_gen.Nc
batch_size, _, H, W = images.shape
imgs = images[:, :3].view(-1, 3, H*W) # B x 3 x H*W
col_sampling, col_selection = self.col_gen(z)
col_sampling = F.softmax(col_sampling.view(-1, H*W, Nd), dim=1) # B x H*W x palette_size
col_selection = F.softmax(col_selection.view(-1, Nd, Nc), dim=1) # B x palette_size x num_faces
# make color palette
color_palette = torch.matmul(imgs, col_sampling) # B x 3 x palette_size
# Select colors from palette for each face
textures = torch.matmul(color_palette, col_selection).permute(0, 2, 1) # B x num_faces x 3
textures = textures.unsqueeze(2) # B x num_faces x 1 x 3
return vertices, faces, textures
return vertices, faces
# To be placed in losses.py of SoftRas source code and recompile the library
class TexLaplacianLoss(nn.Module):
def __init__(self, faces, average=False):
super(TexLaplacianLoss, self).__init__()
self.nf = faces.size(0)
self.average = average
tex_laplacian = np.zeros([self.nf, self.nf]).astype(np.float32)
edge2faces = self.getEdge2Faces(faces)
tex_laplacian[edge2faces[:, 0], edge2faces[:, 1]] = -1
tex_laplacian[edge2faces[:, 1], edge2faces[:, 0]] = -1
r, c = np.diag_indices(tex_laplacian.shape[0])
tex_laplacian[r, c] = -tex_laplacian.sum(1)
for i in range(self.nf):
tex_laplacian[i, :] /= tex_laplacian[i, i]
self.register_buffer('tex_laplacian', torch.from_numpy(tex_laplacian))
def getEdge2Faces(self, faces):
faces_sorted = np.sort(faces.detach().cpu().numpy(), axis=1)
numEdges = len(list(set([tuple(v) for v in np.concatenate((faces_sorted[:, 0:2], faces_sorted[:, 1:3], faces_sorted[:, 0:3:2]))])))
edge2face = -np.ones((numEdges, 2), dtype=np.int32)
edge2idx = dict()
idx = 0
for i, f in enumerate(faces_sorted):
# ASSUMPTION: One edge is shared by 2 faces only
edges = [(f[0], f[1]), (f[1], f[2]), (f[0], f[2])]
for e in edges:
if edge2idx.get(e) is None:
edge2idx[e] = idx
idx += 1
if edge2face[edge2idx[e], 0] == -1:
edge2face[edge2idx[e], 0] = i
elif edge2face[edge2idx[e], 1] == -1:
edge2face[edge2idx[e], 1] = i
else:
print(e, ': This edge has more than 2 faces :(')
return edge2face
def forward(self, x):
batch_size = x.size(0)
x = torch.matmul(self.tex_laplacian, torch.squeeze(x, dim=2))
dims = tuple(range(x.ndimension())[1:])
x = x.pow(2).sum(dims)
if self.average:
return x.sum() / batch_size
else:
return x
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment