Skip to content

Instantly share code, notes, and snippets.

@DarioSucic
Created January 5, 2022 21:33
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save DarioSucic/c1fffe1b9607836292ae3d234e884bf9 to your computer and use it in GitHub Desktop.
Save DarioSucic/c1fffe1b9607836292ae3d234e884bf9 to your computer and use it in GitHub Desktop.
Example of machine learning and differentiable programming with Flux + Zygote
using Flux
using Flux.Losses
using Flux.Data: DataLoader
using Zygote
using Zygote: Buffer, @ignore
using LinearAlgebra
using StaticArrays
using Base.Threads: @spawn
using TensorCast
using Colors
using Makie
using GLMakie
using FileIO
GLMakie.activate!()
set_window_config!(framerate=120)
##
const steps = 1024
const Δt = 0.003
const η = 0.05
const elasticity = 0.67
##
struct Setup
billiard_layers::Int
n_balls::Int
target_ball::Int
goal::Vector{Float32}
radius::Float32
end
Setup(billiard_layers, target_ball, goal, radius) =
Setup(billiard_layers, 1 + (1 + billiard_layers) * billiard_layers ÷ 2, target_ball, goal, radius)
##
function create_grid(s, init_x)
n_balls = 1 + (1 + s.billiard_layers) * s.billiard_layers ÷ 2
x = Buffer(zeros(Float32, 2, n_balls), false)
count, dist = 1, 1
for row in 1:s.billiard_layers
x0 = (row - s.billiard_layers - 1) * s.radius
for j in 0:s.billiard_layers - row
count += 1
x[1, count] = x0 + (2j + 1) * s.radius
x[2, count] = dist + (s.billiard_layers - row) * √3 * s.radius + s.radius - s.radius*0.25*row
end
end
x[:, 1] = init_x
copy(x)
end
function create_scene(s)
x = create_grid(s, [0.0, 0.0])
scene = Scene(
resolution=(960, 640),
show_axis=false,
limits=FRect3D(Vec3f0(-1), Vec3f0(2))
)
scene[:SSAO][:radius][] = 1.0
scene[:SSAO][:blur][] = 2
scene[:SSAO][:bias][] = 0.01
cam3d!(scene, fov=25, eyeposition=Vec3f0(0, 0, 0.75), lookat=Vec3f0(-0.5, 1, 0.05), rotationspeed=0.025)
lightpos = Vec3f0(0, 0, 0.5)
plane_vertices = [-1 -1 -s.radius; +1 -1 -s.radius; +1 +2 -s.radius; -1 +2 -s.radius;]
plane_faces = [1 2 3; 3 4 1]
mesh!(scene, plane_vertices, plane_faces, color=:green, shading=false, lightposition=lightpos, ssao=true)
xs = x[1, :];
ys = x[2, :]
push!(xs, s.goal[1]);
push!(ys, s.goal[2])
xs = Node(xs); ys = Node(ys)
ball_color(i) = begin
if i == 1 colorant"#FAFAFA"
elseif i == s.target_ball colorant"#F0A220"
elseif i == s.n_balls + 1 colorant"#BABAFF"
else colorant"#EF5050" end
end
meshscatter!(scene, xs, ys, color=ball_color.(1:s.n_balls + 1), markersize=s.radius, lightposition=lightpos, ssao=true)
scene, xs, ys
end
function advance(s::Setup, x, v, x_inc, impulse)
v = v + impulse
x = x + Δt * v + x_inc
x, v
end
function collide(s::Setup, x, v)
@cast delta_pos[i, j, k] := x[i, j] - x[i, k] + Δt * (v[i, j] - v[i, k])
@cast dists[j, k] := norm(delta_pos[:, j, k])
mask = dists .< 2 * s.radius
mask2 = ones(s.n_balls, s.n_balls) - I
@cast dirs[i, j, k] := delta_pos[i, j, k] / (dists[j, k] + 1e-6)
@cast rela_v[i, j, k] := v[i, j] - v[i, k]
@cast projected_vs[j, k] := dirs[:, j, k] ⋅ rela_v[:, j, k]
projected_mask = projected_vs .< 0.0
@cast imp_contrib[i, j, k] := -(1 + elasticity) * 0.5 * projected_vs[j, k] * dirs[i, j, k]
@cast tois[j, k] := (dists[j, k] - 2 * s.radius) / min(-1e-6, projected_vs[j, k])
@cast x_inc_contrib[i, j, k] := min(tois[j, k] - Δt, 0) * imp_contrib[i, j, k]
@reduce x_inc[i, j] := sum(k) x_inc_contrib[i, j, k] * mask[j, k] * mask2[j, k] * projected_mask[j, k]
@reduce impulse[i, j] := sum(k) imp_contrib[i, j, k] * mask[j, k] * mask2[j, k] * projected_mask[j, k]
x_inc, impulse
end
# @code_warntype collide(s, [0.0f0, 0.0f0], [0.0f0, 0.0f0])
##
function forward(s::Setup, init_x, init_v)
x = create_grid(s, init_x)
v = [init_v zeros(2, s.n_balls-1)]
x_hist = zeros(Float32, 2, s.n_balls, steps)
@ignore x_hist[:, :, 1] .= x
for t in 2:steps
x_inc, impulse = collide(s, x, v)
x, v = advance(s, x, v, x_inc, impulse)
@ignore x_hist[:, :, t] .= x
end
x, x_hist
end
function optimize(s::Setup)
states = Vector{Array{Float32, 3}}()
init_x = [0.0, 0.0]
init_v = [0.0, 1.0]
opt = ADAM(η)
# opt = Descent(0.01)
θ = params(init_x, init_v)
goal = s.goal
prev_loss = -1
println("Training loop...")
for i in 1:10
local loss, x_hist
gs = gradient(θ) do
x, x_hist = forward(s, init_x, init_v)
# loss = norm(x[:, s.target_ball] - goal)^2
loss = norm(x[:, s.target_ball] - x[:, 1])
end
if loss == prev_loss
println("Loss unchanged. Breaking early.")
break
end
prev_loss = loss
println("\t$i :: loss = $loss \t x_0=$init_x v_0=$init_v")
push!(states, x_hist)
Flux.update!(opt, θ, gs)
end
states
end
##
s = Setup(5, 3, [-0.2, 1.5], 0.05)
scene, xs, ys = create_scene(s)
scene |> display
##
@time states = optimize(s);
##
for (iteration, state) in enumerate(states)
println("Iteration: $iteration")
xs[] = push!(state[1, :, 1], s.goal[1])
ys[] = push!(state[2, :, 1], s.goal[2])
sleep(0.1)
for i in 2:16:steps
xs[] = push!(state[1, :, i], s.goal[1])
ys[] = push!(state[2, :, i], s.goal[2])
sleep(1/120)
end
end
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment