Created
August 29, 2018 12:42
-
-
Save thehesiod/c51e3e39850763d4883772ac4b8435ca to your computer and use it in GitHub Desktop.
upvote aiohttp client
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
import asyncio | |
import argparse | |
import logging | |
import functools | |
# Third Party | |
from google.oauth2 import service_account | |
from google.oauth2 import _client | |
from google.auth import transport | |
from google.auth.transport import requests as gauth_requests | |
import google.auth.credentials | |
import aiohttp | |
import requests | |
import yarl | |
_OAUTH_SCOPES = [ | |
'https://www.googleapis.com/auth/appengine.apis', | |
'https://www.googleapis.com/auth/userinfo.email', | |
] | |
class AuthorizedSession(aiohttp.ClientSession): | |
"""A aiohttp Session class with credentials. | |
This class is used to perform requests to API endpoints that require | |
authorization:: | |
from google.auth.transport.aiohttp import AuthorizedSession | |
authed_session = AuthorizedSession(credentials) | |
response = await authed_session.request( | |
'GET', 'https://www.googleapis.com/storage/v1/b') | |
The underlying :meth:`request` implementation handles adding the | |
credentials' headers to the request and refreshing credentials as needed. | |
Args: | |
credentials (google.auth.credentials.Credentials): The credentials to | |
add to the request. | |
refresh_status_codes (Sequence[int]): Which HTTP status codes indicate | |
that credentials should be refreshed and the request should be | |
retried. | |
max_refresh_attempts (int): The maximum number of times to attempt to | |
refresh the credentials and retry the request. | |
refresh_timeout (Optional[int]): The timeout value in seconds for | |
credential refresh HTTP requests. | |
kwargs: Additional arguments passed to the :class:`requests.Session` | |
constructor. | |
""" | |
def __init__(self, credentials: google.auth.credentials.Credentials, | |
refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, | |
max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, | |
refresh_timeout=None, | |
logger = logging.Logger, | |
**kwargs): | |
super().__init__(**kwargs) | |
self._loop = asyncio.get_event_loop() | |
self.credentials = credentials | |
self._logger = logger | |
self._refresh_status_codes = refresh_status_codes | |
self._max_refresh_attempts = max_refresh_attempts | |
self._refresh_timeout = refresh_timeout | |
self._refresh_lock = asyncio.Lock() | |
self._auth_request: gauth_requests.Request = None | |
async def __aenter__(self): | |
self._auth_request_session = requests.Session().__enter__() | |
# Using an adapter to make HTTP requests robust to network errors. | |
# This adapter retrys HTTP requests when network errors occur | |
# and the requests seems safely retryable. | |
retry_adapter = requests.adapters.HTTPAdapter(max_retries=3) | |
self._auth_request_session.mount("https://", retry_adapter) | |
# Request instance used by internal methods (for example, | |
# credentials.refresh). | |
# Do not pass `self` as the session here, as it can lead to infinite | |
# recursion. | |
self._auth_request: gauth_requests.Request = gauth_requests.Request(self._auth_request_session) | |
return await super().__aenter__() | |
async def __aexit__(self, exc_type, exc_val, exc_tb): | |
try: | |
self._auth_request_session.__exit__(exc_type, exc_val, exc_tb) | |
finally: | |
await super().__aexit__(exc_type, exc_val, exc_tb) | |
async def _request(self, method: str, url: str, *, headers=None, **kwargs): | |
"""Implementation of Requests' request.""" | |
# pylint: disable=arguments-differ | |
# Requests has a ton of arguments to request, but only two | |
# (method, url) are required. We pass through all of the other | |
# arguments to super, so no need to exhaustively list them here. | |
# Use a kwarg for this instead of an attribute to maintain | |
# thread-safety. | |
_credential_refresh_attempt = kwargs.pop( | |
'_credential_refresh_attempt', 0) | |
# Make a copy of the headers. They will be modified by the credentials | |
# and we want to pass the original headers if we recurse. | |
request_headers = headers.copy() if headers is not None else {} | |
async with self._refresh_lock: | |
await self._loop.run_in_executor(None, self.credentials.before_request, | |
self._auth_request, method, url, request_headers) | |
response = await super()._request( | |
method, url, headers=request_headers, **kwargs) | |
# If the response indicated that the credentials needed to be | |
# refreshed, then refresh the credentials and re-attempt the | |
# request. | |
# A stored token may expire between the time it is retrieved and | |
# the time the request is made, so we may need to try twice. | |
if (response.status in self._refresh_status_codes | |
and _credential_refresh_attempt < self._max_refresh_attempts): | |
self._logger.info(f'Refreshing credentials due to a {response.status} response. Attempt {credential_refresh_attempt + 1}/{self._max_refresh_attempts}.',) | |
auth_request_with_timeout = functools.partial( | |
self._auth_request, timeout=self._refresh_timeout) | |
async with self._refresh_lock: | |
await self._loop.run_in_executor(None, self.credentials.refresh, auth_request_with_timeout) | |
# Recurse. Pass in the original headers, not our modified set. | |
return await self._request( | |
method, url, headers=headers, | |
_credential_refresh_attempt=_credential_refresh_attempt + 1, | |
**kwargs) | |
return response | |
async def main(): | |
logging.basicConfig(level=logging.INFO) | |
parser = argparse.ArgumentParser(description='Upvote service helper.') | |
# parser.add_argument('project_id', help='Your cloud project ID.') | |
parser.add_argument('endpoint', type=yarl.URL, help="Endpoint of upvote service. ex: https://upvote.appspot.com") | |
parser.add_argument('service_account_json', help='Path to service account json file.') | |
app_args = parser.parse_args() | |
credentials: service_account.Credentials = service_account.Credentials.from_service_account_file(app_args.service_account_json) | |
credentials: service_account.Credentials = credentials.with_scopes(_OAUTH_SCOPES) | |
async with AuthorizedSession(credentials) as session: | |
response = await session.get(app_args.endpoint / 'api/web/votes/query') | |
body = await response.read() | |
print(response) | |
if __name__ == '__main__': | |
asyncio.get_event_loop().run_until_complete(main()) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment