Skip to content

Instantly share code, notes, and snippets.

@krdln

krdln/main.rs Secret

Created March 8, 2017 03:16
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krdln/8466fb45ad68b7667a29c2558059cbde to your computer and use it in GitHub Desktop.
Save krdln/8466fb45ad68b7667a29c2558059cbde to your computer and use it in GitHub Desktop.
iterator version
extern crate clap;
extern crate csv;
struct TimeSerie {
data: Vec<f64>,
}
impl TimeSerie {
fn length(&self) -> usize {
self.data.len()
}
fn square_dist(&self, self_idx: usize, other_idx: usize, other: &Self) -> f64 {
let dif = self.data[self_idx] - other.data[other_idx];
dif * dif
}
fn compute_dtw(&self, other: &Self) -> f64 {
let dim = self.length();
let mut curr = vec![0f64; dim];
let mut prev = vec![0f64; dim];
// --- Init 0,0
curr[0] = self.square_dist(0, 0, other);
// --- Init the two "axis"
// --- --- first line, along columns (0, ..): self
// --- --- first column, along lines (.., 0): other
for x in 1..dim {
curr[x] = curr[x - 1] + self.square_dist(0, x, other); // First line
}
// --- Compute DTW
for idx_line in 1..dim {
std::mem::swap(&mut curr, &mut prev);
curr[0] = prev[0] + self.square_dist(idx_line, 0, other);
let mut acc = curr[0];
for (((&d11, &d10), curr), idx_col) in prev.iter().zip(&prev[1..]).zip(&mut curr[1..]).zip(1..) {
let mmm = min2(d11, d10);
acc = if acc < mmm {
acc + self.square_dist(idx_line, idx_col, other)
} else {
mmm + self.square_dist(idx_line, idx_col, other)
};
*curr = acc;
}
}
let last = dim - 1;
curr[last] //.sqrt()
}
// --- --- --- static functions
fn new(data: Vec<f64>) -> TimeSerie {
TimeSerie { data: data }
}
}
fn min2(x: f64, y: f64) -> f64 {
if x < y { x } else { y }
}
fn min3(x: f64, y: f64, z: f64) -> f64 {
// match is marginally faster on my machine
// match (x < y, x < z, y < z) {
// (true, true, _) => x,
// (true, false, _) => z,
// (false, _, true) => y,
// (false, _, false) => z,
// }
if x < y { if x < z { x } else { z } } else { if y < z { y } else { z } }
}
fn check_cli<'a>() -> clap::ArgMatches<'a> {
let matches = clap::App::new("ts")
.version("0.0")
.about("Working with time series")
.arg(clap::Arg::with_name("INPUT FILE")
.required(true)
.index(1)
.help("Input file, must be a csv")
).get_matches();
return matches;
}
fn main() {
// --- 0: Get the command line arguments
let matches = check_cli();
let file = matches.value_of("INPUT FILE").unwrap();
// --- 1: Load the CSV
let mut rdr = csv::Reader::from_file(file).unwrap();
let rows = rdr.records().map(|r| r.unwrap());
let mut vec: Vec<TimeSerie> = Vec::new();
for row in rows {
let mut iter = row.into_iter();
let _class: String = iter.next().unwrap();
let data: Vec<f64> = iter.map(|s| s.parse().unwrap()).collect();
vec.push(TimeSerie::new(data));
}
// --- 2: Compute sum of DTW
let mut total_e = 0f64;
let now = std::time::SystemTime::now();
for (id, vi) in vec.iter().enumerate() {
for vj in vec.iter().skip(id) {
total_e += vi.compute_dtw(vj);
}
}
match now.elapsed() {
Ok(elapsed) => { println!("{} s", elapsed.as_secs()); }
Err(_) => { () }
}
println!("Total error: {}", total_e);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment