Created
February 24, 2024 16:17
-
-
Save kujirahand/29155d11b6b2cb80dfe1a5df5b129791 to your computer and use it in GitHub Desktop.
k近傍法を使って、アヤメの分類を行うプログラム
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
// k近傍法(k-nn)によるアヤメの分類プログラム | |
use rand::seq::SliceRandom; | |
// データとラベルを持つ構造体を定義 --- (*1) | |
#[derive(Debug, Clone)] | |
struct KnnItem { | |
data: Vec<f64>, | |
label: String, | |
} | |
// k近傍法でデータを予測する --- (*2) | |
fn knn_predict(items: &[KnnItem], test: &[f64], k: usize) -> String { | |
// testとitemsの距離を求める --- (*3) | |
let mut distances = items.iter().enumerate().map(|(i, item)| { | |
(i, calc_distance(&item.data, test)) | |
}).collect::<Vec<_>>(); | |
// 距離が近い順にソート --- (*4) | |
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap()); | |
// 最も近いk個のラベルを取得 --- (*5) | |
let mut votes = std::collections::HashMap::new(); | |
distances.iter().take(k).for_each(|(i, _distance)| { | |
let label = &items[*i].label; | |
// println!(" - {}: distance={}", label, distance); | |
*votes.entry(label).or_insert(0) += 1; | |
}); | |
// 最も多いラベルを返す --- (*6) | |
let label = votes.into_iter().max_by_key(|&(_, count)| count).unwrap().0; | |
label.to_string() | |
} | |
// ユークリッド距離を求める --- (*7) | |
fn calc_distance(p1: &[f64], p2: &[f64]) -> f64 { | |
let mut distance = 0.0; | |
for (i, d) in p1.iter().enumerate() { | |
distance += (d - p2[i]).powi(2); | |
} | |
distance.sqrt() | |
} | |
// 複数データを一度に予測 --- (*8) | |
fn knn_predict_all(items: &[KnnItem], tests: &[Vec<f64>], k: usize) -> Vec<String> { | |
tests.iter().map(|test| knn_predict(items, test, k)).collect() | |
} | |
fn main() { | |
// アヤメの分類データのCSVを読み込む --- (*9) | |
let text = std::fs::read_to_string("iris.csv").unwrap(); | |
// CSVを行に分割し、各行をカンマで分割して、KnnItemに変換 --- (*10) | |
let mut items: Vec<KnnItem> = vec![]; | |
for (i, line) in text.lines().enumerate() { | |
if i == 0 { continue; } // ヘッダをスキップ | |
if line.trim().is_empty() { continue; } // 空行ならスキップ | |
// カンマで分割してVec<f64>に変換 --- (*11) | |
let parts: Vec<&str> = line.split(',').collect(); | |
let (cols, label) = parts.split_at(4); // 4:1に分ける | |
let cols: Vec<f64> = cols.into_iter().map(|s| s.trim().parse::<f64>().unwrap()).collect(); | |
items.push(KnnItem { data: cols, label: label[0].trim().to_string() }); | |
} | |
// データを評価用とテスト用に分ける --- (*12) | |
items.shuffle(&mut rand::thread_rng()); | |
let (train, test) = items.split_at(100); // 100:50に分ける | |
let test_x = test.iter().map(|item| item.data.clone()).collect::<Vec<_>>(); | |
// テスト用データを使って正解率を求める --- (*13) | |
let k = 7; | |
let test_y = knn_predict_all(&train, &test_x, k); | |
// 正解率を調べる | |
let ok = test.iter().zip(test_y.iter()).filter(|(item, label)| item.label == **label).count(); | |
let accuracy = ok as f64 / test.len() as f64; | |
println!("正解率: {}/{} = {}", ok, test.len(), accuracy); | |
// 適当なデータを与えてアヤメを予測する --- (*14) | |
let test_data = vec![5.9, 2.5, 4.4, 1.2]; | |
let label = knn_predict(&items, &test_data, k); | |
println!("{:?} => {}", test_data, label); | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment