Skip to content

Instantly share code, notes, and snippets.

@gingerwizard
Created January 21, 2020 11:26
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save gingerwizard/2ea39c676e25600413ec2352181872b7 to your computer and use it in GitHub Desktop.
Save gingerwizard/2ea39c676e25600413ec2352181872b7 to your computer and use it in GitHub Desktop.
Code for image search
download_images.py
import argparse
import glob
import gzip
import imghdr
import urllib.request
import os
import ast
import multiprocessing
from multiprocessing import Pool, Value
import requests
from PIL import Image
parser = argparse.ArgumentParser()
parser.add_argument('--num_images', dest='num_images', default=1000)
parser.add_argument('--report_every', dest='report_every', default=10)
parser.add_argument('--images_path', dest='images_path', default='./data')
parser.add_argument('--min_width', dest='min_width', default=240)
parser.add_argument('--min_width', dest='max_height', default=350)
args = parser.parse_args()
num_images = int(args.num_images)
report_every = int(args.report_every)
images_path = args.images_path
data_path = 'metadata.json.gz'
counter = None
def init(args):
''' store the counter for later use '''
global counter
counter = args
#NUM_CPU = multiprocessing.cpu_count()*10
NUM_CPU = 1
if not os.path.isdir(images_path):
os.makedirs(images_path)
def download_file(url, filename):
print('Downloading %s to %s' % (url, filename), flush=True)
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
f.flush()
def parse(num_cpu, modulo):
with gzip.open('metadata.json.gz', 'rb') as f:
for i, l in enumerate(f):
if i % num_cpu == modulo:
yield ast.literal_eval(l.decode())
def download_files(modulo):
for i, data in enumerate(parse(NUM_CPU, modulo)):
if 'imUrl' in data and data['imUrl'] is not None and 'categories' in data and data['imUrl'].split('.')[-1] == 'jpg':
url = data['imUrl']
try:
path = os.path.join(images_path, data['asin']+'.jpg')
if not os.path.isfile(path):
r = requests.get(url, allow_redirects=True, timeout=1)
if r.status_code == 200:
open(path, 'wb').write(r.content)
if imghdr.what(path) != 'jpeg':
print('Removed {} it is a {}'.format(path, imghdr.what(path)))
os.remove(path)
else:
global counter
with counter.get_lock():
if counter.value == num_images:
break
counter.value += 1
if counter.value % report_every == 0:
print('Downloaded %s files' % counter.value)
else:
print('Unable to download {} - response {}'.format(url, r.status_code))
except:
print('Error downloading {}'.format(url))
if not os.path.exists(data_path):
download_file('https://s3.us-east-2.amazonaws.com/mxnet-public/stanford_amazon/metadata.json.gz', 'metadata.json.gz')
counter = Value('i', 0)
pool = Pool(processes=NUM_CPU, initializer = init, initargs = (counter,))
results = pool.map(download_files, list(range(NUM_CPU)))
create_features.py
import argparse
import ast
import glob
import gzip
import os
import shutil
import tempfile
import multiprocessing
import mxnet as mx
from multiprocessing import Pool, Value
import requests
import sys
from elasticsearch import Elasticsearch
from elasticsearch.helpers import parallel_bulk
from mxnet import nd
from mxnet.gluon.data.vision import ImageFolderDataset
from mxnet.gluon.model_zoo import vision
from os.path import join
from mxnet.image import image
parser = argparse.ArgumentParser()
parser.add_argument('--images_path', dest='images_path', default='./data')
parser.add_argument('--report_every', dest='report_every', default=10)
parser.add_argument('--features_file', dest='features_file', default='features.csv')
parser.add_argument('--es_host', dest='es_host', required=True)
parser.add_argument('--es_user', dest='es_user', required=True)
parser.add_argument('--es_password', dest='es_password', required=True)
args = parser.parse_args()
es_host = args.es_host
es_user = args.es_user
es_password = args.es_password
report_every = int(args.report_every)
NUM_CPU = multiprocessing.cpu_count()
SIZE = (224, 224)
MEAN_IMAGE= mx.nd.array([0.485, 0.456, 0.406])
STD_IMAGE = mx.nd.array([0.229, 0.224, 0.225])
data_path = 'metadata.json.gz'
counter = None
def init(args):
''' store the counter for later use '''
global counter
counter = args
ctx = mx.gpu() if len(mx.test_utils.list_gpus()) else mx.cpu()
net = vision.resnet18_v2(pretrained=True, ctx=ctx).features
net.hybridize()
net(mx.nd.ones((1,3,224,224), ctx=ctx))
if os.path.exists('model'):
shutil.rmtree('model')
os.mkdir('model')
net.export(join('model','visualsearch'))
def download_file(url, filename):
print('Downloading %s to %s' % (url, filename), flush=True)
with requests.get(url, stream=True) as r:
r.raise_for_status()
with open(filename, 'wb') as f:
for chunk in r.iter_content(chunk_size=8192):
if chunk: # filter out keep-alive new chunks
f.write(chunk)
f.flush()
def transform(image, label):
resized = mx.image.resize_short(image, SIZE[0]).astype('float32')
cropped, crop_info = mx.image.center_crop(resized, SIZE)
cropped /= 255.
normalized = mx.image.color_normalize(cropped,
mean=MEAN_IMAGE,
std=STD_IMAGE)
transposed = nd.transpose(normalized, (2,0,1))
return transposed, label
if not os.path.exists(data_path):
download_file('https://s3.us-east-2.amazonaws.com/mxnet-public/stanford_amazon/metadata.json.gz', 'metadata.json.gz')
empty_folder = tempfile.mkdtemp()
dataset = ImageFolderDataset(root=empty_folder, transform=transform)
num_images = len(glob.glob(os.path.join(args.images_path, '*.jpg')))
def generate_feature(num_cpu, modulo):
global counter
with gzip.open('metadata.json.gz', 'rb') as f:
for i, l in enumerate(f):
if counter.value >= num_images:
break
if i % num_cpu == modulo:
image_doc = ast.literal_eval(l.decode())
if 'imUrl' in image_doc and image_doc['imUrl'] is not None and 'categories' in image_doc and image_doc['imUrl'].split('.')[-1] == 'jpg':
path = os.path.join(args.images_path, image_doc['asin']+'.jpg')
if os.path.isfile(path):
try:
img = image.imread(path, 1)
data= transform(img, None)[0]
data = nd.stack(*data)
data = (data,)
output = net(nd.stack(*data)).asnumpy().squeeze()
image_doc['image'] = output.tolist()
image_doc['sum'] = output.sum().item()
yield {
'_index': 'images',
'_source': image_doc,
'_id': image_doc['asin']
}
except:
print('Unable to process {}'.format(path))
def index_images(modulo):
cnt = 0
global counter
es = Elasticsearch(hosts=[es_host], http_auth=(es_user, es_password), use_ssl=True,
verify_certs=True)
for success, info in parallel_bulk(
es,
generate_feature(NUM_CPU, modulo),
thread_count=4,
chunk_size=100
):
if success:
cnt += 1
if cnt % 100 == 0:
with counter.get_lock():
if counter.value >= num_images:
break
counter.value += 100
if counter.value % report_every == 0:
print('Processed %s files' % counter.value)
else:
print('Doc failed', info)
counter = Value('i', 0)
pool = Pool(processes=NUM_CPU, initializer = init, initargs = (counter,))
results = pool.map(index_images, list(range(NUM_CPU)))
Example query
{
"query": {
"script_score": {
"query": {
"bool": {
"must_not": [
{
"terms": {
"asin": [
"0831769122"
]
}
}
]
}
},
"script": {
"source": "cosineSimilarity(params.query_vector, doc[\"image\"]) + 1.0",
"params": {
"query_vector": [
0,
0,
1.540963053703308,
0.4241706132888794,
0.20890426635742188,
1.6118087768554688,
0.835717499256134,
0.03161628916859627,
0.4947201907634735,
0.026025529950857162,
0.13415248692035675,
0.9309512972831726,
0.662664532661438,
0.04011522978544235,
0.02185402624309063,
0.09731259942054749,
2.428490161895752,
2.644136428833008,
0.1803264319896698,
0.1870569884777069,
0.19294625520706177,
0,
0.09192559123039246,
0.9319376349449158,
1.5758663415908813,
0.2644614279270172,
0.010325467213988304,
0.5630889534950256,
0.6410951614379883,
0,
0.44744524359703064,
0.26871803402900696,
0.18495744466781616,
0.062309637665748596,
0.041979312896728516,
1.1112490892410278,
1.6697245836257935,
0.40624502301216125,
0.33760592341423035,
1.5447176694869995,
0,
1.024849534034729,
1.4908475875854492,
0.0687083750963211,
0.14537346363067627,
1.7808849811553955,
0.25620320439338684,
1.481380581855774,
0.01321045309305191,
0.7392084002494812,
0.7974129319190979,
0.009287364780902863,
0.03224838152527809,
0.3176394999027252,
0.9530961513519287,
0.029858354479074478,
0.0326579287648201,
1.0348066091537476,
0.0887528508901596,
0.548323392868042,
0.09831473976373672,
0.1980070173740387,
0.4379303455352783,
0.26791760325431824,
0.3927217125892639,
0,
1.7046030759811401,
1.6474181413650513,
0.7436933517456055,
0.4659063518047333,
0,
0.02134300395846367,
0.18011386692523956,
1.6089071035385132,
0.20326927304267883,
1.0782569646835327,
0.20381948351860046,
0.028437014669179916,
0,
0.26501229405403137,
0.2346981167793274,
0.5833343267440796,
0.15589071810245514,
1.1219643354415894,
1.4959543943405151,
0.526347279548645,
0.44229018688201904,
1.7484400272369385,
0,
0.001924035488627851,
0.4323648512363434,
0,
1.7247337102890015,
0.391903281211853,
0.9206516742706299,
0.2868269085884094,
0.20423349738121033,
1.0454078912734985,
5.650066375732422,
0.27723973989486694,
1.406110405921936,
0.1751912534236908,
0.27466142177581787,
0.9901338219642639,
0.36860889196395874,
0,
0.7706685662269592,
0.02179446630179882,
0.46216899156570435,
0.17825061082839966,
0.234298437833786,
1.1009020805358887,
0.09758861362934113,
0.18467773497104645,
1.3925981521606445,
0.06726653128862381,
0.26682764291763306,
0.5038740038871765,
0.8131492137908936,
1.554304599761963,
1.377051830291748,
0.7241430878639221,
0.9470089077949524,
1.4233512878417969,
0.03635390102863312,
0.2702191174030304,
0.24359944462776184,
0,
0.017247596755623817,
0.02519248239696026,
0.3088386654853821,
0.2786160111427307,
0.11465568840503693,
0.022849654778838158,
0.16551275551319122,
0.222029447555542,
1.4945518970489502,
1.0799797773361206,
0.2706669867038727,
0.7009544372558594,
0.0349082425236702,
3.9951984882354736,
0.6507768034934998,
0.7254361510276794,
0.5814064741134644,
0.35163232684135437,
0.06341693550348282,
0.3466331660747528,
0.5005024671554565,
0.013652285560965538,
0.7812651991844177,
0.05506034195423126,
0.6775268912315369,
0.6484507918357849,
0.577008068561554,
0.057117167860269547,
0.27949878573417664,
0.11301720142364502,
0.10693597793579102,
0.5189146995544434,
0.060935214161872864,
1.452130913734436,
0.46898379921913147,
0.13435477018356323,
1.4151155948638916,
0.17511719465255737,
0.045097485184669495,
0.14557816088199615,
0.5012498497962952,
0.286577433347702,
0.33198991417884827,
1.6132311820983887,
1.3566486835479736,
0.004973182920366526,
2.032489061355591,
0.5351644158363342,
0.030874157324433327,
0.8231183886528015,
0.18968677520751953,
0.5130855441093445,
1.2617086172103882,
0.7005868554115295,
0.6212654709815979,
0.0056333523243665695,
0.0019881627522408962,
1.2713096141815186,
0.37036728858947754,
1.669109582901001,
0,
1.08043372631073,
1.1668858528137207,
0.47758758068084717,
0,
0.3623262345790863,
0.22681497037410736,
0.4438260793685913,
0.1676025539636612,
1.3582526445388794,
0.4758371412754059,
0.41051623225212097,
0.41929686069488525,
0.7219406962394714,
0,
0.7498679161071777,
0.023717718198895454,
0.5757825970649719,
0.23130734264850616,
0.00329641904681921,
0.3130616545677185,
0.7579660415649414,
1.1737068891525269,
1.347490906715393,
1.0342113971710205,
0.38839516043663025,
0.17522811889648438,
0.16314129531383514,
0.2899016737937927,
0,
0.6980988383293152,
1.9403613805770874,
0.15018992125988007,
0.0060835424810647964,
0.12880390882492065,
0.16019460558891296,
0.02858414314687252,
1.841732144355774,
0.032862383872270584,
0.32266977429389954,
0,
0.0024477543774992228,
0.03105393797159195,
0.034291695803403854,
0.13315369188785553,
0.06081081181764603,
0.15225747227668762,
0.4152592420578003,
0.11449909210205078,
0.3351048529148102,
0.1491839587688446,
6.732390403747559,
0.18835337460041046,
0.21516548097133636,
0.014231836423277855,
0.5534474849700928,
0.6808311939239502,
0.00991903617978096,
1.1592527627944946,
0.5601569414138794,
0.9500327706336975,
0.0021685557439923286,
0.00243535079061985,
1.0119181871414185,
0.1519746482372284,
0.6072421073913574,
0.9699671864509583,
0.0021363752894103527,
0.21147915720939636,
0.14629802107810974,
0.4021129310131073,
1.0465312004089355,
0.4508935809135437,
0.000870842719450593,
0.00837160274386406,
0.3423082232475281,
0.6309516429901123,
0.2268044650554657,
0.6758613586425781,
0.06232094392180443,
1.7576279640197754,
1.3241137266159058,
0.8758181929588318,
1.5016175508499146,
0.0023167720064520836,
0.0677211806178093,
0.19002526998519897,
0.20503443479537964,
0.09659034013748169,
0.29156309366226196,
0,
0.14233161509037018,
0.202122300863266,
0.2515697777271271,
0.011867965571582317,
0.34800541400909424,
4.7250895500183105,
0.17702384293079376,
0,
0.025476960465312004,
0.6033687591552734,
0.14017435908317566,
0.08375546336174011,
0.3303719758987427,
0.04780208319425583,
0.13803647458553314,
1.6540513038635254,
0.1258217990398407,
0.013266726396977901,
0.8861592411994934,
0.45917490124702454,
0.1663738191127777,
1.5280729532241821,
2.2861108779907227,
1.0124136209487915,
0.003586244536563754,
0.2468062788248062,
1.2693372964859009,
1.5456465482711792,
0.21515671908855438,
0.012422777712345123,
1.3525168895721436,
0.1990843564271927,
1.4386425018310547,
0.8619017601013184,
0.03702832758426666,
0.08740604668855667,
0.46414855122566223,
0.5585793852806091,
1.3815761804580688,
0.5560616254806519,
0.08551550656557083,
0.9624657034873962,
1.7001765966415405,
1.0407840013504028,
0,
0.04150662571191788,
0.7548277378082275,
0.3859689235687256,
1.1470662355422974,
0.37704119086265564,
0.023213719949126244,
0.032876066863536835,
0.04660353809595108,
0,
0.12669360637664795,
2.080057382583618,
0.004261183552443981,
0.5405608415603638,
0.01517476700246334,
0.9162319898605347,
1.3057160377502441,
0.008116023615002632,
0.16229289770126343,
0.031498294323682785,
0.2667829990386963,
1.6943459510803223,
0.01712818071246147,
0.9776731133460999,
0.11202004551887512,
0.9707496166229248,
1.1572446823120117,
0.6737523078918457,
0.029537802562117577,
0.7023208737373352,
0.32520416378974915,
0.2182481586933136,
0.03016565926373005,
1.197949767112732,
0.7400628924369812,
0.49941956996917725,
0.3027303218841553,
0.06849633902311325,
1.3882217407226562,
0.12536801397800446,
1.3533254861831665,
0.22874519228935242,
0.0997890830039978,
0.6829720735549927,
0.6396713256835938,
0.038349978625774384,
0.27607297897338867,
0.0864943414926529,
3.758359909057617,
0.16314372420310974,
1.1048176288604736,
0.23269084095954895,
0.18359535932540894,
0,
0.02708718366920948,
0.030186550691723824,
0.3120630979537964,
1.6162033081054688,
1.5419354438781738,
0.22100111842155457,
1.6052680015563965,
0.8577791452407837,
1.1382622718811035,
0.8976325392723083,
0.884781539440155,
0.5179292559623718,
0.012464682571589947,
0.07565125077962875,
0.20929360389709473,
0,
0.4538927972316742,
0.1732257753610611,
0.14859557151794434,
0.005749203264713287,
0.09260174632072449,
0.5429206490516663,
0.0031830084044486284,
0.28554490208625793,
0.5171560645103455,
0,
0,
0.03407849371433258,
0.02496541291475296,
0.24954873323440552,
1.5160486698150635,
1.2317144870758057,
0.9520125389099121,
1.7308859825134277,
0.13504870235919952,
0.27686750888824463,
0.8378607630729675,
0.39962339401245117,
0.5448357462882996,
0,
0.16238054633140564,
0.27044668793678284,
0.008835949935019016,
0.989281952381134,
0.1033429279923439,
0.16202236711978912,
0.2865678668022156,
0.33581987023353577,
0.10500624775886536,
2.3277668952941895,
0.2940566837787628,
0.0969436764717102,
0.49415504932403564,
0,
0.4888608455657959,
0.2216196209192276,
1.2575459480285645,
0.4807666540145874,
0.10767889767885208,
0.006442478857934475,
0.6898377537727356,
0.047032199800014496,
0.23261453211307526,
0.5975551009178162,
0.6295201182365417,
0.013620519079267979,
0.18975409865379333,
0.011193688027560711,
0.23144879937171936,
1.199135661125183,
0.013570062816143036,
0.19454814493656158,
0.37536388635635376,
0.020356152206659317,
1.1322823762893677,
0.01645093783736229,
0.02044965699315071,
2.033820390701294,
0.9797329306602478,
0.040736425668001175,
0.11488523334264755,
0.45224854350090027,
0.16105255484580994,
0,
0,
1.8725911378860474,
0.2818228304386139,
0.8766666054725647,
0.050560180097818375,
0.0024117210414260626,
0.07966896891593933,
2.299119234085083,
0.15698544681072235,
0.2935243546962738,
0.28591927886009216,
0.10818587243556976,
0.2946636974811554,
0.26486584544181824,
0.03206818923354149,
0.00015118109877221286,
0.008543320000171661,
3.5924277305603027,
0.9469864368438721,
0.46832993626594543,
0,
0.06227591633796692,
0.027538003399968147,
0.47737210988998413,
0,
0.011262851767241955,
0.009047379717230797,
0.27162283658981323,
0.0899040549993515,
0,
0.6813708543777466,
0.37537461519241333,
0,
0.06601111590862274,
0.14864102005958557,
0.3240663707256317,
0.02522111125290394,
0.09855339676141739,
0.49846401810646057,
0.22989211976528168,
0.008297656662762165,
1.890407919883728,
0,
1.8116880655288696,
0.26815447211265564,
1.5780898332595825,
0,
1.2144482135772705,
0.6554135084152222,
0.46947067975997925,
0.8747156858444214
]
}
}
}
},
"size": 6,
"_source": [
"asin",
"title",
"imUrl",
"categories"
]
}
Create query time feature vector
import io
from mxnet import gluon
import mxnet as mx
from mxnet.image import image
import mxnet as mx
from mxnet import gluon, nd
SIZE = (224, 224)
MEAN_IMAGE= mx.nd.array([0.485, 0.456, 0.406])
STD_IMAGE = mx.nd.array([0.229, 0.224, 0.225])
class ImageService:
def __init__(self, model_arch, model_params):
ctx = mx.gpu() if len(mx.test_utils.list_gpus()) else mx.cpu()
self._model = gluon.nn.SymbolBlock.imports(model_arch, ['data'], model_params, ctx=ctx)
def _transform(self, image):
resized = mx.image.resize_short(image, SIZE[0]).astype('float32')
cropped, crop_info = mx.image.center_crop(resized, SIZE)
cropped /= 255.
normalized = mx.image.color_normalize(cropped,
mean=MEAN_IMAGE,
std=STD_IMAGE)
transposed = nd.transpose(normalized, (2, 0, 1))
return transposed
def create_feature(self, bytes):
image_np = image.imdecode(bytes)
image_t = self._transform(nd.array(image_np[:, :, :3]))
data = nd.stack(*image_t)
data = (data,)
vector = self._model(nd.stack(*data)).asnumpy().squeeze()
return vector.tolist()
def resize(self, image, width, height):
if image.size[0] < image.size[1]:
wpercent = (width/float(image.size[0]))
hsize = int((float(image.size[1]) * float(wpercent)))
img = image.resize((width, hsize))
else:
hpercent = (height / float(image.size[1]))
wsize = int((float(image.size[0]) * float(hpercent)))
img = image.resize((wsize, height))
byteIO = io.BytesIO()
img.save(byteIO, format='JPEG')
return byteIO.getvalue()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment