Skip to content

Instantly share code, notes, and snippets.

@gertjana
Created March 20, 2023 08:43
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gertjana/86485fc496233ddb9a561622e2420f19 to your computer and use it in GitHub Desktop.
Save gertjana/86485fc496233ddb9a561622e2420f19 to your computer and use it in GitHub Desktop.
package main
import (
"encoding/json"
"fmt"
"io/ioutil"
"net/http"
"os"
"strings"
"time"
)
// Provide your API key as a LEAP_API_KEY environment variable
var API_KEY = os.Getenv("LEAP_API_KEY")
var HEADERS = map[string]string{
"Content-Type": "application/json",
"Accept": "application/json",
"Authorization": "Bearer " + API_KEY,
}
const BASE_API = "https://api.tryleap.ai/api/v1"
// If a modelId is passed in as the first argument, it will use that model. Otherwise, it will create a new model and train it.
func main() {
const modelName = "GertjanTest"
const prompt = "Photo's of @me with phycadelic colors, high contrast, and a lot of detail."
var sample_images = []string{
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.6435-9/38507713_10155827789237239_1223092985231572992_n.jpg?_nc_cat=102&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=SpI49TU-LZgAX9Tih_6&_nc_ht=scontent-ams4-1.xx&oh=00_AfCVa59wzV2QP-M1iatZt8adOKaQExqMRxFnnjA_R8QNrg&oe=643D0668",
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.6435-9/36088073_10155734387182239_1766884150402351104_n.jpg?_nc_cat=109&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=jOgd9s36aT0AX9_l6hE&_nc_ht=scontent-ams4-1.xx&oh=00_AfDLUiY6iH33W4JiKlBWOP6yHewPJd2TXM_c4_Q6mL5dcw&oe=643D22B6",
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.18169-9/1918748_144973647238_127197_n.jpg?_nc_cat=110&ccb=1-7&_nc_sid=cdbe9c&_nc_ohc=tz6x6BLDyU4AX_MMnSC&_nc_ht=scontent-ams4-1.xx&oh=00_AfDMre6PRPLBSKpGWxXPMW7x1sKuhpTV1Ayu6iEn3Gar2Q&oe=643D07F9",
"https://scontent-ams2-1.xx.fbcdn.net/v/t1.6435-9/78835646_10156940909177239_7040054755149742080_n.jpg?_nc_cat=105&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=JINnJArEVt4AX_znDbv&_nc_ht=scontent-ams2-1.xx&oh=00_AfBzLP8aM5BLymF9sdmCVMf6GRsBZ206iSeqQo1w70XdmQ&oe=643D0E79",
"https://scontent-ams4-1.xx.fbcdn.net/v/t1.6435-9/45342220_10156018880582239_634734194265686016_n.jpg?_nc_cat=107&ccb=1-7&_nc_sid=8bfeb9&_nc_ohc=6Ot2H9ah-QYAX-MpA4e&_nc_ht=scontent-ams4-1.xx&oh=00_AfABy-IuEvCRzXFfqyGytHEJLKVjSKfbDDnMyhRMtnNWRg&oe=643D1E00",
}
var modelId string
var err error
if len(os.Args) > 1 {
modelId = os.Args[1]
fmt.Println("Using model: ", modelId)
} else {
fmt.Println("Creating model: ", modelName)
if modelId, err = CreateModel(modelName); err != nil {
fmt.Println("Error creating model: ", err)
}
}
if err = TrainModel(modelId, sample_images); err != nil {
fmt.Println("Error creating model: ", err)
}
var generated_images []string
if generated_images, err = RunModel(modelId, prompt); err != nil {
fmt.Println("Error running model: ", err)
} else {
for _, image := range generated_images {
fmt.Println(image)
}
}
}
func CreateModel(title string) (string, error) {
url := fmt.Sprintf("%s/images/models", BASE_API)
payload := fmt.Sprintf(`{"title": "%s", "subjectKeyword": "@me"}`, title)
var err error
var response []byte
if response, err = post(url, payload); err != nil {
return "", err
}
data := make(map[string]string)
if err = json.Unmarshal(response, &data); err != nil {
return "", err
}
return data["id"], nil
}
func TrainModel(modelId string, sample_images []string) error {
fmt.Println("Uploading sample images")
var err error
if err = uploadImageSamples(modelId, sample_images); err != nil {
return err
}
fmt.Println("Queueing training job...")
if versionId, model_status, err := queueTrainingJob(modelId); err != nil {
return err
} else {
for model_status != "finished" {
if versionId, model_status, err = getModelVersion(modelId, versionId); err != nil {
return err
}
time.Sleep(10 * time.Second)
}
}
return nil
}
func RunModel(modelId string, prompt string) ([]string, error) {
fmt.Println("Generating image...")
var inferenceId, inference_status string
var err error
if inferenceId, inference_status, err = generateImage(modelId, prompt); err != nil {
return nil, err
}
fmt.Printf("inferenceId: %s, inference_status: %s", inferenceId, inference_status)
var images []string
for inference_status != "finished" {
if inferenceId, inference_status, images, err = getInferenceJob(modelId, inferenceId); err != nil {
return nil, err
}
time.Sleep(10 * time.Second)
}
return images, nil
}
func get(url string) ([]byte, error) {
return req(url, "", "GET")
}
func post(url string, payload string) ([]byte, error) {
return req(url, payload, "POST")
}
func req(url string, payload string, method string) ([]byte, error) {
var body []byte
var req *http.Request
var res *http.Response
var err error
if req, err = http.NewRequest(method, url, strings.NewReader(payload)); err != nil {
return nil, err
}
for key, value := range HEADERS {
req.Header.Add(key, value)
}
if res, err = http.DefaultClient.Do(req); err != nil {
return nil, err
}
defer res.Body.Close()
if body, err = ioutil.ReadAll(res.Body); err != nil {
return nil, err
}
return body, nil
}
func uploadImageSamples(model_id string, sample_images []string) error {
url := fmt.Sprintf("%s/images/models/%s/samples/url", BASE_API, model_id)
payload, _ := json.Marshal(map[string][]string{"images": sample_images})
var err error
if _, err = post(url, string(payload)); err != nil {
return err
}
return nil
}
func queueTrainingJob(model_id string) (string, string, error) {
url := fmt.Sprintf("%s/images/models/%s/queue", BASE_API, model_id)
var err error
var response []byte
if response, err = post(url, ""); err != nil {
return "", "", err
}
data := make(map[string]string)
if err = json.Unmarshal(response, &data); err != nil {
return "", "", err
}
version_id := data["id"]
status := data["status"]
fmt.Printf("Version ID: %s, Status: %s\n", version_id, status)
return version_id, status, nil
}
func getModelVersion(modelId, versionId string) (string, string, error) {
url := fmt.Sprintf("%s/images/models/%s/versions/%s", BASE_API, modelId, versionId)
var err error
var response []byte
if response, err = get(url); err != nil {
return "", "", err
}
data := make(map[string]interface{})
if err = json.Unmarshal(response, &data); err != nil {
return "", "", err
}
status := data["status"].(string)
fmt.Printf("Version ID: %s. Status: %s\n", versionId, status)
return versionId, status, nil
}
func generateImage(modelId, prompt string) (string, string, error) {
url := fmt.Sprintf("%s/images/models/%s/inferences", BASE_API, modelId)
payload, _ := json.Marshal(map[string]interface{}{
"prompt": prompt,
"steps": 50,
"width": 512,
"height": 512,
"numberOfImages": 4,
"seed": 4523184,
"enhancePrompt": true,
"restoreFaces": true,
})
var err error
var response []byte
if response, err = post(url, string(payload)); err != nil {
return "", "", err
}
data := make(map[string]interface{})
if err = json.Unmarshal(response, &data); err != nil {
return "", "", err
}
inferenceId := data["id"].(string)
status := data["status"].(string)
fmt.Printf("InferenceId: %s, Status: %s\n", inferenceId, status)
return inferenceId, status, nil
}
func getInferenceJob(modelId string, inferenceId string) (string, string, []string, error) {
url := fmt.Sprintf("%s/images/models/%s/inferences/%s", BASE_API, modelId, inferenceId)
var err error
var response []byte
if response, err = get(url); err != nil {
return "", "", nil, err
}
data := make(map[string]interface{})
if err = json.Unmarshal(response, &data); err != nil {
return "", "", nil, err
}
var state string = ""
if data["state"] != nil {
state = data["state"].(string)
}
var images []string = make([]string, 5)
if data["images"] != nil {
for _, image := range data["images"].([]interface{}) {
images = append(images, image.(map[string]interface{})["uri"].(string))
}
}
fmt.Printf("Inference ID: %s. State: %s\n", inferenceId, state)
return inferenceId, state, images, nil
}
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment