Last active
November 24, 2017 07:21
-
-
Save goldengrape/a3cdb08642bd4d1b1fb81eec211f7f4b to your computer and use it in GitHub Desktop.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
# coding: utf-8 | |
# # 给所有函数修饰 | |
# | |
# Numba是一个python的加速器, 最简单的加速方式仅仅是在导入numba以后, 在函数定义之前增加@jit. | |
# | |
# 本程序是为了测试简单增加或者消除所有函数前@jit | |
# In[1]: | |
import os | |
import re | |
# # 打开文件 | |
# In[2]: | |
input_filename='add_new_JIT.ipynb' | |
output_filename='remove_new_JIT.ipynb' | |
_,ext=os.path.splitext(input_filename) | |
with open(input_filename,'rt') as f_input: | |
f_content=f_input.read() | |
# # 依照模式增加或移除 | |
# 使用了正则表达式```'(\n)(\s*)(def)'```, 因为不能确定def之前的缩进有多少, 只知道肯定之前是由换行的. (当然也有可能什么也不导入, 直接就定义函数的py程序, 但那样也太罕见了了吧) | |
# | |
# 正则表达式还不熟练, 不知道```r'(\n)(\s*)(def)'```找到以后如何用group来拆分. 所以干脆取巧, 反正中间的缩进部分是要重复两遍的, 不妨就先把整体重复两遍, 然后再替换掉其中一个 | |
# | |
# In[3]: | |
def add_pattern(text,prefix,target_word,add_string): | |
target_pattern=re.compile(prefix+target_word) | |
def add_core(m): | |
s=m.group() | |
new=s+'\n'+s | |
return (re.sub(target_word+'\n',add_string,new)) | |
return (target_pattern.sub(add_core,text)) | |
def remove_pattern(text,prefix,target_word): | |
target_pattern=re.compile(prefix+target_word) | |
def remove_core(m): | |
s=m.group() | |
return "" | |
return (target_pattern.sub(remove_core,text)) | |
# # 增加/去除@jit | |
# | |
# * add_jit: 在每一个def之前添加@git | |
# * remove_jit: 将每个单行的@jit去除 | |
# In[4]: | |
def add_jit(text): | |
if ext=='.py': | |
prefix='(\n)(\s*)' | |
add_numba='from numba import jit' | |
add_numba_jit='@jit' | |
elif ext=='.ipynb': | |
prefix='(\n)(\s*)(\")(\s*)' | |
add_numba='from numba import jit", ' | |
add_numba_jit='@jit", ' | |
text = add_pattern(text,prefix,'import numpy as np',add_numba) | |
text = add_pattern(text,prefix,'def',add_numba_jit) | |
text = text.replace('jit"','jit\\n"') #此处用re.sub总是会把\n给翻译掉, 试过多种方式 | |
return text | |
def remove_jit(text): | |
if ext=='.py': | |
prefix='(\n)(\s*)' | |
add_numba='from numba import jit' | |
add_numba_jit='@jit' | |
elif ext=='.ipynb': | |
prefix='(\s*)' | |
add_numba='from numba import jit' | |
add_numba_jit='@jit' | |
text = remove_pattern(text,prefix,add_numba) | |
text = remove_pattern(text,prefix,add_numba_jit) | |
return text | |
jit_added =add_jit(f_content) | |
jit_removed=remove_jit(f_content) | |
# # 写入文件 | |
# In[5]: | |
with open(output_filename,'wt') as f_output: | |
f_output.write(jit_removed) | |
# In[ ]: | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment