Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save brando90/d6654138457a26b19a923586450cfa8a to your computer and use it in GitHub Desktop.
Save brando90/d6654138457a26b19a923586450cfa8a to your computer and use it in GitHub Desktop.
MAF dual backtranslation self-improving loop

Simplified version of Dually Ground BackTranslation for AutoFormalization:

def train_to_af_for_maf(mdl : causal_lm,
                        formal_data_set, # e.g., ITP lib like mathlib
                        informal_data_set,  # e.g., time-tested maths textbook e.g., Rudin, CLRS.
                        ):
    for (nl, fl*) in formal_data_set; for (nl*, fl) in informal_data_set;
        # -- Learn to Formalize: nl_i->fl* from fl* -> [nl_i]_i -> fl*
        [nl_i]_i := mdl("informalize " + fl*, sampling=top_p, num_out=k)  # noise is good for robustness!
        # - Train Step to Formalize from high quality formal dataset ~ opt.step((nl_i -> fl*)_i) 
        loss = loss_fn([fl*, fl_i]_i); loss.backward() 
        # perhap save generated data if it lower validation loss! (e.g., on proof net)
    
        # -- Learn to Informalize: fl_j->nl* from nl* -> [fl_j]_j -> nl*
        [fl_j]_j := mdl('formalize ' + nl*, sampling=top_p, num_out=k)
        # - Train Step to Informalize from high quality informal dataset ~ opt.step([fl_j -> nl*])
        loss = loss_fn([nl_j, nl*]); loss.backward()
        # perhap save generated data if it lower validation loss! (e.g., on reverse proof net)

        # -- Jointly train everything (for better hardware usage)
        opt.step() # trains all tasks: nl->fl, fl->nl, ... set notes for more https://github.com/brando90/evals-for-autoformalization/blob/main/notes/research_proposals_projs/Dual-AF-PreTrain_-_Dual_Informal-Formal_Back_Translation_Training_for_AutoFormalization.md
        opt.zero_grad()  # zero grads of only params the opt is optimizing
    
        # -- Stop when target its/num_tokens met
        stop(it == target_its)
    return mdl # for clarify of code, but opt likely mutates mdl params

if __name__ == '__main__':
    # Train with AF4MAF Back Translation based
    mdl = train_to_af_for_maf(mdl, formal_data_set, informal_data_set)

    # Eval if model improved from train procedure for MAF
    print('---- Display if our AF4MAF training improved eval metrics on benchmarks ----')
    eval_af(mdl, eval_data_set=af_dataset, metrics=[ppl])

ref: https://github.com/brando90/evals-for-autoformalization/blob/main/notes/research_proposals_projs/Dual-AF-PreTrain_-_Dual_Informal-Formal_Back_Translation_Training_for_AutoFormalization.md

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment