Skip to content

Instantly share code, notes, and snippets.

@iankelk
Last active September 1, 2023 21:26
Clarifai Image Predictions Quick Start
!pip install -q clarifai-grpc && pip install --upgrade --no-deps -q protobuf
import os
from io import BytesIO
import skimage
import IPython.display
import matplotlib.pyplot as plt
from PIL import Image
import numpy as np
%matplotlib inline
from clarifai_grpc.channel.clarifai_channel import ClarifaiChannel
from clarifai_grpc.grpc.api import resources_pb2, service_pb2, service_pb2_grpc
from clarifai_grpc.grpc.api.status import status_pb2, status_code_pb2
# Construct the communications channel
channel = ClarifaiChannel.get_grpc_channel()
# Construct the V2Stub object for accessing all the Clarifai API functionality
stub = service_pb2_grpc.V2Stub(channel)
PAT = 'YOUR_PAT_HERE'
USER_ID = 'clarifai'
APP_ID = 'main'
metadata = (('authorization', 'Key ' + PAT),)
userDataObject = resources_pb2.UserAppIDSet(user_id=USER_ID, app_id=APP_ID)
# images in skimage to use
descriptions = [
"page",
"chelsea",
"astronaut",
"rocket",
"motorcycle_right",
"camera",
"horse",
"coffee"
]
original_images = []
images = []
plt.figure(figsize=(16, 5))
for filename in [filename for filename in os.listdir(skimage.data_dir) if filename.endswith(".png") or filename.endswith(".jpg")]:
name = os.path.splitext(filename)[0]
if name not in descriptions:
continue
image = Image.open(os.path.join(skimage.data_dir, filename)).convert("RGB")
plt.subplot(2, 4, len(images) + 1)
plt.imshow(image)
plt.title(f"{filename}")
plt.xticks([])
plt.yticks([])
original_images.append(image)
images.append(image)
plt.tight_layout()
inputs = []
for image in images:
buffered = BytesIO()
image.save(buffered, format="PNG")
inputs.append(
resources_pb2.Input(
data=resources_pb2.Data(
image=resources_pb2.Image(
base64=buffered.getvalue()
)
)
)
)
# Choose the general visual classifier
MODEL_ID = 'general-image-recognition'
post_model_outputs_response = stub.PostModelOutputs(
service_pb2.PostModelOutputsRequest(
user_app_id=userDataObject, # The userDataObject is created in the overview and is required when using a PAT
model_id=MODEL_ID,
inputs=inputs
),
metadata=metadata
)
if post_model_outputs_response.status.code != status_code_pb2.SUCCESS:
print(post_model_outputs_response.status)
raise Exception("Post model outputs failed, status: " + post_model_outputs_response.status.description)
plt.figure(figsize=(16, 16))
for i, (image, output) in enumerate(zip(original_images, post_model_outputs_response.outputs)):
top_preds = {
concept.name: concept.value for concept in output.data.concepts[:5]
}
plt.subplot(4, 4, 2 * i + 1)
plt.imshow(image)
plt.axis("off")
plt.subplot(4, 4, 2 * i + 2)
y = np.arange(len(top_preds))
plt.grid()
plt.barh(y, list(top_preds.values()))
plt.gca().invert_yaxis()
plt.gca().set_axisbelow(True)
plt.yticks(y, list(top_preds.keys()))
plt.xlabel("probability")
plt.subplots_adjust(wspace=0.6)
plt.show()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment