Skip to content

Instantly share code, notes, and snippets.

@srishanbhattarai
Created March 13, 2020 08:20
Show Gist options
  • Save srishanbhattarai/66fed4241a304ddca77902ad25e7e71d to your computer and use it in GitHub Desktop.
Save srishanbhattarai/66fed4241a304ddca77902ad25e7e71d to your computer and use it in GitHub Desktop.
Basic Markov chain in Rust (clones stuff around so not the most performant, gets job done)
use std::collections::HashMap;
use std::error::Error;
use std::fs::File;
use std::io::{BufRead, BufReader};
use std::path::Path;
struct MarkovChain {
order: usize,
chain: HashMap<Vec<String>, HashMap<String, usize>>,
freqs: HashMap<Vec<String>, usize>,
}
impl Default for MarkovChain {
fn default() -> Self {
MarkovChain::with_order(1)
}
}
impl MarkovChain {
fn with_order(order: usize) -> Self {
MarkovChain {
order,
chain: HashMap::new(),
freqs: HashMap::new(),
}
}
fn train_sentence(&mut self, s: String) {
let mut words: Vec<String> = s.split(' ').map(|s| s.to_string()).collect();
dbg!(words.clone());
for i in 0..(words.len() - self.order) {
let curr: Vec<String> = words.drain(i..(i + self.order)).collect();
let next = (words[0]).clone();
let entry = self.chain.entry(curr.clone()).or_insert(HashMap::new());
let occurences = entry.entry(next).or_insert(0);
*occurences += 1;
let freq = self.freqs.entry(curr).or_insert(0);
*freq += 1;
}
}
pub fn train_file(&mut self, p: &Path) -> Result<(), Box<Error>> {
let f = File::open(&p)?;
let reader = BufReader::new(f);
for line in reader.lines() {
self.train_sentence(line.unwrap());
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn it_works() {
let mut p = Path::new("test.txt");
let mut chain: MarkovChain = Default::default();
assert!(chain.train_file(p).is_ok());
for (k, v) in chain.chain.into_iter() {
let key = k.join(",");
println!("Key: {}", key);
for (ik, iv) in v.into_iter() {
println!("{} = {}", ik, iv);
}
}
}
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment