-
-
Save edgarcosta/888fcf83565871afd635f641341cc14e to your computer and use it in GitHub Desktop.
Parse flint
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import sys, os, subprocess | |
def run_preparser(filename): | |
subprocess.check_call(["./configure"], stdout=subprocess.DEVNULL) | |
# replace some macros with typdefs | |
for a, b in [ | |
('ulong', 'mp_limb_t'), | |
('ulong', 'mp_limb_signed_t'), | |
('slong', 'mp_limb_signed_t'), | |
('flint_bitcnt_t', 'ulong'), | |
]: | |
d = f'#define {a} {b}' | |
t = f'typedef {b} {a};' | |
subprocess.check_call(['sed', '-i', f's/{d}/{t}/g', 'src/flint.h']) | |
# all the headers without longlong_*.h | |
headers = [h for h in subprocess.check_output(['make', 'print-HEADERS']).decode('utf-8').split('=')[1].strip('\n').split(' ') if not h.startswith('src/longlong_')] | |
# define some macros | |
defines = [f'-D{d}' for d in ['__asm(...)', '__asm__(...)', '__atribute__(x)', '_FILE_DEFINED', '_TIMEVAL_H', 'FLINT_NTL_INT_H']] | |
cflags = subprocess.check_output(["make", "print-CFLAGS"]).decode('utf-8').strip().split('=',1)[1].strip('\n').split(' ') | |
# call preparser | |
with open(filename, 'w') as W: | |
subprocess.check_call(['gcc'] + cflags + defines + ['-I../fake_libc_include', '-E'] + headers, stdout=W) | |
import re | |
def cstatements(s, strip=True): | |
parts = [] | |
current = '' | |
depth = 0 # Keep track of the depth of curly braces nesting | |
for char in s: | |
if char == '{': | |
depth += 1 | |
current += char | |
elif char == '}': | |
depth -= 1 | |
current += char | |
if depth == 0 and 'inline' in current: | |
parts.append(current) | |
current = '' | |
elif char == ';' and depth == 0: | |
current += char | |
parts.append(current) | |
current = '' | |
else: | |
current += char | |
if current: | |
parts.append(current) # Add the last part | |
if strip: | |
# replace newlines, double spaces, and spaces at the beginning | |
parts = [re.sub(r'^\s+', '', re.sub(r'\s+', ' ', elt.replace('\n', ' '))) for elt in parts] | |
return parts | |
from pycparser import c_parser, c_ast, c_generator | |
class FuncDefVisitor(c_ast.NodeVisitor): | |
def __init__(self): | |
self.functions = {} | |
def visit_FuncDecl(self, node): | |
func_name, ret_type = self.extract_decl(node.type) | |
# Extract the arguments | |
args = self.extract_args(node.args) | |
self.functions[func_name] = (ret_type, args, node) | |
# Print the function's details | |
# print(f"Function Name: {func_name}\nReturn Type: {ret_type}\nArguments: {args}\n") | |
def extract_decl(self, node, ptr_counter=0): | |
# Extract name and type | |
if isinstance(node, c_ast.PtrDecl): | |
decl_name, ret_type = self.extract_decl(node.type, ptr_counter=ptr_counter+1) | |
elif isinstance(node, c_ast.TypeDecl): | |
if isinstance(node.type, c_ast.Struct) or isinstance(node.type, c_ast.Enum): | |
ret_type = node.type.name | |
else: | |
ret_type = ' '.join(node.type.names) | |
ret_type += '*'*ptr_counter | |
decl_name = node.declname | |
elif isinstance(node, c_ast.FuncDecl): | |
decl_name, ret_type = self.extract_decl(node.type, ptr_counter=0) | |
args = self.extract_args(node.args) | |
types = ", ".join(elt[1] for elt in args) | |
ptr = '*'*ptr_counter + ' ' if ptr_counter > 0 else '' | |
decl_name = f'({ptr}{decl_name})({types})' | |
else: | |
ret_type = 'Unknown' | |
decl_name = 'Unknown' | |
return decl_name, ret_type | |
def extract_args(self, args): | |
# Extract function arguments | |
extracted_args = [] | |
if args: | |
for param in args.params: | |
if isinstance(param, c_ast.EllipsisParam): | |
arg_name, arg_type = "...", "..." | |
extracted_args.append((arg_name, arg_type)) | |
continue | |
arg_name, arg_type = self.extract_decl(param.type) | |
extracted_args.append((arg_name, arg_type)) | |
return extracted_args | |
from functools import cached_property | |
import os | |
class ParseFlintFunctions: | |
def __init__(self, postparsed_filename, doc_directory): | |
self.postparsed_filename = postparsed_filename | |
self.doc_directory = doc_directory | |
@cached_property | |
def headers_text(self): | |
with open(self.postparsed_filename) as F: | |
blocks = [] | |
for block in F.read().split("#"): | |
linesplit = [l for l in block.split('\n') if l] | |
if len(linesplit) > 1: | |
b = "\n".join(linesplit[1:]) | |
statements = cstatements(b) | |
if "src" not in linesplit[0]: | |
# remove non typedefs | |
statements = [x for x in statements if 'typedef' in x] | |
blocks.append("\n".join(statements)) | |
return "\n".join(blocks) | |
@cached_property | |
def parsed_source(self): | |
# Initialize the parser and parse the code | |
parser = c_parser.CParser() | |
ast = parser.parse(self.headers_text) | |
# we do not want to parse functions inside typedefs | |
ast.ext = [elt for elt in ast.ext if not isinstance(elt, c_ast.Typedef)] | |
return ast | |
@cached_property | |
def source_functions(self): | |
v = FuncDefVisitor() | |
v.visit(self.parsed_source) | |
return v.functions | |
@cached_property | |
def documentation_functions_statements(self): | |
# recognize a function definition in rst | |
is_func = re.compile(r"\.\.( )+(c:)?function( )*::") | |
def get_functions(filename): | |
""" | |
Get a list of functions from an rst file | |
""" | |
ret = [] | |
in_list = False | |
with open(filename, 'r') as file: | |
for line in file: | |
m = is_func.match(line) | |
if m: | |
ret.append(line[m.end():].strip()) | |
in_list = True | |
else: | |
if in_list: | |
if line.strip() == '': | |
in_list = False | |
else: | |
ret.append(line.strip()) | |
return ret | |
import os | |
p = os.path.join(self.doc_directory, 'source') | |
return sorted(sum([get_functions(os.path.join(p, f)) for f in os.listdir(p) if f.endswith('.rst')], [])) | |
@cached_property | |
def parsed_documentation(self): | |
k = len(self.parsed_source.ext) | |
# Initialize the parser and parse the code | |
parser = c_parser.CParser() | |
doc_fcns = ";\n".join(self.documentation_functions_statements) + ';' | |
new_txt = self.headers_text + "\n" + doc_fcns | |
ast = parser.parse(new_txt) | |
# we do not want to parse functions inside typedefs | |
ast.ext = [elt for elt in ast.ext if not isinstance(elt, c_ast.Typedef)] | |
# only keep new documentation functions | |
ast.ext = ast.ext[k:] | |
return ast | |
@cached_property | |
def documentation_functions(self): | |
v = FuncDefVisitor() | |
v.visit(self.parsed_documentation) | |
return v.functions | |
@cached_property | |
def missing_functions_in_documentation(self): | |
return sorted(list(set(self.source_functions).difference(self.documentation_functions))) | |
@cached_property | |
def missing_functions_in_source(self): | |
return sorted(list(set(self.documentation_functions).difference(self.source_functions))) | |
@cached_property | |
def mismatch_statement(self): | |
generator = c_generator.CGenerator() | |
c_code = generator.visit | |
return {k : v for k, v in | |
{elt: tuple(c_code(k[elt][2]) for k in [self.source_functions, self.documentation_functions]) | |
for elt in set(self.source_functions).intersection(self.documentation_functions) | |
}.items() | |
if v[0] != v[1] | |
} | |
os.chdir(os.path.join(os.environ['HOME'], 'flint')) | |
%time run_preparser('postparsed') | |
PLF = ParseFlintFunctions('postparsed', 'doc') | |
%time print(len(PLF.source_functions)) | |
%time print(len(PLF.documentation_functions)) | |
%time print(len(PLF.mismatch_statement)) | |
with open("fix_noninline_source.txt", "w") as W: | |
for func_name, (s, d) in PLF.mismatch_statement.items): | |
obj = [d[func_name] for d in [PLF.source_functions, PLF.documentation_functions]] | |
if type(obj[0][2].args) != type(obj[1][2].args): | |
mismatch_quals.add(func_name) | |
continue | |
elif obj[0][2].args: | |
assert obj[0][2].args | |
quals = [[elt.quals for elt in x[2].args.params if not isinstance(elt, c_ast.EllipsisParam) ] for x in obj] | |
if quals[0] != quals[1]: | |
mismatch_quals.add(func_name) | |
continue | |
rtype, args, ast = zip(*obj) | |
arg_types = [[x for _, x in y] for y in args] | |
arg_names = [[x for x, _ in y] for y in args] | |
if arg_names[0] != arg_names[1] and arg_types[0] == arg_types[1] and rtype[0] == rtype[1]: | |
code = [generator.visit(elt[2]) for elt in obj] | |
# remove return type | |
s_nor, d_nor = [elt.split(' ', 1)[1].lstrip('*') for elt in code] | |
if s_nor + ';' not in PLF.headers_text: | |
continue | |
W.write(f"grep -rlF '{s_nor};' src | grep 'h$' | xargs -n1 gsed -i 's/{s_nor};/{d_nor};/g'\n") | |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment