Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save therealnaveenkamal/789231912084fe15df61d8b2c611c162 to your computer and use it in GitHub Desktop.
Save therealnaveenkamal/789231912084fe15df61d8b2c611c162 to your computer and use it in GitHub Desktop.
This code manually splits up the dataset into train, validation and test with 20:4:1 as their respective ratios.
dataframe = pd.read_csv("Data_Entry_2017_v2020.csv")
#Enumerating all column names
columns = ["Image"]
for i in dataframe["Finding Labels"].values:
for j in i.split("|"):
if j not in columns:
columns.append(j)
labels = columns.copy()
labels.remove("Image")
#Taking the first 10000 images from the master table as the train dataset
trainset = pd.DataFrame(columns = columns)
for i in range(10000):
col = [0]*len(columns)
col[0] = dataframe["Image Index"][i]
count = 1
for j in columns[1:]:
if(j in dataframe["Finding Labels"][i]):
col[count] = 1
count+=1
trainset.loc[len(trainset)] = col
#Taking the next 2000 images from the master table as the validation dataset
valset = pd.DataFrame(columns = columns)
for i in range(10000, 12000):
col = [0]*len(columns)
col[0] = dataframe["Image Index"][i]
count = 1
for j in columns[1:]:
if(j in dataframe["Finding Labels"][i]):
col[count] = 1
count+=1
valset.loc[len(valset)] = col
#Taking the next 500 images from the master table as the test dataset
testset = pd.DataFrame(columns = columns)
for i in range(15000, 15500):
col = [0]*len(columns)
col[0] = dataframe["Image Index"][i]
count = 1
for j in columns[1:]:
if(j in dataframe["Finding Labels"][i]):
col[count] = 1
count+=1
testset.loc[len(testset)] = col
#Plotting first 16 images with their disease labels
img_dir = "images"
plt.figure(figsize = (15,15))
for i in range(16):
plt.subplot(4, 4, i+1)
plt.imshow(plt.imread(os.path.join(img_dir, trainset["Image"][i])), cmap = "gray")
plt.title(dataframe[dataframe["Image Index"] == trainset["Image"][i]].values[0][1])
plt.tight_layout()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment