Skip to content

Instantly share code, notes, and snippets.

@al2o3cr
Created June 3, 2024 18:39
Show Gist options
  • Save al2o3cr/f3eae95681b0a5d9498f0fa2a1b7de01 to your computer and use it in GitHub Desktop.
Save al2o3cr/f3eae95681b0a5d9498f0fa2a1b7de01 to your computer and use it in GitHub Desktop.
Splines in Livebook

Splines

Mix.install([
  {:vega_lite, "~> 0.1.9"},
  {:kino_vega_lite, "~> 0.1.8"},
  {:nx, "~> 0.5"}
])

alias VegaLite, as: Vl

Spline equations

article

Inputs

$$ t_i, i=0..n\newline y_i, i=0..n\newline $$

Initial Calculations

$$ h_i = t_{i+1} - t_i\newline i=0..n-1 $$

$$ S_i(x) = \frac{z_{i+1}}{6h_i}(x-t_i)^3 + \frac{z_i}{6h_i}(t_{i+1}-x)^3 +\left(\frac{y_{i+1}}{h_i} - \frac{h_i z_{i+1}}{6}\right)(x-t_i) +\left(\frac{y_i}{h_i} - \frac{h_i z_i}{6}\right)(t_{i+1}-x) $$

By construction, this is an interpolating polynomial: $$ S_i(t_i) = 0 + \frac{z_i h_i^2}{6} + 0 + \left(y_i - \frac{z_i h_i^2}{6}\right) = y_i \newline S_i(t_{i+1}) = \frac{z_{i+1} h_i^2}{6} + 0 + \left( y_{i+1} - \frac{z_{i+1} h_i^2}{6}\right) + 0 = y_{i+1} $$

The problem is to find $z_i$ for $i=0..n$

Derivatives

$$ \begin{equation*} \begin{split} S'i(x) &= \frac{z{i+1}}{2h_i}(x-t_i)^2 - \frac{z_i}{2h_i}(t_{i+1}-x)^2 + \left(\frac{y_{i+1}}{h_i} - \frac{h_i z_{i+1}}{6}\right) - \left(\frac{y_i}{h_i} - \frac{h_i z_i}{6}\right) \ &= \frac{z_{i+1}}{2h_i}(x-t_i)^2 - \frac{z_i}{2h_i}(t_{i+1}-x)^2 + \left(\frac{y_{i+1} - y_i}{h_i}\right) - \frac{h_i}{6}\left(z_{i+1}- z_i\right) \ &= \left( \frac{(x-t_i)^2}{2h_i} - \frac{h_i}{6}\right)z_{i+1} - \left( \frac{(t_{i+1}-x)^2}{2h_i} - \frac{h_i}{6}\right)z_i + \frac{y_{i+1} - y_i}{h_i} \end{split} \end{equation*} $$

The derivatives need to match at the endpoints.

$$ S''i(x) = \frac{x-t_i}{h_i}z{i+1}+\frac{t_{i+1}-x}{h_i}z_i $$ The $z_i$ are the values of the second derivative at each $t_i$:

$$ \begin{equation*} \begin{split} S''i(t_i) &= z_i \ S''i(t{i+1}) &= z{i+1} \ S''{i+1}(t{i+1}) &= z_{i+1} \end{split} \end{equation*} $$

$$ \begin{equation*} \begin{split} S'i(t_i) &= -\frac{h_i}{6}z{i+1} - \frac{h_i}{3}z_i + \frac{y_{i+1} - y_i}{h_i} \ S'i(t{i+1}) &= \frac{h_i}{3}z_{i+1} + \frac{h_i}{6}z_i + \frac{y_{i+1} - y_i}{h_i} \ S'{i+1}(t{i+1}) &= - \frac{h_{i+1}}{6}z_{i+2} - \frac{h_{i+1}}{3}z_{i+1} + \frac{y_{i+2} - y_{i+1}}{h_{i+1}} \end{split} \end{equation*} $$

For a smooth spline, the two derivatives must match at each interior point. $$ \begin{equation*} \begin{split} S'i(t{i+1}) - S'{i+1}(t{i+1}) &= \frac{h_i}{3}z_{i+1} + \frac{h_i}{6}z_i + \frac{y_{i+1} - y_i}{h_i} + \frac{h_{i+1}}{6}z_{i+2} + \frac{h_{i+1}}{3}z_{i+1} - \frac{y_{i+2} - y_{i+1}}{h_{i+1}} \ &=\frac{h_{i+1}}{6}z_{i+2} + \frac{h_i + h_{i+1}}{3}z_{i+1} + \frac{h_i}{6}z_i - \left( \frac{y_{i+2} - y_{i+1}}{h_{i+1}} - \frac{y_{i+1} - y_i}{h_i}\right) \ &= 0 \end{split} \end{equation*} $$

This equation works for $i=0..n-2$

$$ h_iz_i + 2(h_i + h_{i+1})z_{i+1} + h_{i+1}z_{i+2} = 6\left( \frac{y_{i+2} - y_{i+1}}{h_{i+1}} - \frac{y_{i+1} - y_i}{h_i}\right) $$

That's $n-1$ equations in $n+1$ unknowns, so two additional equations need to be specified.

The necessary equations are derived from the boundary conditions:

  • "clamped" boundary: the first derivative is given at an endpoint
  • "natural" boundary: the second derivative is specified (usually zero)
  • "periodic" boundary: the derivatives at the endpoints must match

"Natural" boundary conditions are the simplest - they explicitly supply a value for $z_0$ or $z_n$. $$ \begin{equation*} \begin{split} z_0 &= Z_0 \ z_N &= Z_n \end{split} \end{equation*} $$ "Clamped" boundary conditions use the derivatives at the far ends of the interval: $$ \begin{equation*} \begin{split} S'0(t_0) &= -\frac{h_0}{6}z_1 - \frac{h_0}{3}z_0 + \frac{y_1 - y_0}{h_0} = C_0 \ S'{n-1}(t_n) &= \frac{h_{n-1}}{3}z_n + \frac{h_{n-1}}{6}z_{n-1} + \frac{y_n - y_{n-1}}{h_{n-1}} = C_n \end{split} \end{equation*} $$ resulting in equations: $$ \begin{equation*} \begin{split} 2h_0z_0 + h_0z_1 &= 6\left(\frac{y_1 - y_0}{h_0} - C_0\right) \ h_{n-1}z_{n-1} + 2h_{n-1}z_n &= 6\left(C_n - \frac{y_n - y_{n-1}}{h_{n-1}}\right) \end{split} \end{equation*} $$

Periodic boundary conditions combine the two first derivatives: $$ 2h_0z_0 + h_0z_1 + h_{n-1}z_{n-1} + 2h_{n-1}z_n = 6\left(\frac{y_1 - y_0}{h_0} - \frac{y_n - y_{n-1}}{h_{n-1}}\right) $$ and the two second derivatives at the ends: $$ z_0 - z_n = 0 $$

Systems of equations

Natural boundary conditions

$$ \begin{bmatrix} 1 & 0 & 0 & ... \ h_0 & 2(h_0+h_1) & h_1 & ... \ 0 & h_1 & 2(h_1+h_2) & h_2 \ & & ...\ & & ... & h_{n-2} & 2(h_{n-2}+h_{n-1}) & h_{n-1} \ & & & & 0 & 1 \end{bmatrix} \begin{bmatrix} z_0\ z_1\ z_2\ ...\ z_{n-1}\ z_{n} \end{bmatrix}

\begin{bmatrix} Z_0\ 6\left( \frac{y_2 - y_1}{h_1} - \frac{y_1 - y_0}{h_0}\right)\ 6\left( \frac{y_3 - y_2}{h_2} - \frac{y_2 - y_1}{h_1}\right)\ ...\ 6\left( \frac{y_n - y_{n-1}}{h_{n-1}} - \frac{y_{n-1} - y_{n-2}}{h_{n-2}}\right)\ Z_n \end{bmatrix} $$

Clamped boundary conditions

$$ \begin{bmatrix} 2h_0 & h_0 & 0 & ... \ h_0 & 2(h_0+h_1) & h_1 & ... \ 0 & h_1 & 2(h_1+h_2) & h_2 \ & & ...\ & & ... & h_{n-2} & 2(h_{n-2}+h_{n-1}) & h_{n-1} \ & & & & h_{n-1} & 2h_{n-1} \end{bmatrix} \begin{bmatrix} z_0\ z_1\ z_2\ ...\ z_{n-1}\ z_{n} \end{bmatrix}

\begin{bmatrix} 6\left(\frac{y*1 - y_0}{h_0} - C_0\right)\ 6\left( \frac{y_2 - y_1}{h_1} - \frac{y_1 - y_0}{h_0}\right)\ 6\left( \frac{y_3 - y_2}{h_2} - \frac{y_2 - y_1}{h_1}\right)\ ...\ 6\left( \frac{y_n - y_{n-1}}{h_{n-1}} - \frac{y_{n-1} - y_{n-2}}{h_{n-2}}\right)\ 6\left(C_n - \frac{y_n - y_{n-1}}{h_{n-1}}\right) \end{bmatrix} $$

Periodic boundary conditions

$$ \begin{bmatrix} 2h_0 & h_0 & ... & ... & h_{n-1} & 2h_{n-1} \ h_0 & 2(h_0+h_1) & h_1 & ... \ 0 & h_1 & 2(h_1+h_2) & h_2 \ & & ...\ & & ... & h_{n-2} & 2(h_{n-2}+h_{n-1}) & h_{n-1} \ 1 & & & & & -1 \end{bmatrix} \begin{bmatrix} z_0\ z_1\ z_2\ ...\ z_{n-1}\ z_{n} \end{bmatrix}

\begin{bmatrix} 6\left(\frac{y_1 - y_0}{h_0} - \frac{y_n - y_{n-1}}{h_{n-1}}\right)\ 6\left( \frac{y_2 - y_1}{h_1} - \frac{y_1 - y_0}{h_0}\right)\ 6\left( \frac{y_3 - y_2}{h_2} - \frac{y_2 - y_1}{h_1}\right)\ ...\ 6\left( \frac{y_n - y_{n-1}}{h_{n-1}} - \frac{y_{n-1} - y_{n-2}}{h_{n-2}}\right)\ 0 \end{bmatrix} $$

defmodule Helpers do
  def pairwise_diff(t) do
    Nx.subtract(t[1..-1//1], t[0..-2//1])
  end

  def pairwise_sum(t) do
    Nx.add(t[1..-1//1], t[0..-2//1])
  end

  def offset_matrix(t, offset) do
    Nx.equal(
      Nx.iota(Nx.shape(t), axis: 0),
      Nx.subtract(
        Nx.iota(Nx.shape(t), axis: 1),
        offset
      )
    )
  end

  def diag(t, offset) do
    all_elems =
      if offset < 0 do
        Nx.outer(Nx.broadcast(1, t), t)
      else
        Nx.outer(t, Nx.broadcast(1, t))
      end

    Nx.select(offset_matrix(all_elems, offset), all_elems, 0)
  end

  def eval_poly(x, {t, t_next, y, y_next, z, z_next}) do
    dx = x - t
    dx_next = t_next - x
    dx3 = dx * dx * dx
    dx3_next = dx_next * dx_next * dx_next

    h = t_next - t

    dx3 * z_next / (6 * h) +
      dx3_next * z / (6 * h) +
      dx * (y_next / h - h * z_next / 6) +
      dx_next * (y / h - h * z / 6)
  end

  defp eval_low_end(x, {t, t_next, y, y_next, z, z_next}) do
    h = t_next - t
    slope = -h * z_next / 6 - h * z / 3 + (y_next - y) / h
    slope * (x - t) + y
  end

  defp eval_high_end(x, {t, t_next, y, y_next, z, z_next}) do
    h = t_next - t
    slope = h * z_next / 3 + h * z / 6 + (y_next - y) / h
    slope * (x - t_next) + y_next
  end

  def eval(x, xs, ys, zs) do
    comp = Nx.greater(xs, Nx.broadcast(x, xs))
    comp_0 = Nx.to_number(comp[0])
    last_zero = Nx.to_number(Nx.argmin(comp, tie_break: :high))
    first_one = Nx.to_number(Nx.argmax(comp, tie_break: :low))

    case {last_zero, first_one, comp_0} do
      {idx, idx_next, _} when idx_next == idx + 1 ->
        eval_poly(x, args_at(idx, xs, ys, zs))

      {_, _, 0} ->
        eval_high_end(x, args_at(-2, xs, ys, zs))

      {_, _, 1} ->
        eval_low_end(x, args_at(0, xs, ys, zs))
    end
  end

  defp args_at(idx, xs, ys, zs) do
    {
      Nx.to_number(xs[idx]),
      Nx.to_number(xs[idx + 1]),
      Nx.to_number(ys[idx]),
      Nx.to_number(ys[idx + 1]),
      Nx.to_number(zs[idx]),
      Nx.to_number(zs[idx + 1])
    }
  end
end
x = -1.2
vals = Nx.tensor([0.0, 1.0, 2.0, 3.0])
comp = Nx.greater(vals, Nx.broadcast(x, vals)) |> IO.inspect(label: "comp")
last_zero = Nx.argmin(comp, tie_break: :high)
first_one = Nx.argmax(comp, tie_break: :low)
{last_zero, first_one}
Helpers.diag(Nx.tensor([1, 2, 3, 4]), -1)
import Helpers

points = [{0, 0}, {1, 4}, {3, 2}, {3.5, 0}, {4.2, -3.1}, {6, 12}, {6.5, 3}, {7, 12}, {10, 0}]
# points = [{0, 0}, {1, 0.5}, {2, 2}, {3, 1.5}]
xs = Enum.map(points, fn {x, _} -> x end) |> Nx.tensor()
ys = Enum.map(points, fn {_, y} -> y end) |> Nx.tensor()

hs = pairwise_diff(xs)
both_hs = pairwise_sum(hs)

delta_ys = pairwise_diff(ys)
bs = Nx.divide(delta_ys, hs)
delta_bs = pairwise_diff(bs)
# Natural boundary conditions

z_0 = 0.0
z_n = 0.0

n_super_diagonal = Nx.concatenate([Nx.tensor([0]), hs[1..-1//1], Nx.tensor([0])])
n_diagonal = Nx.concatenate([Nx.tensor([1]), Nx.multiply(both_hs, 2), Nx.tensor([1])])
n_sub_diagonal = Nx.concatenate([hs[0..-2//1], Nx.tensor([0, 0])])

natural_system =
  diag(n_diagonal, 0)
  |> Nx.add(diag(n_super_diagonal, 1))
  |> Nx.add(diag(n_sub_diagonal, -1))

IO.inspect(natural_system, label: "natural system")

n_constant =
  Nx.concatenate([Nx.tensor([z_0]), Nx.multiply(delta_bs, 6), Nx.tensor([z_n])])
  |> IO.inspect(label: "natural constant")

natural_zs = Nx.LinAlg.solve(natural_system, n_constant)
# Clamped boundary conditions

c_0 = 0.0
c_n = 0.0

c_super_diagonal = Nx.concatenate([hs, Nx.tensor([0])])
c_diagonal = Nx.multiply(Nx.concatenate([hs[0..0], both_hs, hs[-1..-1]]), 2)
c_sub_diagonal = Nx.concatenate([hs, Nx.tensor([0])])

clamped_system =
  diag(c_diagonal, 0)
  |> Nx.add(diag(c_super_diagonal, 1))
  |> Nx.add(diag(c_sub_diagonal, -1))

IO.inspect(clamped_system, label: "clamped system")

c_constant =
  Nx.multiply(
    Nx.concatenate([
      Nx.subtract(bs[0..0], Nx.tensor([c_0])),
      delta_bs,
      Nx.subtract(Nx.tensor([c_n]), bs[-1..-1])
    ]),
    6
  )
  |> IO.inspect(label: "clamped constant")

clamped_zs = Nx.LinAlg.solve(clamped_system, c_constant)
dx = 0.05

plot_xs =
  Stream.iterate(Nx.to_number(xs[0]) - 0.5, fn x -> x + dx end)
  |> Stream.take_while(fn x -> x <= Nx.to_number(xs[-1]) + 0.5 end)
  |> Enum.to_list()

clamped_plot_ys = Enum.map(plot_xs, fn x -> eval(x, xs, ys, clamped_zs) end)
natural_plot_ys = Enum.map(plot_xs, fn x -> eval(x, xs, ys, natural_zs) end)

Vl.new(width: 600, height: 450)
|> Vl.layers([
  Vl.new()
  |> Vl.data_from_values(x: plot_xs, y: clamped_plot_ys)
  |> Vl.mark(:point)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative),
  Vl.new()
  |> Vl.data_from_values(x: plot_xs, y: natural_plot_ys)
  |> Vl.mark(:square, color: "red")
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative),
  Vl.new()
  |> Vl.data_from_values(x: Nx.to_flat_list(xs), y: Nx.to_flat_list(ys))
  |> Vl.mark(:point, shape: "triangle-up", color: "green", size: 200)
  |> Vl.encode_field(:x, "x", type: :quantitative)
  |> Vl.encode_field(:y, "y", type: :quantitative)
])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment