Skip to content

Instantly share code, notes, and snippets.

@vankesteren
Last active May 30, 2024 15:12
Show Gist options
  • Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.
Save vankesteren/6c141f7cabcd3eb47292d78cfca1804d to your computer and use it in GitHub Desktop.
Permuting data to induce complex bivariate relations.
using StatsBase: sample, mean, cor
using LinearAlgebra: norm
using Plots, Random
"""
permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number = 1e-3, max_iter::Int = 10_000, max_search::Number = 100, verbose::Bool = true)
Permute y values to approximate a correlation between x and y of ρ.
# Arguments
- `x::Vector`: The vector of x values
- `y::Vector`: The vector of y values
- `rule::Function`: Function taking two numbers and outputting a single value
- `score::Real`: The target score, i.e., the target value of `sum(rule.(x, y))`
- `tol::Real`: The tolerance. If the loss `abs(current_score - score)` is below this value, stop the algorithm.
- `max_iter::Int`: Maximum number of iterations. For large datasets, this may need to be increased.
- `max_search::Int`: The number of iterations to search for improvements
- `verbose::Bool`: Whether to print debug information
"""
function permutefun!(x::Vector, y::Vector, rule::Function, score::Real; tol::Number=1e-3, max_iter::Int=10_000, max_search::Int=100, verbose::Bool=true)
N = length(x)
if N != length(y)
throw(ArgumentError("x and y should be the same length!"))
end
# current objective value
current_rule = rule.(x, y)
current_score = sum(current_rule)
current_loss = abs(score - current_score)
iter::Int = 0
search_iter::Int = 0
while iter < max_iter
# get random index
i = sample(1:N)
# compute change in score
delta_score = rule.(x, y[i]) .+ rule.(x[i], y) .- current_rule .- current_rule[i]
# only change if loss improves
new_loss, j = findmin(abs.(score .- (current_score .+ delta_score)))
if new_loss < current_loss
# Found option! make change
y[i], y[j] = y[j], y[i]
current_rule[i], current_rule[j] = rule(x[i], y[i]), rule(x[j], y[j])
current_score = sum(current_rule)
current_loss = abs(score - current_score)
if verbose
println("Iter $iter | loss $current_loss | $i ↔ $j | score $current_score | search $search_iter")
end
search_iter = 0
else
# increment counter
search_iter += 1
end
# stopping conditions
if search_iter >= max_search
if verbose
println("\nNo improvement found after $(iter-max_search) iterations.")
end
break
end
if current_loss < tol
if verbose
println("\nAchieved tolerance!")
end
break
end
# increment iterations
iter += 1
end
return nothing
end
function permutefun!(x::Vector, y::Vector, rule::Function; tol::Number=1e-3, max_iter::Int=10_000, max_search::Number=100, verbose::Bool=true)
permutefun!(x, y, rule, length(x); tol, max_iter, max_search, verbose)
end
function marginplot(x, y, title)
layout = @layout [
a _
b{0.8w,0.8h} c
]
xlim = (minimum(x), maximum(x))
ylim = (minimum(y), maximum(y))
default(fillcolor=:grey, markercolor=:grey, legend=false)
plt = plot(layout=layout, link=:none, size=(500, 500), margin=-10Plots.px, plot_title=title)
scatter!(x, y, subplot=2, xlim=xlim, ylim=ylim)
histogram!(x, nbins=30, subplot=1, orientation=:v, framestyle=:none, bottommargin=-20Plots.px, xlim=xlim)
histogram!(y, nbins=30, subplot=3, orientation=:h, framestyle=:none, leftmargin=-40Plots.px, ylim=ylim)
return plt
end
# Generate some data
N = 300
Random.seed!(45)
x = rand(N) .- 0.5
y = vcat(randn(Int(N / 2)) ./ 6 .- 0.25, randn(Int(N / 2)) ./ 8 .+ 0.35)
p1 = marginplot(x, y, "Original data")
# Enforce complex constraint
x1, y1 = copy(x), copy(y)
permutefun!(x1, y1, (xi, yi) -> (xi^4 < yi^2))
p2 = marginplot(x1, y1, "x⁴ < y²")
# induce a certain correlation
x2, y2 = copy(x), copy(y)
xm, ym = mean(x), mean(y)
permutefun!(x2, y2, (xi, yi) -> (xi - xm) * (yi - ym), 0.7 * norm(x .- xm) * norm(y .- ym))
cor(x2, y2)
p3 = marginplot(x2, y2, "Correlation = .7")
# make a hole in the data
x3, y3 = copy(x), copy(y)
permutefun!(x3, y3, (xi, yi) -> sqrt(xi^2 + yi^2) > 0.3)
p4 = marginplot(x3, y3, "Hole of radius 0.3")
plot(p1, p2, p3, p4, size=(2000, 500), layout=(1, 4), margin=10Plots.px)
savefig("permutefun.pdf")
@vankesteren
Copy link
Author

vankesteren commented May 30, 2024

permutefun

Note that the marginals stay the same while the functions are being applied because the algorithm only permutes, does not change values.

@vankesteren
Copy link
Author

And with 30000 samples:

permutefun30k

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment