Skip to content

Instantly share code, notes, and snippets.

@catboxanon
Created August 25, 2023 15:44
Show Gist options
  • Save catboxanon/c00b16a8afea333b71870f9a17987c36 to your computer and use it in GitHub Desktop.
Save catboxanon/c00b16a8afea333b71870f9a17987c36 to your computer and use it in GitHub Desktop.
import torch
from modules import script_callbacks, shared
def on_model_loaded(sd_model):
if hasattr(shared.opts, 'clip_tensor_fix_enabled') and shared.opts.data.get('clip_tensor_fix_enabled', False): # type: ignore
try:
t1 = None
t2 = torch.arange(0, 77, dtype=torch.int64).unsqueeze(0)
if 'cond_stage_model.wrapped.transformer.text_model.embeddings.position_ids' in sd_model.state_dict():
t1 = sd_model.cond_stage_model.hijack.clip.wrapped.transformer.text_model.embeddings.position_ids
elif 'cond_stage_model.wrapped.transformer.embeddings.position_ids' in sd_model.state_dict():
t1 = sd_model.cond_stage_model.hijack.clip.wrapped.transformer.embeddings.position_ids
if t1 is not None and not torch.all(torch.eq(t1.to(torch.int64), t2.to(t1.device))).item():
t1 = t2
print('CLIP IDs tensor repaired!')
elif t1 is not None:
print('CLIP IDs tensor OK!')
else:
print('CLIP IDs tensor not found.')
except Exception as e:
print('Exception thrown when trying to verify/fix CLIP tensor: ', e)
def on_ui_settings():
section = ('clip_tensor_fix', 'CLIP Tensor Fix')
shared.opts.add_option('clip_tensor_fix_enabled', shared.OptionInfo( # type: ignore
False, 'Enable CLIP tensor fix on model load', section=section
))
script_callbacks.on_model_loaded(on_model_loaded)
script_callbacks.on_ui_settings(on_ui_settings)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment