Created
January 5, 2022 21:33
-
-
Save DarioSucic/c1fffe1b9607836292ae3d234e884bf9 to your computer and use it in GitHub Desktop.
Example of machine learning and differentiable programming with Flux + Zygote
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
using Flux | |
using Flux.Losses | |
using Flux.Data: DataLoader | |
using Zygote | |
using Zygote: Buffer, @ignore | |
using LinearAlgebra | |
using StaticArrays | |
using Base.Threads: @spawn | |
using TensorCast | |
using Colors | |
using Makie | |
using GLMakie | |
using FileIO | |
GLMakie.activate!() | |
set_window_config!(framerate=120) | |
## | |
const steps = 1024 | |
const Δt = 0.003 | |
const η = 0.05 | |
const elasticity = 0.67 | |
## | |
struct Setup | |
billiard_layers::Int | |
n_balls::Int | |
target_ball::Int | |
goal::Vector{Float32} | |
radius::Float32 | |
end | |
Setup(billiard_layers, target_ball, goal, radius) = | |
Setup(billiard_layers, 1 + (1 + billiard_layers) * billiard_layers ÷ 2, target_ball, goal, radius) | |
## | |
function create_grid(s, init_x) | |
n_balls = 1 + (1 + s.billiard_layers) * s.billiard_layers ÷ 2 | |
x = Buffer(zeros(Float32, 2, n_balls), false) | |
count, dist = 1, 1 | |
for row in 1:s.billiard_layers | |
x0 = (row - s.billiard_layers - 1) * s.radius | |
for j in 0:s.billiard_layers - row | |
count += 1 | |
x[1, count] = x0 + (2j + 1) * s.radius | |
x[2, count] = dist + (s.billiard_layers - row) * √3 * s.radius + s.radius - s.radius*0.25*row | |
end | |
end | |
x[:, 1] = init_x | |
copy(x) | |
end | |
function create_scene(s) | |
x = create_grid(s, [0.0, 0.0]) | |
scene = Scene( | |
resolution=(960, 640), | |
show_axis=false, | |
limits=FRect3D(Vec3f0(-1), Vec3f0(2)) | |
) | |
scene[:SSAO][:radius][] = 1.0 | |
scene[:SSAO][:blur][] = 2 | |
scene[:SSAO][:bias][] = 0.01 | |
cam3d!(scene, fov=25, eyeposition=Vec3f0(0, 0, 0.75), lookat=Vec3f0(-0.5, 1, 0.05), rotationspeed=0.025) | |
lightpos = Vec3f0(0, 0, 0.5) | |
plane_vertices = [-1 -1 -s.radius; +1 -1 -s.radius; +1 +2 -s.radius; -1 +2 -s.radius;] | |
plane_faces = [1 2 3; 3 4 1] | |
mesh!(scene, plane_vertices, plane_faces, color=:green, shading=false, lightposition=lightpos, ssao=true) | |
xs = x[1, :]; | |
ys = x[2, :] | |
push!(xs, s.goal[1]); | |
push!(ys, s.goal[2]) | |
xs = Node(xs); ys = Node(ys) | |
ball_color(i) = begin | |
if i == 1 colorant"#FAFAFA" | |
elseif i == s.target_ball colorant"#F0A220" | |
elseif i == s.n_balls + 1 colorant"#BABAFF" | |
else colorant"#EF5050" end | |
end | |
meshscatter!(scene, xs, ys, color=ball_color.(1:s.n_balls + 1), markersize=s.radius, lightposition=lightpos, ssao=true) | |
scene, xs, ys | |
end | |
function advance(s::Setup, x, v, x_inc, impulse) | |
v = v + impulse | |
x = x + Δt * v + x_inc | |
x, v | |
end | |
function collide(s::Setup, x, v) | |
@cast delta_pos[i, j, k] := x[i, j] - x[i, k] + Δt * (v[i, j] - v[i, k]) | |
@cast dists[j, k] := norm(delta_pos[:, j, k]) | |
mask = dists .< 2 * s.radius | |
mask2 = ones(s.n_balls, s.n_balls) - I | |
@cast dirs[i, j, k] := delta_pos[i, j, k] / (dists[j, k] + 1e-6) | |
@cast rela_v[i, j, k] := v[i, j] - v[i, k] | |
@cast projected_vs[j, k] := dirs[:, j, k] ⋅ rela_v[:, j, k] | |
projected_mask = projected_vs .< 0.0 | |
@cast imp_contrib[i, j, k] := -(1 + elasticity) * 0.5 * projected_vs[j, k] * dirs[i, j, k] | |
@cast tois[j, k] := (dists[j, k] - 2 * s.radius) / min(-1e-6, projected_vs[j, k]) | |
@cast x_inc_contrib[i, j, k] := min(tois[j, k] - Δt, 0) * imp_contrib[i, j, k] | |
@reduce x_inc[i, j] := sum(k) x_inc_contrib[i, j, k] * mask[j, k] * mask2[j, k] * projected_mask[j, k] | |
@reduce impulse[i, j] := sum(k) imp_contrib[i, j, k] * mask[j, k] * mask2[j, k] * projected_mask[j, k] | |
x_inc, impulse | |
end | |
# @code_warntype collide(s, [0.0f0, 0.0f0], [0.0f0, 0.0f0]) | |
## | |
function forward(s::Setup, init_x, init_v) | |
x = create_grid(s, init_x) | |
v = [init_v zeros(2, s.n_balls-1)] | |
x_hist = zeros(Float32, 2, s.n_balls, steps) | |
@ignore x_hist[:, :, 1] .= x | |
for t in 2:steps | |
x_inc, impulse = collide(s, x, v) | |
x, v = advance(s, x, v, x_inc, impulse) | |
@ignore x_hist[:, :, t] .= x | |
end | |
x, x_hist | |
end | |
function optimize(s::Setup) | |
states = Vector{Array{Float32, 3}}() | |
init_x = [0.0, 0.0] | |
init_v = [0.0, 1.0] | |
opt = ADAM(η) | |
# opt = Descent(0.01) | |
θ = params(init_x, init_v) | |
goal = s.goal | |
prev_loss = -1 | |
println("Training loop...") | |
for i in 1:10 | |
local loss, x_hist | |
gs = gradient(θ) do | |
x, x_hist = forward(s, init_x, init_v) | |
# loss = norm(x[:, s.target_ball] - goal)^2 | |
loss = norm(x[:, s.target_ball] - x[:, 1]) | |
end | |
if loss == prev_loss | |
println("Loss unchanged. Breaking early.") | |
break | |
end | |
prev_loss = loss | |
println("\t$i :: loss = $loss \t x_0=$init_x v_0=$init_v") | |
push!(states, x_hist) | |
Flux.update!(opt, θ, gs) | |
end | |
states | |
end | |
## | |
s = Setup(5, 3, [-0.2, 1.5], 0.05) | |
scene, xs, ys = create_scene(s) | |
scene |> display | |
## | |
@time states = optimize(s); | |
## | |
for (iteration, state) in enumerate(states) | |
println("Iteration: $iteration") | |
xs[] = push!(state[1, :, 1], s.goal[1]) | |
ys[] = push!(state[2, :, 1], s.goal[2]) | |
sleep(0.1) | |
for i in 2:16:steps | |
xs[] = push!(state[1, :, i], s.goal[1]) | |
ys[] = push!(state[2, :, i], s.goal[2]) | |
sleep(1/120) | |
end | |
end |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment