Skip to content

Instantly share code, notes, and snippets.

@PiotrCzapla
Last active April 8, 2018 09:19
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save PiotrCzapla/d06c82a362fddaf9ebb5181edcec0b66 to your computer and use it in GitHub Desktop.
Save PiotrCzapla/d06c82a362fddaf9ebb5181edcec0b66 to your computer and use it in GitHub Desktop.
fastai trasnforms.py v2

Here is an example how Transformations could look like in future:

class CoordTransform(Transform):
    def determ(self):
        return CoordAction(self, self.new_state())
    def new_state(self): return {}

class RandomRotate(CoordTransform):
    def __init__(self, deg, p=0.75, mode=cv2.BORDER_REFLECT):
        super().__init__()
        self.deg,self.p = deg,p

    def new_state(self):
        return dict(rdeg = rand0(self.deg),
                    rp = random.random() < self.p)

    def do(self, x, tfmtype, is_y, rdeg, rp):
        interpolation = cv2.INTER_NEAREST if is_y else cv2.INTER_AREA
        mode = cv2.BORDER_CONSTANT if tfmtype in (TfmType.COORD, TfmType.CLASS) else self.mode
        if rp: x = rotate_cv(x, rdeg, mode=mode, interpolation=interpolation)
        return x

    def undo(self, **kwargs): pass

class CenterCrop(CoordTransform):
    def __init__(self, sz, sz_y=None):
        super().__init__()
        self.min_sz,self.sz_y = sz,sz_y

    def do(self, x, tfmtype, is_y):
        return center_crop(x, self.sz_y if is_y else self.min_sz)

    def undo(self, x, tfmtype, is_y):
        pass

And the action class would look like this

class Action:
    def __init__(self, trans, state={}):
        self.trans,self.state = trans,state

    def __call__(self, x, tfmtype, is_y=False):
        if tfmtype == TfmType.NO:    return x
        return self.trans.do(x, tfmtype, is_y, **self.state)
    
class CoordAction(Action):
    # code below is mostly copy form CoordTransform
    def __call__(self, x, tfmtype, is_y=False):
        if not is_y: self.shape = x.shape  # used in subsequent transform_coord
        if tfmtype == TfmType.COORD: return self.trans.do_coord(x)
        return super().trans.do(x, tfmtype, is_y, **self.state)

    def map_y(self, y0):
        if self.shape is None:
            raise ValueError(
                "You cannot run TfmType.COORD transformation without running first transformation of x (is_y=False)")
        y = make_square(y0, self.shape)
        y_tr = self,trans.do(y, TfmType.PIXEL, True, **self.state)
        return to_bb(y_tr, y)

    def do_coord(self, y):
        yp = partition(y, 4)
        y2 = [self.map_y(y) for y in yp]
        return np.concatenate(y2)
    def __str__(self):
        return f"CoordAction({self.trans.__name__}, {self.state}))"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment