-
-
Save krdln/40cfd8587ffda22e7de4c0b93e5d5341 to your computer and use it in GitHub Desktop.
diagonal
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
extern crate clap; | |
extern crate csv; | |
struct TimeSerie { | |
data: Vec<f64>, | |
} | |
// Click "revisions" to see a version without this function | |
fn do_the_thing( | |
curr: &mut[f64], | |
prev: &[f64], | |
prevprev: &[f64], | |
self_data_rev: &[f64], | |
other_data: &[f64], | |
first: usize, | |
last: usize, | |
dim: usize, | |
) { | |
let iter = curr[first .. last + 1].iter_mut() | |
.zip(&prev[first - 1 ..]) | |
.zip(&prev[first ..]) | |
.zip(&prevprev[first - 1 ..]) | |
.zip( | |
self_data_rev[dim - 1 - last ..].iter() | |
.zip(&other_data[first ..]) | |
); | |
for ((((curr, &left), &up), &upleft), (self_value, other_value)) in iter { | |
let dist = self_value - other_value; | |
let sqr_dist = dist * dist; | |
let val = min2(min2(left, up), upleft) + sqr_dist; | |
*curr = val | |
} | |
} | |
impl TimeSerie { | |
fn length(&self) -> usize { | |
self.data.len() | |
} | |
fn square_dist(&self, self_idx: usize, other_idx: usize, other: &Self) -> f64 { | |
let dif = self.data[self_idx] - other.data[other_idx]; | |
dif * dif | |
} | |
fn compute_dtw(&self, other: &Self) -> f64 { | |
let dim = self.length(); | |
// Three consecutive diagonals, each diagonal | |
// is indexed by column, from bottom-left to top-right. | |
let mut prevprev = vec![0f64; dim + 1]; | |
let mut prev = vec![0f64; dim]; | |
let mut curr = vec![0f64; dim]; | |
let mut out = vec![0f64; dim]; | |
assert!(dim >= 3); | |
prev[0] = self.square_dist(0, 0, other); | |
curr[0] = prev[0] + self.square_dist(1, 0, other); | |
curr[1] = prev[0] + self.square_dist(0, 1, other); | |
// Calculate the reverse of self.data, so it's easily vectorized. | |
let self_data_rev: Vec<f64> = self.data.iter().rev().cloned().collect(); | |
let mid_diagonal_idx = dim - 1; | |
// --- Compute DTW | |
for i_diagonal in 2..(2 * dim - 1) { | |
{ | |
let tmp = prevprev; | |
prevprev = prev; | |
prev = curr; | |
curr = tmp; | |
} | |
// Column indices on current diagonal. | |
// Note that in the lower triangle of the matrix, | |
// we use the "right" part of arrays to | |
// store the values. | |
let (first, last) = if i_diagonal <= mid_diagonal_idx { | |
(0, i_diagonal) | |
} else { | |
(i_diagonal - mid_diagonal_idx, dim - 1) | |
}; | |
// Shrink the range, so we're left only with what | |
// can be computed in the main loop. | |
let (first, last) = if i_diagonal <= mid_diagonal_idx { | |
// Compute the first and last element separately, | |
// they have only 1 dependency instead of regular 3. | |
curr[0] = prev[0] + self.square_dist(last, 0, other); | |
curr[last] = prev[last - 1] + self.square_dist(0, last, other); | |
(first + 1, last - 1) | |
} else { | |
(first, last) | |
}; | |
do_the_thing( | |
&mut curr, | |
&prev, | |
&prevprev, | |
&self_data_rev, | |
&other.data, | |
first, | |
last, | |
dim, | |
); | |
} | |
curr[dim - 1] | |
} | |
// --- --- --- static functions | |
fn new(data: Vec<f64>) -> TimeSerie { | |
TimeSerie { data: data } | |
} | |
} | |
fn min2(x: f64, y: f64) -> f64 { | |
if x < y { x } else { y } | |
} | |
fn check_cli<'a>() -> clap::ArgMatches<'a> { | |
let matches = clap::App::new("ts") | |
.version("0.0") | |
.about("Working with time series") | |
.arg(clap::Arg::with_name("INPUT FILE") | |
.required(true) | |
.index(1) | |
.help("Input file, must be a csv") | |
).get_matches(); | |
return matches; | |
} | |
fn main() { | |
// --- 0: Get the command line arguments | |
let matches = check_cli(); | |
let file = matches.value_of("INPUT FILE").unwrap(); | |
// --- 1: Load the CSV | |
let mut rdr = csv::Reader::from_file(file).unwrap(); | |
let rows = rdr.records().map(|r| r.unwrap()); | |
let mut vec: Vec<TimeSerie> = Vec::new(); | |
for row in rows { | |
let mut iter = row.into_iter(); | |
let _class: String = iter.next().unwrap(); | |
let data: Vec<f64> = iter.map(|s| s.parse().unwrap()).collect(); | |
vec.push(TimeSerie::new(data)); | |
} | |
// --- 2: Compute sum of DTW | |
let mut total_e = 0f64; | |
let now = std::time::SystemTime::now(); | |
for (id, vi) in vec.iter().enumerate() { | |
for vj in vec.iter().skip(id) { | |
total_e += vi.compute_dtw(vj); | |
} | |
} | |
match now.elapsed() { | |
Ok(elapsed) => { println!( | |
"{} s {} ms", | |
elapsed.as_secs(), | |
elapsed.subsec_nanos() / 1_000_000 | |
); } | |
Err(_) => { () } | |
} | |
println!("Total error: {}", total_e); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment