This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
from project.pkg2.t2 import print_t2 | |
print_t2("t1") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def print_t2(x="t2"): | |
print(f"{x} called") |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Reference code from https://stackoverflow.com/questions/49994496/mixing-multiple-tf-data-dataset | |
""" | |
def stack_windows(*windows): | |
features = tf.concat([window[0] for window in windows], 0) | |
labels = tf.concat([window[1] for window in windows], 0) | |
return (features, labels) | |
def make_dataset(self:MultiSeriesWindowGenerator, data:tf.Tensor) -> tf.data.Dataset: |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def preprocess_dataset(self:MultiSeriesWindowGenerator, data:pd.DataFrame): | |
try: | |
if np.vstack(data.index).shape[1] != 1: | |
data = data.reset_index() | |
by = self.GROUPBY + [DATE] | |
labels = self.label_columns + self.regressor_columns + self.static_columns | |
data = data.set_index(by).unstack(-1) | |
data = tf.stack([data[label] for label in labels], axis=-1) | |
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class MultiSeriesWindowGenerator(): | |
def __init__(self, | |
input_width, label_width, shift, batch_size, | |
label_columns=[], GROUPBY=None, regressor_columns=[], static_columns=[] | |
): | |
self.batch_size = batch_size | |
# Work out the label column indices. | |
self.label_columns = label_columns |
NewerOlder