Created
January 29, 2025 19:21
-
-
Save Bornholm/3975122c2e9864b2a57108a605a902dd to your computer and use it in GitHub Desktop.
Génération de haïkus et illustration associée avec IA générative
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
package main | |
import ( | |
"bytes" | |
"context" | |
"encoding/base64" | |
"encoding/json" | |
"fmt" | |
"image" | |
"image/color" | |
"image/jpeg" | |
"log" | |
"net/http" | |
"os" | |
"regexp" | |
"strings" | |
"github.com/fogleman/gg" | |
"github.com/openai/openai-go" | |
"github.com/openai/openai-go/option" | |
"github.com/pkg/errors" | |
) | |
const haikuSystemPrompt = ` | |
You are an AI assistant specialized in writing beautiful haikus, privileging creativity and avoiding repetitions. | |
Use the following layout: | |
HAIKU: | |
<haiku> | |
Example: | |
HAIKU: | |
Coder, compiler | |
Les bugs brillent dans la nuit | |
Mise en prod zen | |
` | |
const haikuUserPrompt = ` | |
Write a meaningful haiku in french inspired by the following themes: | |
%s | |
` | |
const imagePromptSystemPrompt = ` | |
You are an AI assistant specialized in writing image generation prompts. The prompt and the negative prompt should specify that no human character should be present. | |
Respond using the following pattern: | |
PROMPT: | |
<prompt> | |
NEGATIVE PROMPT: | |
<negative_prompt> | |
Example: | |
PROMPT: | |
A developer’s desk at night, softly lit by the gentle glow of a computer screen displaying code. Small glowing fireflies symbolizing 'bugs' float around, creating a magical and contemplative atmosphere. There is a steaming coffee cup, adding a cozy touch. The scene has a peaceful, stress-free vibe, as if the production release is going smoothly. The style is minimalist, with a soft color palette in shades of blue, purple, and orange. | |
NEGATIVE PROMPT: | |
blurry, disfigured, deformed, human character, watermark, text, signature, low quality, pixelated, overexposed, cropped, grainy, duplicate, bad proportions, mutation, clone, glitch | |
` | |
const imagePromptUserPrompt = ` | |
Write a prompt in 60 words maximum illustrating the following haiku: | |
%s | |
` | |
func main() { | |
llmBaseURL := os.Getenv("LLM_BASE_URL") | |
if llmBaseURL == "" { | |
// Par défaut, utilisation du service ollama s'exécutant sur la machine locale | |
llmBaseURL = "http://127.0.0.1:11434/v1/" | |
} | |
llmModel := os.Getenv("LLM_MODEL") | |
if llmModel == "" { | |
llmModel = "llama3.2:3b" | |
} | |
imageGenBaseURL := os.Getenv("IMG_GEN_BASE_URL") | |
if imageGenBaseURL == "" { | |
// Par défaut, utilisation du service FastSDCPU en local avec l'image Docker | |
// docker run -it --rm -p 8000:8000 docker.io/bornholm/fastsdcpu-api:v1.0.0-beta.100 | |
imageGenBaseURL = "http://localhost:8000" | |
} | |
theme := os.Args[1:] | |
log.Printf("[INFO] using theme: %s", theme) | |
client := openai.NewClient( | |
option.WithBaseURL(llmBaseURL), | |
) | |
// On génère tout d'abord la haïku indiquant au LLM de répondre suivant un patron défini | |
params := openai.ChatCompletionNewParams{ | |
Model: openai.F(llmModel), | |
Temperature: openai.Float(0.3), | |
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ | |
openai.SystemMessage(haikuSystemPrompt), | |
openai.UserMessage(fmt.Sprintf(haikuUserPrompt, theme)), | |
}), | |
} | |
ctx := context.Background() | |
log.Println("[INFO] generating haiku") | |
completion, err := client.Chat.Completions.New(ctx, params) | |
if err != nil { | |
log.Fatalf("[FATAL] %+v", errors.WithStack(err)) | |
} | |
response := completion.Choices[0].Message.Content | |
// On extrait le haïku de la réponse du LLM | |
haiku := extractHaiku(response) | |
if haiku == "" { | |
log.Fatalf("[FATAL] could not extract haiku from llm response:\n%s", response) | |
} | |
log.Printf("[INFO] haiku generated:\n%s", haiku) | |
// On génère ensuite les prompts dédiés à la création de l'illustration | |
// du haïku | |
params = openai.ChatCompletionNewParams{ | |
Model: openai.F(llmModel), | |
Temperature: openai.Float(0.7), | |
Messages: openai.F([]openai.ChatCompletionMessageParamUnion{ | |
openai.SystemMessage(imagePromptSystemPrompt), | |
openai.UserMessage(fmt.Sprintf(imagePromptUserPrompt, haiku)), | |
}), | |
} | |
log.Println("[INFO] generating haiku illustration prompt") | |
completion, err = client.Chat.Completions.New(ctx, params) | |
if err != nil { | |
log.Fatalf("[FATAL] %+v", errors.WithStack(err)) | |
} | |
response = completion.Choices[0].Message.Content | |
prompt := extractPrompt(response) | |
if prompt == "" { | |
log.Fatalf("[FATAL] could not extract image prompt from llm response:\n%s", response) | |
} | |
negativePrompt := extractNegativePrompt(response) | |
if negativePrompt == "" { | |
log.Fatalf("[FATAL] could not extract image negative prompt from llm response:\n%s", response) | |
} | |
log.Printf("[INFO] haiku illustration prompt generated:\n%s\n\n%s", prompt, negativePrompt) | |
log.Println("[INFO] generating illustration") | |
img, err := generateImage(imageGenBaseURL, prompt, negativePrompt) | |
if err != nil { | |
log.Fatalf("[FATAL] %+v", errors.WithStack(err)) | |
} | |
img, err = drawCartouche(img, haiku) | |
if err != nil { | |
log.Fatalf("[FATAL] %+v", errors.WithStack(err)) | |
} | |
if err := saveImage(img, "my_haiku.jpg"); err != nil { | |
log.Fatalf("[FATAL] %+v", errors.WithStack(err)) | |
} | |
log.Println("[INFO] haiku image saved") | |
} | |
var haikuRegExp = regexp.MustCompile(`(?msi)HAIKU\s*:\n+(.+)`) | |
func extractHaiku(str string) string { | |
matches := haikuRegExp.FindAllStringSubmatch(str, -1) | |
if len(matches) < 1 || len(matches[0]) < 2 { | |
return "" | |
} | |
return matches[0][1] | |
} | |
var promptRegExp = regexp.MustCompile("(?msi)PROMPT:\n+([^\n]+).*") | |
func extractPrompt(str string) string { | |
matches := promptRegExp.FindAllStringSubmatch(str, -1) | |
if len(matches) < 1 || len(matches[0]) < 2 { | |
return "" | |
} | |
return matches[0][1] | |
} | |
var negativePromptRegExp = regexp.MustCompile("(?msi)NEGATIVE PROMPT:\n+(.+)") | |
func extractNegativePrompt(str string) string { | |
matches := negativePromptRegExp.FindAllStringSubmatch(str, -1) | |
if len(matches) < 1 || len(matches[0]) < 2 { | |
return "" | |
} | |
return matches[0][1] | |
} | |
// N'ayant pas accès à un GPU, j'utilise ici l'API proposée par le projet | |
// https://github.com/rupeshs/fastsdcpu | |
func generateImage(baseURL string, prompt string, negativePrompt string) (image.Image, error) { | |
payload := map[string]any{ | |
"prompt": prompt, | |
"negative_prompt": negativePrompt, | |
} | |
var buff bytes.Buffer | |
encoder := json.NewEncoder(&buff) | |
if err := encoder.Encode(payload); err != nil { | |
return nil, errors.WithStack(err) | |
} | |
res, err := http.Post(baseURL+"/api/generate", "application/json", &buff) | |
if err != nil { | |
return nil, errors.WithStack(err) | |
} | |
defer res.Body.Close() | |
decoder := json.NewDecoder(res.Body) | |
type apiResult struct { | |
Images []string `json:"images"` | |
Latency float64 `json:"latency"` | |
} | |
var result apiResult | |
if err := decoder.Decode(&result); err != nil { | |
return nil, errors.WithStack(err) | |
} | |
if len(result.Images) < 1 { | |
return nil, errors.Errorf("unexpected number of returned images %d", len(result.Images)) | |
} | |
imageData, err := base64.StdEncoding.DecodeString(result.Images[0]) | |
if err != nil { | |
return nil, errors.WithStack(err) | |
} | |
imageBuff := bytes.NewBuffer(imageData) | |
img, err := jpeg.Decode(imageBuff) | |
if err != nil { | |
return nil, errors.WithStack(err) | |
} | |
return img, nil | |
} | |
func drawCartouche(img image.Image, text string) (image.Image, error) { | |
drawCtx := gg.NewContextForImage(img) | |
imgWidth := img.Bounds().Dx() | |
imgHeight := img.Bounds().Dy() | |
maxWidth := float64(imgWidth) * 0.666 | |
maxHeight := float64(imgHeight) * 0.5 | |
lines := drawCtx.WordWrap(text, maxWidth-(maxWidth*0.2)) | |
textWidth, textHeight := drawCtx.MeasureMultilineString(strings.Join(lines, "\n"), 1.5) | |
deltaX := (maxWidth / 2) - (textWidth / 2) | |
deltaY := (maxHeight / 2) - (textHeight / 2) | |
drawCtx.SetRGBA(0.5, 0.5, 0.5, 0.5) | |
drawCtx.DrawRectangle(deltaX-10, deltaY-5, textWidth+20, textHeight+20) | |
drawCtx.Fill() | |
drawCtx.SetColor(color.Black) | |
drawCtx.DrawStringWrapped(text, deltaX+1, deltaY+1, 0, 0, maxWidth-(maxWidth*0.2), 1.5, gg.AlignLeft) | |
drawCtx.SetColor(color.White) | |
drawCtx.DrawStringWrapped(text, deltaX, deltaY, 0, 0, maxWidth-(maxWidth*0.2), 1.5, gg.AlignLeft) | |
return drawCtx.Image(), nil | |
} | |
func saveImage(img image.Image, path string) error { | |
file, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE, 0640) | |
if err != nil { | |
return errors.WithStack(err) | |
} | |
defer file.Close() | |
if err := jpeg.Encode(file, img, nil); err != nil { | |
return errors.WithStack(err) | |
} | |
return nil | |
} |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment