Skip to content

Instantly share code, notes, and snippets.

@tori29umai0123
Last active May 28, 2024 23:20
Show Gist options
  • Save tori29umai0123/b710efabf3781f137359fa1616da85f4 to your computer and use it in GitHub Desktop.
Save tori29umai0123/b710efabf3781f137359fa1616da85f4 to your computer and use it in GitHub Desktop.
HF_upload_sdxl_gen_img.py
import os
import sys
import wget
import zipfile
import threading
import re
from time import sleep
import shutil
from huggingface_hub import HfApi, Repository, upload_file
import gen_img
outdir = "output_img"
ckpt = "animagine-xl-3.1.safetensors"
LoRA = "lcm-animaginexl-3_1.safetensors"
prompt_file = "dart_prompts.txt"
stop_event = threading.Event()
upload_event = threading.Event()
batch_size = 5000
sleep_time = 5
def download_file(url, filename):
try:
if not os.path.exists(filename):
print(f"{filename} をダウンロード中...")
wget.download(url, filename)
else:
print(f"{filename} は既に存在します。")
except Exception as e:
print(f"ファイルのダウンロード中にエラーが発生しました: {e}")
def check_repository_access(repo_name, token):
api = HfApi()
try:
repo_info = api.repo_info(repo_name, token=token, repo_type='dataset')
print("リポジトリアクセスを確認しました。")
return True
except Exception as e:
print(f"リポジトリへのアクセスに失敗しました: {e}")
return False
def upload_to_hf(zip_file, repo_name, token):
repo_local_path = os.path.join(os.getcwd(), repo_name.split('/')[-1])
if not os.path.exists(repo_local_path):
Repository(repo_local_path, clone_from=f"https://huggingface.co/datasets/{repo_name}", use_auth_token=token)
shutil.copy(zip_file, repo_local_path)
repo_zip_path = os.path.join(repo_local_path, os.path.basename(zip_file))
upload_file(path_or_fileobj=repo_zip_path, path_in_repo=os.path.basename(repo_zip_path), repo_id=repo_name, token=token, repo_type='dataset')
os.remove(zip_file)
print(f"{zip_file}をHugging Faceのデータセットリポジトリ{repo_name}にアップロードしました。")
upload_event.set()
def zip_images(directory, batch_size, repo_name, token, zip_counter_start):
zip_counter = zip_counter_start
while not stop_event.is_set() or len([f for f in os.listdir(directory) if f.endswith(('.png', '.jpg', '.jpeg', '.webp')) and not f.endswith('.zip')]) > 0:
files = [f for f in os.listdir(directory) if f.endswith(('.png', '.jpg', '.jpeg', '.webp')) and not f.endswith('.zip')]
if len(files) >= batch_size:
zip_and_upload(files, directory, batch_size, repo_name, token, zip_counter)
zip_counter += batch_size
elif stop_event.is_set() and len(files) > 0:
zip_and_upload(files, directory, len(files), repo_name, token, zip_counter)
zip_counter += len(files)
sleep(sleep_time)
def zip_and_upload(files, directory, count, repo_name, token, zip_counter):
files.sort()
zip_filename = os.path.join(directory, f'{zip_counter + count:05d}.zip')
with zipfile.ZipFile(zip_filename, 'w') as zipf:
for file in files[:count]:
zipf.write(os.path.join(directory, file), arcname=file)
print(f'Created {zip_filename}')
for file in files[:count]:
os.remove(os.path.join(directory, file))
upload_event.clear()
upload_to_hf(zip_filename, repo_name, token)
def create_partial_prompt_file(original_file, new_file, start_line):
with open(original_file, 'r', encoding='utf-8') as fin:
lines = fin.readlines()
line_count = len(lines)
if start_line > line_count:
print(f"指定された開始行 {start_line} は、ファイルの行数 {line_count} を超えています。")
return # 何もせずに終了
with open(new_file, 'w', encoding='utf-8') as fout:
fout.writelines(lines[start_line-1:])
def main(repo_name, token):
repo_local_path = os.path.join(os.getcwd(), repo_name.split('/')[-1])
Repository(repo_local_path, clone_from=f"https://huggingface.co/datasets/{repo_name}", use_auth_token=token)
zip_files = [f for f in os.listdir(repo_local_path) if f.endswith('.zip')]
if zip_files:
zip_numbers = [int(re.search(r'(\d+)\.zip', f).group(1)) for f in zip_files if re.search(r'(\d+)\.zip', f)]
zip_counter_start = max(zip_numbers)
else:
zip_counter_start = 0
start_line = zip_counter_start + 1
# 新しいプロンプトファイルを作成
new_prompt_file = "partial_prompts.txt"
create_partial_prompt_file(prompt_file, new_prompt_file, start_line)
# 開始行を出力
print(f"プロンプトファイルの処理は行 {start_line} から開始します。")
zip_thread = threading.Thread(target=zip_images, args=(outdir, batch_size, repo_name, token, zip_counter_start))
zip_thread.start()
# 画像生成用の設定
parser = gen_img.setup_parser()
sys.argv = [
'script_name', '--ckpt', ckpt, '--n_iter', '1', '--scale', '3',
'--steps', '12', '--outdir', outdir, '--xformers', '--bf16',
'--sampler', 'euler_a', '--batch_size', '4', '--vae_batch_size', '2',
'--from_file', new_prompt_file, # 更新されたfrom_fileを使用
'--max_embeddings_multiples', '3', '--seed', '42', '--network_module', 'networks.lora', '--network_weights'
] + [LoRA] + ['--network_mul', '0.4', '--network_merge']
args = parser.parse_args()
gen_img.main(args)
stop_event.set()
zip_thread.join()
print("すべての処理が完了しました。")
if __name__ == "__main__":
ckpt_url = "https://huggingface.co/cagliostrolab/animagine-xl-3.1/resolve/main/animagine-xl-3.1.safetensors"
lora_url = "https://huggingface.co/furusu/SD-LoRA/resolve/main/lcm-animaginexl-3_1.safetensors"
download_file(ckpt_url, ckpt)
download_file(lora_url, LoRA)
if not os.path.exists(outdir):
os.makedirs(outdir)
repo_name = input("Hugging Faceのリポジトリ名を入力してください: ")
token = input("Hugging FaceのAPIトークンを入力してください: ")
if not check_repository_access(repo_name, token):
print("アクセスが拒否されたか、無効なリポジトリです。実行を停止します。")
exit(1)
main(repo_name, token)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment