This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import pytorch_lightning as pl | |
import numpy as np | |
import torch | |
from torch.nn import MSELoss | |
from torch.optim import Adam | |
from torch.utils.data import DataLoader, Dataset | |
import torch.nn as nn | |
class SimpleDataset(Dataset): |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
''' | |
https://arxiv.org/abs/2312.00858 | |
1. put this file in ComfyUI/custom_nodes | |
2. load node from <loaders> | |
start_step, end_step: apply this method when the timestep is between start_step and end_step | |
cache_interval: interval of caching (1 means no caching) | |
cache_depth: depth of caching | |
''' |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import json | |
import pickle | |
import struct | |
import zipfile | |
import numpy as np | |
from sentencepiece import SentencePieceProcessor | |
def rms_norm(x): return (x / np.sqrt(np.square(x).mean(-1, keepdims=True) + 1e-6)) | |
def softmax(x): return (np.exp(x - np.max(x, axis=-1, keepdims=True))) / np.sum((np.exp(x - np.max(x, axis=-1, keepdims=True))), axis=-1, keepdims = True) |