Last active
June 25, 2020 21:17
-
-
Save sgibson91/44a1c3a6bbf34257dbdbb621a98dab0d to your computer and use it in GitHub Desktop.
Use the GitHub API to update the base branch of open Pull Requests
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
""" | |
Script to pull a list of open Pull Requests for a GitHub repository and change | |
the base branch using the GitHub API. For this script to run successfully, you | |
must have write access to the target repository. If the repository belongs to | |
an organisation, then you must be a member of that organisation with write | |
access. | |
This script requires a GitHub Personal Access Token. To update Pull Requests in | |
public repos, the token only needs the `public_repo` scope. To update Pull | |
Requests in private repos, the token needs the full `repo` scope. | |
Author: Sarah Gibson (https://github.com/sgibson91) | |
Python version: >= 3.7 (developed with 3.8) | |
Package Requirements: requests | |
>>> pip install requests | |
""" | |
import sys | |
import json | |
import getpass | |
import logging | |
import requests | |
import argparse | |
# Setup log config | |
logging.basicConfig( | |
level=logging.DEBUG, | |
format="[%(asctime)s %(levelname)s] %(message)s", | |
datefmt="%Y-%m-%d %H:%M:%S", | |
) | |
def parse_args(args): | |
"""Parse command line arguments | |
Args: | |
args (list): Command line arguments | |
Returns: | |
(Namespace object): The parsed command line arguments | |
""" | |
parser = argparse.ArgumentParser( | |
description="Update the base branch of open Pull Requests in a GitHub repository" | |
) | |
# Required arguments | |
parser.add_argument( | |
"repo_owner", type=str, help="The owner of the GitHub repository", | |
) | |
parser.add_argument( | |
"repo_name", type=str, help="The name of the GitHub repository", | |
) | |
# Optional arguments with defaults | |
parser.add_argument( | |
"-o", | |
"--old-branch", | |
type=str, | |
default="master", | |
help="The old base branch name", | |
) | |
parser.add_argument( | |
"-n", "--new-branch", type=str, default="main", help="The new base branch name", | |
) | |
return parser.parse_args() | |
def construct_api_url(repo_owner, repo_name): | |
"""Construct the API URL of the GitHub repository | |
Args: | |
repo_owner (str): The owner account (or organisation) of the target | |
repository | |
repo_name (str): The target repository | |
Returns: | |
(str): The GitHub API URL of the target repository | |
""" | |
return "https://api.github.com/repos/%s/%s/" % (repo_owner, repo_name) | |
def get_request(url, header, params=None): | |
"""Perform a HTTP GET request to a URL with headers and additional parameters. | |
Where responses are paginated (return with extra links), all results are | |
compiled from all pages. | |
Args: | |
url (str): The URL to make the GET request to | |
header (dict): The headers to be sent with the request | |
params (dict, optional): Any additional queries to send with the | |
request. Defaults to None. | |
Returns: | |
(dict): The JSON content response from the request | |
""" | |
resp = requests.get(url, headers=header, params=params) | |
if resp: | |
if resp.links: | |
full_resp = resp.json() | |
while "next" in resp.links.keys(): | |
resp = requests.get( | |
resp.links["next"]["url"], headers=header, params=params | |
) | |
full_resp.extend(resp.json()) | |
return full_resp | |
else: | |
return resp.json() | |
else: | |
logging.error("GET request to following URL failed: %s" % url) | |
sys.exit(1) | |
def patch_request(url, json, headers): | |
"""Make a HTTP PATCH request to a URL with a JSON body and headers | |
Args: | |
url (str): The URL to make the request to | |
json (dict): The body to send with the request | |
headers (dict): The headers to send with the request | |
""" | |
return requests.patch(url, headers=headers, json=json) | |
def get_open_prs(repo_url, old_branch, token): | |
"""Compile a list of all the open Pull Requests in the target repository | |
Args: | |
repo_url (str): The GitHub API URL for the target repository | |
old_branch (str): The target branch we want to change | |
token (str): A GitHub API token to authenticate with | |
Returns: | |
(list of ints): A list of all open Pull Request numbers | |
""" | |
# Construct request accompaniments | |
header = {"Authorization": "token %s" % token} | |
params = { | |
"state": "open", | |
"base": old_branch, | |
} | |
# Return open Pull Requests | |
pr_list = get_request(repo_url + "pulls", header, params=params) | |
logging.info("Total number of Pull Requests: %s" % len(pr_list)) | |
# Compile a list of Pull Request numbers | |
pr_nums = [] | |
for pr in pr_list: | |
pr_nums.append(pr["number"]) | |
return pr_nums | |
def get_pr_details(repo_url, pr_num, token): | |
"""Compile the required information for updating a Pull Request | |
Args: | |
repo_url (str): The GitHub API URL for the target repository | |
pr_num (int): The number of the Pull Request to collect info for | |
token (str): A GitHub API token to authenticate with | |
Returns: | |
(dict): Title and body of the Pull Request | |
""" | |
header = {"Authorization": "token %s" % token} | |
pr_info = get_request(repo_url + "pulls/" + str(pr_num), header) | |
details = { | |
"title": pr_info["title"], | |
"body": pr_info["body"], | |
} | |
return details | |
def update_pr(repo_url, pr_num, payload, new_branch, token): | |
"""Use a HTTP PATCH request to update an open Pull Request | |
Args: | |
repo_url (str): The GitHub API URL for the target repository | |
pr_num (int): The number Pull Request to be updated | |
payload (dict): The info the Pull Request is to be updated with | |
new_branch (str): The new base branch of the Pull Request | |
token (str): A GitHub API token to authenticate with | |
""" | |
# Add additional info to the payload | |
payload["state"] = "open" | |
payload["maintainers_can_edit"] = True | |
payload["base"] = new_branch # This is redifining the base branch of the PR | |
header = {"Authorization": "token %s" % token} | |
resp = patch_request(repo_url + "pulls/" + str(pr_num), payload, header) | |
if resp.ok: | |
logging.info("Pull Request #%s successfully updated" % pr_num) | |
else: | |
logging.warning("Pull Request #%s was not updated\n%s" % (pr_num, resp)) | |
def main(): | |
"""Main function""" | |
args = parse_args(sys.argv[1:]) | |
repo_url = construct_api_url(args.repo_owner, args.repo_name) | |
api_token = getpass.getpass("Please provide a GitHub API token: ") | |
open_prs = get_open_prs(repo_url, args.old_branch, api_token) | |
logging.info("Number of open Pull Requests to be updated: %s" % len(open_prs)) | |
for pr in open_prs: | |
payload = get_pr_details(repo_url, pr, api_token) | |
update_pr(repo_url, pr, payload, args.new_branch, api_token) | |
if __name__ == "__main__": | |
main() |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment