Skip to content

Instantly share code, notes, and snippets.

@unai-ndz
Last active August 21, 2023 15:22
Show Gist options
  • Save unai-ndz/ca380c7aa65c9f2aa3b55df2bb0faab1 to your computer and use it in GitHub Desktop.
Save unai-ndz/ca380c7aa65c9f2aa3b55df2bb0faab1 to your computer and use it in GitHub Desktop.
Wrapper around civitai-minified.py to automate the hashing and downloading of info for your models
import os
import io
import json
import requests
import hashlib
import argparse
from pathlib import Path
import subprocess
import re
# Compute the hash of all the models in a folder
# Use the hash to get the civitai id
# Spawn minified-civitai with the ids of your models to download their info
# Generate a small markdown file for each model from the merged.json downloaded by minified-civitai
# Download the preview images of the models
# Both markdown and images get downloaded in the folder where the model is located
# If you move or rename the models the hashes will be computed again as the chache is based on filepath
parser = argparse.ArgumentParser()
parser.add_argument('-p', '--path', type=str, help='The path where your models are stored')
parser.add_argument('-c', '--regenerate-cache', type=bool, help='Recreate the cache from scratch')
parser.add_argument('-d', '--force-redownload', type=bool, help='Force minified-civitai to update all the models instead of using it\'s cache')
args = parser.parse_args()
model_dir_path = Path(args.path)
CACHE_FILE = 'cache.json' # The name of the JSON cache file
def read_chunks(file, size=io.DEFAULT_BUFFER_SIZE):
"""Yield pieces of data from a file-like object until EOF."""
while True:
chunk = file.read(size)
if not chunk:
break
yield chunk
def gen_file_sha256(filname):
blocksize = 1 << 20
h = hashlib.sha256()
length = 0
with open(filname, 'rb') as f:
for block in read_chunks(f, size=blocksize):
length += len(block)
h.update(block)
hash_value = h.hexdigest()
# print('sha256: ' + hash_value)
return hash_value
def_headers = {'User-Agent': 'Mozilla/5.0 (iPad; CPU OS 12_2 like Mac OS X) AppleWebKit/605.1.15 (KHTML, like Gecko) Mobile/15E148'}
hash_url = 'https://civitai.com/api/v1/model-versions/by-hash/'
# curl https://civitai.com/api/v1/model-versions/by-hash/$HASH \
# -H "Content-Type: application/json" \
# -X GET
# use this sha256 to get model info from civitai
# return: model info dict
def get_model_info_by_hash(hash: str):
print('Request model info from civitai')
if not hash:
print('hash is empty')
return
r = requests.get(hash_url+hash, headers=def_headers)
if not r.ok:
if r.status_code == 404:
# this is not a civitai model
print('Civitai does not have this model')
return False
else:
print('Get error code: ' + str(r.status_code))
print(r.text)
return
# try to get content
content = None
try:
content = r.json()
except Exception as e:
print('Parse response json failed')
print(str(e))
print('response:')
print(r.text)
return
if not content:
print('error, content from civitai is None')
return
# print(content)
return content
CACHE = {} # A dictionary to hold the cached file data
if not args.regenerate_cache:
# Load the cache file
if os.path.exists(CACHE_FILE):
with open(CACHE_FILE, 'r') as f:
CACHE = json.load(f)
exts = ('.bin', '.pt', '.safetensors', '.ckpt')
vae_suffix = '.vae'
# scan model to generate SHA256, then use this SHA256 to get model info from civitai
def scan_models():
print('Scan models')
model_ids = []
model_count = 0
new_model_count = 0
skipped = 0
for root, dirs, files in os.walk(model_dir_path, followlinks=True):
for filename in files:
# check ext
item = os.path.join(root, filename)
base, ext = os.path.splitext(item)
if ext in exts:
# ignore vae file
if len(base) > 4:
if base[-4:] == vae_suffix:
# print('This is a vae file: ' + filename)
continue
model_count = model_count + 1
# If the file is not already in the cache, compute it's hash and add it
print(filename)
file_path = os.path.join(root, filename)
# Calculate hash
if file_path not in CACHE:
hash = gen_file_sha256(file_path)
if not hash:
print('Failed generating SHA256 for model:' + filename)
continue
CACHE[file_path] = {'hash': hash}
new_model_count = new_model_count + 1
# Get it from cache
hash = CACHE[file_path]['hash']
if 'id' in CACHE[file_path]:
model_ids.append(str(CACHE[file_path]['id']))
print(CACHE[file_path]['id'])
else:
if 'source' in CACHE[file_path] and CACHE[file_path]['source'] != 'civitai':
continue
if skipped < 4:
model_info = get_model_info_by_hash(hash)
else:
print('Civitai failed too many times, keep caching hashes stop requesting civitai info')
skipped = skipped + 1
# delay 1 second for ti
# if model_type == 'ti':
# print('Delay 1 second for TI')
# time.sleep(1)
if model_info is None:
print(f'{filename}: Connect to Civitai API service failed. Wait a while and try again')
skipped = skipped + 1
continue
elif model_info:
CACHE[file_path]['source'] = 'civitai'
CACHE[file_path]['id'] = model_info['modelId']
model_ids.append(str(CACHE[file_path]['id']))
else:
CACHE[file_path]['source'] = 'NA'
# Update the cache file
with open(CACHE_FILE, 'w') as f:
json.dump(CACHE, f)
print(f'Scanned {model_count} total models, {new_model_count} new models')
if skipped > 0:
print(f'Skipped models because of too many civitai errors: {skipped}')
return model_ids
def sanitize_filename(filename):
# Remove single quotes
filename = re.sub(r"'", "", filename)
# Replace any non-alphanumeric characters with spaces and trim leading/trailing spaces
filename = re.sub(r'[^a-zA-Z0-9]', ' ', filename).strip().rstrip(' ').title()
# remove spaces
filename = re.sub(r' +', '', filename)
# If the resulting file name is empty or consists of only spaces, change it to 'default'
if not filename or re.fullmatch(r' +', filename):
filename = 'default'
# Truncate the file name if it is longer than the provided maximum length
if len(filename) > 60:
filename = filename[:60]
return filename
def get_folder(type):
if type.lower() == 'textualinversion':
folder = '00.final/embeddings'
elif type.lower() == 'hypernetwork':
folder = '00.final/hypernetwork'
elif type.lower() == 'checkpoint':
folder = '00.final/models'
elif type.lower() == 'lora':
folder = '00.final/lora'
elif type.lower() == 'locon':
folder = '00.final/locon'
else:
folder = '00.final/unknownCivitai'
return folder
def get_path_for_model(id:int):
return next(filter(lambda x: CACHE[x]==id, CACHE.keys()), None)
def generate_markdown():
# Load the JSON file
with open('merged.json', 'r') as f:
models = json.load(f)
model_path_by_id = {}
for path, model in CACHE.items():
if 'id' in model:
model_path_by_id[model['id']] = path
# Loop through the models and write a markdown file for each one
for _, model_json in models.items():
# Get the name and description values
data = model_json['pageData']['props']['pageProps']['trpcState']['json']['queries'][0]['state']['data']
id = str(data['id'])
name = sanitize_filename(data['name'])
description = data['description']
type = data['type']
url = 'https://civitai.com/models/' + id
model_path = Path(model_path_by_id[int(id)])
# Get the directory where the model is
dir = Path(model_path).parent
basename = Path(model_path).stem
# For the purpose of SD each version of a model should get its own markdown file and images
# Each one can have different trigger words, images, etc. (The only thing shared is the name, url and description?)
# version_headers = ['Version', 'trainedWords', 'baseModel', 'epochs', 'Description']
versions = []
for i, v in enumerate(data['modelVersions']):
versions.append({
'version': v['name'] or '',
'description': v['description'] or '',
'trainedWords': v['trainedWords'] or '',
'baseModel': v['baseModel'] or '',
'epochs': v['epochs'] or '',
'images': v['images'],
})
first_version = versions[0] # The one with the more recent release
for i, img_data in enumerate(first_version['images']):
if i == 0:
i = 'preview'
img_file = os.path.join(f'{dir}', f'{basename}.{str(i)}.png')
if not os.path.exists(img_file):
with open(img_file, 'wb') as f:
url = f'https://imagecache.civitai.com/xG1nkqKTMzGDvpLrqFT7WA/{img_data["url"]}/width=400'
f.write(requests.get(url).content)
# Write the markdown file
with open(f'{dir}/{basename}.md', 'w') as f:
# Write title with extra markdown to convert into a link
f.write(f'# [{name}][1]\n\n')
if first_version['trainedWords']:
print(first_version['trainedWords'])
trigger_words = ', '.join(map(str, first_version["trainedWords"]))
f.write(f'Trigger Words: {trigger_words}\n\n')
# Write description
f.write(description + '\n\n')
f.write('Type: ' + type + '\n\n')
# Add links
f.write(f'[1]: <{url}/> "Go to the model\'s page"\n\n')
# Scan directory for models and download info
model_ids = scan_models()
if args.force_redownload:
cmd = ['python', 'get.py', '-p', '--no-base64-images', '--json']
else:
cmd = ['python', 'get.py', '-o', '-p', '--no-base64-images', '--json']
subprocess.run(cmd + model_ids)
generate_markdown()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment