Skip to content

Instantly share code, notes, and snippets.

@Sam-Izdat
Created February 21, 2024 08:27
Show Gist options
  • Save Sam-Izdat/8c4909c4e8a0e401f6a9fec30ca015c0 to your computer and use it in GitHub Desktop.
Save Sam-Izdat/8c4909c4e8a0e401f6a9fec30ca015c0 to your computer and use it in GitHub Desktop.
@classmethod
def normals_to_height(cls,
normal_map:torch.Tensor,
self_tiling:bool=False,
rescaled:bool=False,
eps:float=torch.finfo(torch.float32).eps) -> (torch.Tensor, torch.Tensor):
"""
Compute height from normals. Frankot-Chellappa algorithm.
:param normal_map: Normal map tensor sized [N, C=3, H, W] or [C=3, H, W]
as unit vectors of surface normals.
:param self_tiling: Treat surface as self-tiling.
:param rescaled: Accept unit vector tensor in [0, 1] value range.
:return: Height tensor sized [N, C=1, H, W] or [C=1, H, W] in [0, 1] range
and height scale tensor sized [N, C=1] or [C=1] in [0, inf] range.
"""
ndim = len(normal_map.size())
assert ndim == 3 or ndim == 4, cls.err_size
nobatch = ndim == 3
if nobatch: normal_map = normal_map.unsqueeze(0)
assert normal_map.size(1) == 3, cls.err_normal_ch
if rescaled: normal_map = normal_map * 2. - 1.
device = normal_map.device
N, _, H, W = normal_map.size()
res_disp, res_scale = [], []
for i in range(N):
vec = normal_map[i]
nx, ny = vec[0], vec[1]
if not self_tiling:
nxt = torch.cat([nx, -torch.flip(nx, dims=[1])], dim=1)
nxb = torch.cat([torch.flip(nx, dims=[0]), -torch.flip(nx, dims=[0,1])], dim=1)
nx = torch.cat([nxt, nxb], dim=0)
nyt = torch.cat([ny, torch.flip(ny, dims=[1])], dim=1)
nyb = torch.cat([-torch.flip(ny, dims=[0]), -torch.flip(ny, dims=[0,1])], dim=1)
ny = torch.cat([nyt, nyb], dim=0)
r, c = nx.shape
rg = (torch.arange(r) - (r // 2 + 1)).float() / (r - r % 2)
cg = (torch.arange(c) - (c // 2 + 1)).float() / (c - c % 2)
u, v = torch.meshgrid(cg, rg, indexing='xy')
u = torch.fft.ifftshift(u.to(device))
v = torch.fft.ifftshift(v.to(device))
gx = torch.fft.fft2(-nx)
gy = torch.fft.fft2(ny)
nom = (-1j * u * gx) + (-1j * v * gy)
denom = (u**2) + (v**2) + eps
zf = nom / denom
zf[0, 0] = 0.0
z = torch.real(torch.fft.ifft2(zf))
disp, scale = (z - torch.min(z)) / (torch.max(z) - torch.min(z)), float(torch.max(z) - torch.min(z))
if not self_tiling: disp = disp[:H, :W]
res_disp.append(disp.unsqueeze(0).unsqueeze(0))
res_scale.append(torch.tensor(scale).unsqueeze(0))
res_disp = torch.cat(res_disp, dim=0)
res_scale = torch.cat(res_scale, dim=0).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).to(res_disp.device)
if nobatch:
res_disp = res_disp.squeeze(0)
res_scale = res_scale.squeeze(0)
return res_disp, res_scale / 10.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment