Skip to content

Instantly share code, notes, and snippets.

@YimianDai
Created August 3, 2019 18:48
Show Gist options
  • Save YimianDai/c44c344bfe369f38ffa97792ea968a38 to your computer and use it in GitHub Desktop.
Save YimianDai/c44c344bfe369f38ffa97792ea968a38 to your computer and use it in GitHub Desktop.
Count Object Scale
from data import IceSegmentation
from model import PhaseFourierTransform
import numpy as np
import scipy as sp
import matplotlib.pyplot as plt
import platform, os
from skimage import measure, color
from mxnet import nd

#######################################################
################## Hyper-Parameter ####################
#######################################################
sal_size = 64
orig_size = 512
std_coef = 12.5
kernel_size = 5

show_flag = False

if platform.system() == "Darwin":
    data_root = os.path.join('~', 'Nutstore Files', 'Dataset', 'Iceberg')
elif platform.system() == "Linux":
    data_root = os.path.join('~', 'datasets', 'Iceberg')
else:
    raise ValueError('Notice Dataset Path')


train_dataset = IceSegmentation(split="trainvaltest", mode='testval', base_size=512, crop_size=512)
mean_mat = nd.zeros((len(train_dataset), 3))
std_mat = nd.zeros((len(train_dataset), 3))

hist_arr = np.zeros(512)
for i, data in enumerate(train_dataset):
    # img = data[0]
    # mean_mat[i, :] = data[0].mean(axis=(0,1))
    # std_mat[i, :] = data[0].std(axis=(0,1))
    # print("begin")
    # print(data[1].max())
    labels = measure.label(data[1].asnumpy(), background=0)
    # print("end")
    # print(labels)
    # print(data[2])
    for i, region in enumerate(measure.regionprops(labels)):
        minr, minc, maxr, maxc = region.bbox
        height = maxr - minr + 1
        width = maxc - minc + 1
        scale = min(height, width)
        hist_arr[scale-1] += 1
        # print("BBox ", i, ", Scale: ", scale)
        # print("minr: ", minr)
        # print("minc: ", minc)
        # print("maxr: ", maxr)
        # print("maxc: ", maxc)

    # break
print(hist_arr)

# from tempfile import TemporaryFile
# hist_arr_file = TemporaryFile()
np.save("hist_arr", hist_arr)
# hist_arr_file.seek(0)
# np.load(hist_arr_file)
# print("Mean Values: ", mean_mat.mean(axis=0))
# print("Std Values: ", std_mat.mean(axis=0))
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment