Skip to content

Instantly share code, notes, and snippets.

View Piscabo's full-sized avatar

Piscabo Piscabo

View GitHub Profile
@Piscabo
Piscabo / forward_of_sdxl_original_unet.py
Created February 17, 2024 13:10 — forked from kohya-ss/forward_of_sdxl_original_unet.py
SDXLで高解像度での構図の破綻を軽減する
def forward(self, x, timesteps=None, context=None, y=None, **kwargs):
# broadcast timesteps to batch dimension
timesteps = timesteps.expand(x.shape[0])
hs = []
t_emb = get_timestep_embedding(timesteps, self.model_channels) # , repeat_only=False)
t_emb = t_emb.to(x.dtype)
emb = self.time_embed(t_emb)
assert x.shape[0] == y.shape[0], f"batch size mismatch: {x.shape[0]} != {y.shape[0]}"