Skip to content

Instantly share code, notes, and snippets.

@EssamWisam
Last active March 31, 2023 12:20
Show Gist options
  • Save EssamWisam/7598b862a8ee3a4876e28ef0612253c5 to your computer and use it in GitHub Desktop.
Save EssamWisam/7598b862a8ee3a4876e28ef0612253c5 to your computer and use it in GitHub Desktop.
Updated variable names
using MLUtils
using CategoricalArrays
using StatsPlots
using DataFrames
"""
generate_imbalanced_data(num_rows, num_features, majority_ratio)
Generate random imbalanced data with specified number of rows and features, and a majority class ratio.
# Arguments:
- `num_rows`: The number of rows in the data.
- `num_features`: The number of columns in the data.
- `majority_ratio`: The ratio of the majority class in the target variable `y`.
- `type`: (optional) The type of target variable to generate. Defaults to "category".
# Returns:
- `X`: A DataFrame containing the random data with `num_rows` rows and `num_features` columns.
- `y`: A CategoricalArray containing the target variable where the majority class is `majority_ratio` of the total number of rows.
"""
function generate_imbalanced_data(num_rows, num_features, majority_ratio)
# Generate a random array X as a data frame
X = DataFrame(rand(1:num_rows, num_rows, num_features), :auto)
# Decide how big the majority class should be as a ratio
majority_size = Int(majority_ratio * num_rows)
# Either generate y as a categorical array or just an ordinary vector
y = CategoricalArray([(0<=rand()<majority_ratio) ? "A" : (((rand()>0.5) ? "B" : "C")) for i in 1:num_rows])
return X, y
end
"""
plot_data(y_before, y_after, X_before, X_after)
Plot imbalanced data and labels frequency before and after oversampling in a 2x2 grid
# Arguments:
- `y_before`: The target variable of the data before oversampling.
- `y_after`: The target variable of the data after oversampling.
- `X_before`: The data before oversampling.
- `X_after`: The data after oversampling.
"""
function plot_data(y_before, y_after, X_before, X_after)
y_before, y_after = getobs(y_before), getobs(y_after)
X_before, X_after = getobs(X_before), getobs(X_after)
# Frequency table
# Find labels of y
labels = unique(y_before)
# Find counts of each label for each version of y
label_counts1 = [count(yi -> yi == label, y_before) for label in labels]
label_counts2 = [count(yi -> yi == label, y_after) for label in labels]
# Plot the counts vs the labels in each case
p1 = bar(labels, label_counts1, xlabel="Label", ylabel="Count", title="\nBefore Oversampling", legend=false)
p2 = bar(labels, label_counts2, xlabel="Label", ylabel="Count", title="\nAfter Oversampling", legend=false)
# Scatter plot
p3 = scatter(X_before[:, 1], X_before[:, 2], xlabel="X1", ylabel="X2", title="Before Oversampling with size $(size(X_before)[1])", legend=false)
p4 = scatter(X_after[:, 1], X_after[:, 2], xlabel="X1", ylabel="X2", title="After Oversampling with size $(size(X_after)[1])", legend=false)
# Plotting the figures together in a 2x2 layout
plot(p1, p2, p3, p4, layout=(2, 2), size=(900, 900))
end
"""
naive_random_oversampling(X, y)
Oversample the minority class in a dataset using the Naive Random Oversampling method.
This performs random oversampling on the minority classes until they are equal in size to the majority class.
# Arguments:
- `X: The data to oversample which implements the MLUtils.jl data container interface
- `y: An abstract vector for the target variable. This will be oversampled synchronously with X.
# Returns:
- `Xover: The oversampled data.
- `yover: The oversampled target variable.
"""
function naive_random_oversampling(X, y)
# find the number of observations in each label
labels_counts = group_counts(y)
# find the number of observations in the majority class
majority_count = maximum(values(labels_counts))
# This will contain the mapping from X, y to Xover, yover in terms of indices
all_indices = []
# for each label we will append observations to Xover and yover
for (label, count) in labels_counts
# get the number of new samples to add for this label
num_new_samples = majority_count - count
# Find i, xᵢ, where yᵢ = label
label_indices = findall(yᵢ -> yᵢ == label, y)
# Choose num_new_samples randomly from the indeces found
new_label_indices = randobs(label_indices, num_new_samples)
all_indices = [all_indices; new_label_indices; label_indices]
end
yover = obsview(y, all_indices)
Xover = obsview(X, all_indices)
return Xover, yover
end
# Test I:
X, y = generate_imbalanced_data(1000, 10, 0.8)
Xover, yover = naive_random_oversampling(X, y)
plot_data(y, yover, X, Xover)
# Test II:
X = (x1=[1,2,3,4], x2=[5,6,7,8])
@assert Tables.istable(X)
y = [true, false, true, true]
Xover, yover = naive_random_oversampling(X, y)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment