Created
May 8, 2021 02:41
-
-
Save maajor/34acc3cd9eed9cd563d1531ff079ff26 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# TressFX with taichi | |
# author: info@ma-yidong.com | |
# some code adopted from https://github.com/lyd405121/OpenClothPy | |
import taichi as ti | |
ti.init(arch=ti.gpu, kernel_profiler=True) | |
steps = 1 | |
# strand params | |
n_strand = 100 | |
n_strand_split = 32 | |
stiffness_local = 0.9 | |
stiffness_global = 0.005 | |
# global buffer | |
transform_root = ti.Matrix.field(4,4, float, n_strand) | |
pos = ti.Vector.field(3, float, (n_strand, n_strand_split)) | |
pos_prev = ti.Vector.field(3, float, (n_strand, n_strand_split)) | |
pos_rest = ti.Vector.field(3, float, (n_strand, n_strand_split)) | |
length_rest = ti.Vector.field(1, float, (n_strand, n_strand_split)) | |
time_elapsed = ti.field(float, (1)) | |
# other params | |
imgSize = 720 | |
img = ti.Vector.field(3, float, shape=[imgSize,imgSize]) | |
screenRes = ti.Vector([imgSize, imgSize]) | |
gravity = ti.Vector([0.0, -9.8, 0.0]) | |
deltaT = 0.0167 | |
@ti.func | |
def get_length2(v): | |
return ti.sqrt(v.x*v.x+ v.y*v.y) | |
@ti.func | |
def quat_normalize(q): | |
n = q.dot(q) | |
if n < 1e-10: | |
q.w = 1.0 | |
else: | |
q *= 1.0 / ti.sqrt(n) | |
return q | |
@ti.func | |
def quat_from_two_unit_vector(u, v): | |
r = 1.0 + u.dot(v) | |
n = ti.Vector([0.0,0.0,0.0]) | |
if r < 1e-7: | |
r = 0.0 | |
if ti.abs(u.x) > ti.abs(u.z): | |
n = ti.Vector([-u[1], u[0], 0.0]) | |
else: | |
n = ti.Vector([0.0, -u[2], u[1]]) | |
else: | |
n = u.cross(v) | |
q = ti.Vector([n[0], n[1], n[2], r]) | |
return quat_normalize(q) | |
@ti.func | |
def mul_quat_and_vector(q, v): | |
qvec = ti.Vector([q[0], q[1], q[2]]) | |
uv = qvec.cross(v) | |
uuv = qvec.cross(uv) | |
uv *= (2.0 * q[3]) | |
uuv *= 2.0 | |
return v + uv + uuv | |
@ti.func | |
def make_matrix_rotation_x(angle): | |
return ti.Matrix([ | |
[1,0,0,0], | |
[0,ti.cos(angle),ti.sin(angle),0], | |
[0,-ti.sin(angle),ti.cos(angle),0], | |
[0,0,0,1]]) | |
@ti.func | |
def make_matrix_translation(translation): | |
return ti.Matrix([ | |
[1,0,0,translation.x], | |
[0,1,0,translation.y], | |
[0,0,1,translation.z], | |
[0,0,0,1]]) | |
@ti.func | |
def make_homogeneous(vec): | |
return ti.Vector([vec.x, vec.y, vec.z, 1]) | |
@ti.func | |
def make_3d(vec): | |
return ti.Vector([vec.x, vec.y, vec.z]) | |
@ti.func | |
def fill_pixel(v, z, c): | |
if (v.x >= 0) and (v.x <screenRes.x) and (v.y >=0 ) and (v.y < screenRes.y): | |
img[v] = c | |
@ti.func | |
def transform(vec): | |
phi, theta = 90 * 3.14 / 180.0, 32 * 3.14 / 180.0 | |
vec = vec * 0.1 | |
x, y, z = vec.x-0.2, vec.y-0.3, vec.z | |
c, s = ti.cos(phi), ti.sin(phi) | |
C, S = ti.cos(theta), ti.sin(theta) | |
x, z = x * c + z * s, z * c - x * s | |
u, v = x, y * C + z * S | |
return ti.Vector([(u+0.5)* imgSize,(v+0.5)* imgSize, 0.5]) | |
#https://github.com/miloyip/line/blob/master/line_bresenham.c can be further optimized | |
@ti.func | |
def draw_line(v0,v1): | |
v0 = transform(v0) | |
v1 = transform(v1) | |
s0 = ti.Vector([ti.cast(v0.x, ti.i32), ti.cast(v0.y, ti.i32)]) | |
s1 = ti.Vector([ti.cast(v1.x, ti.i32), ti.cast(v1.y, ti.i32)]) | |
dis = get_length2(s1 - s0) | |
x0 = s0.x | |
y0 = s0.y | |
z0 = v0.z | |
x1 = s1.x | |
y1 = s1.y | |
z1 = v1.z | |
dx = abs(x1 - x0) | |
sx = -1 | |
if x0 < x1 : | |
sx = 1 | |
dy = abs(y1 - y0) | |
sy = -1 | |
if y1 > y0: | |
sy = 1 | |
dz = z1 - z0 | |
err = 0 | |
if dx > dy : | |
err = ti.cast(dx/2, ti.i32) | |
else : | |
err = ti.cast(-dy/2, ti.i32) | |
for i in range(0, 64): | |
distC = get_length2( ti.Vector([x1,y1])- ti.Vector([x0,y0])) | |
fill_pixel(ti.Vector([x0,y0]), dz * (distC / dis) + v0.z, ti.Vector([0.64, 0.804, 0.902])) | |
e2 = err | |
if (e2 > -dx): | |
err -= dy | |
x0 += sx | |
if (e2 < dy): | |
err += dx | |
y0 += sy | |
if (x0 == x1) and (y0 == y1): | |
break | |
@ti.kernel | |
def draw(): | |
for i,j in pos: | |
if j < n_strand_split-1: | |
draw_line(pos[i,j], pos[i,j+1]) | |
@ti.kernel | |
def clear(): | |
for i, j in img: | |
img[i,j] = ti.Vector([0.06,0.184,0.255]) | |
@ti.kernel | |
def drive_root(): | |
time_elapsed[0] += deltaT * 0.3 | |
center = ti.Vector([0,6,0]) | |
frac = ti.abs(time_elapsed[0] - ti.floor(time_elapsed[0])) | |
frac = ti.sin(frac * 2 * 3.1415) | |
frac *= 0.2 | |
for i in range(n_strand): | |
mat = make_matrix_translation(-center) @ make_matrix_rotation_x(frac) @ make_matrix_translation(center) | |
transform_root[i] = mat | |
@ti.kernel | |
def substep(): | |
for i,j in pos: | |
coord = ti.Vector([i, j]) | |
rest = pos_rest[coord] | |
# apply skinning | |
initial_pos = transform_root[i] @ ti.Vector([rest.x, rest.y, rest.z, 1]) | |
# gravity and integrate | |
if j > 0: | |
acc = gravity | |
tmp = pos[coord] | |
pos[coord] = (2*pos[coord] - pos_prev[coord]) + acc * deltaT * deltaT | |
pos_prev[coord] = tmp | |
else: # root | |
pos[coord] = ti.Vector([initial_pos.x, initial_pos.y, initial_pos.z]) | |
# global shape constraints | |
pos[coord] += stiffness_global * ( ti.Vector([initial_pos.x, initial_pos.y, initial_pos.z]) - pos[coord]) | |
# local shape constraints | |
for i in range(n_strand): | |
bone_mat = transform_root[i] | |
for j in range(1, n_strand_split-1): | |
bind_pos = make_3d(bone_mat @ make_homogeneous(pos_rest[i,j])) | |
bind_pos_before = make_3d(bone_mat @ make_homogeneous(pos_rest[i,j-1])) | |
bind_pos_after = make_3d(bone_mat @ make_homogeneous(pos_rest[i,j+1])) | |
vec_bind = bind_pos_after - bind_pos | |
vec_prv_bind = bind_pos - bind_pos_before | |
last_vec = pos_rest[i,j] - pos_rest[i,j-1] | |
rot_global = quat_from_two_unit_vector(vec_prv_bind.normalized(), last_vec.normalized()) | |
orgPos_i_plus_1_InGlobalFrame = mul_quat_and_vector(rot_global, vec_prv_bind) + pos[i,j] | |
dist = stiffness_global * (orgPos_i_plus_1_InGlobalFrame - pos[i,j+1]) | |
pos[i,j] -= dist | |
pos[i,j+1] += dist | |
# edge length constraint | |
for it in ti.static(range(1)): | |
for i in range(n_strand): | |
for j in range(0, n_strand_split-1): | |
delta = pos[i, j+1] - pos[i,j] | |
stretch = 1.0 - length_rest[i,j][0] / delta.norm() | |
delta *= stretch | |
if j == 0: | |
pos[i,j+1] -= delta | |
else: | |
pos[i,j] += delta * 0.5 | |
pos[i,j+1] -= delta * 0.5 | |
# collision to add | |
@ti.kernel | |
def init(): | |
# precompute rest-state values | |
strand_seg_len = 5.0 / n_strand_split | |
for i in range(n_strand): | |
base_pos = ti.Vector([ti.random() * 0.2, 5.0, ti.random() * 0.2]) | |
for j in range(n_strand_split): | |
phase_offset = ti.random() * 5 | |
local_offset = ti.Vector([ | |
j * base_pos.x * 0.2 + j * 0.02 * ti.cos(phase_offset + j/0.5), | |
-j * strand_seg_len, | |
j * base_pos.z * 0.2 + j * 0.02 * ti.sin(phase_offset + j/0.5)]) | |
pos[i, j] = base_pos + local_offset | |
pos_prev [i, j] = pos[i, j] | |
pos_rest[i, j] = pos[i, j] | |
length_rest[i,j] = ti.Vector([strand_seg_len]) | |
init() | |
gui = ti.GUI('TressFx Demo', res=(imgSize,imgSize)) | |
while gui.running and not gui.get_event(gui.ESCAPE): | |
drive_root() | |
for s in range(steps): | |
substep() | |
clear() | |
draw() | |
gui.set_image(img.to_numpy()) | |
gui.show() | |
ti.kernel_profiler_print() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment