Skip to content

Instantly share code, notes, and snippets.

@krdln

krdln/main.rs Secret

Last active March 14, 2017 00:06
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save krdln/40cfd8587ffda22e7de4c0b93e5d5341 to your computer and use it in GitHub Desktop.
Save krdln/40cfd8587ffda22e7de4c0b93e5d5341 to your computer and use it in GitHub Desktop.
diagonal
extern crate clap;
extern crate csv;
struct TimeSerie {
data: Vec<f64>,
}
// Click "revisions" to see a version without this function
fn do_the_thing(
curr: &mut[f64],
prev: &[f64],
prevprev: &[f64],
self_data_rev: &[f64],
other_data: &[f64],
first: usize,
last: usize,
dim: usize,
) {
let iter = curr[first .. last + 1].iter_mut()
.zip(&prev[first - 1 ..])
.zip(&prev[first ..])
.zip(&prevprev[first - 1 ..])
.zip(
self_data_rev[dim - 1 - last ..].iter()
.zip(&other_data[first ..])
);
for ((((curr, &left), &up), &upleft), (self_value, other_value)) in iter {
let dist = self_value - other_value;
let sqr_dist = dist * dist;
let val = min2(min2(left, up), upleft) + sqr_dist;
*curr = val
}
}
impl TimeSerie {
fn length(&self) -> usize {
self.data.len()
}
fn square_dist(&self, self_idx: usize, other_idx: usize, other: &Self) -> f64 {
let dif = self.data[self_idx] - other.data[other_idx];
dif * dif
}
fn compute_dtw(&self, other: &Self) -> f64 {
let dim = self.length();
// Three consecutive diagonals, each diagonal
// is indexed by column, from bottom-left to top-right.
let mut prevprev = vec![0f64; dim + 1];
let mut prev = vec![0f64; dim];
let mut curr = vec![0f64; dim];
let mut out = vec![0f64; dim];
assert!(dim >= 3);
prev[0] = self.square_dist(0, 0, other);
curr[0] = prev[0] + self.square_dist(1, 0, other);
curr[1] = prev[0] + self.square_dist(0, 1, other);
// Calculate the reverse of self.data, so it's easily vectorized.
let self_data_rev: Vec<f64> = self.data.iter().rev().cloned().collect();
let mid_diagonal_idx = dim - 1;
// --- Compute DTW
for i_diagonal in 2..(2 * dim - 1) {
{
let tmp = prevprev;
prevprev = prev;
prev = curr;
curr = tmp;
}
// Column indices on current diagonal.
// Note that in the lower triangle of the matrix,
// we use the "right" part of arrays to
// store the values.
let (first, last) = if i_diagonal <= mid_diagonal_idx {
(0, i_diagonal)
} else {
(i_diagonal - mid_diagonal_idx, dim - 1)
};
// Shrink the range, so we're left only with what
// can be computed in the main loop.
let (first, last) = if i_diagonal <= mid_diagonal_idx {
// Compute the first and last element separately,
// they have only 1 dependency instead of regular 3.
curr[0] = prev[0] + self.square_dist(last, 0, other);
curr[last] = prev[last - 1] + self.square_dist(0, last, other);
(first + 1, last - 1)
} else {
(first, last)
};
do_the_thing(
&mut curr,
&prev,
&prevprev,
&self_data_rev,
&other.data,
first,
last,
dim,
);
}
curr[dim - 1]
}
// --- --- --- static functions
fn new(data: Vec<f64>) -> TimeSerie {
TimeSerie { data: data }
}
}
fn min2(x: f64, y: f64) -> f64 {
if x < y { x } else { y }
}
fn check_cli<'a>() -> clap::ArgMatches<'a> {
let matches = clap::App::new("ts")
.version("0.0")
.about("Working with time series")
.arg(clap::Arg::with_name("INPUT FILE")
.required(true)
.index(1)
.help("Input file, must be a csv")
).get_matches();
return matches;
}
fn main() {
// --- 0: Get the command line arguments
let matches = check_cli();
let file = matches.value_of("INPUT FILE").unwrap();
// --- 1: Load the CSV
let mut rdr = csv::Reader::from_file(file).unwrap();
let rows = rdr.records().map(|r| r.unwrap());
let mut vec: Vec<TimeSerie> = Vec::new();
for row in rows {
let mut iter = row.into_iter();
let _class: String = iter.next().unwrap();
let data: Vec<f64> = iter.map(|s| s.parse().unwrap()).collect();
vec.push(TimeSerie::new(data));
}
// --- 2: Compute sum of DTW
let mut total_e = 0f64;
let now = std::time::SystemTime::now();
for (id, vi) in vec.iter().enumerate() {
for vj in vec.iter().skip(id) {
total_e += vi.compute_dtw(vj);
}
}
match now.elapsed() {
Ok(elapsed) => { println!(
"{} s {} ms",
elapsed.as_secs(),
elapsed.subsec_nanos() / 1_000_000
); }
Err(_) => { () }
}
println!("Total error: {}", total_e);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment