Created
March 7, 2017 02:55
-
-
Save matklad/0d1912d3bcd4a8a8a5b3e867e7466881 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate clap; | |
extern crate csv; | |
struct TimeSerie { | |
data: Vec<f64>, | |
} | |
impl TimeSerie { | |
fn length(&self) -> usize { | |
self.data.len() | |
} | |
fn square_dist(&self, self_idx: usize, other_idx: usize, other: &Self) -> f64 { | |
let dif = self.data[self_idx] - other.data[other_idx]; | |
dif * dif | |
} | |
fn compute_dtw(&self, other: &Self) -> f64 { | |
let dim = self.length(); | |
let mut curr = vec![0f64; dim]; | |
let mut prev = vec![0f64; dim]; | |
// --- Init 0,0 | |
curr[0] = self.square_dist(0, 0, other); | |
// --- Init the two "axis" | |
// --- --- first line, along columns (0, ..): self | |
// --- --- first column, along lines (.., 0): other | |
for x in 1..dim { | |
curr[x] = curr[x - 1] + self.square_dist(0, x, other); // First line | |
} | |
// --- Compute DTW | |
for idx_line in 1..dim { | |
std::mem::swap(&mut curr, &mut prev); | |
curr[0] = prev[0] + self.square_dist(idx_line, 0, other); | |
for idx_col in 1..dim { | |
let d = { | |
// Compute ancestors | |
let d11 = prev[idx_col - 1]; | |
let d01 = curr[idx_col - 1]; | |
let d10 = prev[idx_col]; | |
// Take the smallest ancestor and add the current distance | |
min3(d11, d01, d10) + self.square_dist(idx_line, idx_col, other) | |
// The next line actually call cmath | |
// d11.min(d01).min(d10) + self.square_dist(idx_line, idx_col, other) | |
}; | |
curr[idx_col] = d; | |
} | |
} | |
let last = dim - 1; | |
curr[last] //.sqrt() | |
} | |
// --- --- --- static functions | |
fn new(data: Vec<f64>) -> TimeSerie { | |
TimeSerie { data: data } | |
} | |
} | |
fn min3(x: f64, y: f64, z: f64) -> f64 { | |
// match is marginally faster on my machine | |
// match (x < y, x < z, y < z) { | |
// (true, true, _) => x, | |
// (true, false, _) => z, | |
// (false, _, true) => y, | |
// (false, _, false) => z, | |
// } | |
if x < y { if x < z { x } else { z } } else { if y < z { y } else { z } } | |
} | |
fn check_cli<'a>() -> clap::ArgMatches<'a> { | |
let matches = clap::App::new("ts") | |
.version("0.0") | |
.about("Working with time series") | |
.arg(clap::Arg::with_name("INPUT FILE") | |
.required(true) | |
.index(1) | |
.help("Input file, must be a csv") | |
).get_matches(); | |
return matches; | |
} | |
fn main() { | |
// --- 0: Get the command line arguments | |
let matches = check_cli(); | |
let file = matches.value_of("INPUT FILE").unwrap(); | |
// --- 1: Load the CSV | |
let mut rdr = csv::Reader::from_file(file).unwrap(); | |
let rows = rdr.records().map(|r| r.unwrap()); | |
let mut vec: Vec<TimeSerie> = Vec::new(); | |
for row in rows { | |
let mut iter = row.into_iter(); | |
let _class: String = iter.next().unwrap(); | |
let data: Vec<f64> = iter.map(|s| s.parse().unwrap()).collect(); | |
vec.push(TimeSerie::new(data)); | |
} | |
// --- 2: Compute sum of DTW | |
let mut total_e = 0f64; | |
let now = std::time::SystemTime::now(); | |
for (id, vi) in vec.iter().enumerate() { | |
for vj in vec.iter().skip(id) { | |
total_e += vi.compute_dtw(vj); | |
} | |
} | |
match now.elapsed() { | |
Ok(elapsed) => { println!("{} s", elapsed.as_secs()); } | |
Err(_) => { () } | |
} | |
println!("Total error: {}", total_e); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment