Skip to content

Instantly share code, notes, and snippets.

import pytorch_lightning as pl
import numpy as np
import torch
from torch.nn import MSELoss
from torch.optim import Adam
from torch.utils.data import DataLoader, Dataset
import torch.nn as nn
class SimpleDataset(Dataset):
'''
https://arxiv.org/abs/2312.00858
1. put this file in ComfyUI/custom_nodes
2. load node from <loaders>
start_step, end_step: apply this method when the timestep is between start_step and end_step
cache_interval: interval of caching (1 means no caching)
cache_depth: depth of caching
'''
import json
import pickle
import struct
import zipfile
import numpy as np
from sentencepiece import SentencePieceProcessor
def rms_norm(x): return (x / np.sqrt(np.square(x).mean(-1, keepdims=True) + 1e-6))
def softmax(x): return (np.exp(x - np.max(x, axis=-1, keepdims=True))) / np.sum((np.exp(x - np.max(x, axis=-1, keepdims=True))), axis=-1, keepdims = True)