Skip to content

Instantly share code, notes, and snippets.

@ebsmothers
Created February 8, 2024 00:04
Show Gist options
  • Save ebsmothers/ea631a8f0857d79ece5245c8ab1e8e52 to your computer and use it in GitHub Desktop.
Save ebsmothers/ea631a8f0857d79ece5245c8ab1e8e52 to your computer and use it in GitHub Desktop.
def validate_state_dict_for_lora(
*,
lora_modules: List[str],
full_model_state_dict_keys: List[str],
lora_state_dict_keys: Optional[List[str]] = None,
base_model_state_dict_keys: Optional[List[str]] = None,
):
is_lora_param = lambda x: "lora" in x and any([k in x for k in lora_modules])
for k in full_model_state_dict_keys:
if not is_lora_param(k):
if base_model_state_dict_keys is not None:
assert k in base_model_state_dict_keys
if lora_state_dict_keys is not None:
assert k not in lora_state_dict_keys
else:
if base_model_state_dict_keys is not None:
assert k not in base_model_state_dict_keys
if lora_state_dict_keys is not None:
assert k in lora_state_dict_keys
# Full model is disjoint union of base model and LoRA weights
if lora_state_dict_keys is not None and base_model_state_dict_keys is not None:
combined_state_dict_keys = set(lora_state_dict_keys).union(base_model_state_dict_keys)
shared_state_dict_keys = set(lora_state_dict_keys).intersection(base_model_state_dict_keys)
assert shared_state_dict_keys == {}
assert combined_state_dict_keys == set(full_model_state_dict_keys)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment