Last active
September 5, 2023 23:02
-
-
Save Hehehe421/e11bc2aa7dc8b4a719b96fd68c3c95e2 to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import mlflow | |
from mlflow.tracking import MlflowClient | |
# Define your RankingModel class that inherits from mlflow.pyfunc.PythonModel | |
class RankingModel(mlflow.pyfunc.PythonModel): | |
def predict(self, context, model_input): | |
# Initialize Spark session | |
spark = SparkSession.builder.appName('RankingModel').getOrCreate() | |
# Extract region from the model input dictionary (not used in this case) | |
region = model_input["region"][0] | |
# Call the rank_listings_by_region function | |
return rank_listings_by_region(region, spark).toPandas() | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment