Created
December 17, 2024 00:16
-
-
Save thisismattmiller/39e6ca24c18cf8c5bcc90033afbb35b1 to your computer and use it in GitHub Desktop.
Using LLaVA API via Ollama
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import requests | |
import json | |
import base64 | |
import glob | |
from PIL import Image | |
headers = { | |
'Content-Type': 'application/x-www-form-urlencoded', | |
} | |
files = list(glob.glob("/mnt/f/woodblock/*.png")) | |
# incase the script crashes we can load the last output to pick up from there | |
results = json.load(open('../llava_tags.json')) | |
for img in files: | |
# these are uuids in the file name | |
fileid = img.split("/")[-1].split('.')[0] | |
if fileid in results: | |
print(fileid,'alredy did this one skipp') | |
continue | |
image_file = Image.open(img) | |
image_file.load() # required for png.split() | |
background = Image.new("RGB", image_file.size, (255, 255, 255)) | |
background.paste(image_file, mask=image_file.split()[3]) # 3 is the alpha channel | |
background.save(f"../tmp.png", "PNG") | |
print(fileid, len(results)) | |
with open("../tmp.png", "rb") as image_file: | |
encoded_string = base64.b64encode(image_file.read()) | |
data = { | |
"model": "llava", | |
"format": "json", | |
"prompt": "Generate three good search tags that describe this woodblock print. Do not include tags that describe the medium or style of the drawing such as woodblock, classical art, history painting, etc. Return the tags as a JSON array with the key named tags.", | |
"options":{ | |
"temperature": 0 | |
}, | |
"images": [encoded_string.decode('ascii')], | |
"stream": False | |
} | |
data = json.dumps(data) + "\n" | |
try: | |
response = requests.post('http://localhost:11434/api/generate', headers=headers, data=data) | |
except: | |
print("Connection ERROR on this one", img) | |
continue | |
try: | |
response = response.json() | |
data = json.loads(response['response']) | |
results[fileid] = data['tags'] | |
print(data) | |
except: | |
print("Bad line",line) | |
pass | |
if len(results) % 25 == 0: | |
json.dump(results,open('../llava_tags.json','w')) | |
json.dump(results,open('../llava_tags.json','w')) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment