Simplified version of Dually Ground BackTranslation for AutoFormalization:
def train_to_af_for_maf(mdl : causal_lm,
formal_data_set, # e.g., ITP lib like mathlib
informal_data_set, # e.g., time-tested maths textbook e.g., Rudin, CLRS.
):
for (nl, fl*) in formal_data_set; for (nl*, fl) in informal_data_set;
# -- Learn to Formalize: nl_i->fl* from fl* -> [nl_i]_i -> fl*
[nl_i]_i := mdl("informalize " + fl*, sampling=top_p, num_out=k) # noise is good for robustness!
# - Train Step to Formalize from high quality formal dataset ~ opt.step((nl_i -> fl*)_i)
loss = loss_fn([fl*, fl_i]_i); loss.backward()
# perhap save generated data if it lower validation loss! (e.g., on proof net)
# -- Learn to Informalize: fl_j->nl* from nl* -> [fl_j]_j -> nl*
[fl_j]_j := mdl('formalize ' + nl*, sampling=top_p, num_out=k)
# - Train Step to Informalize from high quality informal dataset ~ opt.step([fl_j -> nl*])
loss = loss_fn([nl_j, nl*]); loss.backward()
# perhap save generated data if it lower validation loss! (e.g., on reverse proof net)
# -- Jointly train everything (for better hardware usage)
opt.step() # trains all tasks: nl->fl, fl->nl, ... set notes for more https://github.com/brando90/evals-for-autoformalization/blob/main/notes/research_proposals_projs/Dual-AF-PreTrain_-_Dual_Informal-Formal_Back_Translation_Training_for_AutoFormalization.md
opt.zero_grad() # zero grads of only params the opt is optimizing
# -- Stop when target its/num_tokens met
stop(it == target_its)
return mdl # for clarify of code, but opt likely mutates mdl params
if __name__ == '__main__':
# Train with AF4MAF Back Translation based
mdl = train_to_af_for_maf(mdl, formal_data_set, informal_data_set)
# Eval if model improved from train procedure for MAF
print('---- Display if our AF4MAF training improved eval metrics on benchmarks ----')
eval_af(mdl, eval_data_set=af_dataset, metrics=[ppl])