Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active May 30, 2024 15:12
Show Gist options
  • Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.
Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.
Permuting data to induce complex bivariate relations.
using StatsBase: sample, mean, cor
using LinearAlgebra: norm
using Plots, Random
"""
permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number = 1e-3, max_iter::Int = 10_000, max_search::Number = 100, verbose::Bool = true)
Permute y values to approximate a correlation between x and y of ρ.
# Arguments
- `x::Vector`: The vector of x values
- `y::Vector`: The vector of y values
- `rule::Function`: Function taking two numbers and outputting a single value
- `score::Real`: The target score, i.e., the target value of `sum(rule.(x, y))`
- `tol::Real`: The tolerance. If the loss `abs(current_score - score)` is below this value, stop the algorithm.
- `max_iter::Int`: Maximum number of iterations. For large datasets, this may need to be increased.
- `max_search::Int`: The number of iterations to search for improvements
- `verbose::Bool`: Whether to print debug information
"""
function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number=1e-3, max_iter::Int=10_000, max_search::Int=100, verbose::Bool=true)
N = length(x)
if N != length(y)
throw(ArgumentError("x and y should be the same length!"))
end
# current objective value
current_rule = rule.(x, y)
current_score = sum(current_rule)
current_loss = abs(score - current_score)
iter::Int = 0
search_iter::Int = 0
while iter < max_iter
# get random index
i = sample(1:N)
# compute change in score
delta_score = rule.(x, y[i]) .+ rule.(x[i], y) .- current_rule .- current_rule[i]
# only change if loss improves
new_loss, j = findmin(abs.(score .- (current_score .+ delta_score)))
if new_loss < current_loss
# Found option! make change
y[i], y[j] = y[j], y[i]
current_rule[i], current_rule[j] = rule(x[i], y[i]), rule(x[j], y[j])
current_score = sum(current_rule)
current_loss = abs(score - current_score)
if verbose
println("Iter $iter | loss $current_loss | $i ↔ $j | score $current_score | search $search_iter")
end
search_iter = 0
else
# increment counter
search_iter += 1
end
# stopping conditions
if search_iter >= max_search
if verbose
println("\nNo improvement found after $(iter-max_search) iterations.")
end
break
end
if current_loss < tol
if verbose
println("\nAchieved tolerance!")
end
break
end
# increment iterations
iter += 1
end
return nothing
end
function permutefun!(x::Vector, y::Vector, rule::Function; tol::Number=1e-3, max_iter::Int=10_000, max_search::Number=100, verbose::Bool=true)
permutefun!(x, y, rule, length(x); tol, max_iter, max_search, verbose)
end
function marginplot(x, y, title)
layout = @layout [
a _
b{0.8w,0.8h} c
]
xlim = (minimum(x), maximum(x))
ylim = (minimum(y), maximum(y))
default(fillcolor=:grey, markercolor=:grey, legend=false)
plt = plot(layout=layout, link=:none, size=(500, 500), margin=-10Plots.px, plot_title=title)
scatter!(x, y, subplot=2, xlim=xlim, ylim=ylim)
histogram!(x, nbins=30, subplot=1, orientation=:v, framestyle=:none, bottommargin=-20Plots.px, xlim=xlim)
histogram!(y, nbins=30, subplot=3, orientation=:h, framestyle=:none, leftmargin=-40Plots.px, ylim=ylim)
return plt
end
# Generate some data
N = 300
Random.seed!(45)
x = rand(N) .- 0.5
y = vcat(randn(Int(N / 2)) ./ 6 .- 0.25, randn(Int(N / 2)) ./ 8 .+ 0.35)
p1 = marginplot(x, y, "Original data")
# Enforce complex constraint
x1, y1 = copy(x), copy(y)
permutefun!(x1, y1, (xi, yi) -> (xi^4 < yi^2))
p2 = marginplot(x1, y1, "x⁴ < y²")
# induce a certain correlation
x2, y2 = copy(x), copy(y)
xm, ym = mean(x), mean(y)
permutefun!(x2, y2, (xi, yi) -> (xi - xm) * (yi - ym), 0.7 * norm(x .- xm) * norm(y .- ym))
cor(x2, y2)
p3 = marginplot(x2, y2, "Correlation = .7")
# make a hole in the data
x3, y3 = copy(x), copy(y)
permutefun!(x3, y3, (xi, yi) -> sqrt(xi^2 + yi^2) > 0.3)
p4 = marginplot(x3, y3, "Hole of radius 0.3")
plot(p1, p2, p3, p4, size=(2000, 500), layout=(1, 4), margin=10Plots.px)
savefig("permutefun.pdf")
@vankesteren
Copy link
Author

And with 30000 samples:

permutefun30k

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment