Skip to content

Instantly share code, notes, and snippets.

@Hehehe421
Last active September 6, 2023 02:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save Hehehe421/f76acf0a7a12031943c23e2176b4caf9 to your computer and use it in GitHub Desktop.
Save Hehehe421/f76acf0a7a12031943c23e2176b4caf9 to your computer and use it in GitHub Desktop.
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.models.signature import infer_signature
# Create a sample test DataFrame
test_data = pd.DataFrame({
'area': ['Any Region'],
'user_id': [100]
})
# Initialize MLflow client
client = MlflowClient()
# Create an instance of the RankingModel
model = RankingModel()
# Start an MLflow run
with mlflow.start_run() as run:
# Call the `predict()` method on the instantiated model with the required arguments
prediction = model.predict(context=None, model_input=test_data)
# Infer the signature of the predict function
signature = infer_signature(test_data, prediction)
# Log the model artifact to MLflow
mlflow.pyfunc.log_model(
artifact_path="sql_model",
python_model=model,
input_example=test_data,
signature=signature
)
# Register the model to the model registry
mv = mlflow.register_model(f'runs:/{run.info.run_id}/sql_model', "sql_model")
client.transition_model_version_stage(f'sql_model', mv.version, "Production", archive_existing_versions=True)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment