This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def train_model_from_tfrecords(tr_data_paths, val_data_paths, model): | |
start_time = time.perf_counter() | |
tr_serialized_dataset = tf.data.TFRecordDataset(tr_data_paths, compression_type='GZIP') | |
tr_dataset = tr_serialized_dataset.shuffle(100_000).batch(batch_size).map(parse_tfrecord) | |
val_serialized_dataset = tf.data.TFRecordDataset(val_data_paths, compression_type='GZIP') | |
val_dataset = val_serialized_dataset.batch(batch_size).map(parse_tfrecord) | |
model.fit(tr_dataset, validation_data=val_dataset, epochs=n_epochs, | |
**model_configs) | |
print(f"Elapsed time: {time.perf_counter() - start_time}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
features = ['age','sales_channel_id', 'club_member_status', 'fashion_news_frequency'] | |
inputs = [] | |
for f in features: | |
inputs.append(tf.keras.layers.Input(shape=(1,), name=f)) | |
concat_input = tf.keras.layers.concatenate(inputs) | |
dense = tf.keras.layers.Dense(4, activation='relu')(concat_input) | |
dense = tf.keras.layers.Dense(4, activation='relu')(dense) | |
dense = tf.keras.layers.Dense(4, activation='relu')(dense) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Configurations for tfrecords | |
""" | |
def parse_tfrecord(serialized_example): | |
# example = tf.io.parse_single_example(serialized_example, feature_desc) | |
# for batch | |
example = tf.io.parse_example(serialized_example, feature_desc) | |
# since it does not provide projection, need to store it w/o or exclude it manually | |
example.pop("t_dat") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def _bytes_feature(value): | |
if isinstance(value, type(tf.constant(0))): | |
value = value.numpy() # BytesList won't unpack a string from an EagerTensor. | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def _float_feature(value): | |
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) | |
def _int64_feature(value): | |
return tf.train.Feature(int64_list=tf.train.Int64List(value=[value])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def parse_dataset(e): | |
tensors = [] | |
for col_nm in features: | |
tensors.append(getattr(e, col_nm)) | |
X = tf.cast(tf.stack(tensors, axis=1), tf.float32) | |
y = getattr(e, target_col) | |
return X, y | |
def train_model_from_petastorm(tr_parquet_path, val_parquet_path, model): | |
start_time = time.perf_counter() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Configurations for petastorm | |
""" | |
from petastorm import make_reader, make_batch_reader | |
from petastorm.tf_utils import make_petastorm_dataset | |
batch_size = 512 | |
n_epochs = 10 | |
features = ['age','sales_channel_id', 'club_member_status', 'fashion_news_frequency'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
inputs = tf.keras.layers.Input(shape=(4,)) | |
dense = tf.keras.layers.Dense(4, activation='relu')(inputs) | |
dense = tf.keras.layers.Dense(4, activation='relu')(dense) | |
dense = tf.keras.layers.Dense(4, activation='relu')(dense) | |
outputs = tf.keras.layers.Dense(1, activation=None)(dense) | |
model = tf.keras.Model(inputs=inputs, outputs=outputs) | |
model.compile( | |
loss = tf.keras.losses.mean_squared_error, | |
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001), | |
metrics = [tf.keras.metrics.RootMeanSquaredError()]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
%%time | |
def _bytes_feature(value): | |
if isinstance(value, type(tf.constant(0))): | |
value = value.numpy() # BytesList won't unpack a string from an EagerTensor. | |
return tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) | |
def _float_feature(value): | |
return tf.train.Feature(float_list=tf.train.FloatList(value=[value])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
def save_multi_parquets(df, folder_path, chunk_size): | |
for idx, df_chunk in enumerate(np.array_split(df, chunk_size)): | |
df_chunk.to_parquet(f'{folder_path}/train_df_{idx}.parquet', index=False) | |
save_multi_parquets(train_df, os.path.join(single_parquet_path, 'optimal'),chunk_size=5) | |
save_multi_parquets(train_df, os.path.join(single_parquet_path, 'multi'), chunk_size=200) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
self.mp_df_sorted = (self.train_df | |
.groupby("movieId") | |
.agg({"rating": ["mean", "count"]})['rating'] | |
.sort_values(["mean", "count"], ascending=False)) |
NewerOlder