Skip to content

Instantly share code, notes, and snippets.

@innerlee
Last active November 30, 2016 01:57
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save innerlee/67f588332666abbab1601c0f0f848cf8 to your computer and use it in GitHub Desktop.
Save innerlee/67f588332666abbab1601c0f0f848cf8 to your computer and use it in GitHub Desktop.
using HDF5
using JLD
using LIBLINEAR
println("> svm on open image")
# read features (2048,165659)
features = h5read("data/grand5_feature.h5", "global_pool")[1, 1, :, :]
ids = readdlm("data/redis_val_list_with_id.txt")[:, 1]
# "000026e7ee790996" "/m/01cbzq" 1.0
raw_labels = readdlm("data/labels.csv", ',')[2:end,[1,3,4]]
label_dict = Dict()
label_dict_multiple = Dict()
# remove too frequent labels. threshoud by quantile
bad = []
remove_frequent = length(ARGS) < 2 ? 0 : parse(Float64, ARGS[2]) #0, 0.9, 0.99
if remove_frequent > 0
bad = Set(readdlm("data/bad$remove_frequent.txt")[:])
end
println("remove frequent: ", remove_frequent)
for i = 1:size(raw_labels, 1)
(id, label, score) = raw_labels[i, :]
label in bad && continue
if !haskey(label_dict, id)
label_dict[id] = score > 0 ? label : "unknown"
label_dict_multiple[id] = score > 0 ? [label] : []
elseif score > 0
label_dict[id] = label
push!(label_dict_multiple[id], label)
end
end
# get gt,
# remove_frequent = 0, 2881 of them are "unknown"
# remove_frequent = 0.99, 9565 of them are "unknown"
gt = map(x->get(label_dict, x, "unknown"), ids)
gt_m = map(x->get(label_dict_multiple, x, "unknown"), ids)
fil = gt .!= "unknown"
# 2048×165623
features = features[:, fil]
# 165623
gt = gt[fil]
gt_m = gt_m[fil]
# shuffle them
perm = shuffle(1:length(gt))
features = features[:, perm]
gt = gt[perm]
gt_m = gt_m[perm]
# split train/test 50/50
num_train = floor(Int, length(gt) / 2)
train_features = features[:, 1:num_train]
train_gt = gt[1:num_train]
test_features = features[:, num_train+1:end]
test_gt = gt[num_train+1:end]
test_gt_m = gt_m[num_train+1:end]
## parse args
length(ARGS) == 0 && push!(ARGS, "$(length(train_gt))")
train_num = eval(parse("$(ARGS[1])"))
train_features = train_features[:, 1:min(end, train_num)]
train_gt = train_gt[1:min(end, train_num)]
println("train size: ", size(train_features))
println("test size: ", size(test_features))
## train & test
tic()
m = linear_train(train_gt, train_features)
(pred, scores) = linear_predict(m, test_features)
toc()
top5list = [m.labels[sortperm(scores[:, i], rev=true)[1:5]] for i=1:size(scores, 2)]
println("top-1 accuracy: ", mean(pred[:] .== test_gt[:]))
println("top-5 accuracy: ", mean(in.(test_gt, top5list)))
println("top-1 tag recall: ", mean(in.(pred, test_gt_m)))
println("top-5 tag recall: ", mean(length(setdiff(test_gt_m[i], top5list[i])) != length(test_gt_m[i]) for i=1:length(test_gt_m)))
## results
# remove frequent: 0
# train size: (2048,80000)
# test size: (2048,81389)
# elapsed time: 7735.106875664 seconds
# top-1 accuracy: 0.49107373232254975
# top-5 accuracy: 0.7078966445097004
# top-1 tag recall: 0.6163609332956542
# top-5 tag recall: 0.8360835002273035
# remove frequent: 0.99
# train size: (2048,78047)
# test size: (2048,78047)
# elapsed time: 10830.214292915 seconds
# top-1 accuracy: 0.3900854613245865
# top-5 accuracy: 0.6232270298666188
# top-1 tag recall: 0.5067843735185209
# top-5 tag recall: 0.7632195984470895
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment