Skip to content

Instantly share code, notes, and snippets.

@edgarcosta
Created June 28, 2024 18:45
Show Gist options
  • Save edgarcosta/888fcf83565871afd635f641341cc14e to your computer and use it in GitHub Desktop.
Save edgarcosta/888fcf83565871afd635f641341cc14e to your computer and use it in GitHub Desktop.
Parse flint
import sys, os, subprocess
def run_preparser(filename):
subprocess.check_call(["./configure"], stdout=subprocess.DEVNULL)
# replace some macros with typdefs
for a, b in [
('ulong', 'mp_limb_t'),
('ulong', 'mp_limb_signed_t'),
('slong', 'mp_limb_signed_t'),
('flint_bitcnt_t', 'ulong'),
]:
d = f'#define {a} {b}'
t = f'typedef {b} {a};'
subprocess.check_call(['sed', '-i', f's/{d}/{t}/g', 'src/flint.h'])
# all the headers without longlong_*.h
headers = [h for h in subprocess.check_output(['make', 'print-HEADERS']).decode('utf-8').split('=')[1].strip('\n').split(' ') if not h.startswith('src/longlong_')]
# define some macros
defines = [f'-D{d}' for d in ['__asm(...)', '__asm__(...)', '__atribute__(x)', '_FILE_DEFINED', '_TIMEVAL_H', 'FLINT_NTL_INT_H']]
cflags = subprocess.check_output(["make", "print-CFLAGS"]).decode('utf-8').strip().split('=',1)[1].strip('\n').split(' ')
# call preparser
with open(filename, 'w') as W:
subprocess.check_call(['gcc'] + cflags + defines + ['-I../fake_libc_include', '-E'] + headers, stdout=W)
import re
def cstatements(s, strip=True):
parts = []
current = ''
depth = 0 # Keep track of the depth of curly braces nesting
for char in s:
if char == '{':
depth += 1
current += char
elif char == '}':
depth -= 1
current += char
if depth == 0 and 'inline' in current:
parts.append(current)
current = ''
elif char == ';' and depth == 0:
current += char
parts.append(current)
current = ''
else:
current += char
if current:
parts.append(current) # Add the last part
if strip:
# replace newlines, double spaces, and spaces at the beginning
parts = [re.sub(r'^\s+', '', re.sub(r'\s+', ' ', elt.replace('\n', ' '))) for elt in parts]
return parts
from pycparser import c_parser, c_ast, c_generator
class FuncDefVisitor(c_ast.NodeVisitor):
def __init__(self):
self.functions = {}
def visit_FuncDecl(self, node):
func_name, ret_type = self.extract_decl(node.type)
# Extract the arguments
args = self.extract_args(node.args)
self.functions[func_name] = (ret_type, args, node)
# Print the function's details
# print(f"Function Name: {func_name}\nReturn Type: {ret_type}\nArguments: {args}\n")
def extract_decl(self, node, ptr_counter=0):
# Extract name and type
if isinstance(node, c_ast.PtrDecl):
decl_name, ret_type = self.extract_decl(node.type, ptr_counter=ptr_counter+1)
elif isinstance(node, c_ast.TypeDecl):
if isinstance(node.type, c_ast.Struct) or isinstance(node.type, c_ast.Enum):
ret_type = node.type.name
else:
ret_type = ' '.join(node.type.names)
ret_type += '*'*ptr_counter
decl_name = node.declname
elif isinstance(node, c_ast.FuncDecl):
decl_name, ret_type = self.extract_decl(node.type, ptr_counter=0)
args = self.extract_args(node.args)
types = ", ".join(elt[1] for elt in args)
ptr = '*'*ptr_counter + ' ' if ptr_counter > 0 else ''
decl_name = f'({ptr}{decl_name})({types})'
else:
ret_type = 'Unknown'
decl_name = 'Unknown'
return decl_name, ret_type
def extract_args(self, args):
# Extract function arguments
extracted_args = []
if args:
for param in args.params:
if isinstance(param, c_ast.EllipsisParam):
arg_name, arg_type = "...", "..."
extracted_args.append((arg_name, arg_type))
continue
arg_name, arg_type = self.extract_decl(param.type)
extracted_args.append((arg_name, arg_type))
return extracted_args
from functools import cached_property
import os
class ParseFlintFunctions:
def __init__(self, postparsed_filename, doc_directory):
self.postparsed_filename = postparsed_filename
self.doc_directory = doc_directory
@cached_property
def headers_text(self):
with open(self.postparsed_filename) as F:
blocks = []
for block in F.read().split("#"):
linesplit = [l for l in block.split('\n') if l]
if len(linesplit) > 1:
b = "\n".join(linesplit[1:])
statements = cstatements(b)
if "src" not in linesplit[0]:
# remove non typedefs
statements = [x for x in statements if 'typedef' in x]
blocks.append("\n".join(statements))
return "\n".join(blocks)
@cached_property
def parsed_source(self):
# Initialize the parser and parse the code
parser = c_parser.CParser()
ast = parser.parse(self.headers_text)
# we do not want to parse functions inside typedefs
ast.ext = [elt for elt in ast.ext if not isinstance(elt, c_ast.Typedef)]
return ast
@cached_property
def source_functions(self):
v = FuncDefVisitor()
v.visit(self.parsed_source)
return v.functions
@cached_property
def documentation_functions_statements(self):
# recognize a function definition in rst
is_func = re.compile(r"\.\.( )+(c:)?function( )*::")
def get_functions(filename):
"""
Get a list of functions from an rst file
"""
ret = []
in_list = False
with open(filename, 'r') as file:
for line in file:
m = is_func.match(line)
if m:
ret.append(line[m.end():].strip())
in_list = True
else:
if in_list:
if line.strip() == '':
in_list = False
else:
ret.append(line.strip())
return ret
import os
p = os.path.join(self.doc_directory, 'source')
return sorted(sum([get_functions(os.path.join(p, f)) for f in os.listdir(p) if f.endswith('.rst')], []))
@cached_property
def parsed_documentation(self):
k = len(self.parsed_source.ext)
# Initialize the parser and parse the code
parser = c_parser.CParser()
doc_fcns = ";\n".join(self.documentation_functions_statements) + ';'
new_txt = self.headers_text + "\n" + doc_fcns
ast = parser.parse(new_txt)
# we do not want to parse functions inside typedefs
ast.ext = [elt for elt in ast.ext if not isinstance(elt, c_ast.Typedef)]
# only keep new documentation functions
ast.ext = ast.ext[k:]
return ast
@cached_property
def documentation_functions(self):
v = FuncDefVisitor()
v.visit(self.parsed_documentation)
return v.functions
@cached_property
def missing_functions_in_documentation(self):
return sorted(list(set(self.source_functions).difference(self.documentation_functions)))
@cached_property
def missing_functions_in_source(self):
return sorted(list(set(self.documentation_functions).difference(self.source_functions)))
@cached_property
def mismatch_statement(self):
generator = c_generator.CGenerator()
c_code = generator.visit
return {k : v for k, v in
{elt: tuple(c_code(k[elt][2]) for k in [self.source_functions, self.documentation_functions])
for elt in set(self.source_functions).intersection(self.documentation_functions)
}.items()
if v[0] != v[1]
}
os.chdir(os.path.join(os.environ['HOME'], 'flint'))
%time run_preparser('postparsed')
PLF = ParseFlintFunctions('postparsed', 'doc')
%time print(len(PLF.source_functions))
%time print(len(PLF.documentation_functions))
%time print(len(PLF.mismatch_statement))
with open("fix_noninline_source.txt", "w") as W:
for func_name, (s, d) in PLF.mismatch_statement.items):
obj = [d[func_name] for d in [PLF.source_functions, PLF.documentation_functions]]
if type(obj[0][2].args) != type(obj[1][2].args):
mismatch_quals.add(func_name)
continue
elif obj[0][2].args:
assert obj[0][2].args
quals = [[elt.quals for elt in x[2].args.params if not isinstance(elt, c_ast.EllipsisParam) ] for x in obj]
if quals[0] != quals[1]:
mismatch_quals.add(func_name)
continue
rtype, args, ast = zip(*obj)
arg_types = [[x for _, x in y] for y in args]
arg_names = [[x for x, _ in y] for y in args]
if arg_names[0] != arg_names[1] and arg_types[0] == arg_types[1] and rtype[0] == rtype[1]:
code = [generator.visit(elt[2]) for elt in obj]
# remove return type
s_nor, d_nor = [elt.split(' ', 1)[1].lstrip('*') for elt in code]
if s_nor + ';' not in PLF.headers_text:
continue
W.write(f"grep -rlF '{s_nor};' src | grep 'h$' | xargs -n1 gsed -i 's/{s_nor};/{d_nor};/g'\n")
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment