Skip to content

Instantly share code, notes, and snippets.

@eloquentarduino
Created February 10, 2023 07:35
Show Gist options
  • Save eloquentarduino/7ca38d0f123941a6faf58fcf2b784139 to your computer and use it in GitHub Desktop.
Save eloquentarduino/7ca38d0f123941a6faf58fcf2b784139 to your computer and use it in GitHub Desktop.
import numpy as np
from principalfft import PrincipalFFT
from numpy.fft import rfft
from sklearn.model_selection import train_test_split
from sklearn.datasets import load_digits
from sklearn.ensemble import RandomForestClassifier
mnist = load_digits()
X, y = mnist.data, mnist.target
Xfft = PrincipalFFT(n_components=8).fit_transform(X)
Xfft_full = np.abs(rfft(X))
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3)
Xfft_train, Xfft_test, yfft_train, yfft_test = train_test_split(Xfft, y, test_size=0.3)
Xfft_full_train, Xfft_full_test, yfft_full_train, yfft_full_test = train_test_split(Xfft_full, y, test_size=0.3)
clf = RandomForestClassifier(50, min_samples_leaf=5, random_state=0).fit(X_train, y_train)
print("Raw score", clf.score(X_test, y_test))
clf = RandomForestClassifier(50, min_samples_leaf=5, random_state=0).fit(Xfft_full_train, yfft_full_train)
print("FFT (full) score", clf.score(Xfft_full_test, yfft_full_test))
clf = RandomForestClassifier(50, min_samples_leaf=5, random_state=0).fit(Xfft_train, yfft_train)
print("FFT (k=1/4) score", clf.score(Xfft_test, yfft_test))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment