Skip to content

Instantly share code, notes, and snippets.

@Roger-luo
Last active January 10, 2018 09:26
Show Gist options
  • Save Roger-luo/0fd0301f06901fb19d97bdac12c8563d to your computer and use it in GitHub Desktop.
Save Roger-luo/0fd0301f06901fb19d97bdac12c8563d to your computer and use it in GitHub Desktop.
change PyTorch generic type name to whatever you want
"""Convert generic type name (real) in PyTorch project
to whatever you want.
Author: Roger-luo
"""
import os
import re
import sys
import shutil
import inspect
class Token(object):
"""token name in torch c source.
This class will parse the numerical type
for each token name.
"""
typenames = [
'real',
'complex',
'ntype',
'double',
'float',
'zdouble',
'zfloat',
'int',
'long',
'short',
]
def __init__(self, text, meta=None):
self.str = text
# split by '_' and uppercase
src = []
for each in re.split(r'(_)', text):
src.extend(re.split(r'([A-Z][a-z]+)', each))
self.names = src
self.with_prefix('acc')
self.with_prefix('u')
for name, method in inspect.getmembers(self):
if name.startswith('init_'):
method()
def with_prefix(self, prefix):
src = []
for each in self.names:
if each.lower() == prefix + 'real':
src.extend([each[:len(prefix)], each[len(prefix):]])
else:
src.append(each)
self.names = src
@staticmethod
def pattern(text):
if text.istitle():
return 'title'
elif text.isupper():
return 'upper'
elif text.islower():
return 'lower'
else:
raise ValueError("invalid text")
def init_dtype(self):
for ind, each in enumerate(self.names):
if each.lower() in self.typenames:
self._dtype = dict(
name=each.lower(),
pattern=self.pattern(each),
index=ind,
)
break
@property
def dtype(self):
out = getattr(self, '_dtype', None)
if out is not None:
return out['name']
else:
return out
@dtype.setter
def dtype(self, val):
# check if dtype exists
if self._dtype is None:
raise ValueError("variable name does not have dtype")
# check type
lower = val.lower()
if lower not in self.typenames:
raise TypeError("Invalid type")
# match pattern
dtype = lower
if self._dtype['pattern'] == 'title':
dtype = dtype.title()
elif self._dtype['pattern'] == 'upper':
dtype = dtype.upper()
self._dtype['name'] = lower
self.names[self._dtype['index']] = dtype
def __repr__(self):
dtype = self.dtype if self.dtype is not None else 'none'
return "TOKEN{" + dtype + "}[" + ''.join(self.names) + "]"
def __str__(self):
return ''.join(self.names)
class THTokenName(object):
"""change names in torch c source
files from real to ntype
"""
rules = [
re.compile(r'(;)'), # block
re.compile(r'(\()'), # inline parathesis
re.compile(r'(\))'),
re.compile(r'(\[)'),
re.compile(r'(\])'),
re.compile(r'({)'),
re.compile(r'(})'),
re.compile(r'(<)'), # CPP/CUDA specifier
re.compile(r'(>)'),
re.compile(r'(`)'), # markdown
re.compile(r'(\')'), # python string
re.compile(r'(")'),
re.compile(r'(#)'), # macro
re.compile(r'(\.)'), # operator
re.compile(r'(\\)'),
re.compile(r'(\w+)(\s*)(\*)'), # pointers
re.compile(r'(,)'), # commas
re.compile(r'(\s+)'), # spaces
re.compile(r'(_)'), # underlines
]
def manipulate(self, src):
for rule in self.rules:
out = []
for each in src:
out.extend(re.split(rule, each))
src = out
return out
def tokenize(self, src):
"""split source code following self.rules
and methods begin with split_, e.g split_block
def split_block(self, src):
out = []
for each in src:
out.extend(re.split(r'(;)', each))
return out
"""
tokens = [src]
tokens = self.manipulate(tokens)
for name, method in inspect.getmembers(self):
if name.startswith('split'):
tokens = method(tokens)
return tokens
def split_varname(self, src):
for ind, each in enumerate(src):
m = re.match(r'[A-Za-z]+', each)
if m is not None:
src[ind] = Token(each)
return src
class THComplexRename(THTokenName):
"""change real in torch source names
to num.
"""
static_src = 'torch/lib/'
c_src_dirs = [
'TH', 'THC', 'THS', 'THCS', 'THD', 'ATen',
'THNN', 'THCUNN',
]
def __init__(self, src, target,
static_src=None,
c_src_dirs=None,
tname='ntype',
):
super(THComplexRename, self).__init__()
self.root = os.path.abspath(src)
self.target = os.path.abspath(target)
self.tname = tname
if static_src is not None:
self.static_src = static_src
if c_src_dirs is not None:
self.c_src_dirs = c_src_dirs
def rename_src(self, src):
tokens = self.tokenize(src)
out = []
for each in tokens:
if isinstance(each, Token) and each.dtype == 'real':
each.dtype = self.tname
out.append(each)
return ''.join(str(each) for each in out)
def rename_file(self, path):
with open(path, 'r') as f:
raw = f.read()
return self.rename_src(raw)
def rename_dir(self, path):
src_path = os.path.join(self.root, path)
target_path = os.path.join(self.target, path)
# make target directory
os.makedirs(target_path, exist_ok=True)
# walk through source directory
for dirpath, dirnames, filenames in os.walk(src_path):
sub_dir_relpath = os.path.relpath(dirpath, src_path)
target_dir = os.path.join(target_path, sub_dir_relpath)
os.makedirs(target_dir, exist_ok=True)
for file in filenames:
msg = 'processing: %s' % os.path.join(dirpath, file)
print(msg)
with open(os.path.join(target_dir, file), 'w') as f:
f.write(self.rename_file(os.path.join(dirpath, file)))
def rename(self):
if os.path.isdir(self.target):
print("Warning: target dir exist\nrewrite?[y/n]:", end='')
if sys.stdin.read(1) == 'y':
shutil.rmtree(self.target)
else:
return
shutil.copytree(self.root, self.target)
# libraries
for each in self.c_src_dirs:
self.rename_dir(os.path.join(self.static_src, each))
# torch csrc
self.rename_dir('torch/csrc')
# tools
self.rename_dir('tools')
# test
self.rename_dir('test')
# copy build script
shutil.copyfile(
os.path.join(self.root, self.static_src, 'build_libs.sh'),
os.path.join(self.target, self.static_src, 'build_libs.sh')
)
if __name__ == '__main__':
torch = THComplexRename(
'pytorch', # source dir
'complex', # target dir
tname='ntype' # target name in lowercase
)
torch.rename()
# torch.rename_dir('TH')
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment