-
-
Save krdln/7b91265e5f80755274db203e6cd8ebe8 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// Loading external crates | |
extern crate clap; // Command line parsing | |
extern crate csv; // CSV loading | |
// use std::cmp::PartialEq; | |
// --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- | |
// Time series trait | |
// --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- | |
struct TimeSerie { | |
#[allow(dead_code)] | |
class:String, | |
data:Vec<f64>, | |
} | |
impl TimeSerie { | |
#[inline(always)] | |
fn length(&self) -> usize { | |
self.data.len() | |
} | |
#[inline(always)] | |
fn at(&self, idx: usize) -> f64 { unsafe {*self.data.get_unchecked(idx)} } | |
#[inline(always)] | |
fn square_dist(&self, self_idx: usize, other_idx: usize, other: &Self) -> f64 { | |
let dif = self.at(self_idx) - other.at(other_idx); | |
dif*dif | |
} | |
fn compute_dtw(&self, other: &Self, m : &mut [f64]) -> f64 { | |
let dim = self.length(); | |
// --- Init 0,0 | |
m[0] = self.square_dist(0, 0, other); | |
// --- Init the two "axis" | |
// --- --- first line, along columns (0, ..): self | |
// --- --- first column, along lines (.., 0): other | |
for x in 1..dim { | |
m[x] = m[x-1] + self.square_dist(0, x, other); // First line | |
m[dim*x] = m[dim*(x-1)] + self.square_dist(x, 0, other); // First col | |
} | |
// --- Compute DTW | |
let mut prev_min: Vec<f64> = vec![0.0; dim]; | |
unsafe { | |
for idx_line in 1..dim { | |
{ | |
let prev_line = &m[dim * (idx_line - 1) .. dim * idx_line]; | |
for ((one, two), best) in prev_line.iter().zip(&prev_line[1..]).zip(&mut prev_min[1..]) { | |
*best = if *one < *two { *one } else { *two }; | |
} | |
} | |
for idx_col in 1..dim { | |
*m.get_unchecked_mut(dim*idx_line+idx_col) = { | |
// Compute ancestors | |
let d01 = *m.get_unchecked(dim*(idx_line) + idx_col-1); | |
let d10 = *prev_min.get_unchecked(idx_col); | |
// Take the smallest ancestor and add the current distance | |
// (if d01 < d10 { d01 } else { d10 }) + square_dists.get_unchecked(idx_col) | |
(if d01 < d10 { d01 } else { d10 }) + self.square_dist(idx_line, idx_col, other) | |
// The next line actually call cmath | |
// d11.min(d01).min(d10) + self.square_dist(idx_line, idx_col, other) | |
}; | |
} | |
} | |
} | |
let last = dim - 1; | |
(m[dim*last+last]) //.sqrt() | |
} | |
// --- --- --- static functions | |
fn new(class:String, data:Vec<f64>) -> TimeSerie { TimeSerie{class: class, data: data} } | |
} | |
// --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- | |
// Command line building | |
// --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- | |
fn check_cli<'a>() -> clap::ArgMatches<'a> { | |
let matches = clap::App::new("ts") | |
.version("0.0") | |
.about("Working with time series") | |
.arg(clap::Arg::with_name("INPUT FILE") | |
.required(true) | |
.index(1) | |
.help("Input file, must be a csv") | |
).get_matches(); | |
return matches; | |
} | |
// --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- | |
// Main app | |
// --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- --- | |
fn main() { | |
// --- 0: Get the command line arguments | |
let matches = check_cli(); | |
let file = matches.value_of("INPUT FILE").unwrap(); | |
// --- 1: Load the CSV | |
let mut rdr = csv::Reader::from_file(file).unwrap(); | |
let rows = rdr.records().map(|r| r.unwrap()); | |
let mut vec:Vec<TimeSerie> = Vec::new(); | |
for row in rows { | |
let mut iter = row.into_iter(); | |
let class:String = iter.next().unwrap(); | |
let data:Vec<f64> = iter.map( |s| s.parse().unwrap() ).collect(); | |
vec.push( TimeSerie::new(class, data) ); | |
} | |
// --- 2: Compute sum of DTW | |
let mut total_e:f64 = 0.0; | |
let ts_size = vec[0].length(); | |
let mut working_area = vec![0 as f64; ts_size * ts_size]; | |
let now = std::time::SystemTime::now(); | |
for (id, vi) in vec.iter().enumerate() { | |
for vj in vec.iter().skip(id) { | |
total_e += vi.compute_dtw(vj, &mut working_area); | |
} | |
} | |
match now.elapsed() { | |
Ok(elapsed) => { println!("{} s", elapsed.as_secs()); } | |
Err(_) => { () } | |
} | |
println!("Total error: {}", total_e); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment