Skip to content

Instantly share code, notes, and snippets.

@iwiwi
Created December 30, 2020 07:23
Embed
What would you like to do?
use crate::*;
mod normal_distribution {
const S2PI: f64 = 2.50662827463100050242E0;
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c
const P0: [f64; 5] = [
-5.99633501014107895267E1,
9.80010754185999661536E1,
-5.66762857469070293439E1,
1.39312609387279679503E1,
-1.23916583867381258016E0,
];
const Q0: [f64; 8] = [
/* 1.00000000000000000000E0, */
1.95448858338141759834E0,
4.67627912898881538453E0,
8.63602421390890590575E1,
-2.25462687854119370527E2,
2.00260212380060660359E2,
-8.20372256168333339912E1,
1.59056225126211695515E1,
-1.18331621121330003142E0,
];
const P1: [f64; 9] = [
4.05544892305962419923E0,
3.15251094599893866154E1,
5.71628192246421288162E1,
4.40805073893200834700E1,
1.46849561928858024014E1,
2.18663306850790267539E0,
-1.40256079171354495875E-1,
-3.50424626827848203418E-2,
-8.57456785154685413611E-4,
];
const Q1: [f64; 8] = [
/* 1.00000000000000000000E0, */
1.57799883256466749731E1,
4.53907635128879210584E1,
4.13172038254672030440E1,
1.50425385692907503408E1,
2.50464946208309415979E0,
-1.42182922854787788574E-1,
-3.80806407691578277194E-2,
-9.33259480895457427372E-4,
];
const P2: [f64; 9] = [
3.23774891776946035970E0,
6.91522889068984211695E0,
3.93881025292474443415E0,
1.33303460815807542389E0,
2.01485389549179081538E-1,
1.23716634817820021358E-2,
3.01581553508235416007E-4,
2.65806974686737550832E-6,
6.23974539184983293730E-9,
];
const Q2: [f64; 8] = [
/* 1.00000000000000000000E0, */
6.02427039364742014255E0,
3.67983563856160859403E0,
1.37702099489081330271E0,
2.16236993594496635890E-1,
1.34204006088543189037E-2,
3.28014464682127739104E-4,
2.89247864745380683936E-6,
6.79019408009981274425E-9,
];
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L67
fn polevl(x: f64, coef: &[f64]) -> f64 {
let mut ans = 0.0;
for c in coef {
ans = ans * x + *c;
}
ans
}
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/polevl.h#L90
fn p1evl(x: f64, coef: &[f64]) -> f64 {
let mut ans = 1.0;
for c in coef {
ans = ans * x + *c;
}
ans
}
// https://github.com/scipy/scipy/blob/v1.5.4/scipy/special/cephes/ndtri.c#L134
pub fn ppf(y0: f64) -> f64 {
dbg!(y0);
assert!(0.0 <= y0 && y0 <= 1.0);
let y;
let code;
if y0 > (1.0 - 0.13533528323661269189) {
y = 1.0 - y0;
code = 0;
} else {
y = y0;
code = 1;
}
if y > 0.13533528323661269189 {
let y = y - 0.5;
let y2 = y * y;
let x = y + y * (y2 * polevl(y2, &P0) / p1evl(y2, &Q0));
let x = x * S2PI;
return x;
}
let x = (-2.0 * y.ln()).sqrt();
let x0 = x - x.ln() / x;
let z = 1.0 / x;
let x1;
if x < 8.0 {
x1 = z * polevl(z, &P1) / p1evl(z, &Q1);
} else {
x1 = z * polevl(z, &P2) / p1evl(z, &Q2);
}
let mut x = x0 - x1;
if code != 0 {
x = -x;
}
x
}
}
const BOUNDS_THRESHOLD: f64 = 1e-7;
#[derive(Debug, Clone)]
pub struct QuantileTransformer {
references: Vec<f64>,
quantiles: Vec<Vec<f64>>,
}
fn transform_col(x: f64, quantiles: &Vec<f64>, references: &Vec<f64>) -> f64 {
let y;
let xlb = quantiles[0];
let xub = *quantiles.last().unwrap();
if x <= xlb {
y = 0.0;
} else if x >= xub {
y = 1.0;
} else {
// xの左右を二分探索で探す
let mut ilb = 0;
let mut iub = quantiles.len() - 1;
while iub - ilb > 1 {
let imd = (ilb + iub) / 2;
let qmd = quantiles[imd];
if qmd < x {
ilb = imd;
} else {
iub = imd;
}
}
assert!(quantiles[ilb] <= x);
assert!(quantiles[iub] >= x);
// 線形補間する
let xlb = quantiles[ilb];
let xub = quantiles[iub];
let dlb = x - xlb;
let dub = xub - x;
let wlb = dub / (dlb + dub);
let wub = dlb / (dlb + dub);
dbg!(wlb, wub);
y = references[ilb] * wlb + references[iub] * wub;
}
let y = y.clamp(
BOUNDS_THRESHOLD - f64::EPSILON,
1.0 - (BOUNDS_THRESHOLD - f64::EPSILON),
);
dbg!(y);
normal_distribution::ppf(y)
}
#[derive(serde::Deserialize)]
struct Dump {
output_distribution: String,
references_: Vec<f64>,
quantiles_: Vec<Vec<f64>>,
}
impl QuantileTransformer {
pub fn from_dump(dump: serde_json::Value) -> R<QuantileTransformer> {
let dump: Dump = serde_json::from_value(dump)?;
// normalしかサポートしない
assert_eq!(dump.output_distribution, "normal");
// 転置しといたほうが便利、ってか元のsklearnの実装も転置しといたほうが便利に見えて仕方ないのに何で転置してないんだろ
let n_features = dump.quantiles_[0].len();
let n_references = dump.references_.len();
let mut quantiles = vec![vec![0.0; n_references]; n_features];
for i in 0..n_features {
for j in 0..n_references {
quantiles[i][j] = dump.quantiles_[j][i];
}
}
// quantilesがユニークじゃない場合は結構変な処理しないといけないが、冷静に俺はそういうの使う予定ないから落とす
for qs in quantiles.iter_mut() {
qs.dedup();
assert_eq!(qs.len(), n_references);
}
Ok(QuantileTransformer {
references: dump.references_,
quantiles,
})
}
pub fn transform(&self, x: &[f64]) -> Vec<f64> {
assert_eq!(x.len(), self.quantiles.len());
x.iter()
.zip(self.quantiles.iter())
.map(|(x, quantiles)| transform_col(*x, quantiles, &self.references))
.collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create() -> QuantileTransformer {
let j = serde_json::json!({
"n_quantiles": 10,
"output_distribution": "normal",
"ignore_implicit_zeros": false,
"subsample": 100000,
"random_state": null,
"copy": true,
"n_features_in_": 3,
"n_quantiles_": 10,
"references_": [
0.0,
0.1111111111111111,
0.2222222222222222,
0.3333333333333333,
0.4444444444444444,
0.5555555555555556,
0.6666666666666666,
0.7777777777777777,
0.8888888888888888,
1.0,
],
"quantiles_": [
[-2.613328303719778, -9.321205290168217, 0.13674275875021852],
[-1.1414101426254262, -3.294752914303114, 0.31128319201392723],
[-0.7306000262128471, -1.919245430597138, 0.5628969761610223],
[-0.427117826339735, -0.7238972474428322, 0.7805069173920551],
[-0.05639286493061981, 0.885189368484695, 0.9814794091618718],
[0.1722755897532091, 1.991170360919146, 1.2760902770308562],
[0.39748654705760395, 3.8665254431411706, 1.6340155327513783],
[0.67348557453346, 5.3665644352138, 1.959333216365857],
[0.9437791586194102, 7.205771950040178, 2.8721836987270755],
[2.4620269142769113, 12.26632790171583, 17.47969633690788],
],
});
QuantileTransformer::from_dump(j).unwrap()
}
fn check(qt: &QuantileTransformer, x: &[f64], y: &[f64]) {
let z = qt.transform(x);
assert_eq!(y.len(), z.len());
for (a, b) in y.iter().zip(z.iter()) {
assert_approx_eq!(a, b);
}
}
#[test]
fn test_references() {
let qt = create();
let cases = &[
(
[-2.613328303719778, -9.321205290168217, 0.13674275875021852],
[-5.199337582605575, -5.199337582605575, -5.199337582605575],
),
(
[-1.1414101426254262, -3.294752914303114, 0.31128319201392723],
[-1.22064034884735, -1.22064034884735, -1.22064034884735],
),
(
[-0.7306000262128471, -1.919245430597138, 0.5628969761610223],
[-0.764709673786387, -0.764709673786387, -0.764709673786387],
),
(
[-0.427117826339735, -0.7238972474428322, 0.7805069173920551],
[-0.430727299295457, -0.430727299295457, -0.430727299295457],
),
(
[-0.05639286493061981, 0.885189368484695, 0.9814794091618718],
[-0.139710298881862, -0.139710298881862, -0.139710298881862],
),
(
[0.1722755897532091, 1.991170360919146, 1.2760902770308562],
[0.1397102988818621, 0.1397102988818621, 0.1397102988818621],
),
(
[0.39748654705760395, 3.8665254431411706, 1.6340155327513783],
[0.4307272992954574, 0.4307272992954574, 0.4307272992954574],
),
(
[0.67348557453346, 5.3665644352138, 1.959333216365857],
[0.7647096737863867, 0.7647096737863867, 0.7647096737863867],
),
(
[0.9437791586194102, 7.205771950040178, 2.8721836987270755],
[1.2206403488473496, 1.2206403488473496, 1.2206403488473496],
),
(
[2.4620269142769113, 12.26632790171583, 17.47969633690788],
[5.19933758270342, 5.19933758270342, 5.19933758270342],
),
];
for case in cases {
check(&qt, &case.0, &case.1);
}
}
#[test]
fn test_random() {
let qt = create();
let cases = &[
(
[9.543169032696227, 13.004265458495848, 1.2817188786196336],
[5.19933758270342, 5.19933758270342, 0.14413444750289997],
),
(
[6.148524766915234, 12.277375912423363, -2.0032741621514276],
[5.19933758270342, 5.19933758270342, -5.199337582605575],
),
(
[3.2479687586644044, 6.333366933638192, -0.18854883910825748],
[5.19933758270342, 0.9788976648571445, -5.199337582605575],
),
(
[12.026790240804463, -8.17807192753417, 7.139275629961631],
[5.19933758270342, -2.0320120702704356, 1.414185316101896],
),
(
[11.297395996011659, 12.389898901401988, 4.3528110754416325],
[5.19933758270342, 5.19933758270342, 1.2824135105950227],
),
(
[-5.548440019781655, 3.9141903012544503, 8.462175829374196],
[-5.199337582605575, 0.44045804894271895, 1.4863658158773896],
),
(
[0.48259293083427224, 1.8454399739866432, -8.433798779937739],
[0.5270731855886761, 0.10273894814593508, -5.199337582605575],
),
(
[2.9464268070547117, 12.73414827971353, 5.139387737583421],
[5.19933758270342, 5.19933758270342, 1.3173195432982814],
),
(
[-3.4636253585345047, 13.086601949163505, 2.851517222764288],
[-5.199337582605575, 5.19933758270342, 1.207464726166755],
),
(
[-6.081831945801352, -6.336175884405989, -7.420982125992383],
[-5.199337582605575, -1.5978724353638714, -5.199337582605575],
),
(
[-6.275065630029899, 7.586341240789082, 1.7097115253806585],
[-5.199337582605575, 1.266007404810665, 0.5030071199147996],
),
(
[-0.13038122929708784, -4.161259056492385, 5.808837900045351],
[-0.1960917782971291, -1.3097800036623548, 1.3483455214298523],
),
(
[7.618117371577743, 5.909956441138961, 10.508511337022796],
[5.19933758270342, 0.8801295864809635, 1.6161969123723285],
),
(
[8.693866471169883, -0.7330487877831544, 9.163553125519751],
[5.19933758270342, -0.4330680381712861, 1.5280004611048847],
),
(
[0.4282314588607754, -1.1967112585064736, -6.889187183469036],
[0.46502685937014465, -0.5551855155241495, -5.199337582605575],
),
(
[7.685794675525781, 4.500672464040191, 14.250996390962012],
[5.19933758270342, 0.5640480936891384, 1.967567697630419],
),
(
[1.5798977254115556, -9.611474340719736, -6.600078504292563],
[1.5176006671813367, -5.199337582605575, -5.199337582605575],
),
(
[0.17534066429800532, -3.5008980831264935, 12.394201107918992],
[0.1435390280777139, -1.2409594504784927, 1.7661839743843688],
),
(
[3.5825769280006945, -1.551988325695696, -6.6376732992024206],
[5.19933758270342, -0.6546087494691977, -5.199337582605575],
),
(
[-4.10776763251768, -9.18756794697149, 9.369452587399543],
[-5.199337582605575, -2.811715563246142, 1.5407399020190642],
),
];
for case in cases {
check(&qt, &case.0, &case.1);
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment