Skip to content

Instantly share code, notes, and snippets.

@hanwinbi
Created September 1, 2020 09:33
Show Gist options
  • Save hanwinbi/c94d05014a79648ffa7cbdba6a53976b to your computer and use it in GitHub Desktop.
Save hanwinbi/c94d05014a79648ffa7cbdba6a53976b to your computer and use it in GitHub Desktop.
KITS19测试代码
import os
depth = 48 # 输入的深度
data_root = '/datasets/users/bihanwen/data20/kits/' # 带数据增强部分 '/datasets/KITS2020/TEST/' # '/datasets/users/bihanwen/temp/' # 原始数据图片的路径
model_path = '/datasets/users/bihanwen/model/pth_48slice/'
result_path = '/datasets/users/bihanwen/result/'
json_path = os.path.abspath('./data/no_aug_48slice') + '/' # json存放的路径
log_path = './log/log_48slice/'
def gene_dir(_dir):
if not os.path.isdir(_dir):
os.makedirs(_dir)
gene_dir(model_path)
gene_dir(json_path)
gene_dir(log_path)
print(json_path, data_root)
# 用于配置测试的路径
# rela_path = '/datasets/users/bihanwen/temp'
# abs_path = os.path.abspath(rela_path) + '/'
# abs_data_root = os.path.abspath(data_root) + '/'
# print(rela_path, abs_path)
import os
import config
import json
import cv2 as cv
from collections import OrderedDict
rela_path = config.json_path
data_path = config.data_root # 数据集的绝对路径
output_json_folder = config.json_path # 输出所有数据信息的json文件夹目录
json_dict = OrderedDict()
# 获取数据集信息的方法
def dataset_info(path):
cases = sorted(os.listdir(path)) # 将文件目录进行读取并排序
print('cases:', cases)
json_dict['case num'] = len(cases) # 创建字典,数据集中案例的数目
json_dict['case'] = list() # 案例列表
ave_size = 0 # 所有案例的平均图片大小(这里其实是算的总大小)
ave_slice_num = 0 # 所有案例的平均切片数目
for case in cases: # 遍历案例
GT = str(path+case+'/GT/') # gt案例路径
Images = str(path+case+'/Images/') # 原始图片路径
slice_path = sorted(os.listdir(GT)) #
total_image_num = len(slice_path)
print('slice name', slice_path)
count = 0 # 同一个分类的切片计数,0表示没有进行增强的
for slice in slice_path:
if slice[0] == '0':
count += 1
dirfile = str(path + case + '/GT/' + slice_path[0]) # 读取一个案例中的一张图片获得属性
print('dirfile:', dirfile)
img = cv.imread(dirfile)
size = img.shape
print(size)
ave_size += size[0]
ave_slice_num += count
print("sum_size:{0},sum_slice:{1}".format(ave_size, ave_slice_num))
# 把信息添加到字典中
dict = {'GT': GT, "Images": Images, "Total Image Num": total_image_num, "Slice num": count, "Img Size": size}
json_dict['case'].append(dict)
print(case)
json_dict['Average pic size'] = ave_size/len(cases) # 图片的平均大小
json_dict['Average slice num'] = ave_slice_num/len(cases) # 平均的切片数量
with open(os.path.join(output_json_folder, "dataset.json"), 'w') as f:
json.dump(json_dict, f, indent=4, sort_keys=True)
dataset_info(data_path)
import os
import cv2
import json
import torch
import random
import config
from collections import OrderedDict
json_dir = config.json_path
json_dict = OrderedDict()
# 从seeds.json中按6:2:2比例得到训练集、验证集、测试集
def randomDiv(seeds_path, size):
with open(seeds_path, 'r') as load_f:
load_dict = json.load(load_f)
seeds = load_dict['case']
print(seeds)
lenofsets = len(seeds)
trainsize = int(size[0] * lenofsets)
validationsize = int(size[1] * lenofsets)
# 生成随机数作为seeds字典的idx
idx = list(range(0, lenofsets))
trainDataset = random.sample(idx, trainsize)
restDataset = set(idx) - set(trainDataset)
validationDataset = random.sample(restDataset, validationsize)
testDataset = set(restDataset) - set(validationDataset)
json_dict['train case'] = get_slice_include_aug(trainDataset, seeds)
json_dict['test case'] = get_origin_slice(testDataset, seeds)
json_dict['validation case'] = get_origin_slice(validationDataset, seeds)
trainData_path = os.path.join(json_dir, 'trainData.json')
testData_path = os.path.join(json_dir, 'testData.json')
validationData_path = os.path.join(json_dir, 'validationData.json')
with open(trainData_path, 'w') as f:
traincase = json_dict["train case"]
json.dump(traincase, f, indent=4)
with open(testData_path, 'w') as f:
testcase = json_dict['test case']
json.dump(testcase, f, indent=4)
with open(validationData_path, 'w') as f:
validationcase = json_dict['validation case']
json.dump(validationcase, f, indent=4)
print('train data path', trainData_path)
return trainData_path, testData_path, testData_path
# 训练集中包括数据增强部分
def get_slice_include_aug(random_seed, seeds):
case_list = list()
loop_time = int(seeds[0]['Total Image Num']/seeds[0]['Slice num']) # 一个案例中遍历的次数,数据增强为四次
# 遍历得到的随机种子,生成对应的list
for idx in random_seed:
file_list = sorted(os.listdir(seeds[idx]['GT']))
slice_num = seeds[idx]['Slice num'] # 每个案例的切片数目不一样,获取案例的切片数
start_pos = int((seeds[idx]['Slice num'] - config.depth) / 2) # 得到此案例的中间切片位置
for i in range(loop_time):
# 第一张切片
start = file_list[1]
slice_path = seeds[idx]['GT'] + start
case_list.append(slice_path)
# 倒数切片
start = file_list[-config.depth]
slice_path = seeds[idx]['GT'] + start
case_list.append(slice_path)
# 中间切片
start = file_list[start_pos + slice_num * i]
slice_path = seeds[idx]['GT'] + start
case_list.append(slice_path)
random.shuffle(case_list) # 将选中的样例打乱
return case_list
# 测试集和验证集中不包括数据增强
def get_origin_slice(random_seed, seeds):
case_list = list()
for idx in random_seed:
file_list = sorted(os.listdir(seeds[idx]['GT']))
start_pos = int((seeds[idx]['Slice num'] - config.depth) / 2) # 得到此案例的中间切片位置
start = file_list[start_pos]
slice_path = seeds[idx]['GT'] + start
case_list.append(slice_path)
# 增加前48张和后48张
start = file_list[0]
slice_path = seeds[idx]['GT'] + start
case_list.append(slice_path)
start = file_list[-config.depth]
slice_path = seeds[idx]['GT'] + start
case_list.append(slice_path)
# print(case_list)
random.shuffle(case_list)
return case_list
trainData, validationData, testData = randomDiv(json_dir+'dataset.json', (0.6, 0.2, 0.2))
import torch
from nnunet.utilities.nd_softmax import softmax_helper
from nnunet.utilities.tensor_utilities import sum_tensor
from torch import nn
from utils import make_one_hot_3d
class SoftDiceLoss(nn.Module):
def __init__(self, smooth=1., apply_nonlin=None, batch_dice=True, do_bg=False, smooth_in_nom=True,
background_weight=1, rebalance_weights=None, square_nominator=False, square_denom=False):
"""
hahaa no documentation for you today
:param smooth:
:param apply_nonlin:
:param batch_dice:
:param do_bg:
:param smooth_in_nom:
:param background_weight:
:param rebalance_weights:
"""
super(SoftDiceLoss, self).__init__()
self.square_denom = square_denom
self.square_nominator = square_nominator
if not do_bg:
assert background_weight == 1, "if there is no bg, then set background weight to 1 you dummy"
self.rebalance_weights = rebalance_weights
self.background_weight = background_weight
if smooth_in_nom:
self.smooth_in_nom = smooth
else:
self.smooth_in_nom = 0
self.do_bg = do_bg
self.batch_dice = batch_dice
self.apply_nonlin = apply_nonlin
self.smooth = smooth
self.y_onehot = None
def forward(self, x, y):
with torch.no_grad():
y = y.long()
shp_x = x.shape
# print('x shape is:',shp_x)
shp_y = y.shape
# print('y shape is:',shp_y)
#y shape is: torch.Size([8, 1, 192, 192, 48])
# nonlin maybe mean NONLINEARITY!
if self.apply_nonlin is not None:
x = self.apply_nonlin(x)
if len(shp_x) != len(shp_y): # 统一维度
y = y.view((shp_y[0], 1, *shp_y[1:]))
# print('After apply nonlin, x shape is:',x.shape)
#After apply nonlin, x shape is: torch.Size([8, 3, 192, 192, 48])
# output shape is: [8,3,192,192,48] when batch size is 8 and labels are [0,1,2]
# now x and y should have shape (B, C, X, Y(, Z))) and (B, 1, X, Y(, Z))), respectively
y_onehot = torch.zeros(shp_x)
if x.device.type == "cuda":
y_onehot = y_onehot.cuda(x.device.index)
y_onehot.scatter_(1, y, 1)
if not self.do_bg:
x = x[:, 1:]# This means to reduce the first 0 dimension of the shape of output x, to remove background prediction
# x is the probability output, so its range is between [0,1]
y_onehot = y_onehot[:, 1:]
# print('y_onehot shape is:',y_onehot.shape)
# y_onehot shape is: torch.Size([8, 2, 192, 192, 48])
# print('The last version of x shape is:',x.shape)
#The last version of x shape is: torch.Size([8, 2, 192, 192, 48])
# print('x max is:', torch.max(x))
# x max is: tensor(1.0000, device='cuda:4', grad_fn=<MaxBackward1>)
# print('x min is:', torch.min(x))
# x min is: tensor(3.1973e-07, device='cuda:4', grad_fn=<MinBackward1>)
# print('y_onehot max is:', torch.max(y_onehot))
#y_onehot max is: tensor(1., device='cuda:4')
#x max is: tensor(1.0000, device='cuda:4', grad_fn=<MaxBackward1>)
# print('y_onehot min is:', torch.min(y_onehot))
#y_onehot min is: tensor(0., device='cuda:4')
if not self.batch_dice:
if self.background_weight != 1 or (self.rebalance_weights is not None):
raise NotImplementedError("nah son")
l = soft_dice(x, y_onehot, self.smooth, self.smooth_in_nom, self.square_nominator, self.square_denom)
# print('Using soft_dice!')
else:
l = soft_dice_per_batch_2(x, y_onehot, self.smooth, self.smooth_in_nom,
background_weight=self.background_weight,
rebalance_weights=self.rebalance_weights)
# print('Using soft_dice_per_batch_2!')
# Here we use the soft_dice_per_batch_2
# print('dc shape is:',l.size())
# dc_loss is: tensor(-0.0282, device='cuda:4', grad_fn=<MeanBackward0>)
return l
def soft_dice_per_batch_2(net_output, gt, smooth=1., smooth_in_nom=1., background_weight=1, rebalance_weights=None,
square_nominator=False, square_denom=False):
if rebalance_weights is not None and len(rebalance_weights) != gt.shape[1]:
rebalance_weights = rebalance_weights[1:] # this is the case when use_bg=False
# print('\nrebalance_weights is:',rebalance_weights)
#rebalance_weights is: None
axes = tuple([0] + list(range(2, len(net_output.size()))))
# print('\naxes is:',axes)
# axes is: (0, 2, 3, 4)
# print('\nnet_output shape is:',net_output.shape)
# print('\ngt shape is:',gt.shape)
# net_output shape is: torch.Size([8, 2, 192, 192, 48])
# gt shape is: torch.Size([8, 2, 192, 192, 48])
tp = sum_tensor(net_output * gt, axes, keepdim=False)
# print('\ntp is:',tp)
# tp shape is: torch.Size([2])
# tp is: tensor([62684.4570, 82510.1562], device='cuda:4', grad_fn=<SumBackward2>)
fn = sum_tensor((1 - net_output) * gt, axes, keepdim=False)
# print('\nfn is:',fn)
# fn shape is: torch.Size([2])
# fn is: tensor([195664.5312, 103144.8438], device='cuda:4', grad_fn=<SumBackward2>)
fp = sum_tensor(net_output * (1 - gt), axes, keepdim=False)
# print('\nfp is:',fp)
# fp shape is: torch.Size([2])
# fp is: tensor([3610596., 6475380.], device='cuda:4', grad_fn=<SumBackward2>)
weights = torch.ones(tp.shape)
# print('\nweights shape is:',weights.shape)
# weights shape is: torch.Size([2])
weights[0] = background_weight
# print('\nbackground_weight is:',background_weight)
# background_weight is: 1
if net_output.device.type == "cuda":
weights = weights.cuda(net_output.device.index)
if rebalance_weights is not None:
rebalance_weights = torch.from_numpy(rebalance_weights).float()
if net_output.device.type == "cuda":
rebalance_weights = rebalance_weights.cuda(net_output.device.index)
tp = tp * rebalance_weights
fn = fn * rebalance_weights
nominator = tp
if square_nominator:
nominator = nominator ** 2
if square_denom:
denom = 2 * tp ** 2 + fp ** 2 + fn ** 2
else:
denom = 2 * tp + fp + fn
# result_1=(- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights)
# print('\nresult_1 is:',result_1)
# result_1 is: tensor([-0.0616, -0.0038], device='cuda:4', grad_fn=<MulBackward0>)
dice_1 = (((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights)
# print('\ndice_1 is:',dice_1)
result_1 = torch.pow((-torch.log(dice_1[0])), 0.3)*0.4+torch.pow((-torch.log(dice_1[1])), 0.3)*0.6
# print('\nresult_1 is:',result_1)
# result = (- ((2 * nominator + smooth_in_nom) / (denom + smooth)) * weights).mean()
# print('\nresult is:',result)
# result is: tensor(-0.0327, device='cuda:4', grad_fn= < MeanBackward0 >)
# Here we should notice that the soft dice is set as negative.
return result_1
def soft_dice(net_output, gt, smooth=1., smooth_in_nom=1., square_nominator=False, square_denom=False):
axes = tuple(range(2, len(net_output.size())))
if square_nominator:
intersect = sum_tensor(net_output * gt, axes, keepdim=False)
else:
intersect = sum_tensor((net_output * gt) ** 2, axes, keepdim=False)
if square_denom:
denom = sum_tensor(net_output ** 2 + gt ** 2, axes, keepdim=False)
else:
denom = sum_tensor(net_output + gt, axes, keepdim=False)
result = (- ((2 * intersect + smooth_in_nom) / (denom + smooth))).mean()
return result
class DC_and_CE_loss(nn.Module):
def __init__(self, aggregate="sum", mssu=False): # aggregate表示ce+dc的和
super(DC_and_CE_loss, self).__init__()
self.aggregate = aggregate
self.ce = CrossentropyND()
self.dc = SoftDiceLoss(apply_nonlin=softmax_helper)
self.mssu = mssu
def forward(self, net_output, target):
# print('target shape is:{0}, output{1}'.format(target.shape, net_output.shape))
#target shape is: torch.Size([8, 1, 192, 192, 48])
ce_weights = torch.tensor([0.28, 0.28, 0.44]).to(torch.cuda.current_device())
ce_1 = CrossentropyND(weight=ce_weights)
# dc_loss = self.dc(net_output, target)
# # ce_loss = self.ce(net_output, target)
# ce1_loss = ce_1(net_output, target)
# target_layers=list()
dc_loss_layers = list()
ce_loss_layers = list()
# if isinstance(target, list):
if self.mssu:
# print('The target is list!')
# for i in range(len(target)):
for i in range(len(net_output)):
# print('net_output[%d] is cuda?'%(2*i),net_output[2*i].is_cuda)
# print('target is cuda?', target.is_cuda)
# print('target %d shape is:'%i,target.shape)
# print('net_output %d shape is:'%(2*i),net_output[2*i].shape)
# print('net_output[%d] shape is:'%i,net_output[i].shape)
# print('target[%d] shape is:' % i, target.shape)
dc_loss_layers.append(self.dc(net_output[i], target))
ce_loss_layers.append(ce_1(net_output[i], target))
dc_loss = dc_loss_layers[0]*0.4+dc_loss_layers[1]*0.2+dc_loss_layers[2]*0.1+dc_loss_layers[3]*0.1+dc_loss_layers[4]*0.1
ce_loss = ce_loss_layers[0]*0.4+ce_loss_layers[1]*0.2+ce_loss_layers[2]*0.1+ce_loss_layers[3]*0.1+ce_loss_layers[4]*0.1
# print('Final dc_loss is:',dc_loss)
# print('Final ce_loss is:',ce_loss)
if self.aggregate == "sum":
# print('ce_loss:{0},dc_loss:{1}'.format(ce_loss, dc_loss))
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later)
return result
else:
# print('Target is not list!')
dc_loss = self.dc(net_output, target)
ce_loss = ce_1(net_output, target)
if self.aggregate == "sum":
# print('ce_loss:{0},dc_loss:{1}'.format(ce_loss, dc_loss))
result = ce_loss + dc_loss
else:
raise NotImplementedError("nah son") # reserved for other stuff (later)
return result
class CrossentropyND(torch.nn.CrossEntropyLoss):
"""
Network has to have NO NONLINEARITY!
"""
def forward(self, inp, target):
target = target.long()
num_classes = inp.size()[1]
i0 = 1
i1 = 2
while i1 < len(inp.shape): # this is ugly but torch only allows to transpose two axes at once
inp = inp.transpose(i0, i1)
i0 += 1
i1 += 1
inp = inp.contiguous()
inp = inp.view(-1, num_classes)
target = target.view(-1,)
return super(CrossentropyND, self).forward(inp, target)
import torch
import torch.nn as nn
from torchsummary import summary
# [conv3d+IN+Leaky Relu+conv3d+IN],
def Conv_IN_LeRU_2s(in_dim, out_dim, kernel_size, stride, padding, activation):
return nn.Sequential(
nn.Conv3d(in_dim, out_dim, kernel_size, stride, padding),
nn.InstanceNorm3d(out_dim),
activation,
nn.Conv3d(out_dim, out_dim, kernel_size, stride, padding),
nn.InstanceNorm3d(out_dim)
)
# 跨步卷积
def stride_conv(in_dim, out_dim, kernel_size, stride, padding):
return nn.Sequential(nn.Conv3d(in_dim, out_dim, kernel_size, stride, padding))
# 残差网络
def ResNet(raw, processed):
temp = torch.add(raw, processed)
return temp
# 反卷积
def conv_trans(in_dim, out_dim, kernel_size, stride, padding):
return nn.ConvTranspose3d(in_dim, out_dim, kernel_size, stride, padding)
def de_conv_in_relu_2s(in_dim, out_dim, kernel_size, stride, padding, activation):
return nn.Sequential(
nn.Conv3d(in_dim, out_dim, kernel_size, stride, padding),
nn.InstanceNorm3d(out_dim),
activation,
nn.Conv3d(out_dim, out_dim, kernel_size=(1, 1, 1), stride=1)
)
# 三线性插值
def tri_inter(input, size, mode):
return nn.functional.interpolate(input=input, size=size, mode=mode)
class UNetStage1(nn.Module):
def __init__(self):
super(UNetStage1, self).__init__()
# 按照网络结构,进入网络后马上进行一次卷积
self.init = nn.Conv3d(1, 30, 3, 1, 1)
# 第一层
self.encoder1 = Conv_IN_LeRU_2s(30, 30, 3, 1, 1, nn.LeakyReLU())
# 加入残差网络1
self.encoder1_1 = nn.LeakyReLU()
# 第二层
# padding的计算:https://pytorch.org/docs/master/generated/torch.nn.Conv3d.html
self.stride_conv1 = stride_conv(30, 60, (3, 3, 3), (1, 2, 2), 1)
self.encoder2 = Conv_IN_LeRU_2s(60, 60, 3, 1, 1, nn.LeakyReLU())
# 加入残差网络2
self.encoder2_1 = nn.LeakyReLU()
# 第三层
self.stride_conv2 = stride_conv(60, 120, 3, 2, 1)
self.encoder3 = Conv_IN_LeRU_2s(120, 120, 3, 1, 1, nn.LeakyReLU())
# 加入残差网络3
self.encoder3_1 = nn.LeakyReLU()
# 第四层
self.stride_conv3 = stride_conv(120, 240, 3, 2, 1)
self.encoder4 = Conv_IN_LeRU_2s(240, 240, 3, 1, 1, nn.LeakyReLU())
# 加入残差网络4
self.encoder4_1 = nn.LeakyReLU()
# 第五层
self.stride_conv4 = stride_conv(240, 480, 3, 2, 1)
self.encoder5 = Conv_IN_LeRU_2s(480, 480, 3, 1, 1, nn.LeakyReLU())
# 加入残差网络5
self.encoder5_1 = nn.LeakyReLU()
# 第六层
self.stride_conv5 = stride_conv(480, 960, 3, 2, 1)
self.encoder6 = Conv_IN_LeRU_2s(960, 960, 3, 1, 1, nn.LeakyReLU())
# 加入残差网络6
self.encoder6_1 = nn.LeakyReLU()
# 第六层的ResNet结果
# decode部分
# Out = (in - 1) * stride - 2 * padding + kernel_size,
# Link: https://pytorch.org/docs/master/generated/torch.nn.ConvTranspose3d.html
self.decoder1 = conv_trans(960, 480, kernel_size=2, stride=2, padding=0)
# 进行cat操作,skip connection
self.decoder1_1 = nn.Conv3d(960, 480, 1, 1, 0) # 将通道数减少
self.decoder1_2 = de_conv_in_relu_2s(480, 480, 3, 1, 1, nn.LeakyReLU())
# ResNet
self.decoder1_3 = nn.LeakyReLU()
self.res1 = nn.Conv3d(480, 3, 3, 1, 1)
self.decoder2 = conv_trans(480, 240, kernel_size=2, stride=2, padding=0)
# 进行cat操作,skip connection
self.decoder2_1 = nn.Conv3d(480, 240, 1, 1, 0) # 将通道数减少
self.decoder2_2 = de_conv_in_relu_2s(240, 240, 3, 1, 1, nn.LeakyReLU())
# ResNet
self.decoder2_3 = nn.LeakyReLU()
self.res2 = nn.Conv3d(240, 3, 3, 1, 1)
self.decoder3 = conv_trans(240, 120, kernel_size=2, stride=2, padding=0)
# 进行cat操作,skip connection
self.decoder3_1 = nn.Conv3d(240, 120, 1, 1, 0) # 将通道数减少
self.decoder3_2 = de_conv_in_relu_2s(120, 120, 3, 1, 1, nn.LeakyReLU())
# ResNet
self.decoder3_3 = nn.LeakyReLU()
self.res3 = nn.Conv3d(120, 3, 3, 1, 1)
self.decoder4 = conv_trans(120, 60, kernel_size=2, stride=2, padding=0)
self.decoder4_1 = nn.Conv3d(120, 60, 1, 1, 0) # 将通道数减少
# 进行cat操作,skip connection
self.decoder4_2 = de_conv_in_relu_2s(60, 60, 3, 1, 1, nn.LeakyReLU())
# ResNet
self.decoder4_3 = nn.LeakyReLU()
self.res4 = nn.Conv3d(60, 3, 3, 1, 1)
self.decoder5 = conv_trans(60, 30, kernel_size=(1, 2, 2), stride=(1, 2, 2), padding=0)
# 进行cat操作,skip connection
self.decoder5_1 = nn.Conv3d(60, 30, 1, 1, 0) # 将通道数减少
self.decoder5_2 = de_conv_in_relu_2s(30, 30, 3, 1, 1, nn.LeakyReLU())
# ResNet
self.decoder5_3 = nn.LeakyReLU()
self.end = nn.Conv3d(30, 3, 3, 1, 1)
# 下采样
'''
方法名:内部处理
encoderNum:conv->ins_norm->relu conv->ins_norm
encoderNum_1: relu
stride_convNum: 跨步卷积 kernal:3*3*3 第一层stride:1*2*2 其他层:2*2*2 padding:1
'''
def get_features(self, x):
# start_time = printbar()
enc0 = self.init(x) # [N C 32 160 160] -> [N 30 32 160 160] 放入网络之前进行一次3d卷积,通道数变为30
enc1 = self.encoder1(enc0) # [N 30 32 160 160]
res1 = ResNet(enc0, enc1) # 残差块
sync1 = self.encoder1_1(res1)
# print('sync1', sync1.shape)
enc2 = self.stride_conv1(sync1) # [N 30 32 160 160] -> [N 60 32 80 80]
enc2_1 = self.encoder2(enc2)
res2 = ResNet(enc2, enc2_1)
sync2 = self.encoder2_1(res2)
# print('sync2', sync2.shape)
enc3 = self.stride_conv2(sync2) # [N 60 32 80 80] -> [N 120 16 40 40]
enc3_1 = self.encoder3(enc3)
res3 = ResNet(enc3, enc3_1)
sync3 = self.encoder3_1(res3)
# print('sync3', sync3.shape)
enc4 = self.stride_conv3(sync3) # [N 120 16 40 40] -> [N 240 8 20 20]
enc4_1 = self.encoder4(enc4)
res4 = ResNet(enc4, enc4_1)
sync4 = self.encoder4_1(res4)
# print('sync4', sync4.shape)
enc5 = self.stride_conv4(sync4) # [N 240 8 20 20] -> [N 480 4 10 10]
enc5_1 = self.encoder5(enc5)
res5 = ResNet(enc5, enc5_1)
sync5 = self.encoder5_1(res5)
# print('sync5', sync5.shape)
enc6 = self.stride_conv5(sync5) # [N 480 4 10 10] -> [N 960 2 5 5]
enc6_1 = self.encoder6(enc6)
res6 = ResNet(enc6, enc6_1)
sync6 = self.encoder6_1(res6)
# print('sync6', sync6.shape)
# end_time = printbar()
# print('Encode time:', end_time - start_time)
return sync6, sync5, sync4, sync3, sync2, sync1
# 上采样
'''
方法名:解释
decoderNum: 反卷积
skip_conNum: 跨层连接
decoderNum_1: 将拼接后结果通道数减半
decoderNum_2: conv->ins_norm->relu conv->ins_norm
decoderNum_3: relu
tri_inter: 三线性插值
'''
def upSample(self, enc):
# start_time = printbar()
dec1 = self.decoder1(enc[0]) # [N 960 2 5 5] -> [N 480 4 10 10]最后一层的结果直接进行上采样
# print("enc[0]:{0},dec1:{1}".format(enc[1].shape, dec1.shape))
skip_con1 = torch.cat((enc[1], dec1), dim=1) # [N 480 4 10 10] -> [N 960 4 10 10]
dec1_1 = self.decoder1_1(skip_con1) # [N 960 4 10 10] -> [N 480 4 10 10]
dec1_2 = self.decoder1_2(dec1_1) # [N 480 4 10 10]
resnet1 = ResNet(dec1_1, dec1_2)
resnet1 = self.decoder1_3(resnet1)
# result1 = tri_inter(resnet1, (32, 160, 160), 'trilinear')
dec2 = self.decoder2(resnet1) # [N 480 4 10 10] -> [N 240 8 20 20]
# print("enc[1]:{0},dec2:{1}".format(enc[2].shape, dec2.shape))
skip_con2 = torch.cat((enc[2], dec2), dim=1) # [N 240 8 20 20] -> [N 480 8 20 20]
dec2_1 = self.decoder2_1(skip_con2)
dec2_2 = self.decoder2_2(dec2_1)
resnet2 = ResNet(dec2_1, dec2_2)
resnet2 = self.decoder2_3(resnet2)
# result2 = tri_inter(resnet2, (32, 160, 160), 'trilinear')
dec3 = self.decoder3(resnet2) # [N 480 8 20 20] -> [N 240 16 40 40]
# print("enc3:{0},dec3:{1}".format(enc[3].shape, dec3.shape))
skip_con3 = torch.cat((enc[3], dec3), dim=1) # [N 480 16 40 40] -> [N 240 16 40 40]
dec3_1 = self.decoder3_1(skip_con3)
dec3_2 = self.decoder3_2(dec3_1)
resnet3 = ResNet(dec3_2, dec3_2)
resnet3 = self.decoder3_3(resnet3)
# result3 = tri_inter(resnet3, (32, 160, 160), 'trilinear')
dec4 = self.decoder4(resnet3) # [N 240 16 40 40] -> [N 120 32 80 80]
# print("enc[4]:{0},dec4:{1}".format(enc[4].shape, dec4.shape))
skip_con4 = torch.cat((enc[4], dec4), dim=1) # [N 120 32 80 80] -> [N 60 32 80 80]
dec4_1 = self.decoder4_1(skip_con4)
dec4_2 = self.decoder4_2(dec4_1)
resnet4 = ResNet(dec4_2, dec4_2)
resnet4 = self.decoder4_3(resnet4)
# result4 = tri_inter(resnet4, (32, 160, 160), 'trilinear')
dec5 = self.decoder5(resnet4) # [N 60 32 80 80] -> [N 30 32 160 160]
# print("enc[5]:{0},dec5:{1}".format(enc[5].shape, dec5.shape))
skip_con5 = torch.cat((enc[5], dec5), dim=1) # [N 60 32 160 160] -> [N 30 32 160 160]
dec5_1 = self.decoder5_1(skip_con5)
dec5_2 = self.decoder5_2(dec5_1)
resnet5 = ResNet(dec5_2, dec5_2)
resnet5 = self.decoder5_3(resnet5)
result5 = self.end(resnet5) # 最后一层的输出
# result4 = self.res4(result4)
# result3 = self.res3(result3)
# result2 = self.res2(result2)
# result1 = self.res1(result1)
return result5 # , result4, result3, result2, result1 # 多尺度
def forward(self, x):
enc = self.get_features(x)
res = self.upSample(enc)
return res
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") # PyTorch
# model = UNetStage1().to(device)
# summary(model, input_size=(1, 32, 160, 160), batch_size=1)
import os
import cv2 as cv
import torch
import numpy as np
import config
import json
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
dir = config.json_path
# 载入测试集[验证集、测试集]案例
def load_data(path):
with open(path, 'r') as load_f:
load_dict = json.load(load_f)
data = []
for i in range(len(load_dict)):
data.append(load_dict[i])
return data
trainData = load_data(dir+'trainData.json')
testData = load_data(dir+'testData.json')
validationData = load_data(dir+'validationData.json')
# caseData = load_case_data(dir+'trainData.json', 0)
class DataSets(Dataset):
def __init__(self, casedata):
self.transform = transforms.Compose(
[transforms.Normalize(mean=(0.485,), std=(0.229,))]
)
# print('case_path', casedata)
self.GT = casedata # load_case_data(casedata)
self.slice = config.depth # 一个病人取32张切片
def __len__(self):
return len(self.GT)
def __getitem__(self, idx):
# print('case:[%d] name:' % idx, self.GT[idx])
case_data = self.load_case_data(self.GT[idx])
imgs = self.getOriginImage(case_data)
gts = self.getGroundTruth(case_data)
gts = np.array(gts, dtype='int64')
gts = torch.from_numpy(gts)
imgs = np.array(imgs, dtype='float32')
imgs = torch.from_numpy(imgs)
imgs = self.transform(imgs)
return imgs, gts
def getOriginImage(self, casedata):
imgs = []
for i in range(config.depth):
origin_image = casedata[i].replace('GT', 'Images')
pic = cv.imread(origin_image, 0)
pic = cv.resize(pic, (160, 160))
imgs.append(pic)
return imgs
def getGroundTruth(self, casedata):
gts = []
for i in range(config.depth):
gt = cv.imread(casedata[i], 0)
gt = cv.resize(gt, (160, 160))
gt = gt / 127
gts.append(gt)
return gts
# 载入一个样例中的全部切片
def load_case_data(self, case):
# print('current case:', case)
caseData = []
for i in range(config.depth):
start = int(case[-8:-4]) # 起始图片的序号
slice_path = case[0:-8] + str("%04d" % (start + i)) + '.bmp' # 切片数是32,连续的32张
caseData.append(slice_path)
return caseData
import time
import NetModel
from utils import *
import torch.optim as optim
from torch.autograd import Variable
from torch.optim import lr_scheduler
from torch.utils.data import DataLoader
from Loss import DC_and_CE_loss
from torch.cuda.amp import GradScaler, autocast
from PrepareData import DataSets
from PrepareData import trainData, validationData, testData
from torch.utils.tensorboard import SummaryWriter
writer = SummaryWriter(config.log_path)
import warnings
warnings.filterwarnings("ignore")
lr_scheduler_eps = 1e-3
lr_scheduler_patience = 20
initial_lr = 3e-4
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = NetModel.UNetStage1()
model = model.to(device)
criterion_dc_ce = DC_and_CE_loss().to(device)
optimizer = optim.Adam(model.parameters(), initial_lr)
def train(Epoches,mpth):
start_epoch = 0
pthList = sorted(os.listdir(mpth))
print(pthList)
if not pthList:
print('starting train:')
else:
print('Continue training:')
pth = pthList[-1]
checkPoint = torch.load(mpth + pth)
model.load_state_dict(checkPoint['model'])
optimizer.load_state_dict(checkPoint['optimizer'])
start_epoch = checkPoint['epoch'] + 1
numEpoches = Epoches
flag = 1
scaler = GradScaler()
lr_scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.5, patience=lr_scheduler_patience,verbose=True, threshold=lr_scheduler_eps, threshold_mode="abs") # 降低学习率 https://pytorch.org/docs/stable/optim.html#torch.optim.lr_scheduler.ReduceLROnPlateau
for epoch in range(start_epoch, numEpoches):
t = time.time()
print('------------------------')
print('this is {} epoch.'.format(epoch))
t = time.localtime(t)
t = time.strftime("%Y-%m-%d %H:%M:%S", t)
print(t)
model.train()
oneCaseLoss = 0
data = DataSets(trainData)
dataLoader = DataLoader(data, batch_size=1, shuffle=False, num_workers=0)
start_time = time.time()
for i, (x, yy) in enumerate(dataLoader):
optimizer.zero_grad()
x = Variable(x).to(device)
y = Variable(yy).to(device)
x = x.unsqueeze(1)
# with autocast(enabled=True):
output = model(x)
loss = criterion_dc_ce(output, y) # loss为SoftDice+CrossEntropy
iterLoss = loss.item()
print('loss[%d]:' % i, iterLoss)
oneCaseLoss += iterLoss
# loss.backward()
# optimizer.step()
scaler.scale(loss).backward()
scaler.step(optimizer)
scaler.update()
end_time = time.time()
print('run time:{0},one case loss:{1}'.format(end_time-start_time, oneCaseLoss/len(trainData)))
lr = optimizer.param_groups[0]['lr']
writer.add_scalar('lr', lr, epoch)
writer.add_scalar('loss/tr_loss', oneCaseLoss/len(trainData), epoch)
if epoch % 3 == 0:
###保存模型###
savepPth = mpth + 'v10_NotAll_' + str('%.2d' % epoch) + '.pth'
state = {'model': model.state_dict(), 'optimizer': optimizer.state_dict(), 'epoch': epoch}
torch.save(state, savepPth)
if epoch % 3 == 0:
## 验证集 给出验证效果
model.eval()
oneCaseValLoss = 0
kidneyDice = 0
tumorDice = 0
data = DataSets(validationData)
dataLoader = DataLoader(data, batch_size=1, shuffle=False, num_workers=4)
for x, yy in dataLoader:
with torch.no_grad():
x = Variable(x).to(device)
y = Variable(yy).to(device)
x = x.unsqueeze(1)
output = model(x)
loss = criterion_dc_ce(output, y)
iterLoss = loss.item()
oneCaseValLoss += iterLoss
# dice指数
# output = torch.softmax(output[0], dim=1)
output = torch.softmax(output, dim=1)
dice_kidney, dice_tumor = Dice(y, output)
kidneyDice += dice_kidney
tumorDice += dice_tumor
lr_scheduler.step(oneCaseValLoss) # 用于更新学习率
print('********************************************************************')
print('**** lr: {:.8f} ****'.format(lr))
print('**** val loss: {:.8f} ****'.format(oneCaseValLoss/len(validationData)))
print('**** kidneyDice: {:.8f} ****'.format(kidneyDice/len(validationData)))
print('**** tumorDice: {:.8f} ****'.format(tumorDice/len(validationData)))
writer.add_scalar('loss/val_loss', oneCaseValLoss/len(validationData), epoch)
writer.add_scalar('Dice/kidney_dice', kidneyDice/len(validationData), epoch)
writer.add_scalar('Dice/tumor_dice', tumorDice/len(validationData), epoch)
if __name__ == "__main__":
modelPath = config.model_path
t = time.time()
t = time.localtime(t)
t = time.strftime("%Y--%m--%d %H:%M:%S", t)
print(t)
train(800, modelPath)
writer.flush()
import numpy as np
import os
import json
import torch
import datetime
import config
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
def make_one_hot_3d(x, n): # 对输入的volume数据x,对每个像素值进行one-hot编码
x = x.unsqueeze(1)
result = torch.zeros(x.shape[0], n, x.shape[2], x.shape[3], x.shape[4])
result = result.to(x.device).scatter_(1, x, 1)
return result
def expand_as_one_hot(input, C, ignore_index=None):
"""
Converts NxDxHxW label image to NxCxDxHxW, where each label gets converted to its corresponding one-hot vector
:param input: 4D input image (NxDxHxW)
:param C: number of channels/labels
:param ignore_index: ignore index to be kept during the expansion
:return: 5D output image (NxCxDxHxW)
"""
assert input.dim() == 4
# expand the input tensor to Nx1xDxHxW before scattering
input = input.unsqueeze(1)
# create result tensor shape (NxCxDxHxW)
shape = list(input.size())
shape[1] = C
if ignore_index is not None:
# create ignore_index mask for the result
mask = input.expand(shape) == ignore_index
# clone the src tensor and zero out ignore_index in the input
input = input.clone()
input[input == ignore_index] = 0
# scatter to get the one-hot tensor
result = torch.zeros(shape).to(input.device).scatter_(1, input, 1)
# bring back the ignore_index in the result
result[mask] = ignore_index
return result
else:
# scatter to get the one-hot tensor
result = torch.zeros(shape).scatter_(1, input, 1)
return result
def dice_coeff(pred, target):
pred = torch.from_numpy(pred)
target = torch.from_numpy(target)
smooth = 1.
num = pred.size(0)
m1 = pred.view(num, -1) # Flatten
m2 = target.view(num, -1) # Flatten
intersection = (m1 * m2).sum()
return (2. * intersection + smooth) / (m1.sum() + m2.sum() + smooth)
# 求Dice
def Dice(y, output):
y_pred = output.detach().cpu().numpy().copy() # (1, 3, 32, 160, 160)
y_pred = y_pred.squeeze(0)
pred_bg = y_pred[0] # (32, 160, 160)
pred_kidney = y_pred[1]
pred_tumor = y_pred[2]
y = make_one_hot_3d(y, 3) # one-hot处理
y_t = y.detach().cpu().numpy().copy() # (1, 3, 32, 160, 160)
y_t = y_t.squeeze(0)
y_bg = y_t[0]
y_kidney = y_t[1]
y_tumor = y_t[2]
dice_bg = dice_coeff(pred_bg, y_bg)
dice_kidney = dice_coeff(pred_kidney, y_kidney)
dice_tumor = dice_coeff(pred_tumor, y_tumor)
return dice_kidney, dice_tumor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment