Created
May 19, 2021 02:48
-
-
Save bschifferer/d56a23d327c11b3477f4f64e289d2e33 to your computer and use it in GitHub Desktop.
HugeCTR low level API
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#Define number of GPUs | |
NUM_GPUS = [0,1,2,3] | |
# Set config file | |
json_file = "dlrm_config.json" | |
# Set solver config | |
solver_config = solver_parser_helper(seed = 0, | |
batchsize = 16384, | |
batchsize_eval = 16384, | |
model_file = "", | |
embedding_files = [], | |
vvgpu = [NUM_GPUS], | |
use_mixed_precision = False, | |
scaler = 1.0, | |
i64_input_key = True, | |
use_algorithm_search = True, | |
use_cuda_graph = True, | |
repeat_dataset = True | |
) | |
# Set learning rate | |
lr_sch = get_learning_rate_scheduler(json_file) | |
# Train model | |
sess = Session(solver_config, json_file) | |
sess.start_data_reading() | |
for i in range(10000): | |
lr = lr_sch.get_next() | |
sess.set_learning_rate(lr) | |
sess.train() | |
if (i%100 == 0): | |
loss = sess.get_current_loss() | |
print("[HUGECTR][INFO] iter: {}; loss: {}".format(i, loss)) | |
if (i%3000 == 0 and i != 0): | |
sess.check_overflow() | |
sess.copy_weights_for_evaluation() | |
metrics = sess.evaluation() | |
print("[HUGECTR][INFO] iter: {}, {}".format(i, metrics)) | |
sess.download_params_to_files("./", i+1) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment