Last active
January 10, 2018 09:26
-
-
Save Roger-luo/0fd0301f06901fb19d97bdac12c8563d to your computer and use it in GitHub Desktop.
change PyTorch generic type name to whatever you want
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
"""Convert generic type name (real) in PyTorch project | |
to whatever you want. | |
Author: Roger-luo | |
""" | |
import os | |
import re | |
import sys | |
import shutil | |
import inspect | |
class Token(object): | |
"""token name in torch c source. | |
This class will parse the numerical type | |
for each token name. | |
""" | |
typenames = [ | |
'real', | |
'complex', | |
'ntype', | |
'double', | |
'float', | |
'zdouble', | |
'zfloat', | |
'int', | |
'long', | |
'short', | |
] | |
def __init__(self, text, meta=None): | |
self.str = text | |
# split by '_' and uppercase | |
src = [] | |
for each in re.split(r'(_)', text): | |
src.extend(re.split(r'([A-Z][a-z]+)', each)) | |
self.names = src | |
self.with_prefix('acc') | |
self.with_prefix('u') | |
for name, method in inspect.getmembers(self): | |
if name.startswith('init_'): | |
method() | |
def with_prefix(self, prefix): | |
src = [] | |
for each in self.names: | |
if each.lower() == prefix + 'real': | |
src.extend([each[:len(prefix)], each[len(prefix):]]) | |
else: | |
src.append(each) | |
self.names = src | |
@staticmethod | |
def pattern(text): | |
if text.istitle(): | |
return 'title' | |
elif text.isupper(): | |
return 'upper' | |
elif text.islower(): | |
return 'lower' | |
else: | |
raise ValueError("invalid text") | |
def init_dtype(self): | |
for ind, each in enumerate(self.names): | |
if each.lower() in self.typenames: | |
self._dtype = dict( | |
name=each.lower(), | |
pattern=self.pattern(each), | |
index=ind, | |
) | |
break | |
@property | |
def dtype(self): | |
out = getattr(self, '_dtype', None) | |
if out is not None: | |
return out['name'] | |
else: | |
return out | |
@dtype.setter | |
def dtype(self, val): | |
# check if dtype exists | |
if self._dtype is None: | |
raise ValueError("variable name does not have dtype") | |
# check type | |
lower = val.lower() | |
if lower not in self.typenames: | |
raise TypeError("Invalid type") | |
# match pattern | |
dtype = lower | |
if self._dtype['pattern'] == 'title': | |
dtype = dtype.title() | |
elif self._dtype['pattern'] == 'upper': | |
dtype = dtype.upper() | |
self._dtype['name'] = lower | |
self.names[self._dtype['index']] = dtype | |
def __repr__(self): | |
dtype = self.dtype if self.dtype is not None else 'none' | |
return "TOKEN{" + dtype + "}[" + ''.join(self.names) + "]" | |
def __str__(self): | |
return ''.join(self.names) | |
class THTokenName(object): | |
"""change names in torch c source | |
files from real to ntype | |
""" | |
rules = [ | |
re.compile(r'(;)'), # block | |
re.compile(r'(\()'), # inline parathesis | |
re.compile(r'(\))'), | |
re.compile(r'(\[)'), | |
re.compile(r'(\])'), | |
re.compile(r'({)'), | |
re.compile(r'(})'), | |
re.compile(r'(<)'), # CPP/CUDA specifier | |
re.compile(r'(>)'), | |
re.compile(r'(`)'), # markdown | |
re.compile(r'(\')'), # python string | |
re.compile(r'(")'), | |
re.compile(r'(#)'), # macro | |
re.compile(r'(\.)'), # operator | |
re.compile(r'(\\)'), | |
re.compile(r'(\w+)(\s*)(\*)'), # pointers | |
re.compile(r'(,)'), # commas | |
re.compile(r'(\s+)'), # spaces | |
re.compile(r'(_)'), # underlines | |
] | |
def manipulate(self, src): | |
for rule in self.rules: | |
out = [] | |
for each in src: | |
out.extend(re.split(rule, each)) | |
src = out | |
return out | |
def tokenize(self, src): | |
"""split source code following self.rules | |
and methods begin with split_, e.g split_block | |
def split_block(self, src): | |
out = [] | |
for each in src: | |
out.extend(re.split(r'(;)', each)) | |
return out | |
""" | |
tokens = [src] | |
tokens = self.manipulate(tokens) | |
for name, method in inspect.getmembers(self): | |
if name.startswith('split'): | |
tokens = method(tokens) | |
return tokens | |
def split_varname(self, src): | |
for ind, each in enumerate(src): | |
m = re.match(r'[A-Za-z]+', each) | |
if m is not None: | |
src[ind] = Token(each) | |
return src | |
class THComplexRename(THTokenName): | |
"""change real in torch source names | |
to num. | |
""" | |
static_src = 'torch/lib/' | |
c_src_dirs = [ | |
'TH', 'THC', 'THS', 'THCS', 'THD', 'ATen', | |
'THNN', 'THCUNN', | |
] | |
def __init__(self, src, target, | |
static_src=None, | |
c_src_dirs=None, | |
tname='ntype', | |
): | |
super(THComplexRename, self).__init__() | |
self.root = os.path.abspath(src) | |
self.target = os.path.abspath(target) | |
self.tname = tname | |
if static_src is not None: | |
self.static_src = static_src | |
if c_src_dirs is not None: | |
self.c_src_dirs = c_src_dirs | |
def rename_src(self, src): | |
tokens = self.tokenize(src) | |
out = [] | |
for each in tokens: | |
if isinstance(each, Token) and each.dtype == 'real': | |
each.dtype = self.tname | |
out.append(each) | |
return ''.join(str(each) for each in out) | |
def rename_file(self, path): | |
with open(path, 'r') as f: | |
raw = f.read() | |
return self.rename_src(raw) | |
def rename_dir(self, path): | |
src_path = os.path.join(self.root, path) | |
target_path = os.path.join(self.target, path) | |
# make target directory | |
os.makedirs(target_path, exist_ok=True) | |
# walk through source directory | |
for dirpath, dirnames, filenames in os.walk(src_path): | |
sub_dir_relpath = os.path.relpath(dirpath, src_path) | |
target_dir = os.path.join(target_path, sub_dir_relpath) | |
os.makedirs(target_dir, exist_ok=True) | |
for file in filenames: | |
msg = 'processing: %s' % os.path.join(dirpath, file) | |
print(msg) | |
with open(os.path.join(target_dir, file), 'w') as f: | |
f.write(self.rename_file(os.path.join(dirpath, file))) | |
def rename(self): | |
if os.path.isdir(self.target): | |
print("Warning: target dir exist\nrewrite?[y/n]:", end='') | |
if sys.stdin.read(1) == 'y': | |
shutil.rmtree(self.target) | |
else: | |
return | |
shutil.copytree(self.root, self.target) | |
# libraries | |
for each in self.c_src_dirs: | |
self.rename_dir(os.path.join(self.static_src, each)) | |
# torch csrc | |
self.rename_dir('torch/csrc') | |
# tools | |
self.rename_dir('tools') | |
# test | |
self.rename_dir('test') | |
# copy build script | |
shutil.copyfile( | |
os.path.join(self.root, self.static_src, 'build_libs.sh'), | |
os.path.join(self.target, self.static_src, 'build_libs.sh') | |
) | |
if __name__ == '__main__': | |
torch = THComplexRename( | |
'pytorch', # source dir | |
'complex', # target dir | |
tname='ntype' # target name in lowercase | |
) | |
torch.rename() | |
# torch.rename_dir('TH') |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment