Skip to content

Instantly share code, notes, and snippets.

@tnagler
Created June 29, 2022 15:30
Show Gist options
  • Save tnagler/43132f474deeced10b8cfb7ef88c0d8c to your computer and use it in GitHub Desktop.
Save tnagler/43132f474deeced10b8cfb7ef88c0d8c to your computer and use it in GitHub Desktop.

Problem statement from Alex Hayes

i have (x, y) pairs generated from two smooth curves, but i don't know which curve each point comes from

is there a way to recover the original curves?https://t.co/h9kL2JEeEV pic.twitter.com/tzy4O1VnrT

— alex hayes (@alexpghayes) June 28, 2022
library(tidyverse)

x <- seq(0, 10, length.out = 100)

unobserved <- tibble(
  x = x, 
  true_curve1 = sin(x),
  true_curve2 = tanh(x) - 0.5,
  coin = as.logical(rbinom(length(x), size = 1, prob = 0.5)),
  y1 = ifelse(coin, true_curve1, true_curve2),
  y2 = ifelse(coin, true_curve2, true_curve1)
)

observed <- unobserved |> 
  pivot_longer(
    c("y1", "y2"),
    values_to = "y"
  ) |> 
  select(x, y)

observed |> 
  ggplot(aes(x, y)) +
  geom_point() +
  labs(
    title = "How to identify which of the two curves each point belongs to?",
    subtitle = "100 data points have been generated from two smooth curves, but we only see (x, y) pairs"
  ) +
  theme_minimal()

Algorithm that minimizes curve roughness

Here, roughness means sum of squared second derivatives.

path_roughness <- function(x, y) {
  drv1 <- if (x[1] != x[2]) (y[, 2] - y[, 1]) / (x[2] - x[1]) else 0
  drv2 <- if (x[2] != x[3]) (y[, 3] - y[, 2]) / (x[3] - x[2]) else 0
  (drv2 - drv1)^2 / (x[3] - x[1])^2
}

assign_curves <- function(x, y) {
  x <- unname(unlist(x))
  y <- as.matrix(y)
  d <- ncol(y)  # number of curves
  
  # trick: first design point is replicated to allow for computing 2nd derivative
  # (which really is just 1st deriv there)
  curves <- matrix(NA, length(x) + 1, d + 1)
  curves[, 1] <- c(x[1], x)
  curves[1, -1] <- curves[2, -1] <- y[1, ]
  
  for (i in seq(2, nrow(y))) {
    # construct possible paths
    y_poss <- expand.grid(
      curves[i, -1], 
      y[i, ]
    ) 
    y_poss <- cbind(curves[i - 1, -1], as.matrix(y_poss))
    
    # roughness of each path serves as weight for the assignment problem
    roughness <- path_roughness(curves[i + seq(-1, 1), 1], y_poss)
    sol <- RcppHungarian::HungarianSolver(matrix(roughness, d, d))
    
    curves[i + 1, 1 + sol$pairs[, 1]] <- matrix(y_poss[, 3], d, d)[sol$pairs]
  }
  
  colnames(curves) <- c("x", paste0("y", 1:d))
  curves[-1, ]  # remove duplicated first design point
}

Results


# convert to wide format
observed <- observed |> 
  group_by(x) |> 
  mutate(id = paste0("V", seq_along(y))) |> 
  pivot_wider(values_from = y, names_from = id)

# run algorithm and plot
res <- assign_curves(observed[, 1], observed[, -1])
res |>
  as_tibble() |>
  pivot_longer(-x) |>
  ggplot(aes(x, value, color = name)) +
  geom_point() +
  theme_minimal()

@tnagler
Copy link
Author

tnagler commented Jun 29, 2022

Results:

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment