Skip to content

Instantly share code, notes, and snippets.

@zhihou7
Created January 18, 2022 03:10
Show Gist options
  • Save zhihou7/d3c7563a1b8244bf3914ed0853c7ea04 to your computer and use it in GitHub Desktop.
Save zhihou7/d3c7563a1b8244bf3914ed0853c7ea04 to your computer and use it in GitHub Desktop.
import os
import json
train = json.load(open("data/hico_20160224_det/annotations/trainval_hico.json"))
import sys
zero_shot_type =3
if len(sys.argv) > 1:
zero_shot_type = int(sys.argv[1])
unseen_idx = {}
if zero_shot_type == 3:
unseen_idx = {509, 279, 280, 402, 504, 286, 499, 498, 289, 485, 303, 311, 325, 439, 351, 358, 66, 427, 379, 418, 70, 416,
389, 90, 395, 76, 397, 84, 135, 262, 401, 592, 560, 586, 548, 593, 526, 181, 257, 539, 535, 260, 596, 345, 189,
205, 206, 429, 179, 350, 405, 522, 449, 261, 255, 546, 547, 44, 22, 334, 599, 239, 315, 317, 229, 158, 195,
238, 364, 222, 281, 149, 399, 83, 127, 254, 398, 403, 555, 552, 520, 531, 440, 436, 482, 274, 8, 188, 216, 597,
77, 407, 556, 469, 474, 107, 390, 410, 27, 381, 463, 99, 184, 100, 292, 517, 80, 333, 62, 354, 104, 55, 50,
198, 168, 391, 192, 595, 136, 581}
elif zero_shot_type == 4:
unseen_idx = {38, 41, 20, 18, 245, 11, 19, 154, 459, 42, 155, 139, 60, 461, 577, 153, 582, 89, 141, 576, 75, 212, 472, 61,
457, 146, 208, 94, 471, 131, 248, 544, 515, 566, 370, 481, 226, 250, 470, 323, 169, 480, 479, 230, 385, 73,
159, 190, 377, 176, 249, 371, 284, 48, 583, 53, 162, 140, 185, 106, 294, 56, 320, 152, 374, 338, 29, 594, 346,
456, 589, 45, 23, 67, 478, 223, 493, 228, 240, 215, 91, 115, 337, 559, 7, 218, 518, 297, 191, 266, 304, 6, 572,
529, 312, 9, 308, 417, 197, 193, 163, 455, 25, 54, 575, 446, 387, 483, 534, 340, 508, 110, 329, 246, 173, 506,
383, 93, 516, 64}
elif zero_shot_type == 11:
unseen_idx = [111, 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125,
126, 127, 128, 224, 225, 226, 227, 228, 229, 230, 231, 290, 291, 292, 293,
294, 313, 314, 315, 316, 317, 318, 319, 320, 321, 322, 323, 324, 336, 337,
338, 339, 340, 341, 418, 419, 420, 421, 422, 423, 424, 425, 426, 427, 428,
429, 430, 431, 432, 433, 453, 454, 455, 456, 457, 458, 459, 460, 461, 462,
463, 464, 465, 466, 467, 468, 469, 470, 471, 472, 473, 533, 534, 535, 536,
537, 558, 559, 560, 561, 595, 596, 597, 598, 599]
# miss [ 5, 6, 28, 56, 88] verbs 006 break 007 brush_with 029 flip 057 move 089 slide
elif zero_shot_type == 7:
# 24 rare merge of zs3 & zs4
unseen_idx = [509, 279, 280, 402, 504, 286, 499, 498, 289, 485, 303, 311, 325, 439, 351, 358, 66, 427, 379, 418, 70, 416, 389,
90, 38, 41, 20, 18, 245, 11, 19, 154, 459, 42, 155, 139, 60, 461, 577, 153, 582, 89, 141, 576, 75, 212, 472, 61,
457, 146, 208, 94, 471, 131, 248, 544, 515, 566, 370, 481, 226, 250, 470, 323, 169, 480, 479, 230, 385, 73, 159,
190, 377, 176, 249, 371, 284, 48, 583, 53, 162, 140, 185, 106, 294, 56, 320, 152, 374, 338, 29, 594, 346, 456, 589,
45, 23, 67, 478, 223, 493, 228, 240, 215, 91, 115, 337, 559, 7, 218, 518, 297, 191, 266, 304, 6, 572, 529, 312,
9]
ntrain = []
stat = {}
pre_anno_length = 0
after_anno_length = 0
for item in train:
litem = []
for item1 in item['hoi_annotation']:
if item1['hoi_category_id'] - 1 not in unseen_idx:
# seen
if item1['hoi_category_id'] in stat:
stat[item1['hoi_category_id']] +=1
else:
stat[item1['hoi_category_id']] =1
litem.append(item1)
pre_anno_length += len(item['hoi_annotation'])
item['hoi_annotation'] =litem
after_anno_length += len(litem)
if len(litem) == 0:
continue
ntrain.append(item)
print(pre_anno_length, after_anno_length)
json.dump(ntrain, open("data/hico_20160224_det/annotations/trainval_hico_zs{}.json".format(zero_shot_type), 'w'))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment