Skip to content

Instantly share code, notes, and snippets.

@harupy
Created August 18, 2023 09:54
Show Gist options
  • Save harupy/7483a548d2033b74e89c3c64007ce98c to your computer and use it in GitHub Desktop.
Save harupy/7483a548d2033b74e89c3c64007ce98c to your computer and use it in GitHub Desktop.
from __future__ import annotations
import ast
import os
import random
import subprocess
import textwrap
import openai
class DocstringVisitor(ast.NodeVisitor):
def __init__(self):
self.docstring_nodes = []
def visit_FunctionDef(self, node: ast.FunctionDef):
if (
node.body
and isinstance(node.body[0], ast.Expr)
and isinstance(node.body[0].value, ast.Str)
and ":param" in node.body[0].value.s
and not node.name.startswith("_")
):
self.docstring_nodes.append(node.body[0].value)
def transform(docstring: str) -> str:
res = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{
"role": "user",
"content": f"""
Hi GPT4, I'd like you to rewrite python docstrings in a more readable format. Here's an example:
# Before
```python
\"\"\"
This is a docstring
:param artifact_path: The run-relative path to which to log model artifacts.
:param custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings) to
custom classes or functions associated with the Keras model. MLflow saves
...
:return: This is a return value.
...a
\"\"\"
```
# After (similar to google docstrings, but no need to add types)
```python
\"\"\"
This is a docstring
Args:
artifact_path: The run-relative path to which to log model artifacts.
custom_objects: A Keras ``custom_objects`` dictionary mapping names (strings)
to custom classes or functions associated with the Keras model. MLflow saves
...
Returns:
This is a return value.
...
\"\"\"
```
# Transformation Rules:
- Be sure to prserve the indentation of the original docstring.
- Be sure to preserve the quotes of the original docstring.
- Be sure to avoid the line length exceeding 100 characters.
- Be sure to only update the parameters and returns sections.
- The Returns section should is optional. If the original docstring doesn't have
':return:' or ':returns:' entries, then don't add a 'Returns' section.
- Be sure to use the following format for the new docstring:
```python
{{new_docstring}}
```
Given these rules, can you rewrite the following docstring? Thanks for your help!
```python
{docstring}
```
""",
}
],
)
return res.choices[0].message.content
def node_to_char_range(docstring_node: ast.Str, line_lengths: list[int]) -> tuple[int, int]:
start = sum(line_lengths[: docstring_node.lineno - 1]) + docstring_node.col_offset
node_length = (
(line_lengths[docstring_node.lineno - 1] - docstring_node.col_offset)
+ sum(line_lengths[docstring_node.lineno : docstring_node.end_lineno - 1])
+ docstring_node.end_col_offset
)
return start, start + node_length
def extract_code(s: str) -> str | None:
import re
if m := re.search(r"```python\n(.*)```", s, re.DOTALL):
return m.group(1)
return None
def format_code(code: str, indent: str, opening_quote: str, closing_quote: str) -> str:
code = code.strip().lstrip('r"\n').rstrip('" \n')
code = textwrap.dedent(code)
code = textwrap.indent(code, indent)
code = f"{opening_quote}\n{code}\n{indent}{closing_quote}"
return code
def leading_quote(s: str) -> str:
for idx, c in enumerate(s):
if c not in ("'", '"', "f", "r"):
return s[:idx]
raise ValueError("No leading quote found")
def trailing_quote(s: str) -> str:
for idx, c in enumerate(s[::-1]):
if c not in ("'", '"'):
return s[-idx:]
raise ValueError("No leading quote found")
def main():
assert "OPENAI_API_KEY" in os.environ
py_files = subprocess.check_output(["git", "ls-files", "mlflow/*.py"]).decode().splitlines()
random.shuffle(py_files)
for py_file in py_files:
with open(py_file) as f:
src = f.read()
tree = ast.parse(src)
visitor = DocstringVisitor()
visitor.visit(tree)
if not visitor.docstring_nodes:
continue
lines = src.splitlines(keepends=True)
line_lengths = list(map(len, lines))
new_src = str(src)
offset = 0
for node in visitor.docstring_nodes:
print(f"Transforming {py_file}:{node.lineno}:{node.col_offset + 1}")
start, end = node_to_char_range(node, line_lengths)
indent = " " * node.col_offset
original = src[start:end]
transformed = transform(indent + original)
code = extract_code(transformed)
if code is None:
continue
code = format_code(
code,
indent,
leading_quote(original),
trailing_quote(original),
)
original_length = end - start
new_src = new_src[: (start + offset)] + code + new_src[(end + offset) :]
offset += len(code) - original_length
with open(py_file, "w") as f:
f.write(new_src)
if __name__ == "__main__":
main()
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment