Skip to content

Instantly share code, notes, and snippets.

@kujirahand
Created February 24, 2024 16:17
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save kujirahand/29155d11b6b2cb80dfe1a5df5b129791 to your computer and use it in GitHub Desktop.
Save kujirahand/29155d11b6b2cb80dfe1a5df5b129791 to your computer and use it in GitHub Desktop.
k近傍法を使って、アヤメの分類を行うプログラム
// k近傍法(k-nn)によるアヤメの分類プログラム
use rand::seq::SliceRandom;
// データとラベルを持つ構造体を定義 --- (*1)
#[derive(Debug, Clone)]
struct KnnItem {
data: Vec<f64>,
label: String,
}
// k近傍法でデータを予測する --- (*2)
fn knn_predict(items: &[KnnItem], test: &[f64], k: usize) -> String {
// testとitemsの距離を求める --- (*3)
let mut distances = items.iter().enumerate().map(|(i, item)| {
(i, calc_distance(&item.data, test))
}).collect::<Vec<_>>();
// 距離が近い順にソート --- (*4)
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
// 最も近いk個のラベルを取得 --- (*5)
let mut votes = std::collections::HashMap::new();
distances.iter().take(k).for_each(|(i, _distance)| {
let label = &items[*i].label;
// println!(" - {}: distance={}", label, distance);
*votes.entry(label).or_insert(0) += 1;
});
// 最も多いラベルを返す --- (*6)
let label = votes.into_iter().max_by_key(|&(_, count)| count).unwrap().0;
label.to_string()
}
// ユークリッド距離を求める --- (*7)
fn calc_distance(p1: &[f64], p2: &[f64]) -> f64 {
let mut distance = 0.0;
for (i, d) in p1.iter().enumerate() {
distance += (d - p2[i]).powi(2);
}
distance.sqrt()
}
// 複数データを一度に予測 --- (*8)
fn knn_predict_all(items: &[KnnItem], tests: &[Vec<f64>], k: usize) -> Vec<String> {
tests.iter().map(|test| knn_predict(items, test, k)).collect()
}
fn main() {
// アヤメの分類データのCSVを読み込む --- (*9)
let text = std::fs::read_to_string("iris.csv").unwrap();
// CSVを行に分割し、各行をカンマで分割して、KnnItemに変換 --- (*10)
let mut items: Vec<KnnItem> = vec![];
for (i, line) in text.lines().enumerate() {
if i == 0 { continue; } // ヘッダをスキップ
if line.trim().is_empty() { continue; } // 空行ならスキップ
// カンマで分割してVec<f64>に変換 --- (*11)
let parts: Vec<&str> = line.split(',').collect();
let (cols, label) = parts.split_at(4); // 4:1に分ける
let cols: Vec<f64> = cols.into_iter().map(|s| s.trim().parse::<f64>().unwrap()).collect();
items.push(KnnItem { data: cols, label: label[0].trim().to_string() });
}
// データを評価用とテスト用に分ける --- (*12)
items.shuffle(&mut rand::thread_rng());
let (train, test) = items.split_at(100); // 100:50に分ける
let test_x = test.iter().map(|item| item.data.clone()).collect::<Vec<_>>();
// テスト用データを使って正解率を求める --- (*13)
let k = 7;
let test_y = knn_predict_all(&train, &test_x, k);
// 正解率を調べる
let ok = test.iter().zip(test_y.iter()).filter(|(item, label)| item.label == **label).count();
let accuracy = ok as f64 / test.len() as f64;
println!("正解率: {}/{} = {}", ok, test.len(), accuracy);
// 適当なデータを与えてアヤメを予測する --- (*14)
let test_data = vec![5.9, 2.5, 4.4, 1.2];
let label = knn_predict(&items, &test_data, k);
println!("{:?} => {}", test_data, label);
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment