Skip to content

Instantly share code, notes, and snippets.

@Sam-Izdat
Last active July 2, 2024 18:08
Show Gist options
  • Save Sam-Izdat/8b68b5d7a65bcf688109d60d0d9b82fe to your computer and use it in GitHub Desktop.
Save Sam-Izdat/8b68b5d7a65bcf688109d60d0d9b82fe to your computer and use it in GitHub Desktop.
import time
import math
import torch
import taichi as ti
import taichi.math as tm
import tinytex as ttex
from tinycio import fsio
from .util import *
from .logger import Logging
from .texture import Texture2D, WrapMode
@ti.data_oriented
class RadianceCascades:
cascades = []
n_cascades=5
c0_directions=32
c0_probe_size=2
c0_interval_length = 12
max_theta = 256
probe_h_offset = 0.1
depth_bias = 3.5
height_scale_factor = 5.
probe_pad = 1
n_theta_n = 128
n_phi_n = 32
def __init__(self, canvas_shape:tuple, kernel_shape:tuple, env_map:torch.Tensor=None, repeat:bool=False):
self.canvas_shape = canvas_shape
self.kernel_shape = kernel_shape
self.repeat = int(repeat)
self.wrap_mode = WrapMode.REPEAT if repeat else WrapMode.CLAMP
self.log = Logging.get_logger()
self.time = time.time()
if env_map is None:
env_fp = '../../data/scene/carpentry_shop_02_1k.exr'
env_map = fsio.load_image(env_fp, graphics_format=fsio.GraphicsFormat.SFLOAT32)
self.log.debug('-------- RC INIT --------')
self._compute_params()
self._log_timed('Computed params')
env_shape = (self.cascades[self.n_cascades-1]['n_theta'], self.cascades[self.n_cascades-1]['n_phi'])
self.env_map = Texture2D(ttex.Resampling.resize(env_map, shape=env_shape))
self._log_timed('Set up IBL')
def compute(self,
base_color:torch.Tensor,
height:torch.Tensor,
normal:torch.Tensor,
emissive:torch.Tensor=None,
unorm=False,
opengl_normals:bool=True,
bounces:int=3):
H, W = self.canvas_shape
device = base_color.device
if emissive is None: emissive = torch.zeros(3, H, W).to(device)
self.bounces = bounces
im_h, im_w = height.shape[1:]
base_color = self._size_to_canvas(base_color).to(device)
height = self._size_to_canvas(height).to(device)
normal = self._size_to_canvas(normal).to(device)
emissive = self._size_to_canvas(emissive).to(device)
if unorm: normal = normal * 2. - 1.
if opengl_normals: normal[1:2] * -1
dim_scale = H / im_h if im_h > im_w else W / im_w
height *= dim_scale * self.height_scale_factor
total_time = time.time()
self.log.debug('------ RC COMPUTE -------')
self._log_timed(f'Computing {H} x {W} on device: {device}')
# height remains denormalized
pos = torch.cat([meshgrid_2d(H, W).permute(0, 3, 1, 2).squeeze(0).to(device) * 0.5 + 0.5, height], dim=0)
self.positions = Texture2D(pos, wrap_mode=self.wrap_mode); self._log_timed('Populated positions')
self.base_color = Texture2D(base_color, wrap_mode=self.wrap_mode); self._log_timed('Populated base color')
self.radiance = Texture2D(emissive, wrap_mode=self.wrap_mode); self._log_timed('Populated radiance')
self.emissive = Texture2D(emissive, wrap_mode=self.wrap_mode); self._log_timed('Populated emissive')
self.normal = Texture2D(normal, wrap_mode=self.wrap_mode); self._log_timed('Populated normals')
c0_probe_height = self.cascades[0]['probe_shape'][0]
c0_probe_width = self.cascades[0]['probe_shape'][0]
c0_probe_halfheight = c0_probe_height // 2
c0_probe_halfwidth = c0_probe_width // 2
for i in range(self.bounces):
for cx in range(self.canvas_shape[1]//self.kernel_shape[1]):
for cy in range(self.canvas_shape[0]//self.kernel_shape[0]):
kernel = tm.ivec2(cx, cy)
self.log.debug(f'[RC] Chunk [{cx},{cy}]')
self._gather(quadruple=False, kernel=kernel); self._log_timed('Completed gather')
self._merge(kernel=kernel); self._log_timed('Merged cascades')
self._integrate(kernel=kernel); self._log_timed('Integrated cascades')
self.log.info(f'[RC] Done. TOTAL time elapsed: {time.time() - total_time:.4f}')
def _size_to_canvas(self, im:torch.Tensor) -> torch.Tensor:
im = im.clone()
im_height, im_width = im.shape[1:]
canvas_height, canvas_width = self.canvas_shape
im_aspect = im_width / im_height
canvas_aspect = canvas_width / canvas_height
aspect_ratio_canvas = canvas_width / canvas_height
if im_aspect > canvas_aspect:
target_width = canvas_width
target_height = int(canvas_width / im_aspect)
else:
target_height = canvas_height
target_width = int(canvas_height * im_aspect)
im = ttex.Resampling.resize(im, (target_height, target_width))
im = ttex.Resampling.tile(im, self.canvas_shape)
return im
def _log_timed(self, msg:str) -> bool:
self.log.debug(f'[RC][{(time.time() - self.time):.4f}] {msg}')
self.time = time.time()
return True
def _compute_params(self):
shape = self.kernel_shape
self.c0_interval_length = max((math.sqrt(shape[0]*shape[1]) / (float(1 << 2 * (self.n_cascades - 1)))), self.c0_interval_length)
for n in range(self.n_cascades):
t1 = self.c0_interval_length
p = {}
p['grid_shape'] = (
int(shape[0] / self.c0_probe_size // math.pow(2, n)) + self.probe_pad * 2,
int(shape[1] / self.c0_probe_size // math.pow(2, n)) + self.probe_pad * 2)
p['probe_shape'] = (
int(shape[0] // (p['grid_shape'][0] - self.probe_pad * 2)),
int(shape[1] // (p['grid_shape'][1] - self.probe_pad * 2)))
p['n_probes'] = int(p['grid_shape'][0] * p['grid_shape'][1])
p['t_min'] = 0.0 if n == 0 else t1 * float(1 << 2 * (n - 1))
p['t_max'] = t1 * float(1 << 2 * n)
p['n_phi'] = int(self.c0_directions * math.pow(2, n))
p['n_theta'] = int(min(p['n_phi']//2, self.max_theta))
p['t_min_arclength'] = ((p['t_min'] * 2) * math.pi / p['n_phi'])
p['t_max_arclength'] = ((p['t_max'] * 2) * math.pi / p['n_phi'])
p['t_min_chord'] = max(2. * p['t_min'] * math.sin(p['t_min_arclength'] / max(2 * p['t_min'], 1)), 1.)
p['t_max_chord'] = max(2. * p['t_max'] * math.sin(p['t_max_arclength'] / max(2 * p['t_max'], 1)), 1.)
p['n_steps'] = max(int(min(math.floor((p['t_max'] - p['t_min']) / (p['t_min_chord'])), p['t_max'] - p['t_min'])), 1)
p['step_size'] = float(p['t_max'] - p['t_min']) / float(p['n_steps'])
p['map_tex_res'] = (int(p['n_phi'] * p['grid_shape'][0]), int(p['n_theta'] * p['grid_shape'][1]))
p['radiance'] = ti.field(tm.vec4, shape=(p['n_probes'], p['n_theta'], p['n_phi']))
p['occ_vec'] = ti.types.vector(int(p['n_theta'] / 32) + 1, dtype=ti.u32)
# occ_vec bitfield is excessive for heightfield occlusion, but leaves the door open to changes in future
self.cascades.append(p)
def info(self, printout=True):
info = ''
for n, cascade in enumerate(self.cascades):
info += '--------------------\n'
info += 'C ' + str(n) + '\n'
for param in cascade:
if param in ['radiance', 'occ_vec']: continue
info += param + ' ' + str(cascade[param]) + '\n'
if printout: print(info)
else: return info
@ti.func
def _phi_to_k(self, phi:float, n_phi:float) -> int:
if (n_phi > self.c0_directions): phi += tm.pi / n_phi
# we need the 0.5 to 'wrap around' in lieu of a round here
return int(tm.round(((phi * n_phi) / (tm.pi * 2)) % (n_phi - 0.5)))
@ti.func
def _theta_to_j(self, theta:float, n_theta:float) -> int:
return int(tm.round((theta * (n_theta - 1.)) / tm.pi))
@ti.func
def _k_to_phi(self, k:float, n_phi:float) -> float:
phi = k * ((tm.pi * 2) / n_phi)
if (n_phi > self.c0_directions): phi -= tm.pi / n_phi
return phi % (tm.pi * 2)
@ti.func
def _j_to_theta(self, j:float, n_theta:float) -> float:
return j * (tm.pi / (n_theta - 1))
@ti.func
def _jk_to_vec(self, j:int, k:int, n_theta:int, n_phi:int) -> tm.vec3:
theta_h = self._j_to_theta(j, n_theta)
phi_h = self._k_to_phi(k, n_phi)
return tm.vec3(tm.sin(theta_h) * tm.cos(phi_h), tm.sin(theta_h) * tm.sin(phi_h), tm.cos(theta_h))
@ti.func
def _log_polar(self, center:tm.vec2, rho:float, phi:float) -> tm.vec2:
phi = phi - tm.pi
return tm.vec2(center.x + rho * tm.cos(phi), center.y + rho * tm.sin(phi))
@ti.func
def _compute_j(self, euclidian_distance:float, height_difference:float, n_theta:int) -> int:
theta_radians = tm.clamp(tm.atan2(-height_difference, euclidian_distance) + (tm.pi * 0.5), 0., tm.pi)
return self._theta_to_j(theta_radians, n_theta)
@ti.func
def _downsample_cascade(self, field:ti.template(), p:int, n_theta:int, n_phi:int, j:int, k:int) -> tm.vec4:
s0 = field[p, (j + 0) % n_theta, tm.clamp((k + 0), 0, n_phi)]
s1 = field[p, (j + 1) % n_theta, tm.clamp((k + 0), 0, n_phi)]
s2 = field[p, (j + 0) % n_theta, tm.clamp((k + 1), 0, n_phi)]
s3 = field[p, (j + 1) % n_theta, tm.clamp((k + 1), 0, n_phi)]
return (s0 + s1 + s2 + s3) / 4.
@ti.func
def _bilateral_interp_coeffs(self,
xy_grid:tm.vec2,
probe_elevation:float,
grid_height:int,
grid_width:int,
probe_height:int,
probe_width:int,
kernel:tm.ivec2) -> (tm.ivec4, tm.vec4):
"""Returns pixel indices and weights for interpolation"""
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
probe_halfheight = probe_height * 0.5
probe_halfwidth = probe_width * 0.5
hp = 1. / (canvas_shape * 2.)
padded_kernel = tm.vec2(kernel_shape.x + self.probe_pad * probe_width * 2, kernel_shape.y + self.probe_pad * probe_height * 2)
xy_grid = tm.vec2(xy_grid.x + self.probe_pad * probe_width, xy_grid.y + self.probe_pad * probe_height)
uv_grid = xy_grid / padded_kernel
pos = tm.vec2(grid_width * uv_grid.x - 0.5, grid_height * uv_grid.y - 0.5)
x0, y0, x1, y1 = int(tm.floor(pos.x)), int(tm.floor(pos.y)), 0, 0
x0, y0 = tm.clamp(x0, 0, grid_width-1), tm.clamp(y0, 0, grid_height-1)
x1, y1 = tm.min(x0+1, grid_width-1), tm.min(y0+1, grid_height-1)
dx = (pos.x + 1.) - (float(x0) + 1.)
dy = (pos.y + 1.) - (float(y0) + 1.)
q00 = int((y0 * grid_width) + x0)
q01 = int((y1 * grid_width) + x0)
q10 = int((y0 * grid_width) + x1)
q11 = int((y1 * grid_width) + x1)
indices = tm.ivec4(q00, q01, q10, q11)
depth_bias = self.depth_bias
co = tm.vec2(kernel.x * kernel_shape.x - probe_width * self.probe_pad, kernel.y * kernel_shape.y - probe_height * self.probe_pad)
uv00 = tm.vec2(co.x + x0 * probe_width + probe_halfwidth, co.y + y0 * probe_height + probe_halfheight) / canvas_shape.xy + hp
uv01 = tm.vec2(co.x + x0 * probe_width + probe_halfwidth, co.y + y1 * probe_height + probe_halfheight) / canvas_shape.xy + hp
uv10 = tm.vec2(co.x + x1 * probe_width + probe_halfwidth, co.y + y0 * probe_height + probe_halfheight) / canvas_shape.xy + hp
uv11 = tm.vec2(co.x + x1 * probe_width + probe_halfwidth, co.y + y1 * probe_height + probe_halfheight) / canvas_shape.xy + hp
d00 = self.positions.sample_bilinear(uv00).z
d01 = self.positions.sample_bilinear(uv01).z
d10 = self.positions.sample_bilinear(uv10).z
d11 = self.positions.sample_bilinear(uv11).z
uv_pos = tm.vec3(co + uv_grid, probe_elevation)
weights = tm.vec4(0.);
weights[0] = (1. - dx) * (1. - dy)
weights[1] = (1. - dx) * dy
weights[2] = dx * (1. - dy)
weights[3] = dx * dy
weights[0] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv00 * canvas_shape, d00)) * 0.01) * depth_bias)
weights[1] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv01 * canvas_shape, d01)) * 0.01) * depth_bias)
weights[2] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv10 * canvas_shape, d10)) * 0.01) * depth_bias)
weights[3] *= tm.exp(-(tm.distance(uv_pos, tm.vec3(uv11 * canvas_shape, d11)) * 0.01) * depth_bias)
weights /= tm.dot(tm.vec4(1.), weights)
return indices, weights
@ti.kernel
def _gather(self, quadruple:bool, kernel:tm.ivec2) -> bool:
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
twopi = 2 * tm.pi
for i in ti.static(range(self.n_cascades)):
n_probes = self.cascades[i]['n_probes']
t_min = self.cascades[i]['t_min']
t_max = self.cascades[i]['t_max']
n_phi = self.cascades[i]['n_phi']
n_theta = self.cascades[i]['n_theta']
n_steps = self.cascades[i]['n_steps']
step_size = self.cascades[i]['step_size']
grid_height = self.cascades[i]['grid_shape'][0]
grid_width = self.cascades[i]['grid_shape'][1]
probe_height = self.cascades[i]['probe_shape'][0]
probe_width = self.cascades[i]['probe_shape'][1]
probe_halfheight = probe_height * 0.5
probe_halfwidth = probe_width * 0.5
n_taps = int(1 << i if quadruple else 1)
self.cascades[i]['radiance'].fill(0.)
hp = 1. / (canvas_shape * 2.)
for p, k in ti.ndrange(n_probes, n_phi):
for tap in range(n_taps):
row = int(p / grid_width)
col = int(p % grid_width)
x = kernel.x * kernel_shape.x - probe_width * self.probe_pad + (float(col) * probe_width) + probe_halfwidth
y = kernel.y * kernel_shape.y - probe_height * self.probe_pad + (float(row) * probe_height) + probe_halfheight
xy_grid = tm.vec2(x, y)
z_probe = self.positions.sample_bilinear(xy_grid / canvas_shape + hp).z
occlusion = self.cascades[i]['occ_vec'](0.)
phi = self._k_to_phi(k, n_phi)
if i > 0: phi += tm.pi / float(n_theta)
if tap > 0: phi = (phi + (twopi / (n_phi * n_taps) * tap))
phi = phi % twopi
for s in range(n_steps):
rho = tm.max(t_min + step_size * s, 1.)
uv = self._log_polar(xy_grid, rho, phi) / canvas_shape + hp
if (self.repeat == 0) and (uv.x < 0. or uv.y < 0. or uv.x > 1. or uv.y > 1.): break
z_sample = self.positions.sample_bilinear(uv).z
sj = n_theta - self._compute_j(rho, (z_probe + self.probe_h_offset) - z_sample, n_theta)
for j in range(sj, n_theta):
occ_idx, occ_val = int(j / 32.), 1 << int(j % 32)
if occlusion[occ_idx] & occ_val: break
result = tm.vec4(self.radiance.sample_bilinear(uv), 1.) / n_taps
result += self.cascades[i]['radiance'][p, j, k]
self.cascades[i]['radiance'][p, j, k] = result
occlusion[occ_idx] |= occ_val
return True
@ti.kernel
def _merge(self, kernel:tm.ivec2) -> bool:
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
for i in ti.static(range(self.n_cascades-1, -1, -1)):
n_probes = self.cascades[i]['n_probes']
n_phi = self.cascades[i]['n_phi']
n_theta = self.cascades[i]['n_theta']
grid_height = self.cascades[i]['grid_shape'][0]
grid_width = self.cascades[i]['grid_shape'][1]
probe_height = self.cascades[i]['probe_shape'][0]
probe_width = self.cascades[i]['probe_shape'][1]
probe_halfheight = probe_height * 0.5
probe_halfwidth = probe_width * 0.5
for p, k in ti.ndrange(n_probes, n_phi):
row = int(p / grid_width)
col = int(p % grid_width)
x = (float(col) * probe_width) + probe_halfwidth
y = (float(row) * probe_height) + probe_halfheight
px = int(kernel.x * kernel_shape.x - probe_width * self.probe_pad + x)
py = int(kernel.y * kernel_shape.y - probe_height * self.probe_pad + y)
probe_elevation = self.positions.data[int(py), int(px)].z
xy_grid = tm.vec2(x - probe_width * self.probe_pad, y - probe_height * self.probe_pad)
for j in range(n_theta):
rad_current = self.cascades[i]['radiance'][p, j, k]
if rad_current.a < 1.:
if ti.static(i == self.n_cascades-1):
rad_current.rgb = tm.mix(self.env_map.data[j, k].rgb, rad_current.rgb, rad_current.a)
self.cascades[i]['radiance'][p, j, k] = tm.vec4(rad_current)
else:
uv_sphere = tm.vec2(float(k) / float(n_phi), float(j) / float(n_theta))
m_grid_height = self.cascades[i+1]['grid_shape'][0]
m_grid_width = self.cascades[i+1]['grid_shape'][1]
m_probe_height = float(self.cascades[i+1]['probe_shape'][0])
m_probe_width = float(self.cascades[i+1]['probe_shape'][1])
indices, weights = self._bilateral_interp_coeffs(
xy_grid=xy_grid,
probe_elevation=probe_elevation,
grid_height=m_grid_height,
grid_width=m_grid_width,
probe_height=m_probe_height,
probe_width=m_probe_width,
kernel=kernel)
m_n_theta = self.cascades[i+1]['n_theta']
m_n_phi = self.cascades[i+1]['n_phi']
m_j = int(m_n_theta * uv_sphere.y)
m_k = int(m_n_phi * uv_sphere.x)
q00 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[0], m_n_theta, m_n_phi, m_j, m_k) * weights[0]
q01 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[1], m_n_theta, m_n_phi, m_j, m_k) * weights[1]
q10 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[2], m_n_theta, m_n_phi, m_j, m_k) * weights[2]
q11 = self._downsample_cascade(self.cascades[i+1]['radiance'], indices[3], m_n_theta, m_n_phi, m_j, m_k) * weights[3]
rad_merge = q00 + q01 + q10 + q11
rad_current.rgb = tm.mix(rad_merge.rgb, rad_current.rgb, rad_current.a)
self.cascades[i]['radiance'][p, j, k] = rad_current
return True
@ti.kernel
def _integrate(self, kernel:tm.ivec2) -> bool:
canvas_shape = tm.vec2(float(self.canvas_shape[1]), float(self.canvas_shape[0]))
kernel_shape = tm.vec2(float(self.kernel_shape[1]), float(self.kernel_shape[0]))
grid_height = self.cascades[0]['grid_shape'][0]
grid_width = self.cascades[0]['grid_shape'][1]
probe_height = float(self.cascades[0]['probe_shape'][0])
probe_width = float(self.cascades[0]['probe_shape'][1])
probe_halfwidth = int(probe_width // 2)
probe_halfheight = int(probe_height // 2)
n_phi = self.cascades[0]['n_phi']
n_theta = self.cascades[0]['n_theta']
for y, x in ti.ndrange(self.kernel_shape[0], self.kernel_shape[1]):
px = int(kernel.x * kernel_shape.x + x)
py = int(kernel.y * kernel_shape.y + y)
base_color = self.base_color.data[py, px]
normal = self.normal.data[py, px]
probe_elevation = self.positions.data[py, px].z
xy_grid = tm.vec2(x, y)
radiance = tm.vec4(0.)
ct = 0
indices, weights = tm.ivec4(0), tm.vec4(0.)
indices, weights = self._bilateral_interp_coeffs(
xy_grid=xy_grid,
probe_elevation=probe_elevation,
grid_height=grid_height,
grid_width=grid_width,
probe_height=probe_height,
probe_width=probe_width,
kernel=kernel)
for j, k in ti.ndrange(n_theta, n_phi):
v = self._jk_to_vec(j, k, n_theta, n_phi)
v.x = -v.x
light = tm.vec4(0.)
q00 = self.cascades[0]['radiance'][indices[0], j, k] * weights[0]
q01 = self.cascades[0]['radiance'][indices[1], j, k] * weights[1]
q10 = self.cascades[0]['radiance'][indices[2], j, k] * weights[2]
q11 = self.cascades[0]['radiance'][indices[3], j, k] * weights[3]
light = q00 + q01 + q10 + q11
dot_nl = tm.dot(v, normal)
theta = self._j_to_theta(j, n_theta)
radiance += light * tm.sin(theta) * ((2 * tm.pi) / (n_phi * n_theta)) * dot_nl if dot_nl > 0. else 0.
self.radiance.data[py, px] = self.emissive.data[py, px] + (base_color/tm.pi * radiance.rgb )
return True
# Example:
#
# rc = tr.RadianceCascades(canvas_shape=(1024,1024), kernel_shape=(64,64), env_map=env_tensor)
# rc.compute(base_color=base_color, height=height, normal=normal, emissive=emissive, bounces=1)
# rc.radiance.data.to_torch().permute(2, 0, 1)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment