Skip to content

Instantly share code, notes, and snippets.

@saravanabalagi
Last active April 10, 2024 22:29
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save saravanabalagi/1d8fe0701f3c6068882239cf651e477e to your computer and use it in GitHub Desktop.
Save saravanabalagi/1d8fe0701f3c6068882239cf651e477e to your computer and use it in GitHub Desktop.
Add GLU to Torch Pruning lib WIP

In dependency.py, add a new update fn in DependencyGraph

    def _update_glu_index_mapping(self, glu_node: Node):
        if glu_node.type != ops.OPTYPE.GLU:
            return

        # GLU halves the number of channels by applying sigmoid
        input_node = glu_node.inputs[0]
        in_channels = self.get_out_channels(input_node)
        out_channels = in_channels // 2

        # TODO: Need to check
        for i, in_node in enumerate(glu_node.inputs):
            for dep in in_node.dependencies:
                if dep.target == glu_node:
                    dep.index_mapping[0] = (_helpers._GLUIndexMapping(out_channels))

        # TODO: Need to check
        for i, out_node in enumerate(glu_node.outputs):
            for dep in out_node.dependencies:
                if dep.target == glu_node:
                    dep.index_mapping[0] = (_helpers._GLUIndexMapping(out_channels))

Call it when update_index_mapping is executed

    def update_index_mapping(self):
        for module, node in self.module2node.items():
            ...
            if node.type == ops.OPTYPE.GLU:
                self._update_glu_index_mapping(node)

Add GLU in ops.py

class GLUPruner(DummyPruner):
    pass    

# Standard Modules
TORCH_CONV = nn.modules.conv._ConvNd
...
TORCH_GLU = nn.GLU

class OPTYPE(IntEnum):
    CONV = 0
    ...
    GLU = 18  # nn.GLU
    
def module2type(module):
    ...
    elif isinstance(module, TORCH_GLU):
        return OPTYPE.GLU
        
def type2class(op_type):
    ...
    elif op_type == OPTYPE.GLU:
        return TORCH_GLU

Add GLU Index mapping in helpers:

class _GLUIndexMapping(object):
    def __init__(self, out_channels):
        self.out_channels = out_channels

    def __call__(self, idxs: _HybridIndex):
        # TOOD: Update this
        return [ _HybridIndex(idx=i.idx % self.out_channels, root_idx=i.root_idx) for i in idxs]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment