import labelbox as lb
import transformers
import torch
import torch.nn.functional as F
from PIL import Image
import requests
import numpy as np

# Add your API key
API_KEY = "your_api_key_here"
client = lb.Client(API_KEY)

# Get images from a Labelbox dataset
DATASET_ID = "your_dataset_id_here"
dataset = client.get_dataset(DATASET_ID)
export_task = dataset.export_v2()
export_task.wait_till_done()

data_row_urls = [dr_url['data_row']['row_data'] for dr_url in export_task.result]

# Get ResNet-50 from HuggingFace
image_processor = transformers.AutoImageProcessor.from_pretrained("microsoft/resnet-50")
model = transformers.ResNetModel.from_pretrained("microsoft/resnet-50")

img_emb = []

for url in data_row_urls:
    response = requests.get(url, stream=True)
    image = Image.open(response.raw).convert('RGB').resize((224, 224))
    img_hf = image_processor(image, return_tensors="pt")
    with torch.no_grad():
        last_layer = model(**img_hf, output_hidden_states=True).last_hidden_state
        resnet_embeddings = F.adaptive_avg_pool2d(last_layer, (1, 1))
        resnet_embeddings = torch.flatten(resnet_embeddings, start_dim=1, end_dim=3)
        img_emb.append(resnet_embeddings.cpu().numpy())

data_rows = []

for url, embedding in zip(data_row_urls, img_emb):
    data_rows.append({
        "row_data": url,
        "embeddings": [{"embedding_id": new_custom_embedding_id, "vector": embedding[0].tolist()}]
    })

dataset = client.create_dataset(name='image_custom_embedding_resnet', iam_integration=None)
task = dataset.create_data_rows(data_rows)
print(task.errors)