Skip to content

Instantly share code, notes, and snippets.

@ScienceDuck
Last active September 12, 2019 14:52
Show Gist options
  • Star 1 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save ScienceDuck/5f263970776404da0f9e0e11e57ec01a to your computer and use it in GitHub Desktop.
Save ScienceDuck/5f263970776404da0f9e0e11e57ec01a to your computer and use it in GitHub Desktop.
Fastai v1 multiple images
def open_rcic_image(fn):
images = []
for i in range(6):
file_name = fn+str(i+1)+'.png'
im = cv2.imread(file_name)
im = cv2.cvtColor(im, cv2.COLOR_RGB2GRAY)
images.append(im)
image = np.dstack(images)
#print(pil2tensor(image, np.float32).shape)#.div_(255).shape)
return Image(pil2tensor(image, np.float32).div_(255))
class MultiChannelImageList(ImageList):
'''Read in RCIC 6 channel Images'''
def open(self, fn):
return open_rcic_image(fn)
class ImageTuple(ItemBase):
def __init__(self, img1, img2):
self.img1, self.img2 = img1, img2
self.obj, self.data = (img1, img2), [-1+2*img1.data, -1+2*img2.data]
def __str__(self): return str(self.obj)
def apply_tfms(self, tfms, **kwargs):
self.img1 = self.img1.apply_tfms(tfms, **kwargs)
self.img2 = self.img2.apply_tfms(tfms, **kwargs)
self.data = [-1+2*self.img1.data, -1+2*self.img2.data]
return self
def to_one(self): return Image(0.5+torch.cat(self.data, 2)/2)
class MultiChannelImageTupleList(MultiChannelImageList):
def __init__(self, items, itemsB=None, **kwargs):
super().__init__(items, **kwargs)
self.itemsB = itemsB
self.copy_new.append('itemsB')
def get(self, i):
img1 = super().get(i)
fn = self.itemsB[i]
return ImageTuple(img1, open_rcic_image(fn))
def reconstruct(self, t:Tensor):
return ImageTuple(Image(t[0]/2+0.5), Image(t[1]/2+0.5))
@classmethod
def from_dfs(cls, df, path, cols=[0,1], **kwargs):
itemsB = MultiChannelImageList.from_df(df=df.iloc[:,[cols[1],-1]], path=path).items
res = super().from_df(df=df.iloc[:,[cols[0],-1]], path=path, itemsB=itemsB, **kwargs)
res.path = path
return res
def show_xys(self, xs, ys, figsize:Tuple[int,int]=(12,6), **kwargs):
"Show the `xs` and `ys` on a figure of `figsize`. `kwargs` are passed to the show method."
rows = int(math.sqrt(len(xs)))
fig, axs = plt.subplots(rows,rows,figsize=figsize)
for i, ax in enumerate(axs.flatten() if rows > 1 else [axs]):
xs[i].to_one().show(ax=ax, **kwargs)
plt.tight_layout()
def show_xyzs(self, xs, ys, zs, figsize:Tuple[int,int]=None, **kwargs):
"""Show `xs` (inputs), `ys` (targets) and `zs` (predictions) on a figure of `figsize`.
`kwargs` are passed to the show method."""
figsize = ifnone(figsize, (12,3*len(xs)))
fig,axs = plt.subplots(len(xs), 2, figsize=figsize)
fig.suptitle('Ground truth / Predictions', weight='bold', size=14)
for i,(x,z) in enumerate(zip(xs,zs)):
x.to_one().show(ax=axs[i,0], **kwargs)
z.to_one().show(ax=axs[i,1], **kwargs)
data = (MultiChannelImageTupleList.from_dfs(proc_df, data_folder+'/train/')
.split_by_idx(val_idx)
.label_from_df()
.transform(tfms, size=size)
.databunch(bs=bs, num_workers=0)
# .normalize(stats)
)
def _normalize_tuple_batch(b:Tuple[Tensor,Tensor], mean:FloatTensor, std:FloatTensor, do_x:bool=True, do_y:bool=False)->Tuple[Tensor,Tensor]:
"`b` = `x`,`y` - normalize `x` array of imgs and `do_y` optionally `y`."
x,y = b
mean,std = mean.to(x[0].device),std.to(x[0].device)
if do_x: x = (normalize(x[0],mean,std), normalize(x[1],mean,std))
if do_y and len(y.shape) == 4: y = normalize(y,mean,std)
return x,y
def denormalize_tuple(x:Tuple[Tensor,Tensor], mean:FloatTensor,std:FloatTensor, do_x:bool=True)->TensorImage:
"Denormalize `x` tuple with `mean` and `std`."
if do_x:
return (x[0].cpu().float()*std[...,None,None] + mean[...,None,None],
x[1].cpu().float()*std[...,None,None] + mean[...,None,None])
else:
return (x[0].cpu(), x[1].cpu())
def normalize_funcs_tuple(mean:FloatTensor, std:FloatTensor, do_x:bool=True, do_y:bool=False)->Tuple[Callable,Callable]:
mean,std = tensor(mean),tensor(std)
return (partial(_normalize_tuple_batch, mean=mean, std=std, do_x=do_x, do_y=do_y),
partial(denormalize_tuple, mean=mean, std=std, do_x=do_x))
data.norm,data.denorm = normalize_funcs_tuple(*stats)
data.add_tfm(data.norm)
data.show_batch()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment