Last active
July 3, 2024 16:21
-
-
Save classAndrew/c526070f9b9518a1a2814385b86d1076 to your computer and use it in GitHub Desktop.
Ray tracing using Torch tensors
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import time | |
import time | |
import numpy as np | |
import torch | |
torch.set_default_device('cuda') | |
device="cuda" | |
n = 320 | |
width, height = n*2, n | |
aspect_ratio = width/height | |
samples_per_pixel = 10 | |
class Vec: pass | |
class Ray: pass | |
class Sphere: pass | |
class HitRecord: pass | |
Vec = lambda x, y, z: torch.Tensor([x, y, z]).to(device) | |
Ray = lambda start, dir: torch.Tensor([start, dir]).to(device) | |
Sphere = lambda center, radius: torch.Tensor([*center, radius]).to(device) | |
class HitRecord: | |
def __init__(self, num_rays): | |
self.hit_points = torch.zeros((num_rays, 3)) | |
self.hit_normals = torch.zeros((num_rays, 3)) | |
self.hit_t = torch.zeros(num_rays) | |
self.is_front_face = torch.zeros(num_rays, dtype=bool) | |
self.material_id = torch.zeros(num_rays, dtype=int) | |
def prod(v): | |
return int(torch.prod(v)) | |
def ray_at(ray, t): | |
return ray[..., 0, :] + ray[..., 1, :]*t[..., None] | |
def vec_norm(v): | |
return (v[..., 0]**2 + v[..., 1]**2 + v[..., 2]**2) ** .5 | |
def vec_unit(v): | |
return v/vec_norm(v)[..., None] | |
def set_face_normal(hit_record: HitRecord, ray, hit_idx, hit_outward_normals): | |
hit_ray_dirs = ray[hit_idx][..., 1, :] | |
if len(hit_idx.shape) == 0: | |
return | |
front_faces = torch.einsum("ij,ij->i", hit_ray_dirs, hit_outward_normals) < 0 | |
hit_record.is_front_face[hit_idx] = front_faces | |
hit_outward_normals[~front_faces] *= -1 | |
hit_record.hit_normals[hit_idx] = hit_outward_normals | |
def gamma_correct(colors): | |
# ok there might be some negative values somehow | |
colors[colors > 0] = torch.sqrt(colors[colors > 0]) | |
colors[colors < 0] = 0. | |
def get_many_rng_unit_vecs(size): | |
N_RNG_BATCH = 18000 | |
generated_count = 0 | |
results = [] | |
while generated_count < size: | |
rng_seq = torch.FloatTensor(N_RNG_BATCH).uniform_(-1, 1.).reshape(N_RNG_BATCH//3, 3) | |
in_sphere = vec_norm(rng_seq) <= 1. | |
sphere_vecs = vec_unit(rng_seq[in_sphere]) | |
results.append(sphere_vecs) | |
generated_count += sphere_vecs.shape[0] | |
return torch.vstack([*results])[:size].to(device) | |
def get_reflected(ray_dirs, normals): | |
v_dot_n = torch.einsum("ij,ij->i", ray_dirs, normals) | |
return ray_dirs - 2 * v_dot_n[..., None] * normals | |
def get_refracted(uv, normals, etai_over_etat): | |
cos_theta = torch.fmin(torch.einsum("ij,ij->i", -uv, normals), torch.ones(normals.shape[0])) | |
out_perp = etai_over_etat[:, None] * (uv + cos_theta[:, None] * normals) | |
out_parallel = -torch.sqrt(torch.abs(1.0 - vec_norm(out_perp) **2))[:, None] * normals | |
return out_perp + out_parallel | |
def get_reflectance(cosine, refractive_index): | |
r0 = (1-refractive_index) / (1+refractive_index) | |
r0 = r0*r0 | |
return r0 * (1-r0)*(1. - cosine)**5 | |
def sphere_hit(sphere, ray, ray_tmin, ray_tmax, hit_record: HitRecord): | |
sphere_center, sphere_radius = sphere[:3].to(device), sphere[3].to(device) | |
oc = sphere_center - ray[..., 0, :] | |
a = torch.einsum("ij,ij->i", ray[..., 1, :], ray[..., 1, :]) # torch.dot(ray[1], ray[1]) | |
h = torch.einsum("ij,ij->i", ray[..., 1, :], oc) # torch.dot(ray[1], oc) | |
c = torch.einsum("ij,ij->i", oc, oc) - sphere_radius*sphere_radius # torch.dot(oc, oc) - sphere_radius*sphere_radius | |
discriminant = h*h - a*c | |
has_hit = discriminant > 0 | |
pos_disc_idx = has_hit.nonzero().squeeze() | |
sqrtd = torch.sqrt(discriminant[pos_disc_idx]) | |
h = h[pos_disc_idx] | |
a = a[pos_disc_idx] | |
c = c[pos_disc_idx] | |
roots = (h - sqrtd) / a | |
valid_roots = torch.ones_like(roots, dtype=bool) | |
r1_out_bounds_idx = ((roots <= ray_tmin[pos_disc_idx]) | (roots >= ray_tmax[pos_disc_idx])).nonzero().squeeze() | |
if not sqrtd.shape: | |
return has_hit | |
sqrtd = sqrtd[r1_out_bounds_idx] | |
h = h[r1_out_bounds_idx] | |
a = a[r1_out_bounds_idx] | |
roots2 = (h + sqrtd) / a | |
roots[r1_out_bounds_idx] = roots2 | |
r2_out_bounds_mask = ((roots <= ray_tmin[pos_disc_idx]) | (roots >= ray_tmax[pos_disc_idx])) | |
valid_roots[r2_out_bounds_mask] = False | |
has_hit[pos_disc_idx] = valid_roots | |
# hit_record[pos_disc_idx, 2, 0] = roots | |
hit_record.hit_t[pos_disc_idx] = roots | |
hit_idx = has_hit.nonzero().squeeze() | |
hit_record.hit_points[hit_idx] = p = ray_at(ray[hit_idx], hit_record.hit_t[hit_idx]) | |
# hit_record[hit_idx, 0] = p = ray_at(ray[hit_idx], hit_record[hit_idx][..., 2, 0]) | |
outward_normal = (p - sphere_center) / sphere_radius | |
set_face_normal(hit_record, ray, hit_idx, outward_normal) | |
return has_hit | |
def hit_all(hittable_list, rays, ray_tmin, ray_tmax, recs: HitRecord): | |
has_hit_anything = torch.zeros(rays.shape[0], dtype=bool) | |
closest_t = ray_tmax | |
temp_recs = HitRecord(rays.shape[0]) # torch.zeros_like(recs) | |
for sphere, material_id in hittable_list: | |
has_hit = sphere_hit(sphere, rays, ray_tmin, closest_t, temp_recs) | |
hit_idx = has_hit.nonzero().squeeze() | |
has_hit_anything[hit_idx] = True | |
closest_t[hit_idx] = temp_recs.hit_t[hit_idx] # temp_recs[hit_idx][..., 2, 0] | |
recs.hit_normals[hit_idx] = temp_recs.hit_normals[hit_idx] | |
recs.hit_points[hit_idx] = temp_recs.hit_points[hit_idx] | |
recs.hit_t[hit_idx] = temp_recs.hit_t[hit_idx] | |
recs.is_front_face[hit_idx] = temp_recs.is_front_face[hit_idx] | |
recs.material_id[hit_idx] = material_id # temp_recs.material_id[hit_idx] | |
return has_hit_anything | |
def get_indices_by_loop(material_ids, has_hit): | |
# has_hit is needed because a ray that has not hit will have material id 0 assigned to it | |
indices = [] | |
for k in range(MAX_MATERIALS): | |
indices.append(((material_ids == k) & has_hit).nonzero().squeeze()) | |
return indices | |
def scatter_rays_diffuse(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation): | |
if not material_idx.shape or material_idx.shape[0] == 0: | |
return | |
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3)) | |
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx] | |
bounced_rays[..., 1, :] = bounce_dirs = hit_recs.hit_normals[material_idx] + get_many_rng_unit_vecs(material_idx.shape[0]) | |
# fix degenerate scattering | |
zero_dir_mask = (torch.abs(bounce_dirs) < 1e-8).sum(axis=1) == 3 # there are 3 components | |
bounced_rays[zero_dir_mask, 1, :] = hit_recs.hit_normals[material_idx][zero_dir_mask] | |
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1) | |
def scatter_rays_metal(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation): | |
if not material_idx.shape or material_idx.shape[0] == 0: | |
return | |
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3)) | |
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx] | |
bounced_rays[..., 1, :] = get_reflected(rays[material_idx, 1, :], hit_recs.hit_normals[material_idx]) | |
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1) | |
def scatter_rays_fuzz(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation): | |
if not material_idx.shape or material_idx.shape[0] == 0: | |
return | |
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3)) | |
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx] | |
bounced_rays[..., 1, :] = bounce_dirs = get_reflected(rays[material_idx, 1, :], hit_recs.hit_normals[material_idx]) + get_many_rng_unit_vecs(material_idx.shape[0]) | |
zero_dir_mask = (torch.abs(bounce_dirs) < 1e-8).sum(axis=1) == 3 # there are 3 components | |
bounced_rays[zero_dir_mask, 1, :] = hit_recs.hit_normals[material_idx][zero_dir_mask] | |
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1) | |
def scatter_rays_dialectric(hittable_list, rays, depth, colors, hit_recs: HitRecord, material_idx, attenuation, ri): | |
if not material_idx.shape or material_idx.shape[0] == 0: | |
return | |
refractive_index = torch.ones(material_idx.shape[0]) | |
refractive_index[hit_recs.is_front_face[material_idx]] = 1/ri | |
refractive_index[~hit_recs.is_front_face[material_idx]] = ri | |
unit_direction = vec_unit(rays[material_idx][..., 1, :]) | |
cos_theta = torch.fmin(torch.einsum("ij,ij->i", -unit_direction, hit_recs.hit_normals[material_idx]), torch.ones_like(refractive_index)) | |
sin_theta = torch.sqrt(1.0 - cos_theta*cos_theta) | |
high_reflectance_mask = get_reflectance(cos_theta, refractive_index) > torch.FloatTensor(material_idx.shape[0]).uniform_(0., 1.).to(device) | |
cannot_refract_mask = (refractive_index * sin_theta > 1.) | high_reflectance_mask | |
can_refract_mask = ~cannot_refract_mask | |
bounced_rays = torch.zeros((material_idx.shape[0], 2, 3)) | |
bounced_rays[..., 0, :] = hit_recs.hit_points[material_idx] | |
bounced_rays[cannot_refract_mask, 1, :] = get_reflected(unit_direction[cannot_refract_mask], hit_recs.hit_normals[material_idx][cannot_refract_mask]) | |
bounced_rays[can_refract_mask, 1, :] = get_refracted(unit_direction[can_refract_mask], hit_recs.hit_normals[material_idx][can_refract_mask], refractive_index[can_refract_mask]) | |
colors[material_idx] = attenuation*ray_color(hittable_list, bounced_rays, depth-1) | |
hittable_list = [(Sphere(Vec(0.5, 1, -1), 0.5), 3), (Sphere(Vec(-0.5, 1, -1), 0.5), 2), (Sphere(Vec(-0.5, 0, -1), 0.5), 1), | |
(Sphere(Vec(0, -100.5, -1), 100), 0), (Sphere(Vec(0.5, 0, -1), 0.5), 4)] | |
MAX_MATERIALS = max(x for _, x in hittable_list) + 1 | |
def ray_color(hittable_list, rays, depth=10): | |
if depth <= 0: | |
return torch.zeros((rays.shape[0], 3)) # Vec(0., 0., 0.) | |
hit_recs = HitRecord(rays.shape[0]) # torch.zeros((rays.shape[0], 3, 3)) | |
rays_max = torch.ones((rays.shape[0]))*torch.inf | |
rays_min = torch.ones((rays.shape[0]))*0.001 | |
has_hits = hit_all(hittable_list, rays, rays_min, rays_max, hit_recs) | |
colors = torch.zeros((rays.shape[0], 3)) | |
unhit_rays_unit = vec_unit(rays[~has_hits]) | |
ray_alphas = 0.5 * (unhit_rays_unit[..., 1, 1] + 1.) | |
colors[~has_hits] = torch.outer(1-ray_alphas, Vec(1., 1., 1.)) + torch.outer(ray_alphas, Vec(0.5, 0.7, 1.)) | |
ground_idx, yellow_metal_idx, blue_diffuse_idx, yellow_fuzz_idx, glass_15_idx = get_indices_by_loop(hit_recs.material_id, has_hits) | |
scatter_rays_diffuse(hittable_list, rays, depth, colors, hit_recs, ground_idx, Vec(0.8, 0.8, 0)) | |
scatter_rays_diffuse(hittable_list, rays, depth, colors, hit_recs, blue_diffuse_idx, Vec(0.1, 0.2, 0.5)) | |
scatter_rays_metal(hittable_list, rays, depth, colors, hit_recs, yellow_metal_idx, Vec(0.8, 0.6, 0.2)) | |
scatter_rays_fuzz(hittable_list, rays, depth, colors, hit_recs, yellow_fuzz_idx, Vec(0.8, 0.6, 0.2)) | |
scatter_rays_dialectric(hittable_list, rays, depth, colors, hit_recs, glass_15_idx, Vec(1., .8, 1.), 0.8) | |
return colors | |
camera_origin = Vec(2, 0, 0).to(device) | |
camera_lookat = Vec(0, 0, -1).to(device) | |
camera_up = Vec(0, 1., 0.).to(device) | |
def paint(): | |
# # calculate camera basis | |
# focal_length = vec_norm(camera_lookat - camera_origin) | |
w = -vec_unit(camera_lookat - camera_origin) # new z | |
u = vec_unit(torch.cross(camera_up, w)) # new x | |
v = torch.cross(w, u) # this should give camera up I think? | |
render = torch.zeros(size=(height, width, 3)) | |
delta_plane_y = 1/height | |
delta_plane_x = 1/width * aspect_ratio | |
x_linspace, y_linspace = np.linspace(-.5, .5, width), np.linspace(-.5, .5, height) | |
plane_lower_lefts = np.array([*np.meshgrid(x_linspace, y_linspace), np.zeros(shape=(height,width))]).transpose([2,1,0]) | |
plane_lower_lefts[..., 0] *= aspect_ratio | |
plane_lower_lefts = np.repeat(plane_lower_lefts[None, ...], samples_per_pixel, axis=0) | |
plane_lower_lefts = torch.Tensor(plane_lower_lefts) | |
anti_alias_randomness = torch.FloatTensor(samples_per_pixel, width, height, 3).uniform_(0, 1) | |
anti_alias_randomness[..., 2] = 0. # z axis | |
anti_alias_randomness[..., 0] *= delta_plane_x # x axis | |
anti_alias_randomness[..., 1] *= delta_plane_y # y axis | |
plane_positions = (plane_lower_lefts + anti_alias_randomness).to(device) | |
rays = torch.zeros((samples_per_pixel, width, height, 2, 3)) | |
rays[..., 0, :] = camera_origin | |
rays[..., 1, :] = plane_positions - Vec(0, 0, 1) | |
# flatten rays to a 'list of rays' | |
rays = rays.reshape((samples_per_pixel*width*height, 2, 3)) | |
rays[..., 1, :] = (rays[..., 1, :] @ torch.vstack([u, v, w]).to(device)) | |
render = ray_color(hittable_list, rays) | |
render = render.reshape((samples_per_pixel, width, height, 3)).mean(axis=0) | |
gamma_correct(render) | |
render = render.numpy() | |
return render | |
start = time.time() | |
paint() | |
end = time.time() | |
print(f"Render Taken: {end-start:.4}s") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment