Skip to content

Instantly share code, notes, and snippets.

@HarshTrivedi
Created December 17, 2021 20:20
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save HarshTrivedi/c1f54b0b532f847cbddaaf39042dca2b to your computer and use it in GitHub Desktop.
Save HarshTrivedi/c1f54b0b532f847cbddaaf39042dca2b to your computer and use it in GitHub Desktop.
#########################################################
local setting = std.extVar("setting"); #options: full, fixture
local num_cores = std.parseInt(std.extVar("num_cores"));
#########################################################
# Set this for memory optimization
local activation_checkpointing = false;
local seed = 100;
local transformer_model_name = "nielsr/nt5-small-rc1";
local batch_size = 32;
local accumulation_steps = 1;
local max_context_tokens = 700;
local max_question_tokens = 100;
local max_answer_tokens = 200;
local shuffle_context = false;
local skip_context = true;
local use_program_as_question = false;
local use_program_last_step_as_question = false;
local generate_final_answers = true;
local generate_intermediate_answers = true;
local generate_step_instructions = true;
local intermediate_chain_or_output = "output";
local max_answer_tokens = (
(if generate_final_answers then 1 else 0)*50 +
(if generate_intermediate_answers && intermediate_chain_or_output == "chain" then 1 else 0)*500+
(if generate_intermediate_answers && intermediate_chain_or_output == "output" then 1 else 0)*200+
(if generate_step_instructions then 1 else 0)*100
);
local target_is_serialized = true;
local fixture_path = 'fixtures/synthetic_data/sampled_for_mtl_one_hop.jsonl';
local train_data_path =
if setting == 'full'
then 'processed_data/synthetic_data/sampled_for_mtl/one_hop_train.jsonl'
else if setting == 'fixture'
then fixture_path
else '-';
local validation_data_path =
if setting == 'full'
then 'processed_data/synthetic_data/sampled_for_mtl/one_hop_dev.jsonl'
else if setting == 'fixture'
then fixture_path
else '-';
local num_epochs =
if setting == 'full'
then 20
else if setting == 'fixture'
then 30
else 2;
local patience =
if setting == 'full'
then 20
else if setting == 'fixture'
then num_epochs
else num_epochs;
local dataset_reader = {
"type": "synthetic_dbqa",
"transformer_model_name": transformer_model_name,
"max_context_tokens": max_context_tokens,
"max_question_tokens": max_question_tokens,
"max_answer_tokens": max_answer_tokens,
"shuffle_context": shuffle_context,
"skip_context": skip_context,
"use_program_as_question": use_program_as_question,
"use_program_last_step_as_question": use_program_last_step_as_question,
"generate_final_answers": generate_final_answers,
"generate_intermediate_answers": generate_intermediate_answers,
"generate_step_instructions": generate_step_instructions,
"intermediate_chain_or_output": intermediate_chain_or_output,
"add_additional_tokens": true
};
local data_loader = {
"batch_size": batch_size,
"shuffle": true,
"num_workers": 20,
"max_instances_in_memory": batch_size*50,
[if setting == 'fixture' then "batches_per_epoch"]: 200*accumulation_steps,
};
local tensorboard_callback = {"type": "tensorboard"};
local wandb_callback = {
"type": "wandb",
"project": "synth2realmh",
"entity": "harshtrivedi",
"name": std.extVar("WANDB_RUN_NAME"),
"watch_model": false,
"summary_interval": 1,
"should_log_parameter_statistics": false,
"should_log_learning_rate": false,
};
{
"train_data_path": train_data_path,
"validation_data_path": validation_data_path,
"dataset_reader": dataset_reader,
"validation_dataset_reader": dataset_reader,
"model": {
"type": "qa_t5",
"model_name": transformer_model_name,
"beam_search": {
"beam_size": 5,
"max_steps": max_answer_tokens,
},
"target_is_serialized": target_is_serialized,
[if activation_checkpointing then "checkpoint_wrapper"]: {
"type": "fairscale",
"offload_to_cpu": true,
"maintain_forward_counter": true,
},
},
"data_loader": data_loader,
"validation_data_loader": self.data_loader + {
"max_instances_in_memory": null,
"batches_per_epoch": null,
"batch_size": batch_size,
},
"vocabulary": {
"type": "empty",
},
"trainer": {
"cuda_device": 0,
"use_amp": false,
"num_epochs": num_epochs,
"patience": patience,
"num_gradient_accumulation_steps": accumulation_steps,
"optimizer": {
"type": "huggingface_adafactor",
},
"grad_norm": 1.0,
"callbacks": if std.extVar("on_beaker") == "true"
then [tensorboard_callback, wandb_callback]
else [tensorboard_callback],
"validation_metric": "+df_match_score",
},
"random_seed": seed,
"numpy_seed": seed,
"pytorch_seed": seed,
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment