Skip to content

Instantly share code, notes, and snippets.

@feliche93
Created July 3, 2023 08:44
Show Gist options
  • Save feliche93/0c928a9ca2ee8bc9b907173a007b3868 to your computer and use it in GitHub Desktop.
Save feliche93/0c928a9ca2ee8bc9b907173a007b3868 to your computer and use it in GitHub Desktop.
Discord Midjourney Image Automation
import asyncio
import os
from getpass import getpass
from pathlib import Path
from typing import Dict, List, Optional
import boto3
import requests
from dotenv import load_dotenv
from playwright.async_api import Page, async_playwright
from sqlalchemy import create_engine, text
from sqlalchemy.engine.base import Engine
import time
from sqlalchemy.exc import OperationalError
from sqlalchemy.exc import OperationalError
load_dotenv(override=True)
def download_image(image_url: str, image_path: str, timeout: int = 5) -> str:
"""
Downloads an image from a provided URL and saves it to a local path.
Args:
image_url (str): URL of the image to download.
image_path (str): Local path where the image will be saved, including the image file name.
timeout (int): Maximum time, in seconds, to wait for the server's response. Default is 5 seconds.
Raises:
HTTPError: If there was an unsuccessful HTTP response.
Timeout: If the request times out.
Returns:
str: Local path where the image has been saved.
"""
response = requests.get(image_url, timeout=timeout)
response.raise_for_status() # Raise exception if invalid response.
with open(image_path, "wb") as f:
f.write(response.content)
return image_path
def upload_to_s3(
image_path: str,
bucket: str,
s3_image_name: str,
aws_access_key_id: str,
aws_secret_access_key: str,
region_name: str,
) -> str:
"""
Uploads an image file to an S3 bucket and returns the URL of the uploaded file.
Args:
image_path (str): Path to the image file to upload.
bucket (str): Name of the S3 bucket to upload to.
s3_image_name (str): Name to give to the file once it's uploaded.
aws_access_key_id (str): AWS access key ID.
aws_secret_access_key (str): AWS secret access key.
region_name (str): The name of the AWS region where the S3 bucket is located.
Returns:
str: URL of the uploaded image in the S3 bucket.
Raises:
ClientError: If there was an error uploading the file to S3.
"""
s3 = boto3.client(
"s3", aws_access_key_id=aws_access_key_id, aws_secret_access_key=aws_secret_access_key
)
with open(image_path, "rb") as f:
s3_path = "blog_post_covers/" + s3_image_name # prepend the S3 'folder' name
s3.upload_fileobj(f, bucket, s3_path)
# remove the image from the local filesystem
os.remove(image_path)
url = f"https://{bucket}.s3.{region_name}.amazonaws.com/{s3_path}"
return url
async def login_to_discord(
page: Page,
server_id: str,
channel_id: str,
email: Optional[str] = None,
password: Optional[str] = None,
auth_code: Optional[str] = None,
) -> None:
"""
Log in to Discord via a Playwright browser page.
Args:
page (Page): Playwright browser page instance.
server_id (str): Discord server ID to navigate to after login.
channel_id (str): Discord channel ID to navigate to after login.
email (Optional[str], optional): Email to use for logging in to Discord. Defaults to None.
password (Optional[str], optional): Password to use for logging in to Discord. Defaults to None.
auth_code (Optional[str], optional): Authentication code to use for logging in to Discord. Defaults to None.
Raises:
TimeoutError: If any of the page actions do not complete within the default timeout period.
"""
discord_channel_url = f"https://discord.com/channels/{server_id}/{channel_id}"
await page.goto(discord_channel_url)
await page.get_by_role("button", name="Continue in browser").click()
await page.get_by_label("Email or Phone Number*").click()
if not email:
email = input("Please enter your email: ")
await page.get_by_label("Email or Phone Number*").fill(email)
await page.get_by_label("Email or Phone Number*").press("Tab")
if not password:
password = getpass("Please enter your password: ")
await page.get_by_label("Password*").fill(password)
await page.get_by_role("button", name="Log In").click()
if not auth_code:
auth_code = input("Please enter your authentication code: ")
await page.get_by_placeholder("6-digit authentication code/8-digit backup code").fill(auth_code)
await page.get_by_role("button", name="Log In").click()
async def post_prompt(page: Page, prompt: str) -> None:
"""
Post a prompt message in Discord via a Playwright browser page.
Args:
page (Page): Playwright browser page instance.
prompt (str): The prompt to be posted in the message box.
Raises:
TimeoutError: If any of the page actions do not complete within the default timeout period.
"""
message_text_boy = page.get_by_role("textbox", name="Message #general").nth(0)
await message_text_boy.fill("/imagine ")
prompt_input = page.locator(".optionPillValue-2uxsMp").nth(0)
await prompt_input.fill(prompt, timeout=2000)
await message_text_boy.press("Enter", timeout=2000)
async def upscale_image(page: Page) -> None:
"""
Upscale an image on a Discord channel using the U1 button.
Args:
page (Page): Playwright browser page instance.
Raises:
TimeoutError: If any of the page actions do not complete within the default timeout period.
"""
last_message = page.locator(selector="li").last
upscale_1 = last_message.locator("button", has_text="U1")
# Wait for the upscale button to be visible
while not await upscale_1.is_visible():
print("Upscale button is not yet available, waiting...")
await asyncio.sleep(5) # wait for 5 seconds
print("Upscale button is now available, clicking...")
await upscale_1.click(timeout=1000)
async def get_image_url(
page: Page, timeout: int = 1000, check_interval: int = 5, max_wait: int = 30
) -> str:
"""
Get the href attribute of the last image link on the page, retrying until it exists and the 'Vary (Strong)' button is visible.
Args:
page (Page): Playwright browser page instance.
timeout (int): Maximum time, in milliseconds, to wait for the image link. Default is 1000 milliseconds.
check_interval (int): Time, in seconds, to wait between checks for the button and image link. Default is 5 seconds.
max_wait (int): Maximum time, in seconds, to wait before giving up. Default is 30 seconds.
Returns:
str: The href attribute of the last image link.
Raises:
TimeoutError: If the image link does not appear within the maximum wait time.
"""
last_message = page.locator(selector="li").last
vary_strong = last_message.locator("button", has_text="Vary (Strong)")
image_links = last_message.locator("xpath=//a[starts-with(@class, 'originalLink-')]")
start_time = time.time()
# Wait for the 'Vary (Strong)' button and an image link to appear
while True:
if await vary_strong.is_visible() and await image_links.count() > 0:
last_image_link = await image_links.last.get_attribute("href", timeout=timeout)
print("Image link is present, returning it.")
return last_image_link
print("Waiting for 'Vary (Strong)' button to appear and for image link to appear...")
# If the maximum wait time has been reached, raise an exception
if time.time() - start_time > max_wait:
raise TimeoutError(
"Waited for 30 seconds but 'Vary (Strong)' button did not appear and image link did not appear."
)
await asyncio.sleep(check_interval) # wait for 5 seconds
def update_db_record(
engine: Engine, s3_path: str, keyword_value: str, max_retries: int = 5, retry_wait: int = 2
) -> None:
"""
Update a database record's blog_post_cover_image_url field with an S3 URL.
Args:
engine (Engine): SQLAlchemy Engine instance.
s3_path (str): S3 URL to be added to the blog_post_cover_image_url field.
keyword_value (str): Keyword value to identify the specific record to be updated.
max_retries (int): Maximum number of retries in case of failure. Default is 5.
retry_wait (int): Time, in seconds, to wait between retries. Default is 2 seconds.
Raises:
SQLAlchemyError: If any SQLAlchemy error occurs while updating the record.
"""
retries = 0
while retries < max_retries:
try:
with engine.connect() as connection:
query = text(
"UPDATE keywords SET blog_post_cover_image_url = :s3_path WHERE slug = :keyword_value"
)
connection.execute(query, s3_path=s3_path, keyword_value=keyword_value)
break # break the loop if the operation is successful
except OperationalError:
retries += 1
print(f"OperationalError occurred. Retry {retries} of {max_retries}.")
time.sleep(retry_wait)
else: # If we've exhausted all retries, re-raise the last exception
raise
def get_records_with_null_cover_image(engine: Engine) -> List[Dict[str, str]]:
"""
Retrieve records from the database where blog_post_cover_image_url is NULL.
Args:
engine (Engine): SQLAlchemy Engine instance.
Returns:
List[Dict[str, str]]: A list of dictionaries where each dictionary represents a record
with 'slug' and 'blog_post_cover_prompt' as keys.
Raises:
SQLAlchemyError: If any SQLAlchemy error occurs while retrieving the records.
"""
with engine.connect() as connection:
query = text(
"SELECT slug, blog_post_cover_prompt FROM keywords WHERE blog_post_cover_image_url IS NULL"
)
result = connection.execute(query)
records = [{"slug": row[0], "blog_post_cover_prompt": row[1]} for row in result]
return records
S3_BUCKET_NAME = os.environ.get("S3_BUCKET_NAME")
S3_ACCESS_KEY_ID = os.environ.get("S3_ACCESS_KEY_ID")
S3_SECRET_ACCESS_KEY = os.environ.get("S3_SECRET_ACCESS_KEY")
S3_REGION_NAME = os.environ.get("S3_REGION_NAME")
DATABASE_URL = os.environ.get("DATABASE_URL")
DISCORD_SERVER_ID = "1124815914815201481"
DISCORD_CHANEL_ID = "1124815915297542217"
IMAGE_PATH = Path(__file__).parent / "temp_images"
async def main() -> None:
async with async_playwright() as playwright:
# playwright = await async_playwright().start()
engine = create_engine(DATABASE_URL)
browser = await playwright.chromium.launch(headless=False)
context = await browser.new_context()
page = await context.new_page()
records = get_records_with_null_cover_image(engine)
await login_to_discord(
page=page,
server_id=DISCORD_SERVER_ID,
channel_id=DISCORD_CHANEL_ID,
)
for record in records[181:]:
slug = record["slug"]
prompt = record["blog_post_cover_prompt"]
await post_prompt(
page=page,
prompt=prompt,
)
await upscale_image(page=page)
image_url = await get_image_url(page=page)
local_image_path = IMAGE_PATH / f"{slug}.png"
image_path = download_image(image_url=image_url, image_path=local_image_path)
s3_path = upload_to_s3(
image_path=image_path,
aws_access_key_id=S3_ACCESS_KEY_ID,
aws_secret_access_key=S3_SECRET_ACCESS_KEY,
bucket=S3_BUCKET_NAME,
region_name=S3_REGION_NAME,
s3_image_name=f"{slug}.png",
)
update_db_record(
engine=engine,
s3_path=s3_path,
keyword_value=slug,
)
await context.close()
await browser.close()
asyncio.run(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment