Skip to content

Instantly share code, notes, and snippets.

@iwiwi
Created December 30, 2020 07:23
Show Gist options
  • Save iwiwi/10fb477eceaff0d36cdacf9a268db780 to your computer and use it in GitHub Desktop.
Save iwiwi/10fb477eceaff0d36cdacf9a268db780 to your computer and use it in GitHub Desktop.
use crate::*;
mod normal_distribution {
const S2PI: f64 = 2.50662827463100050242E0;
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c
const P0: [f64; 5] = [
-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
1.39312609387279679503E1,
-1.23916583867381258016E0,
];
const Q0: [f64; 8] = [
/* 1.00000000000000000000E0, */
1.95448858338141759834E0,
4.67627912898881538453E0,
8.63602421390890590575E1,
-2.25462687854119370527E2,
2.00260212380060660359E2,
-8.20372256168333339912E1,
1.59056225126211695515E1,
-1.18331621121330003142E0,
];
const P1: [f64; 9] = [
4.05544892305962419923E0,
3.15251094599893866154E1,
5.71628192246421288162E1,
4.40805073893200834700E1,
1.46849561928858024014E1,
2.18663306850790267539E0,
-1.40256079171354495875E-1,
-3.50424626827848203418E-2,
-8.57456785154685413611E-4,
];
const Q1: [f64; 8] = [
/* 1.00000000000000000000E0, */
1.57799883256466749731E1,
4.53907635128879210584E1,
4.13172038254672030440E1,
1.50425385692907503408E1,
2.50464946208309415979E0,
-1.42182922854787788574E-1,
-3.80806407691578277194E-2,
-9.33259480895457427372E-4,
];
const P2: [f64; 9] = [
3.23774891776946035970E0,
6.91522889068984211695E0,
3.93881025292474443415E0,
1.33303460815807542389E0,
2.01485389549179081538E-1,
1.23716634817820021358E-2,
3.01581553508235416007E-4,
2.65806974686737550832E-6,
6.23974539184983293730E-9,
];
const Q2: [f64; 8] = [
/* 1.00000000000000000000E0, */
6.02427039364742014255E0,
3.67983563856160859403E0,
1.37702099489081330271E0,
2.16236993594496635890E-1,
1.34204006088543189037E-2,
3.28014464682127739104E-4,
2.89247864745380683936E-6,
6.79019408009981274425E-9,
];
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L67
fn polevl(x: f64, coef: &[f64]) -> f64 {
let mut ans = 0.0;
for c in coef {
ans = ans * x + *c;
}
ans
}
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L90
fn p1evl(x: f64, coef: &[f64]) -> f64 {
let mut ans = 1.0;
for c in coef {
ans = ans * x + *c;
}
ans
}
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c#L134
pub fn ppf(y0: f64) -> f64 {
dbg!(y0);
assert!(0.0 <= y0 && y0 <= 1.0);
let y;
let code;
if y0 > (1.0 - 0.13533528323661269189) {
y = 1.0 - y0;
code = 0;
} else {
y = y0;
code = 1;
}
if y > 0.13533528323661269189 {
let y = y - 0.5;
let y2 = y * y;
let x = y + y * (y2 * polevl(y2, &P0) / p1evl(y2, &Q0));
let x = x * S2PI;
return x;
}
let x = (-2.0 * y.ln()).sqrt();
let x0 = x - x.ln() / x;
let z = 1.0 / x;
let x1;
if x < 8.0 {
x1 = z * polevl(z, &P1) / p1evl(z, &Q1);
} else {
x1 = z * polevl(z, &P2) / p1evl(z, &Q2);
}
let mut x = x0 - x1;
if code != 0 {
x = -x;
}
x
}
}
const BOUNDS_THRESHOLD: f64 = 1e-7;
#[derive(Debug, Clone)]
pub struct QuantileTransformer {
references: Vec<f64>,
quantiles: Vec<Vec<f64>>,
}
fn transform_col(x: f64, quantiles: &Vec<f64>, references: &Vec<f64>) -> f64 {
let y;
let xlb = quantiles[0];
let xub = *quantiles.last().unwrap();
if x <= xlb {
y = 0.0;
} else if x >= xub {
y = 1.0;
} else {
// xの左右を二分探索で探す
let mut ilb = 0;
let mut iub = quantiles.len() - 1;
while iub - ilb > 1 {
let imd = (ilb + iub) / 2;
let qmd = quantiles[imd];
if qmd < x {
ilb = imd;
} else {
iub = imd;
}
}
assert!(quantiles[ilb] <= x);
assert!(quantiles[iub] >= x);
// 線形補間する
let xlb = quantiles[ilb];
let xub = quantiles[iub];
let dlb = x - xlb;
let dub = xub - x;
let wlb = dub / (dlb + dub);
let wub = dlb / (dlb + dub);
dbg!(wlb, wub);
y = references[ilb] * wlb + references[iub] * wub;
}
let y = y.clamp(
BOUNDS_THRESHOLD - f64::EPSILON,
1.0 - (BOUNDS_THRESHOLD - f64::EPSILON),
);
dbg!(y);
normal_distribution::ppf(y)
}
#[derive(serde::Deserialize)]
struct Dump {
output_distribution: String,
references_: Vec<f64>,
quantiles_: Vec<Vec<f64>>,
}
impl QuantileTransformer {
pub fn from_dump(dump: serde_json::Value) -> R<QuantileTransformer> {
let dump: Dump = serde_json::from_value(dump)?;
// normalしかサポートしない
assert_eq!(dump.output_distribution, "normal");
// 転置しといたほうが便利、ってか元のsklearnの実装も転置しといたほうが便利に見えて仕方ないのに何で転置してないんだろ
let n_features = dump.quantiles_[0].len();
let n_references = dump.references_.len();
let mut quantiles = vec![vec![0.0; n_references]; n_features];
for i in 0..n_features {
for j in 0..n_references {
quantiles[i][j] = dump.quantiles_[j][i];
}
}
// quantilesがユニークじゃない場合は結構変な処理しないといけないが、冷静に俺はそういうの使う予定ないから落とす
for qs in quantiles.iter_mut() {
qs.dedup();
assert_eq!(qs.len(), n_references);
}
Ok(QuantileTransformer {
references: dump.references_,
quantiles,
})
}
pub fn transform(&self, x: &[f64]) -> Vec<f64> {
assert_eq!(x.len(), self.quantiles.len());
x.iter()
.zip(self.quantiles.iter())
.map(|(x, quantiles)| transform_col(*x, quantiles, &self.references))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create() -> QuantileTransformer {
let j = serde_json::json!({
"n_quantiles": 10,
"output_distribution": "normal",
"ignore_implicit_zeros": false,
"subsample": 100000,
"random_state": null,
"copy": true,
"n_features_in_": 3,
"n_quantiles_": 10,
"references_": [
0.0,
0.1111111111111111,
0.2222222222222222,
0.3333333333333333,
0.4444444444444444,
0.5555555555555556,
0.6666666666666666,
0.7777777777777777,
0.8888888888888888,
1.0,
],
"quantiles_": [
[-2.613328303719778, -9.321205290168217, 0.13674275875021852],
[-1.1414101426254262, -3.294752914303114, 0.31128319201392723],
[-0.7306000262128471, -1.919245430597138, 0.5628969761610223],
[-0.427117826339735, -0.7238972474428322, 0.7805069173920551],
[-0.05639286493061981, 0.885189368484695, 0.9814794091618718],
[0.1722755897532091, 1.991170360919146, 1.2760902770308562],
[0.39748654705760395, 3.8665254431411706, 1.6340155327513783],
[0.67348557453346, 5.3665644352138, 1.959333216365857],
[0.9437791586194102, 7.205771950040178, 2.8721836987270755],
[2.4620269142769113, 12.26632790171583, 17.47969633690788],
],
});
QuantileTransformer::from_dump(j).unwrap()
}
fn check(qt: &QuantileTransformer, x: &[f64], y: &[f64]) {
let z = qt.transform(x);
assert_eq!(y.len(), z.len());
for (a, b) in y.iter().zip(z.iter()) {
assert_approx_eq!(a, b);
}
}
#[test]
fn test_references() {
let qt = create();
let cases = &[
(
[-2.613328303719778, -9.321205290168217, 0.13674275875021852],
[-5.199337582605575, -5.199337582605575, -5.199337582605575],
),
(
[-1.1414101426254262, -3.294752914303114, 0.31128319201392723],
[-1.22064034884735, -1.22064034884735, -1.22064034884735],
),
(
[-0.7306000262128471, -1.919245430597138, 0.5628969761610223],
[-0.764709673786387, -0.764709673786387, -0.764709673786387],
),
(
[-0.427117826339735, -0.7238972474428322, 0.7805069173920551],
[-0.430727299295457, -0.430727299295457, -0.430727299295457],
),
(
[-0.05639286493061981, 0.885189368484695, 0.9814794091618718],
[-0.139710298881862, -0.139710298881862, -0.139710298881862],
),
(
[0.1722755897532091, 1.991170360919146, 1.2760902770308562],
[0.1397102988818621, 0.1397102988818621, 0.1397102988818621],
),
(
[0.39748654705760395, 3.8665254431411706, 1.6340155327513783],
[0.4307272992954574, 0.4307272992954574, 0.4307272992954574],
),
(
[0.67348557453346, 5.3665644352138, 1.959333216365857],
[0.7647096737863867, 0.7647096737863867, 0.7647096737863867],
),
(
[0.9437791586194102, 7.205771950040178, 2.8721836987270755],
[1.2206403488473496, 1.2206403488473496, 1.2206403488473496],
),
(
[2.4620269142769113, 12.26632790171583, 17.47969633690788],
[5.19933758270342, 5.19933758270342, 5.19933758270342],
),
];
for case in cases {
check(&qt, &case.0, &case.1);
}
}
#[test]
fn test_random() {
let qt = create();
let cases = &[
(
[9.543169032696227, 13.004265458495848, 1.2817188786196336],
[5.19933758270342, 5.19933758270342, 0.14413444750289997],
),
(
[6.148524766915234, 12.277375912423363, -2.0032741621514276],
[5.19933758270342, 5.19933758270342, -5.199337582605575],
),
(
[3.2479687586644044, 6.333366933638192, -0.18854883910825748],
[5.19933758270342, 0.9788976648571445, -5.199337582605575],
),
(
[12.026790240804463, -8.17807192753417, 7.139275629961631],
[5.19933758270342, -2.0320120702704356, 1.414185316101896],
),
(
[11.297395996011659, 12.389898901401988, 4.3528110754416325],
[5.19933758270342, 5.19933758270342, 1.2824135105950227],
),
(
[-5.548440019781655, 3.9141903012544503, 8.462175829374196],
[-5.199337582605575, 0.44045804894271895, 1.4863658158773896],
),
(
[0.48259293083427224, 1.8454399739866432, -8.433798779937739],
[0.5270731855886761, 0.10273894814593508, -5.199337582605575],
),
(
[2.9464268070547117, 12.73414827971353, 5.139387737583421],
[5.19933758270342, 5.19933758270342, 1.3173195432982814],
),
(
[-3.4636253585345047, 13.086601949163505, 2.851517222764288],
[-5.199337582605575, 5.19933758270342, 1.207464726166755],
),
(
[-6.081831945801352, -6.336175884405989, -7.420982125992383],
[-5.199337582605575, -1.5978724353638714, -5.199337582605575],
),
(
[-6.275065630029899, 7.586341240789082, 1.7097115253806585],
[-5.199337582605575, 1.266007404810665, 0.5030071199147996],
),
(
[-0.13038122929708784, -4.161259056492385, 5.808837900045351],
[-0.1960917782971291, -1.3097800036623548, 1.3483455214298523],
),
(
[7.618117371577743, 5.909956441138961, 10.508511337022796],
[5.19933758270342, 0.8801295864809635, 1.6161969123723285],
),
(
[8.693866471169883, -0.7330487877831544, 9.163553125519751],
[5.19933758270342, -0.4330680381712861, 1.5280004611048847],
),
(
[0.4282314588607754, -1.1967112585064736, -6.889187183469036],
[0.46502685937014465, -0.5551855155241495, -5.199337582605575],
),
(
[7.685794675525781, 4.500672464040191, 14.250996390962012],
[5.19933758270342, 0.5640480936891384, 1.967567697630419],
),
(
[1.5798977254115556, -9.611474340719736, -6.600078504292563],
[1.5176006671813367, -5.199337582605575, -5.199337582605575],
),
(
[0.17534066429800532, -3.5008980831264935, 12.394201107918992],
[0.1435390280777139, -1.2409594504784927, 1.7661839743843688],
),
(
[3.5825769280006945, -1.551988325695696, -6.6376732992024206],
[5.19933758270342, -0.6546087494691977, -5.199337582605575],
),
(
[-4.10776763251768, -9.18756794697149, 9.369452587399543],
[-5.199337582605575, -2.811715563246142, 1.5407399020190642],
),
];
for case in cases {
check(&qt, &case.0, &case.1);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment