Skip to content

Instantly share code, notes, and snippets.

@chavinlo
Created November 16, 2022 02:16
Show Gist options
  • Save chavinlo/335266a3a6825ffafbec191e7d0e35bd to your computer and use it in GitHub Desktop.
Save chavinlo/335266a3a6825ffafbec191e7d0e35bd to your computer and use it in GitHub Desktop.
# Install bitsandbytes:
# `nvcc --version` to get CUDA version.
# `pip install -i https://test.pypi.org/simple/ bitsandbytes-cudaXXX` to install for current CUDA.
# Example Usage:
# Single GPU: torchrun --nproc_per_node=1 trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
# Multiple GPUs: torchrun --nproc_per_node=N trainer_dist.py --model="CompVis/stable-diffusion-v1-4" --run_name="liminal" --dataset="liminal-dataset" --hf_token="hf_blablabla" --bucket_side_min=64 --use_8bit_adam=True --gradient_checkpointing=True --batch_size=10 --fp16=True --image_log_steps=250 --epochs=20 --resolution=768 --use_ema=True
import argparse
import socket
import torch
import torchvision
import transformers
import diffusers
import os
import glob
import random
import tqdm
import resource
import psutil
import pynvml
import wandb
import gc
import time
import itertools
import numpy as np
import json
import re
import traceback
#Distributed only
import hivemind
import requests
import zipfile
import shutil
from hivemind.optim import power_sgd_averager
try:
pynvml.nvmlInit()
except pynvml.nvml.NVMLError_LibraryNotFound:
pynvml = None
from typing import Iterable
from diffusers import AutoencoderKL, UNet2DConditionModel, DDPMScheduler, PNDMScheduler, DDIMScheduler, StableDiffusionPipeline
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker
from diffusers.optimization import get_scheduler
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from PIL import Image, ImageOps
from typing import Dict, List, Generator, Tuple
from scipy.interpolate import interp1d
torch.backends.cuda.matmul.allow_tf32 = True
# defaults should be good for everyone
# TODO: add custom VAE support. should be simple with diffusers
parser = argparse.ArgumentParser(description='Stable Diffusion Finetuner')
parser.add_argument('--model', type=str, default=None, required=True, help='The name of the model to use for finetuning. Could be HuggingFace ID or a directory')
parser.add_argument('--resume', type=str, default=None, help='The path to the checkpoint to resume from. If not specified, will create a new run.')
parser.add_argument('--run_name', type=str, default=None, required=True, help='Name of the finetune run.')
#parser.add_argument('--dataset', type=str, default=None, required=True, help='The path to the dataset to use for finetuning.')
parser.add_argument('--num_buckets', type=int, default=16, help='The number of buckets.')
parser.add_argument('--bucket_side_min', type=int, default=256, help='The minimum side length of a bucket.')
parser.add_argument('--bucket_side_max', type=int, default=768, help='The maximum side length of a bucket.')
parser.add_argument('--lr', type=float, default=5e-6, help='Learning rate')
parser.add_argument('--epochs', type=int, default=10, help='Number of epochs to train for')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size')
parser.add_argument('--use_ema', type=str, default='False', help='Use EMA for finetuning')
parser.add_argument('--ucg', type=float, default=0.1, help='Percentage chance of dropping out the text condition per batch. Ranges from 0.0 to 1.0 where 1.0 means 100% text condition dropout.') # 10% dropout probability
parser.add_argument('--gradient_checkpointing', dest='gradient_checkpointing', type=str, default='False', help='Enable gradient checkpointing')
parser.add_argument('--use_8bit_adam', dest='use_8bit_adam', type=str, default='False', help='Use 8-bit Adam optimizer')
parser.add_argument('--adam_beta1', type=float, default=0.9, help='Adam beta1')
parser.add_argument('--adam_beta2', type=float, default=0.999, help='Adam beta2')
parser.add_argument('--adam_weight_decay', type=float, default=1e-2, help='Adam weight decay')
parser.add_argument('--adam_epsilon', type=float, default=1e-08, help='Adam epsilon')
parser.add_argument('--lr_scheduler', type=str, default='cosine', help='Learning rate scheduler [`cosine`, `linear`, `constant`]')
parser.add_argument('--lr_scheduler_warmup', type=float, default=0.05, help='Learning rate scheduler warmup steps. This is a percentage of the total number of steps in the training run. 0.1 means 10 percent of the total number of steps.')
parser.add_argument('--seed', type=int, default=42, help='Seed for random number generator, this is to be used for reproduceability purposes.')
parser.add_argument('--output_path', type=str, default='./output', help='Root path for all outputs.')
parser.add_argument('--save_steps', type=int, default=500, help='Number of steps to save checkpoints at.')
parser.add_argument('--resolution', type=int, default=512, help='Image resolution to train against. Lower res images will be scaled up to this resolution and higher res images will be scaled down.')
parser.add_argument('--shuffle', dest='shuffle', type=str, default='True', help='Shuffle dataset')
parser.add_argument('--hf_token', type=str, default=None, required=False, help='A HuggingFace token is needed to download private models for training.')
parser.add_argument('--project_id', type=str, default='diffusers', help='Project ID for reporting to WandB')
parser.add_argument('--fp16', dest='fp16', type=str, default='False', help='Train in mixed precision')
parser.add_argument('--image_log_steps', type=int, default=100, help='Number of steps to log images at.')
parser.add_argument('--image_log_amount', type=int, default=4, help='Number of images to log every image_log_steps')
parser.add_argument('--image_log_inference_steps', type=int, default=50, help='Number of inference steps to use to log images.')
parser.add_argument('--image_log_scheduler', type=str, default="PNDMScheduler", help='Number of inference steps to use to log images.')
parser.add_argument('--clip_penultimate', type=str, default='False', help='Use penultimate CLIP layer for text embedding')
parser.add_argument('--output_bucket_info', type=str, default='False', help='Outputs bucket information and exits')
parser.add_argument('--resize', type=str, default='False', help="Resizes dataset's images to the appropriate bucket dimensions.")
parser.add_argument('--use_xformers', type=str, default='False', help='Use memory efficient attention')
#Modified
parser.add_argument('--wandb', dest='enablewandb', type=str, default='False', help='Enable WeightsAndBiases Reporting')
parser.add_argument('--inference', dest='enableinference', type=str, default='False', help='Enable Inference during training (Consumes 2GB of VRAM)')
#Hivemind only
#parser.add_argument('--hivemind', dest='enablehivemind', type=str, default='True', help='Enable Hivemind usage)')
parser.add_argument('--peers', type=str, default=None, nargs="*", help='MUST BE PASSED AS A LIST! ex.: --peers /ipv4/1.1.1.1 /ipv4/2.2.2.2 | Multiaddrs of one or more active DHT peers. If none it will start a new session.')
#Dataset server
parser.add_argument('--datasetserver', type=str, dest='datasetserver', default=None, help='Address of dataset server')
parser.add_argument('--wantedimages', type=int, dest='wantedimages', default=None, help='Number of wanted images')
parser.add_argument('--workingdirectory', type=str, dest='workingdirectory', default="distributed_data", help='Folder where the downloader is going to do its work')
args = parser.parse_args()
for arg in vars(args):
if type(getattr(args, arg)) == str:
if getattr(args, arg).lower() == 'true':
setattr(args, arg, True)
elif getattr(args, arg).lower() == 'false':
setattr(args, arg, False)
def setup():
torch.distributed.init_process_group("nccl", init_method="env://")
def cleanup():
torch.distributed.destroy_process_group()
def get_rank() -> int:
if not torch.distributed.is_initialized():
return 0
return torch.distributed.get_rank()
def get_world_size() -> int:
if not torch.distributed.is_initialized():
return 1
return torch.distributed.get_world_size()
def get_gpu_ram() -> str:
"""
Returns memory usage statistics for the CPU, GPU, and Torch.
:return:
"""
gpu_str = ""
torch_str = ""
try:
cudadev = torch.cuda.current_device()
nvml_device = pynvml.nvmlDeviceGetHandleByIndex(cudadev)
gpu_info = pynvml.nvmlDeviceGetMemoryInfo(nvml_device)
gpu_total = int(gpu_info.total / 1E6)
gpu_free = int(gpu_info.free / 1E6)
gpu_used = int(gpu_info.used / 1E6)
gpu_str = f"GPU: (U: {gpu_used:,}mb F: {gpu_free:,}mb " \
f"T: {gpu_total:,}mb) "
torch_reserved_gpu = int(torch.cuda.memory.memory_reserved() / 1E6)
torch_reserved_max = int(torch.cuda.memory.max_memory_reserved() / 1E6)
torch_used_gpu = int(torch.cuda.memory_allocated() / 1E6)
torch_max_used_gpu = int(torch.cuda.max_memory_allocated() / 1E6)
torch_str = f"TORCH: (R: {torch_reserved_gpu:,}mb/" \
f"{torch_reserved_max:,}mb, " \
f"A: {torch_used_gpu:,}mb/{torch_max_used_gpu:,}mb)"
except AssertionError:
pass
cpu_maxrss = int(resource.getrusage(resource.RUSAGE_SELF).ru_maxrss / 1E3 +
resource.getrusage(
resource.RUSAGE_CHILDREN).ru_maxrss / 1E3)
cpu_vmem = psutil.virtual_memory()
cpu_free = int(cpu_vmem.free / 1E6)
return f"CPU: (maxrss: {cpu_maxrss:,}mb F: {cpu_free:,}mb) " \
f"{gpu_str}" \
f"{torch_str}"
datasetServer = args.datasetserver
wantedImages = args.wantedimages
workingDirectory = args.workingdirectory
if os.path.exists(workingDirectory + "/tmp"):
print("Warning, tmp folder will be cleared in 10 secs")
time.sleep(10)
shutil.rmtree(workingDirectory + "/tmp")
if datasetServer is None:
print("No dataset server chosen.")
datasetServer = str(input("Dataset Server: "))
else:
print("Dataset server is: " + datasetServer)
if wantedImages is None:
wantedImages = int(input("How many images to download each time?: "))
print("Number of images to download each time: " + str(wantedImages))
print("Attempting to get server info...")
#ex.: datasetServer = 127.0.0.1
r = requests.get('http://' + str(datasetServer) + '/info')
if r.status_code == 200:
data = json.loads(r.text)
print("Server: " + data['ServerName'])
print(data['ServerDescription'])
print("Server Version: " + data['ServerVersion'])
print("Currently serving " + str(data['FilesBeingServed']) + " Files")
print("Age: " + data['ExecutedAt'])
else:
print("Unable to get server info")
exit()
directoryToExtract = workingDirectory + "/tmp/dataset"
print("directoryToExtract: " + directoryToExtract)
print("Wokring: " + workingDirectory)
os.makedirs(workingDirectory, exist_ok=True)
import sys
def download_file(url, inputjson, output):
link = url
file_name = output
with open(file_name, "wb") as f:
print("Downloading %s" % file_name)
response = requests.post(link, stream=True, json=inputjson)
total_length = response.headers.get('content-length')
if total_length is None: # no content length header
f.write(response.content)
else:
dl = 0
total_length = int(total_length)
for data in response.iter_content(chunk_size=4096):
dl += len(data)
f.write(data)
done = int(50 * dl / total_length)
sys.stdout.write("\r[%s%s]" % ('=' * done, ' ' * (50-done)) )
sys.stdout.flush()
def onlineGather(datasetServer, wantedImages, directoryToExtract):
#ex.: datasetServer = "127.0.0.1" assuming port is 80
print("Dataset server is: " + str(datasetServer))
#Info on how this works should be on a md file soon
urlDomain = 'http://' + datasetServer
urlGetTasks = urlDomain + '/v1/get/tasks/' + str(wantedImages)
requestGetTasks = requests.get(urlGetTasks)
responseAsJson = requestGetTasks.json()
print("Downloading Files...")
downloadUrl = urlDomain + "/v1/get/files"
#TODO: fix memory file
#print("Saving as BytesIO")
#memory_file = BytesIO()
tmpZipFilename = workingDirectory + "/tmp.zip"
download_file(downloadUrl, responseAsJson, tmpZipFilename)
#memory_file.seek(0)
print("Unzipping...")
with zipfile.ZipFile(tmpZipFilename, 'r') as zip_ref:
print("Extracting to: " + directoryToExtract)
zip_ref.extractall(directoryToExtract)
print("Extracted")
os.remove(tmpZipFilename)
responseRecipt = responseAsJson
return(responseRecipt)
def onlineReport(datasetServer, recipt):
print("Reporting epoch completition...")
urlDomain = 'http://' + datasetServer
urlReport = urlDomain + '/v1/post/epochcount'
postReportEpoch = requests.post(urlReport, json=recipt)
if postReportEpoch.status_code == 200:
return True
else:
return False
def _sort_by_ratio(bucket: tuple) -> float:
return bucket[0] / bucket[1]
def _sort_by_area(bucket: tuple) -> float:
return bucket[0] * bucket[1]
class ImageStore:
def __init__(self, data_dir: str) -> None:
self.data_dir = data_dir
self.image_files = []
[self.image_files.extend(glob.glob(f'{data_dir}' + '/*.' + e)) for e in ['jpg', 'jpeg', 'png', 'bmp', 'webp']]
self.image_files = [x for x in self.image_files if self.__valid_file(x)]
def __len__(self) -> int:
return len(self.image_files)
def __valid_file(self, f) -> bool:
try:
Image.open(f)
return True
except:
print(f'WARNING: Unable to open file: {f}')
return False
# iterator returns images as PIL images and their index in the store
def entries_iterator(self) -> Generator[Tuple[Image.Image, int], None, None]:
for f in range(len(self)):
yield Image.open(self.image_files[f]).convert(mode='RGB'), f
# get image by index
def get_image(self, ref: Tuple[int, int, int]) -> Image.Image:
return Image.open(self.image_files[ref[0]]).convert(mode='RGB')
# gets caption by removing the extension from the filename and replacing it with .txt
def get_caption(self, ref: Tuple[int, int, int]) -> str:
filename = re.sub('\.[^/.]+$', '', self.image_files[ref[0]]) + '.txt'
with open(filename, 'r', encoding='UTF-8') as f:
return f.read()
# ====================================== #
# Bucketing code stolen from hasuwoof: #
# https://github.com/hasuwoof/huskystack #
# ====================================== #
class AspectBucket:
def __init__(self, store: ImageStore,
num_buckets: int,
batch_size: int,
bucket_side_min: int = 256,
bucket_side_max: int = 768,
bucket_side_increment: int = 64,
max_image_area: int = 512 * 768,
max_ratio: float = 2):
self.requested_bucket_count = num_buckets
self.bucket_length_min = bucket_side_min
self.bucket_length_max = bucket_side_max
self.bucket_increment = bucket_side_increment
self.max_image_area = max_image_area
self.batch_size = batch_size
self.total_dropped = 0
if max_ratio <= 0:
self.max_ratio = float('inf')
else:
self.max_ratio = max_ratio
self.store = store
self.buckets = []
self._bucket_ratios = []
self._bucket_interp = None
self.bucket_data: Dict[tuple, List[int]] = dict()
self.init_buckets()
self.fill_buckets()
def init_buckets(self):
possible_lengths = list(range(self.bucket_length_min, self.bucket_length_max + 1, self.bucket_increment))
possible_buckets = list((w, h) for w, h in itertools.product(possible_lengths, possible_lengths)
if w >= h and w * h <= self.max_image_area and w / h <= self.max_ratio)
buckets_by_ratio = {}
# group the buckets by their aspect ratios
for bucket in possible_buckets:
w, h = bucket
# use precision to avoid spooky floats messing up your day
ratio = '{:.4e}'.format(w / h)
if ratio not in buckets_by_ratio:
group = set()
buckets_by_ratio[ratio] = group
else:
group = buckets_by_ratio[ratio]
group.add(bucket)
# now we take the list of buckets we generated and pick the largest by area for each (the first sorted)
# then we put all of those in a list, sorted by the aspect ratio
# the square bucket (LxL) will be the first
unique_ratio_buckets = sorted([sorted(buckets, key=_sort_by_area)[-1]
for buckets in buckets_by_ratio.values()], key=_sort_by_ratio)
# how many buckets to create for each side of the distribution
bucket_count_each = int(np.clip((self.requested_bucket_count + 1) / 2, 1, len(unique_ratio_buckets)))
# we know that the requested_bucket_count must be an odd number, so the indices we calculate
# will include the square bucket and some linearly spaced buckets along the distribution
indices = {*np.linspace(0, len(unique_ratio_buckets) - 1, bucket_count_each, dtype=int)}
# make the buckets, make sure they are unique (to remove the duplicated square bucket), and sort them by ratio
# here we add the portrait buckets by reversing the dimensions of the landscape buckets we generated above
buckets = sorted({*(unique_ratio_buckets[i] for i in indices),
*(tuple(reversed(unique_ratio_buckets[i])) for i in indices)}, key=_sort_by_ratio)
self.buckets = buckets
# cache the bucket ratios and the interpolator that will be used for calculating the best bucket later
# the interpolator makes a 1d piecewise interpolation where the input (x-axis) is the bucket ratio,
# and the output is the bucket index in the self.buckets array
# to find the best fit we can just round that number to get the index
self._bucket_ratios = [w / h for w, h in buckets]
self._bucket_interp = interp1d(self._bucket_ratios, list(range(len(buckets))), assume_sorted=True,
fill_value=None)
for b in buckets:
self.bucket_data[b] = []
def get_batch_count(self):
return sum(len(b) // self.batch_size for b in self.bucket_data.values())
def get_bucket_info(self):
return json.dumps({ "buckets": self.buckets, "bucket_ratios": self._bucket_ratios })
def get_batch_iterator(self) -> Generator[Tuple[Tuple[int, int, int]], None, None]:
"""
Generator that provides batches where the images in a batch fall on the same bucket
Each element generated will be:
(index, w, h)
where each image is an index into the dataset
:return:
"""
max_bucket_len = max(len(b) for b in self.bucket_data.values())
index_schedule = list(range(max_bucket_len))
random.shuffle(index_schedule)
bucket_len_table = {
b: len(self.bucket_data[b]) for b in self.buckets
}
bucket_schedule = []
for i, b in enumerate(self.buckets):
bucket_schedule.extend([i] * (bucket_len_table[b] // self.batch_size))
random.shuffle(bucket_schedule)
bucket_pos = {
b: 0 for b in self.buckets
}
total_generated_by_bucket = {
b: 0 for b in self.buckets
}
for bucket_index in bucket_schedule:
b = self.buckets[bucket_index]
i = bucket_pos[b]
bucket_len = bucket_len_table[b]
batch = []
while len(batch) != self.batch_size:
# advance in the schedule until we find an index that is contained in the bucket
k = index_schedule[i]
if k < bucket_len:
entry = self.bucket_data[b][k]
batch.append(entry)
i += 1
total_generated_by_bucket[b] += self.batch_size
bucket_pos[b] = i
yield [(idx, *b) for idx in batch]
def fill_buckets(self):
entries = self.store.entries_iterator()
total_dropped = 0
for entry, index in tqdm.tqdm(entries, total=len(self.store)):
if not self._process_entry(entry, index):
total_dropped += 1
for b, values in self.bucket_data.items():
# shuffle the entries for extra randomness and to make sure dropped elements are also random
random.shuffle(values)
# make sure the buckets have an exact number of elements for the batch
to_drop = len(values) % self.batch_size
self.bucket_data[b] = list(values[:len(values) - to_drop])
total_dropped += to_drop
self.total_dropped = total_dropped
def _process_entry(self, entry: Image.Image, index: int) -> bool:
aspect = entry.width / entry.height
if aspect > self.max_ratio or (1 / aspect) > self.max_ratio:
return False
best_bucket = self._bucket_interp(aspect)
if best_bucket is None:
return False
bucket = self.buckets[round(float(best_bucket))]
self.bucket_data[bucket].append(index)
del entry
return True
class AspectBucketSampler(torch.utils.data.Sampler):
def __init__(self, bucket: AspectBucket, num_replicas: int = 1, rank: int = 0):
super().__init__(None)
self.bucket = bucket
self.num_replicas = num_replicas
self.rank = rank
def __iter__(self):
# subsample the bucket to only include the elements that are assigned to this rank
indices = self.bucket.get_batch_iterator()
indices = list(indices)[self.rank::self.num_replicas]
return iter(indices)
def __len__(self):
return self.bucket.get_batch_count() // self.num_replicas
class AspectDataset(torch.utils.data.Dataset):
def __init__(self, store: ImageStore, tokenizer: CLIPTokenizer, ucg: float = 0.1):
self.store = store
self.tokenizer = tokenizer
self.ucg = ucg
self.transforms = torchvision.transforms.Compose([
torchvision.transforms.RandomHorizontalFlip(p=0.5),
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize([0.5], [0.5])
])
def __len__(self):
return len(self.store)
def __getitem__(self, item: Tuple[int, int, int]):
return_dict = {'pixel_values': None, 'input_ids': None}
image_file = self.store.get_image(item)
if args.resize:
image_file = ImageOps.fit(
image_file,
(item[1], item[2]),
bleed=0.0,
centering=(0.5, 0.5),
method=Image.Resampling.LANCZOS
)
return_dict['pixel_values'] = self.transforms(image_file)
if random.random() > self.ucg:
caption_file = self.store.get_caption(item)
else:
caption_file = ''
return_dict['input_ids'] = self.tokenizer(caption_file, max_length=self.tokenizer.model_max_length, padding='do_not_pad', truncation=True).input_ids
return return_dict
def collate_fn(self, examples):
pixel_values = torch.stack([example['pixel_values'] for example in examples if example is not None])
pixel_values.to(memory_format=torch.contiguous_format).float()
input_ids = [example['input_ids'] for example in examples if example is not None]
padded_tokens = self.tokenizer.pad({'input_ids': input_ids}, return_tensors='pt', padding=True)
return {
'pixel_values': pixel_values,
'input_ids': padded_tokens.input_ids,
'attention_mask': padded_tokens.attention_mask,
}
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
class EMAModel:
"""
Exponential Moving Average of models weights
"""
def __init__(self, parameters: Iterable[torch.nn.Parameter], decay=0.9999):
parameters = list(parameters)
self.shadow_params = [p.clone().detach() for p in parameters]
self.decay = decay
self.optimization_step = 0
def get_decay(self, optimization_step):
"""
Compute the decay factor for the exponential moving average.
"""
value = (1 + optimization_step) / (10 + optimization_step)
return 1 - min(self.decay, value)
@torch.no_grad()
def step(self, parameters):
parameters = list(parameters)
self.optimization_step += 1
self.decay = self.get_decay(self.optimization_step)
for s_param, param in zip(self.shadow_params, parameters):
if param.requires_grad:
tmp = self.decay * (s_param - param)
s_param.sub_(tmp)
else:
s_param.copy_(param)
torch.cuda.empty_cache()
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
"""
Copy current averaged parameters into given collection of parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored moving averages. If `None`, the
parameters with which this `ExponentialMovingAverage` was
initialized will be used.
"""
parameters = list(parameters)
for s_param, param in zip(self.shadow_params, parameters):
param.data.copy_(s_param.data)
# From CompVis LitEMA implementation
def store(self, parameters):
"""
Save the current parameters for restoring later.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
temporarily stored.
"""
self.collected_params = [param.clone() for param in parameters]
def restore(self, parameters):
"""
Restore the parameters stored with the `store` method.
Useful to validate the model with EMA parameters without affecting the
original optimization process. Store the parameters before the
`copy_to` method. After validation (or model saving), use this to
restore the former parameters.
Args:
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
updated with the stored parameters.
"""
for c_param, param in zip(self.collected_params, parameters):
param.data.copy_(c_param.data)
del self.collected_params
gc.collect()
def to(self, device=None, dtype=None) -> None:
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
Args:
device: like `device` argument to `torch.Tensor.to`
"""
# .to() on the tensors handles None correctly
self.shadow_params = [
p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
for p in self.shadow_params
]
def hivemindWorker(optimizer, peersArg=None):
init_peers = peersArg
optimizer = optimizer
if init_peers is not None:
dht = hivemind.DHT(
host_maddrs=["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
initial_peers=init_peers,
start=True
)
print("Type: Relay")
else:
dht = hivemind.DHT(
host_maddrs=["/ip4/0.0.0.0/tcp/0", "/ip4/0.0.0.0/udp/0/quic"],
start=True
)
print("Type: New")
print('\n'.join(str(addr) for addr in dht.get_visible_maddrs()))
print("Global IP:", hivemind.utils.networking.choose_ip_address(dht.get_visible_maddrs()))
from functools import partial
hm_opt = hivemind.Optimizer(
dht=dht, # use a DHT that is connected with other peers
run_id='test_run', # unique identifier of this collaborative run
batch_size_per_step=1, # each call to opt.step adds this many samples towards the next epoch
target_batch_size=1000, # after peers collectively process this many samples, average weights and begin the next epoch
optimizer=optimizer, # wrap the SGD optimizer defined above
use_local_updates=True, # perform optimizer steps with local gradients, average parameters in background
matchmaking_time=1500.0, # when averaging parameters, gather peers in background for up to this many seconds
averaging_timeout=1500.0, # give up on averaging if not successful in this many seconds
verbose=True, # print logs incessently
)
return(hm_opt)
def main():
rank = get_rank()
world_size = get_world_size()
torch.cuda.set_device(rank)
enablewandb = args.enablewandb
enableinference = args.enableinference
if rank == 0:
os.makedirs(args.output_path, exist_ok=True)
if enablewandb:
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb')
else:
run = wandb.init(project=args.project_id, name=args.run_name, config=vars(args), dir=args.output_path+'/wandb', mode="disabled")
# Inform the user of host, and various versions -- useful for debugging issues.
print("RUN_NAME:", args.run_name)
print("HOST:", socket.gethostname())
print("CUDA:", torch.version.cuda)
print("TORCH:", torch.__version__)
print("TRANSFORMERS:", transformers.__version__)
print("DIFFUSERS:", diffusers.__version__)
print("MODEL:", args.model)
print("FP16:", args.fp16)
print("RESOLUTION:", args.resolution)
if args.hf_token is None:
try:
args.hf_token = os.environ['HF_API_TOKEN']
print('It is recommended to set the HF_API_TOKEN environment variable instead of passing it as a command line argument since WandB will automatically log it.')
except Exception:
print("No HF Token detected in arguments or enviroment variable, setting it to none (as in string)")
args.hf_token = "none"
device = torch.device('cuda')
print("DEVICE:", device)
# setup fp16 stuff
scaler = torch.cuda.amp.GradScaler(enabled=args.fp16)
# Set seed
torch.manual_seed(args.seed)
print('RANDOM SEED:', args.seed)
if args.resume:
args.model = args.resume
tokenizer = CLIPTokenizer.from_pretrained(args.model, subfolder='tokenizer', use_auth_token=args.hf_token)
text_encoder = CLIPTextModel.from_pretrained(args.model, subfolder='text_encoder', use_auth_token=args.hf_token)
vae = AutoencoderKL.from_pretrained(args.model, subfolder='vae', use_auth_token=args.hf_token)
unet = UNet2DConditionModel.from_pretrained(args.model, subfolder='unet', use_auth_token=args.hf_token)
#Move the models before initializing the optimizer
weight_dtype = torch.float16 if args.fp16 else torch.float32
# move models to device
vae = vae.to(device, dtype=weight_dtype)
unet = unet.to(device, dtype=torch.float32)
text_encoder = text_encoder.to(device, dtype=weight_dtype)
# Freeze vae and text_encoder
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
if args.gradient_checkpointing:
unet.enable_gradient_checkpointing()
if args.use_xformers:
unet.set_use_memory_efficient_attention_xformers(True)
if args.use_8bit_adam: # Bits and bytes is only supported on certain CUDA setups, so default to regular adam if it fails.
# if str2optimizer8bit_blockwise check https://github.com/TimDettmers/bitsandbytes/issues/62
try:
import bitsandbytes as bnb
optimizer_cls = bnb.optim.AdamW8bit
except:
print('bitsandbytes not supported, using regular Adam optimizer')
optimizer_cls = torch.optim.AdamW
else:
optimizer_cls = torch.optim.AdamW
optimizer = optimizer_cls(
unet.parameters(),
lr=args.lr,
betas=(args.adam_beta1, args.adam_beta2),
eps=args.adam_epsilon,
weight_decay=args.adam_weight_decay,
)
noise_scheduler = DDPMScheduler(
beta_start=0.00085,
beta_end=0.012,
beta_schedule='scaled_linear',
num_train_timesteps=1000,
)
#TODO: put arguments
def trainDataloader():
# load dataset
store = ImageStore(directoryToExtract)
dataset = AspectDataset(store, tokenizer)
bucket = AspectBucket(store, args.num_buckets, args.batch_size, args.bucket_side_min, args.bucket_side_max, 64, args.resolution * args.resolution, 2.0)
sampler = AspectBucketSampler(bucket=bucket, num_replicas=world_size, rank=rank)
print(f'STORE_LEN: {len(store)}')
if args.output_bucket_info:
print(bucket.get_bucket_info())
exit(0)
train_dataloader = torch.utils.data.DataLoader(
dataset,
batch_sampler=sampler,
num_workers=0,
collate_fn=dataset.collate_fn
)
return train_dataloader
#unet = torch.nn.parallel.DistributedDataParallel(unet, device_ids=[rank], output_device=rank, gradient_as_bucket_view=True)
# create ema
if args.use_ema:
ema_unet = EMAModel(unet.parameters())
print(get_gpu_ram())
global_step = 0
if args.resume:
target_global_step = int(args.resume.split('_')[-1])
print(f'resuming from {args.resume}...')
#LR SCHEDULER MOVED TO BE SET IF HIVEMIND DISABLED
# lr_scheduler = get_scheduler(
# args.lr_scheduler,
# optimizer=optimizer,
# num_warmup_steps=int(args.lr_scheduler_warmup * num_steps_per_epoch * args.epochs),
# num_training_steps=args.epochs * num_steps_per_epoch,
# #last_epoch=(global_step // num_steps_per_epoch) - 1,
# )
#probably unnecessary but ok
def gt():
return(time.time_ns())
def save_checkpoint(global_step):
if rank == 0:
if args.use_ema:
ema_unet.store(unet.parameters())
ema_unet.copy_to(unet.parameters())
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
),
safety_checker=StableDiffusionSafetyChecker.from_pretrained("CompVis/stable-diffusion-safety-checker"),
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
)
print(f'saving checkpoint to: {args.output_path}/{args.run_name}_{global_step}')
pipeline.save_pretrained(f'{args.output_path}/{args.run_name}_{global_step}')
if args.use_ema:
ema_unet.restore(unet.parameters())
# barrier
torch.distributed.barrier()
# train!
#forget about local training, use WD instead
finalOptimizer = hivemindWorker(optimizer, args.peers)
datasetRunCount = 0
try:
while True:
recipt = onlineGather(datasetServer=datasetServer, wantedImages=wantedImages, directoryToExtract=directoryToExtract)
#Reload Dataset
print("Reloading Dataset...")
train_dataloader = trainDataloader()
num_steps_per_epoch = len(train_dataloader)
progress_bar = tqdm.tqdm(range(num_steps_per_epoch), desc="Total Steps", leave=False)
loss = torch.tensor(0.0, device=device, dtype=weight_dtype)
unet.train()
for _, batch in enumerate(train_dataloader):
if args.resume and global_step < target_global_step:
if rank == 0:
progress_bar.update(1)
global_step += 1
continue
b_start = time.perf_counter()
latents = vae.encode(batch['pixel_values'].to(device, dtype=weight_dtype)).latent_dist.sample()
latents = latents * 0.18215
# Sample noise
noise = torch.randn_like(latents)
bsz = latents.shape[0]
# Sample a random timestep for each image
timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device)
timesteps = timesteps.long()
# Add noise to the latents according to the noise magnitude at each timestep
# (this is the forward diffusion process)
noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps)
# Get the text embedding for conditioning
encoder_hidden_states = text_encoder(batch['input_ids'].to(device), output_hidden_states=True)
if args.clip_penultimate:
encoder_hidden_states = text_encoder.text_model.final_layer_norm(encoder_hidden_states['hidden_states'][-2])
else:
encoder_hidden_states = encoder_hidden_states.last_hidden_state
# Predict the noise residual and compute loss
with torch.autocast('cuda', enabled=args.fp16):
noise_pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
loss = torch.nn.functional.mse_loss(noise_pred.float(), noise.float(), reduction="mean")
# Backprop and all reduce
scaler.scale(loss).backward()
scaler.step(finalOptimizer)
scaler.update()
finalOptimizer.step()
finalOptimizer.zero_grad()
# Update EMA
if args.use_ema:
ema_unet.step(unet.parameters())
# perf
b_end = time.perf_counter()
seconds_per_step = b_end - b_start
steps_per_second = 1 / seconds_per_step
rank_images_per_second = args.batch_size * steps_per_second
world_images_per_second = rank_images_per_second * world_size
samples_seen = global_step * args.batch_size * world_size
# All reduce loss
torch.distributed.all_reduce(loss, op=torch.distributed.ReduceOp.SUM)
if rank == 0:
progress_bar.update(1)
global_step += 1
logs = {
"train/loss": loss.detach().item() / world_size,
"train/datasetRunCount": datasetRunCount,
"train/step": global_step,
"train/samples_seen": samples_seen,
"perf/rank_samples_per_second": rank_images_per_second,
"perf/global_samples_per_second": world_images_per_second,
}
progress_bar.set_postfix(logs)
run.log(logs, step=global_step)
if global_step % args.save_steps == 0:
save_checkpoint(global_step)
if enableinference:
if global_step % args.image_log_steps == 0:
if rank == 0:
# get prompt from random batch
prompt = tokenizer.decode(batch['input_ids'][random.randint(0, len(batch['input_ids'])-1)].tolist())
if args.image_log_scheduler == 'DDIMScheduler':
print('using DDIMScheduler scheduler')
scheduler = DDIMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear"
)
else:
print('using PNDMScheduler scheduler')
scheduler=PNDMScheduler(
beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", skip_prk_steps=True
)
pipeline = StableDiffusionPipeline(
text_encoder=text_encoder,
vae=vae,
unet=unet,
tokenizer=tokenizer,
scheduler=scheduler,
safety_checker=None, # disable safety checker to save memory
feature_extractor=CLIPFeatureExtractor.from_pretrained("openai/clip-vit-base-patch32"),
).to(device)
# inference
if enablewandb:
images = []
else:
saveInferencePath = args.output_path + "/inference"
os.makedirs(saveInferencePath, exist_ok=True)
with torch.no_grad():
with torch.autocast('cuda', enabled=args.fp16):
for _ in range(args.image_log_amount):
if enablewandb:
images.append(
wandb.Image(pipeline(
prompt, num_inference_steps=args.image_log_inference_steps
).images[0],
caption=prompt)
)
else:
from datetime import datetime
images = pipeline(prompt, num_inference_steps=args.image_log_inference_steps).images[0]
filenameImg = str(time.time_ns()) + ".png"
filenameTxt = str(time.time_ns()) + ".txt"
images.save(saveInferencePath + "/" + filenameImg)
with open(saveInferencePath + "/" + filenameTxt, 'a') as f:
f.write('Used prompt: ' + prompt + '\n')
f.write('Generated Image Filename: ' + filenameImg + '\n')
f.write('Generated at: ' + str(global_step) + ' steps' + '\n')
f.write('Generated at: ' + str(datetime.now().strftime('%Y-%m-%d %H:%M:%S'))+ '\n')
# log images under single caption
if enablewandb:
run.log({'images': images}, step=global_step)
# cleanup so we don't run out of memory
del pipeline
gc.collect()
torch.distributed.barrier()
print('Did one dataset run. Reporting...')
reportStatus = onlineReport(datasetServer=datasetServer, recipt=recipt)
if reportStatus is True:
print("Report Success")
else:
print("Report failed, exiting...")
exit()
print("Cleaning folder...")
shutil.rmtree(workingDirectory + "/tmp")
except Exception as e:
print(f'Exception caught on rank {rank} at step {global_step}, saving checkpoint...\n{e}\n{traceback.format_exc()}')
pass
except KeyboardInterrupt:
print("Quitting...")
print("Saving checkpoint...")
save_checkpoint(global_step)
print("Checkpoint Saved.")
if __name__ == "__main__":
setup()
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment