Created
February 15, 2023 10:21
-
-
Save Astroneko404/c4fa3bd1f8f88827b159fe7ca0a3960b to your computer and use it in GitHub Desktop.
tensorflow removed
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# import tensorflow as tf | |
from text import symbols | |
class AttrDict(dict): | |
def __init__(self, *args, **kwargs): | |
super(AttrDict, self).__init__(*args, **kwargs) | |
self.__dict__ = self | |
def create_hparams(hparams_string=None, verbose=False): | |
"""Create model hyperparameters. Parse nondefault from given string.""" | |
hparams = AttrDict({ | |
################################ | |
# Experiment Parameters # | |
################################ | |
"epochs":100, | |
"iters_per_checkpoint":1000, | |
"seed":1234, | |
"dynamic_loss_scaling":True, | |
"fp16_run":False, | |
"distributed_run":False, | |
"dist_backend":"nccl", | |
"dist_url":"tcp://localhost:54321", | |
"cudnn_enabled":True, | |
"cudnn_benchmark":False, | |
"ignore_layers":['embedding.weight'], | |
################################ | |
# Data Parameters # | |
################################ | |
"load_mel_from_disk":False, | |
"training_files":'filelists/repeat500_transcript_train.txt', | |
"validation_files":'filelists/repeat500_transcript_val.txt', | |
"text_cleaners":['basic_cleaners'], | |
################################ | |
# Audio Parameters # | |
################################ | |
"max_wav_value":32768.0, | |
"sampling_rate":22050, | |
"filter_length":1024, | |
"hop_length":256, | |
"win_length":1024, | |
"n_mel_channels":80, | |
"mel_fmin":0.0, | |
"mel_fmax":8000.0, | |
################################ | |
# Model Parameters # | |
################################ | |
"n_symbols":len(symbols), | |
"symbols_embedding_dim":512, | |
# Encoder parameters | |
"encoder_kernel_size":5, | |
"encoder_n_convolutions":3, | |
"encoder_embedding_dim":512, | |
# Decoder parameters | |
"n_frames_per_step":1, # currently only 1 is supported | |
"decoder_rnn_dim":1024, | |
"prenet_dim":256, | |
"max_decoder_steps":2000, | |
"gate_threshold":0.5, | |
"p_attention_dropout":0.1, | |
"p_decoder_dropout":0.1, | |
# Attention parameters | |
"attention_rnn_dim":1024, | |
"attention_dim":128, | |
# Location Layer parameters | |
"attention_location_n_filters":32, | |
"attention_location_kernel_size":31, | |
# Mel-post processing network parameters | |
"postnet_embedding_dim":512, | |
"postnet_kernel_size":5, | |
"postnet_n_convolutions":5, | |
################################ | |
# Optimization Hyperparameters # | |
################################ | |
"use_saved_learning_rate":False, | |
"learning_rate":1e-3, | |
"weight_decay":1e-6, | |
"grad_clip_thresh":1.0, | |
"batch_size":32, | |
"mask_padding":True # set model's padded outputs to padded values | |
}) | |
if hparams_string: | |
hps = hparams_string[1:-2].split("-") | |
for hp in hps: | |
k,v = hp.split(":") | |
if k in hparams: | |
hparams[k] = v | |
print("Set hparam: " + k + " to " + v) | |
# if hparams_string: | |
# tf.logging.info('Parsing command line hparams: %s', hparams_string) | |
# hparams.parse(hparams_string) | |
# if verbose: | |
# tf.logging.info('Final parsed hparams: %s', hparams.values()) | |
return hparams |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment