This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
{ | |
"success": true, | |
"credits_left": 20, | |
"rate_limit_left": 20, | |
"person": { | |
"publicIdentifier": "ghutson", | |
"linkedInIdentifier": "ACoAAAaZQFQBTwymlGveZW3Ac1Ey9ZDrmAsXakc", | |
"memberIdentifier": "110706772", | |
"linkedInUrl": "https://www.linkedin.com/in/ghutson", | |
"firstName": "Gary", |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import yaml | |
import argparse as AP | |
from huggingface_hub import HfApi, ModelCard, create_repo, get_full_repo_name | |
# Load settings from YAML file | |
with open('dreambooth_param.yml', 'r') as train: | |
params = yaml.safe_load(train) | |
train_params = params['train_params'] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Name: dreambooth_train.py | |
Author: Gary Hutson | |
Date: 23/05/2023 | |
Usage: python dreambooth_train.py | |
""" | |
from dreambooth.dataloader import pull_dataset_from_hf_hub, DreamBoothDataset | |
from dreambooth.image import image_grid | |
from dreambooth.collator import collate_fn |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Train the model | |
model = train_dreambooth( | |
text_encoder=text_encoder, | |
vae = vae, | |
unet = unet, | |
tokenizer=tokenizer, | |
feature_extractor=feature_extractor, | |
train_dataset=train_dataset, | |
train_batch_size=train_batch_size, | |
max_train_steps=max_train_steps, |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Get text encoder, UNET and VAE | |
text_encoder = CLIPTextModel.from_pretrained(model_id, subfolder="text_encoder") | |
vae = AutoencoderKL.from_pretrained(model_id, subfolder="vae") | |
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet") | |
feature_extractor = CLIPFeatureExtractor.from_pretrained(FEATURE_EXTRACTOR) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
class DreamBoothDataset(Dataset): | |
def __init__(self, dataset, instance_prompt, tokenizer, size=512): | |
self.dataset = dataset | |
self.instance_prompt = instance_prompt | |
self.tokenizer = tokenizer | |
self.size = size | |
self.transforms = transforms.Compose( | |
[ | |
transforms.Resize(size), | |
transforms.CenterCrop(size), |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# Create a train dataset from the Dreambooth data loader | |
train_dataset = DreamBoothDataset(dataset, instance_prompt, tokenizer) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
if __name__ =='__main__': | |
# Load the image dataset from HuggingFace hub | |
dataset = pull_dataset_from_hf_hub(dataset_id=hf_data_location) | |
# Name your concept and set of images | |
name_of_your_concept = name_of_your_concept | |
type_of_thing = object_type | |
instance_prompt = f"a photo of {name_of_your_concept} {type_of_thing}" | |
print(f"Instance prompt: {instance_prompt}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# SET project constants and variables | |
with open('dreambooth_param.yml', 'r') as train: | |
params = yaml.safe_load(train) | |
# Set the parameters | |
train_params = params['train_params'] | |
STABLE_DIFFUSION_NAME = train_params['stable_diffusion_backbone'] | |
FEATURE_EXTRACTOR = train_params['feature_extractor'] | |
hf_data_location = train_params['hugging_face_image_store'] | |
learning_rate = float(train_params['learning_rate']) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
train_params: | |
stable_diffusion_backbone: CompVis/stable-diffusion-v1-4 | |
feature_extractor: openai/clip-vit-base-patch32 | |
hugging_face_image_store: StatsGary/dreambooth-hackathon-images | |
learning_rate: 2e-06 | |
max_train_steps: 400 | |
resolution: 512 | |
train_bs: 1 | |
grad_accum_steps: 8 | |
max_gradient_norm: 1.0 |
NewerOlder