Skip to content

Instantly share code, notes, and snippets.

@jeanmidevacc
Last active January 31, 2022 14:32
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save jeanmidevacc/a112d504a2c85b29839c0d040a19ab84 to your computer and use it in GitHub Desktop.
Save jeanmidevacc/a112d504a2c85b29839c0d040a19ab84 to your computer and use it in GitHub Desktop.
import pandas as pd
from kats import models
# Selection your base model
def build_model(model, kts):
if model == "prophet":
return models.prophet.ProphetModel(kts, params=models.prophet.ProphetParams())
elif model == "theta":
return models.theta.ThetaModel(kts, params=models.theta.ThetaParams())
elif model == "holtwinters":
return models.holtwinters.HoltWintersModel(kts, params=models.holtwinters.HoltWintersParams())
elif model == "linear":
return models.linear_model.LinearModel(kts, params=models.linear_model.LinearModelParams())
elif model == "quadratic":
return models.quadratic_model.QuadraticModel(kts, params=models.quadratic_model.QuadraticModelParams())
print(f"{model} is not a right input , using default prophet")
return prophet.ProphetModel(kts, params = models.prophet.ProphetParams())
# Load the data
dfp_data = pd.read_csv(folder + "kaggle_playground_january_2022/train.csv")
dfp_data["category"] = dfp_data.apply(lambda row: row["country"] + "-" + row["store"] + "-" + row["product"], axis=1)
categories = dfp_data["category"].unique()
# Build the category specific timeSeriesData
category = categories[0]
kts_category = build_kats_timeserie(dfp_data[dfp_data["category"] == category], "date", "num_sold")
# Train a model of forecast
model = build_model("theta", kts_category)
model.fit()
# Build predictions
dfp_forecast = model.predict(steps=365, freq="D")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment