|
# -*- coding: utf-8 -*- |
|
""" |
|
|
|
caption_image.py - basic captioning example using lavis |
|
|
|
usage: caption_image.py [-h] -i IMAGE_PATH [-m MODEL_TYPE] [-d DEVICE] [-v VERBOSE] |
|
# lavis |
|
|
|
https://github.com/salesforce/LAVIS |
|
""" |
|
|
|
# pip install git+https://github.com/salesforce/LAVIS.git -q |
|
|
|
import argparse |
|
import logging |
|
import time |
|
from pathlib import Path |
|
import re |
|
import requests |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s" |
|
) |
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
import PIL |
|
import torch |
|
from lavis.models import load_model_and_preprocess |
|
from PIL import Image |
|
|
|
|
|
def load_image(impath: str or Path) -> PIL.Image: |
|
""" |
|
load_image - load image from path |
|
|
|
:param strorPath impath: path to image |
|
:return PIL.Image: image object |
|
""" |
|
# check if impath is a URL |
|
if re.match(r"^https?://", impath): |
|
return Image.open(requests.get(impath, stream=True).raw).convert("RGB") |
|
else: |
|
impath = Path(impath) |
|
return Image.open(impath).convert("RGB") |
|
|
|
|
|
def load_and_caption_image( |
|
impath: str or Path, |
|
model_type: str = "base_coco", |
|
device: str = None, |
|
verbose: bool = False, |
|
): |
|
""" |
|
load_and_caption_image - load image and caption it using lavis |
|
|
|
:param strorPath impath: path to image |
|
:param str model_type: _description_, defaults to "base_coco" |
|
:param str device: _description_, defaults to None |
|
:param bool verbose: _description_, defaults to False |
|
:return _type_: _description_ |
|
""" |
|
logger = logging.getLogger(__name__) |
|
raw_image = load_image(impath) |
|
device = ( |
|
(torch.device("cuda" if torch.cuda.is_available() else "cpu")) |
|
if device is None |
|
else device |
|
) |
|
logger.info(f"loading model {model_type} on device {device} ...") |
|
model, vis_processors, _ = load_model_and_preprocess( |
|
name="blip_caption", model_type=model_type, is_eval=True, device=device |
|
) |
|
if verbose: |
|
logger.info(f"Loaded model:\t{model_type}") |
|
|
|
image = vis_processors["eval"](raw_image).unsqueeze(0).to(device) |
|
logger.info("running inference ...") |
|
st = time.perf_counter() |
|
caption = model.generate({"image": image})[0] |
|
rt = round(time.perf_counter() - st, 2) |
|
logger.info(f"Finished inference in {rt} seconds, caption:\t{caption}") |
|
return caption |
|
|
|
|
|
def get_parser(): |
|
""" |
|
get_parser - get parser for command line arguments |
|
|
|
:return argparse.ArgumentParser: parser |
|
""" |
|
parser = argparse.ArgumentParser( |
|
description="lavis_basic_captioning.py - basic captioning example using lavis" |
|
) |
|
parser.add_argument( |
|
"-i", |
|
"--image_path", |
|
type=str, |
|
required=True, |
|
help="path to image to caption", |
|
) |
|
parser.add_argument( |
|
"-m", |
|
"--model_type", |
|
type=str, |
|
default="base_coco", |
|
help="model type to use", |
|
) |
|
parser.add_argument( |
|
"-d", |
|
"--device", |
|
type=str, |
|
default=None, |
|
help="device to use", |
|
) |
|
parser.add_argument( |
|
"-v", |
|
"--verbose", |
|
type=bool, |
|
default=False, |
|
help="verbose", |
|
) |
|
return parser |
|
|
|
|
|
if __name__ == "__main__": |
|
parser = get_parser() |
|
args = get_parser().parse_args() |
|
caption = load_and_caption_image( |
|
impath=args.image_path, |
|
model_type=args.model_type, |
|
device=args.device, |
|
verbose=args.verbose, |
|
) |
|
print(caption) |