Skip to content

Instantly share code, notes, and snippets.

@5shekel
Last active September 3, 2023 23:30
Show Gist options
  • Save 5shekel/664db89d050da2186231d7d7590b814c to your computer and use it in GitHub Desktop.
Save 5shekel/664db89d050da2186231d7d7590b814c to your computer and use it in GitHub Desktop.
demo controlnet dwpose
import streamlit as st
import json
import base64
import requests
from PIL import Image
import os, io
output_folder = "../output"
if not os.path.exists(output_folder):
os.makedirs(output_folder)
url = "http://127.0.0.1:8787"
st.title("DWpose ControlNet Demo")
st.write("we resize uploaded image bigger then 512x512 to that")
# File Upload
uploaded_file = st.file_uploader("Choose an image...", type=["jpg", "png", "jpeg"])
# Two columns for input and output images
col1, col2 = st.columns(2)
if uploaded_file:
image = Image.open(uploaded_file)
# Resize the image
max_size = 512
if image.size[0] > max_size or image.size[1] > max_size:
image.thumbnail((max_size, max_size))
# Save the input image
if uploaded_file:
filename = uploaded_file.name
else:
filename = os.path.basename(output_folder)
input_filename = os.path.join(output_folder, filename)
image.save(input_filename)
# Display input image
with col1:
st.image(image, caption='Uploaded Image.', use_column_width=True)
# Convert to bytes for API call
img_byte_arr = io.BytesIO()
image.save(img_byte_arr, format='PNG')
encoded_image = base64.b64encode(img_byte_arr.getvalue()).decode('utf-8')
# Prepare API Call
# Dropdown for choosing the controlnet module
selected_module = st.selectbox(
'Choose ControlNet Module:',
('dw_openpose_full', 'openpose_full')
)
payload = {
"controlnet_module": selected_module,
"controlnet_input_images": [encoded_image],
"controlnet_processor_res": 512,
"controlnet_threshold_a": 64,
"controlnet_threshold_b": 64
}
# Trigger API
if st.button('Run API Call to Mikubill/sd-webui-controlnet extension'):
with st.spinner('Running API call...'):
try:
response = requests.post(f'{url}/controlnet/detect', json=payload)
response.raise_for_status()
response_json = response.json()
# Decode and display the base64-encoded image
if 'images' in response_json and len(response_json['images']) > 0:
base64_image = response_json['images'][0]
decoded_image = base64.b64decode(base64_image)
processed_image = Image.open(io.BytesIO(decoded_image))
with col2:
st.image(processed_image, caption='Processed Image.', use_column_width=True)
# Save the processed image
if uploaded_file:
filename = uploaded_file.name
else:
filename = os.path.basename(output_folder)
name, ext = os.path.splitext(filename)
pose_filename = os.path.join(output_folder, f"{name}_pose{ext}")
processed_image.save(pose_filename)
# Display keypoints if available
if 'keypoints' in response_json:
st.json(response_json['keypoints'])
# Save Output JSON
json_filename = os.path.join(output_folder, f"{name}.json")
with open(json_filename, 'w') as f:
json.dump(response_json, f)
# Display the API response as JSON
st.subheader('API Response:')
st.json(response_json)
except requests.RequestException as e:
st.error(f"API call failed: {e}")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment