Created
March 4, 2022 09:01
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
#coding:utf-8 | |
#该代码适用于从测试集中读取原图,进行拼接。 | |
import numpy as np | |
import torch | |
import math | |
import SimpleITK as sitk | |
def generate_test_locations(image, patch_size, stride): | |
ww,hh,dd = image.shape | |
sz = math.ceil((ww - patch_size[0]) / stride[0]) + 1 | |
sx = math.ceil((hh - patch_size[1]) / stride[1]) + 1 | |
sy = math.ceil((dd - patch_size[2]) / stride[2]) + 1 | |
return (sz,sx,sy),(ww,hh,dd) | |
def infer_tumorandliver(model, ct_array_nor, cube_shape=(129, 512, 512)): | |
patch_size = cube_shape # 送入网络时的patch大小 | |
patch_stride = [60, 256, 256] # 重叠部位大小 | |
locations, image_shape = generate_test_locations(ct_array_nor, patch_size, patch_stride) # 生成步长与图像shape | |
print('location', locations, image_shape) | |
image = np.zeros((1,) + (ct_array_nor.shape)).astype(np.float32) # 生成与原图对应大小的全0体积图像,用于保存预测结果图 | |
seg = np.zeros((ct_array_nor.shape)).astype(np.float32)# 生成与原图对应大小的全0体积图像,用于除去重叠部位 | |
print('image shape', image.shape) | |
for z in range(0, locations[0]): | |
zs = min(patch_stride[0] * z, image_shape[0] - patch_size[0]) | |
for x in range(0, locations[1]): | |
xs = min(patch_stride[1] * x, image_shape[1] - patch_size[1]) | |
for y in range(0, locations[2]): | |
ys = min(patch_stride[2] * y, image_shape[2] - patch_size[2]) | |
patch = ct_array_nor[zs:zs + patch_size[0], | |
xs:xs + patch_size[1], | |
ys:ys + patch_size[2]] | |
# print('patch',patch) | |
patch = np.expand_dims(np.expand_dims(patch, axis=0), axis=0).astype(np.float32) | |
# 适用于深度学习预测 | |
# patch_tensor = torch.from_numpy(patch).cuda() | |
# output = model(patch_tensor) | |
# output = output.cpu().data.numpy() | |
# 适用于正常拼接 | |
patch_tensor = torch.from_numpy(patch).cuda() | |
output = patch_tensor.cpu().data.numpy() | |
image[:, zs:zs + patch_size[0], xs:xs + patch_size[1], ys:ys + patch_size[2]] \ | |
= image[:, zs:zs + patch_size[0], xs:xs + patch_size[1], ys:ys + patch_size[2]] + output[0, 0, | |
:, :, :] | |
seg[zs:zs + patch_size[0], xs:xs + patch_size[1], ys:ys + patch_size[2]] \ | |
= seg[zs:zs + patch_size[0], xs:xs + patch_size[1], ys:ys + patch_size[2]] + 1 | |
image = image / np.expand_dims(seg, axis=0) | |
image = np.squeeze(image) | |
# 可以对图像进行窗宽窗位调整 | |
image[image<50] = 0 | |
image[image>400] =400 | |
mask_pred_containers = image | |
return mask_pred_containers | |
if __name__ == "__main__": | |
image_path = r"D:\MyData\3Dircadb1_fusion_date\train\image\image_6.nii" | |
result_path = r"D:\MyData\3Dircadb1_fusion_date\train\image_6.nii" | |
image = sitk.ReadImage(image_path) | |
image_array = sitk.GetArrayFromImage(image) | |
print("image_array:",image_array.shape) | |
result = infer_tumorandliver(model=None,ct_array_nor=image_array) | |
predict_seg = sitk.GetImageFromArray(result) | |
predict_seg.SetDirection(image.GetDirection()) | |
predict_seg.SetOrigin(image.GetOrigin()) | |
predict_seg.SetSpacing(image.GetSpacing()) | |
sitk.WriteImage(predict_seg, result_path) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment