This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class Baseline(tf.keras.Model): | |
def __init__(self, label_index=None): | |
super().__init__() | |
self.label_index = label_index | |
def call(self, inputs): | |
if self.label_index is None: | |
return inputs | |
result = inputs[:, :, self.label_index] | |
# Convert the labels to one hot for the loss functions to be able to calculate the loss correctly. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def plot(self, model=None, plot_col=REGRESSORS[0], label_col=LABELS[0], max_subplots=3): | |
inputs, labels = self.example | |
plt.figure(figsize=(12, 8)) | |
# Earlier this code was inside the for loop. This code does not change based on the iteration. So I moved it outside the loop. | |
plot_col_index = self.column_indices[plot_col] | |
max_n = min(max_subplots, len(inputs)) | |
if self.label_columns: | |
label_col_index = self.label_columns_indices.get(label_col, None) | |
else: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def update_datasets(self:MultiSeriesWindowsGenerator, train_df:pd.DataFrame, val_df:pd.DataFrame, test_df:pd.DataFrame, norm:bool=False): | |
# Store the raw data. We now get the features and targets separately from preprocess_dataset. | |
# We will need to concat them together later to be backward compatible. | |
train_df, train_targets = self.preprocess_dataset(train_df) | |
val_df, val_targets = self.preprocess_dataset(val_df) | |
test_df, test_targets = self.preprocess_dataset(test_df) | |
# We should not normalize the targets. This is the reason why we had to separate the features and targets in the previous step. |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def preprocess_dataset(self:MultiSeriesWindowsGenerator, data:pd.DataFrame): | |
try: | |
if np.vstack(data.index).shape[1] != 1: | |
data = data.reset_index() | |
by = self.GROUPBY + [DATE] | |
labels = self.regressor_columns + self.static_columns | |
data = data.set_index(by).unstack(-1) | |
features = tf.stack([data[label] for label in labels], axis=-1) | |
targets = tf.stack([data[label] for label in self.label_columns], axis=-1) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
series = pd.concat([series_1, series_2, series_3], axis=0).reset_index(drop=True) | |
for label in LABELS: | |
series[label] = pd.qcut(series[label], 4, labels=[0, 1, 2, 3]).astype(float) |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def print_t2(x="t2"): | |
print(f"{x} called") | |
def print_t5(x="t5"): | |
print(f"{x} called 594859..") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[build-system] | |
requires = ["setuptools"] | |
build-backend = "setuptools.build_meta" | |
[metadata] | |
name = "project" | |
version = "0.0.1" | |
[tool.setuptools.packages] | |
find = {} # Scan the project directory with the default parameters |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[tool.setuptools.packages] | |
find = {} # Scan the project directory with the default parameters |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[metadata] | |
name = "project" | |
version = "0.0.1" | |
[project] | |
name = "project" | |
version = "0.0.1" | |
authors = [ | |
{ name="your_name", email="your_email" }, |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
[build-system] | |
requires = ["setuptools"] | |
build-backend = "setuptools.build_meta" |
NewerOlder