Skip to content

Instantly share code, notes, and snippets.

@catnipan
Last active January 4, 2021 19:53
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save catnipan/07a60ba3420ef496744b9e4c3c60733c to your computer and use it in GitHub Desktop.
Save catnipan/07a60ba3420ef496744b9e4c3c60733c to your computer and use it in GitHub Desktop.
recurrence - a Rust macro example
// Credit to https://danielkeep.github.io/practical-intro-to-macros.html
macro_rules! count_exprs {
() => (0);
($head:expr $(, $tail:expr)*) => (1 + count_exprs!($($tail),*));
}
macro_rules! recurrence {
( $seq:ident [ $ind:ident ]: $sty:ty = $recur:expr, $($inits:expr),+) => {
{
const MEMORY: usize = count_exprs!($($inits),+);
#[derive(Debug)]
struct Recurrence {
mem: [$sty; MEMORY],
pos: usize,
}
struct IndexOffset<'a> {
slice: &'a [$sty; MEMORY],
offset: usize,
}
impl<'a> std::ops::Index<usize> for IndexOffset<'a> {
type Output = $sty;
#[inline(always)]
fn index<'b>(&'b self, index: usize) -> &'b $sty {
let real_index = index + MEMORY - self.offset;
&self.slice[real_index]
}
}
impl Iterator for Recurrence {
type Item = $sty;
#[inline]
fn next(&mut self) -> Option<$sty> {
if self.pos < MEMORY {
let next_val = self.mem[self.pos];
self.pos += 1;
Some(next_val)
} else {
let next_val = {
let $ind = self.pos;
let $seq = IndexOffset { slice: &self.mem, offset: $ind };
$recur
};
{
use std::mem::swap;
let mut swap_tmp = next_val;
for i in (0..MEMORY).rev() {
swap(&mut swap_tmp, &mut self.mem[i]);
}
}
self.pos += 1;
Some(next_val)
}
}
}
Recurrence { mem: [$($inits),+], pos: 0 }
}
};
}
fn main() {
let fib = recurrence![a[n]: u64 = a[n-1] + a[n-2], 0, 1];
assert_eq!(fib.take(10).collect::<Vec<_>>(), vec![0, 1, 1, 2, 3, 5, 8, 13, 21, 34]);
let factorial = recurrence![a[n]: usize = a[n-1] * n, 1];
assert_eq!(factorial.take(10).collect::<Vec<_>>(), vec![1, 1, 2, 6, 24, 120, 720, 5040, 40320, 362880]);
let my = recurrence![a[n]: usize = {
if n % 2 == 0 {
a[n - 1] + 1
} else {
a[n - 1] * 2
}
}, 3];
assert_eq!(my.take(10).collect::<Vec<_>>(), vec![3, 6, 7, 14, 15, 30, 31, 62, 63, 126]);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment