Skip to content

Instantly share code, notes, and snippets.

@arm2arm
Created February 23, 2023 22:04
Show Gist options
  • Save arm2arm/3445a50e8020c9774a262a280c901c9f to your computer and use it in GitHub Desktop.
Save arm2arm/3445a50e8020c9774a262a280c901c9f to your computer and use it in GitHub Desktop.
# Import required libraries
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from sklearn.model_selection import train_test_split
import xgboost as xgb
import numpy as np
from tensorflow.keras.applications.vgg16 import VGG16
from tensorflow.keras.preprocessing.image import ImageDataGenerator
# Load CIFAR10 dataset
(x_train, y_train), (x_test, y_test) = cifar10.load_data()
# Preprocess data
x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255.0
x_test /= 255.0
# Split data into train and validation sets
x_train, x_val, y_train, y_val = train_test_split(x_train, y_train, test_size=0.2)
# Define VGG16 model for feature extraction
vgg16_model = VGG16(include_top=False, weights='imagenet', input_shape=(32, 32, 3))
# Extract features from train, validation, and test sets
train_features = vgg16_model.predict(x_train)
val_features = vgg16_model.predict(x_val)
test_features = vgg16_model.predict(x_test)
# Flatten features
train_features = np.reshape(train_features, (train_features.shape[0], -1))
val_features = np.reshape(val_features, (val_features.shape[0], -1))
test_features = np.reshape(test_features, (test_features.shape[0], -1))
# Define XGBoost model
xgb_model = xgb.XGBClassifier(objective='multi:softmax', num_class=10, max_depth=6, eta=0.3)
# Train XGBoost model on extracted features
xgb_model.fit(train_features, y_train)
# Evaluate XGBoost model on validation set
accuracy = xgb_model.score(val_features, y_val)
print(f"Accuracy: {accuracy}")
# Predict on test set
y_pred = xgb_model.predict(test_features)
# Evaluate XGBoost model on test set
accuracy = np.sum(y_pred == y_test) / len(y_test)
print(f"Accuracy: {accuracy}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment