Skip to content

Instantly share code, notes, and snippets.

@SaremS
Created May 9, 2023 15:07
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save SaremS/3d801020cd1083224acc97748a6c289d to your computer and use it in GitHub Desktop.
Save SaremS/3d801020cd1083224acc97748a6c289d to your computer and use it in GitHub Desktop.
from copy import deepcopy
from neuralforecast import NeuralForecast
from neuralforecast.models import NBEATS, NHITS
train_nxt = pd.DataFrame(train).reset_index()
train_nxt.columns = ["ds","y"]
train_nxt["unique_id"] = np.ones(len(train))
test_nxt = pd.DataFrame(test).reset_index()
test_nxt.columns = ["ds","y"]
test_nxt["unique_id"] = np.ones(len(test))
horizon = len(test_nxt)
models = [NBEATS(input_size=2 * horizon, h=horizon,max_epochs=50),
NHITS(input_size=2 * horizon, h=horizon,max_epochs=50)]
nf = NeuralForecast(models=models, freq='M')
nf.fit(df=train_nxt)
Y_hat_df = nf.predict().reset_index()
nbeats = Y_hat_df["NBEATS"]
nhits = Y_hat_df["NHITS"]
rmse_simple = np.sqrt(np.mean((test.values-result_mean)**2))
rmse_nbeats = np.sqrt(np.mean((test.values-nbeats.values)**2))
rmse_nhits = np.sqrt(np.mean((test.values-nhits.values)**2))
pd.DataFrame([rmse_simple,rmse_nbeats,rmse_nhits], index = ["Simple", "NBEATS", "NHITS"], columns=["RMSE"])
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment