Skip to content

Instantly share code, notes, and snippets.

@CodingFu
Created September 28, 2023 12:54
Show Gist options
  • Save CodingFu/e38837301371d14a1712c9df35f404da to your computer and use it in GitHub Desktop.
Save CodingFu/e38837301371d14a1712c9df35f404da to your computer and use it in GitHub Desktop.
from google.cloud import aiplatform
import os
import openai
import httpx
import tempfile
import shutil
import vertexai
import logging
from fastapi import FastAPI, Request, status
from fastapi.exceptions import RequestValidationError
from pydantic import BaseModel
from fastapi import FastAPI, File, HTTPException, UploadFile, Form
from fastapi.responses import JSONResponse, StreamingResponse
from fastapi.middleware.cors import CORSMiddleware
from vertexai.preview.vision_models import Image, ImageGenerationModel
# Grab values from the environment if available
from dotenv import load_dotenv
load_dotenv()
# Set the OpenAI API key and information accordingly
openai.api_key = os.getenv("OPENAI_API_KEY").rstrip()
openai.api_model = os.getenv("OPENAI_API_MODEL", "gpt-35-turbo")
# Set Google parameters
# google_scope = os.getenv("GOOGLE_SCOPE", "https://www.googleapis.com/auth/cloud-platform")
# google_project_id = os.getenv("GOOGLE_PROJECT_ID")
# google_region = os.getenv("GOOGLE_REGION", "us-central1")
# google_model = ImageGenerationModel.from_pretrained("imagegeneration@002")
#
# vertexai.init(project=google_project_id, location=google_region)
# Set the default prompt
DEFAULT_PROMPT = "Colgate toothpaste next to a toothbrush"
DEFAULT_PROVIDERS = ["google", "openai"]
# Create uploads folder if it doesn't exist
if not os.path.exists('uploads'):
os.makedirs('uploads')
class ImageRequest(BaseModel):
prompt: str = Form(DEFAULT_PROMPT)
n: int = Form(1)
size: str = Form("1024x1024")
class ImageVariationRequest(ImageRequest):
image_file: UploadFile = File(...)
def process_google_image(response):
# TODO: Fix: take only first image
response_image = response.images[0]
with tempfile.NamedTemporaryFile(suffix=".png", delete=False) as temp:
response_image.save(location=temp.name)
temp_path = temp.name # You can use this path later to serve the image
# Assume that temp_path is the path to the saved temporary image from the above example
try:
def iterfile(): # This generator will read the image in chunks
with open(temp_path, "rb") as f:
while True:
chunk = f.read(8192) # 8K chunks
if not chunk:
break
yield chunk
return StreamingResponse(iterfile(), media_type="image/png")
except Exception as e:
raise HTTPException(status_code=500, detail=str(e))
finally:
# Once the image is streamed, you can safely delete the temporary file
shutil.rmtree(temp_path, ignore_errors=True)
# FastAPI based application
app = FastAPI(
title="Colgate GenAI Service",
description="TBD",
summary="TBD",
version="0.0.1",
terms_of_service="http://example.com/terms/",
license_info={
"name": "Apache 2.0",
"url": "https://www.apache.org/licenses/LICENSE-2.0.html",
},
)
origins = [
"http://localhost.tiangolo.com",
"https://localhost.tiangolo.com",
"http://localhost",
"http://localhost:8080",
]
app.add_middleware(
CORSMiddleware,
allow_origins=origins,
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
@app.get("/")
def home():
return "Hello, FastAPI!"
@app.post("/image/openai/edit")
def generative_image_openai_edit(prompt: str = Form(DEFAULT_PROMPT),
n: int = Form(1),
size: str = Form("1024x1024"),
image_file: UploadFile = File(...),
mask_file: UploadFile = File(...)):
image_filename = os.path.join('uploads', image_file.filename)
with open(image_filename, "wb") as buffer:
buffer.write(image_file.file.read())
mask_filename = os.path.join('uploads', mask_file.filename)
with open(mask_filename, "wb") as buffer:
buffer.write(mask_file.file.read())
response = openai.Image.create_edit(image=open(image_filename, "rb"), mask=open(mask_filename, "rb"),
prompt=prompt, n=n, size=size)
response_result = response['data'][0]['url']
with httpx.Client() as client:
response = client.get(response_result)
response.stream = True
return StreamingResponse(response.iter_bytes(), media_type=response.headers["Content-Type"])
@app.post("/image/openai/variation")
def generative_image_openai_variation(n: int = Form(1),
size: str = Form("1024x1024"),
image_file: UploadFile = File(...)):
image_filename = os.path.join('uploads', image_file.filename)
with open(image_filename, "wb") as buffer:
buffer.write(image_file.file.read())
response = openai.Image.create_variation(image=open(image_filename, "rb"),
n=n, size=size)
response_result = response['data'][0]['url']
with httpx.Client() as client:
response = client.get(response_result)
response.stream = True
return StreamingResponse(response.iter_bytes(), media_type=response.headers["Content-Type"])
@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request: Request, exc: RequestValidationError):
exc_str = f'{exc}'.replace('\n', ' ').replace(' ', ' ')
logging.error(f"{request}: {exc_str}")
content = {'status_code': 10422, 'message': exc_str, 'data': None}
return JSONResponse(content=content, status_code=status.HTTP_422_UNPROCESSABLE_ENTITY)
@app.post("/image/google/edit")
def generative_image_google_edit(prompt: str = Form(DEFAULT_PROMPT),
n: int = Form(1),
image_file: UploadFile = File(...),
mask_file: UploadFile = File(...)):
image_filename = os.path.join('uploads', image_file.filename)
with open(image_filename, "wb") as buffer:
buffer.write(image_file.file.read())
mask_filename = os.path.join('uploads', mask_file.filename)
with open(mask_filename, "wb") as buffer:
buffer.write(mask_file.file.read())
response = google_model.edit_image(
base_image=Image.load_from_file(image_filename),
mask=Image.load_from_file(mask_filename),
prompt=prompt,
number_of_images=n
)
return process_google_image(response)
@app.post("/image/google/variation")
def generative_image_google_variation(prompt: str = Form(DEFAULT_PROMPT),
n: int = Form(1),
image_file: UploadFile = File(...)):
image_filename = os.path.join('uploads', image_file.filename)
with open(image_filename, "wb") as buffer:
buffer.write(image_file.file.read())
response = google_model.edit_image(
base_image=Image.load_from_file(image_filename),
prompt=prompt,
number_of_images=n
)
return process_google_image(response)
###
# Generate an Image
###
def generative_image_openai(request: ImageRequest):
response = openai.Image.create(prompt=request.prompt,
n=request.n,
size=request.size)
response_result = response['data'][0]['url']
with httpx.Client() as client:
response = client.get(response_result)
response.stream = True
return StreamingResponse(response.iter_bytes(), media_type=response.headers["Content-Type"])
def generative_image_google(request: ImageRequest):
response = google_model.generate_images(
prompt=request.prompt,
number_of_images=request.n
)
return process_google_image(response)
@app.post("/image/{provider}")
def generative_image_generic(request: ImageRequest,
provider: str = "google"):
# Check that it is a valid provider
if provider not in DEFAULT_PROVIDERS:
return JSONResponse(content={"message": f"Supported service providers are currently {', '.join(DEFAULT_PROVIDERS)}"}, status_code=400)
generative_image_function = f"generative_image_{provider}"
return globals()[generative_image_function](request)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=6000)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment