Created
May 14, 2024 13:56
-
-
Save BexTuychiev/fd0e5dd8b103f99877230cbdeb5bb9a4 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pandas as pd | |
from xgboost import XGBClassifier | |
from pathlib import Path | |
from sklearn.preprocessing import LabelEncoder | |
from google.cloud import storage | |
import joblib | |
# Path to your CSV file in GCS bucket | |
gcs_path = "gs://vertex-tutorial-bucket-bex" | |
dataset_path = gcs_path + "/dry_bean.csv" | |
beans = pd.read_csv(dataset_path) | |
X = beans.drop('Class', axis=1) | |
# Encode the target | |
le = LabelEncoder() | |
y = le.fit_transform(beans['Class']) | |
# Define the model with some initial parameters | |
model = XGBClassifier(objective='multi:softprob', n_estimators=100) | |
# Train the model | |
model.fit(X, y) | |
print("Model trained successfully!") | |
# Save the model locally | |
artifact_filename = 'model.joblib' | |
# Save model artifact to local filesystem (doesn't persist) | |
local_path = artifact_filename | |
joblib.dump(model.get_booster(), local_path) | |
# Save the model to GCP | |
storage_path = gcs_path + "/" + artifact_filename | |
blob = storage.blob.Blob.from_string(storage_path, client=storage.Client()) | |
blob.upload_from_filename(local_path) | |
print("Model uploaded successfully!") |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment