-
-
Save krdln/8466fb45ad68b7667a29c2558059cbde to your computer and use it in GitHub Desktop.
iterator version
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate clap; | |
extern crate csv; | |
struct TimeSerie { | |
data: Vec<f64>, | |
} | |
impl TimeSerie { | |
fn length(&self) -> usize { | |
self.data.len() | |
} | |
fn square_dist(&self, self_idx: usize, other_idx: usize, other: &Self) -> f64 { | |
let dif = self.data[self_idx] - other.data[other_idx]; | |
dif * dif | |
} | |
fn compute_dtw(&self, other: &Self) -> f64 { | |
let dim = self.length(); | |
let mut curr = vec![0f64; dim]; | |
let mut prev = vec![0f64; dim]; | |
// --- Init 0,0 | |
curr[0] = self.square_dist(0, 0, other); | |
// --- Init the two "axis" | |
// --- --- first line, along columns (0, ..): self | |
// --- --- first column, along lines (.., 0): other | |
for x in 1..dim { | |
curr[x] = curr[x - 1] + self.square_dist(0, x, other); // First line | |
} | |
// --- Compute DTW | |
for idx_line in 1..dim { | |
std::mem::swap(&mut curr, &mut prev); | |
curr[0] = prev[0] + self.square_dist(idx_line, 0, other); | |
let mut acc = curr[0]; | |
for (((&d11, &d10), curr), idx_col) in prev.iter().zip(&prev[1..]).zip(&mut curr[1..]).zip(1..) { | |
let mmm = min2(d11, d10); | |
acc = if acc < mmm { | |
acc + self.square_dist(idx_line, idx_col, other) | |
} else { | |
mmm + self.square_dist(idx_line, idx_col, other) | |
}; | |
*curr = acc; | |
} | |
} | |
let last = dim - 1; | |
curr[last] //.sqrt() | |
} | |
// --- --- --- static functions | |
fn new(data: Vec<f64>) -> TimeSerie { | |
TimeSerie { data: data } | |
} | |
} | |
fn min2(x: f64, y: f64) -> f64 { | |
if x < y { x } else { y } | |
} | |
fn min3(x: f64, y: f64, z: f64) -> f64 { | |
// match is marginally faster on my machine | |
// match (x < y, x < z, y < z) { | |
// (true, true, _) => x, | |
// (true, false, _) => z, | |
// (false, _, true) => y, | |
// (false, _, false) => z, | |
// } | |
if x < y { if x < z { x } else { z } } else { if y < z { y } else { z } } | |
} | |
fn check_cli<'a>() -> clap::ArgMatches<'a> { | |
let matches = clap::App::new("ts") | |
.version("0.0") | |
.about("Working with time series") | |
.arg(clap::Arg::with_name("INPUT FILE") | |
.required(true) | |
.index(1) | |
.help("Input file, must be a csv") | |
).get_matches(); | |
return matches; | |
} | |
fn main() { | |
// --- 0: Get the command line arguments | |
let matches = check_cli(); | |
let file = matches.value_of("INPUT FILE").unwrap(); | |
// --- 1: Load the CSV | |
let mut rdr = csv::Reader::from_file(file).unwrap(); | |
let rows = rdr.records().map(|r| r.unwrap()); | |
let mut vec: Vec<TimeSerie> = Vec::new(); | |
for row in rows { | |
let mut iter = row.into_iter(); | |
let _class: String = iter.next().unwrap(); | |
let data: Vec<f64> = iter.map(|s| s.parse().unwrap()).collect(); | |
vec.push(TimeSerie::new(data)); | |
} | |
// --- 2: Compute sum of DTW | |
let mut total_e = 0f64; | |
let now = std::time::SystemTime::now(); | |
for (id, vi) in vec.iter().enumerate() { | |
for vj in vec.iter().skip(id) { | |
total_e += vi.compute_dtw(vj); | |
} | |
} | |
match now.elapsed() { | |
Ok(elapsed) => { println!("{} s", elapsed.as_secs()); } | |
Err(_) => { () } | |
} | |
println!("Total error: {}", total_e); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment