Skip to content

Instantly share code, notes, and snippets.

@matthewjberger
Created February 6, 2024 03:11
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save matthewjberger/ef5f02e619e33aff097f7def840c086e to your computer and use it in GitHub Desktop.
Save matthewjberger/ef5f02e619e33aff097f7def840c086e to your computer and use it in GitHub Desktop.
Fast fourier transform demonstration in rust
use std::f32::consts::PI;
// Complex number struct
#[derive(Clone, Copy)]
struct Complex {
re: f32,
im: f32,
}
// Implementing basic operations for Complex numbers
impl Complex {
fn new(re: f32, im: f32) -> Self {
Complex { re, im }
}
fn from_polar(r: f32, theta: f32) -> Self {
Complex::new(r * theta.cos(), r * theta.sin())
}
fn add(self, other: Complex) -> Complex {
Complex::new(self.re + other.re, self.im + other.im)
}
fn sub(self, other: Complex) -> Complex {
Complex::new(self.re - other.re, self.im - other.im)
}
fn mul(self, other: Complex) -> Complex {
Complex::new(
self.re * other.re - self.im * other.im,
self.re * other.im + self.im * other.re,
)
}
}
// Recursive FFT function
fn fft(signal: &[Complex]) -> Vec<Complex> {
let n = signal.len();
if n <= 1 {
return signal.to_vec();
}
let even: Vec<Complex> = signal.iter().step_by(2).cloned().collect();
let odd: Vec<Complex> = signal.iter().skip(1).step_by(2).cloned().collect();
let fft_even = fft(&even);
let fft_odd = fft(&odd);
let mut result = vec![Complex::new(0.0, 0.0); n];
for k in 0..n / 2 {
let exp = Complex::from_polar(1.0, -2.0 * PI * k as f32 / n as f32);
result[k] = fft_even[k].add(exp.mul(fft_odd[k]));
result[k + n / 2] = fft_even[k].sub(exp.mul(fft_odd[k]));
}
result
}
// A simple test function
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_fft() {
let n = 8;
let mut signal = vec![Complex::new(0.0, 0.0); n];
let freq = 1.0;
for i in 0..n {
signal[i] = Complex::new(((2.0 * PI * freq * i as f32 / n as f32).sin()), 0.0);
}
let fft_result = fft(&signal);
// Check if the FFT of a sine wave peaks at the frequency of the wave
assert!((fft_result[1].re * fft_result[1].re + fft_result[1].im * fft_result[1].im).sqrt() > 1.0);
}
}
fn main() {
let n = 8;
let mut signal = vec![Complex::new(0.0, 0.0); n];
let freq = 1.0;
for i in 0..n {
signal[i] = Complex::new(((2.0 * PI * freq * i as f32 / n as f32).sin()), 0.0);
}
let fft_result = fft(&signal);
for (i, c) in fft_result.iter().enumerate() {
println!("{}: {} + {}i", i, c.re, c.im);
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment