Skip to content

Instantly share code, notes, and snippets.

@thehesiod
Created August 29, 2018 12:42
Show Gist options
  • Star 0 You must be signed in to star a gist
  • Fork 0 You must be signed in to fork a gist
  • Save thehesiod/c51e3e39850763d4883772ac4b8435ca to your computer and use it in GitHub Desktop.
Save thehesiod/c51e3e39850763d4883772ac4b8435ca to your computer and use it in GitHub Desktop.
upvote aiohttp client
import asyncio
import argparse
import logging
import functools
# Third Party
from google.oauth2 import service_account
from google.oauth2 import _client
from google.auth import transport
from google.auth.transport import requests as gauth_requests
import google.auth.credentials
import aiohttp
import requests
import yarl
_OAUTH_SCOPES = [
'https://www.googleapis.com/auth/appengine.apis',
'https://www.googleapis.com/auth/userinfo.email',
]
class AuthorizedSession(aiohttp.ClientSession):
"""A aiohttp Session class with credentials.
This class is used to perform requests to API endpoints that require
authorization::
from google.auth.transport.aiohttp import AuthorizedSession
authed_session = AuthorizedSession(credentials)
response = await authed_session.request(
'GET', 'https://www.googleapis.com/storage/v1/b')
The underlying :meth:`request` implementation handles adding the
credentials' headers to the request and refreshing credentials as needed.
Args:
credentials (google.auth.credentials.Credentials): The credentials to
add to the request.
refresh_status_codes (Sequence[int]): Which HTTP status codes indicate
that credentials should be refreshed and the request should be
retried.
max_refresh_attempts (int): The maximum number of times to attempt to
refresh the credentials and retry the request.
refresh_timeout (Optional[int]): The timeout value in seconds for
credential refresh HTTP requests.
kwargs: Additional arguments passed to the :class:`requests.Session`
constructor.
"""
def __init__(self, credentials: google.auth.credentials.Credentials,
refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES,
max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS,
refresh_timeout=None,
logger = logging.Logger,
**kwargs):
super().__init__(**kwargs)
self._loop = asyncio.get_event_loop()
self.credentials = credentials
self._logger = logger
self._refresh_status_codes = refresh_status_codes
self._max_refresh_attempts = max_refresh_attempts
self._refresh_timeout = refresh_timeout
self._refresh_lock = asyncio.Lock()
self._auth_request: gauth_requests.Request = None
async def __aenter__(self):
self._auth_request_session = requests.Session().__enter__()
# Using an adapter to make HTTP requests robust to network errors.
# This adapter retrys HTTP requests when network errors occur
# and the requests seems safely retryable.
retry_adapter = requests.adapters.HTTPAdapter(max_retries=3)
self._auth_request_session.mount("https://", retry_adapter)
# Request instance used by internal methods (for example,
# credentials.refresh).
# Do not pass `self` as the session here, as it can lead to infinite
# recursion.
self._auth_request: gauth_requests.Request = gauth_requests.Request(self._auth_request_session)
return await super().__aenter__()
async def __aexit__(self, exc_type, exc_val, exc_tb):
try:
self._auth_request_session.__exit__(exc_type, exc_val, exc_tb)
finally:
await super().__aexit__(exc_type, exc_val, exc_tb)
async def _request(self, method: str, url: str, *, headers=None, **kwargs):
"""Implementation of Requests' request."""
# pylint: disable=arguments-differ
# Requests has a ton of arguments to request, but only two
# (method, url) are required. We pass through all of the other
# arguments to super, so no need to exhaustively list them here.
# Use a kwarg for this instead of an attribute to maintain
# thread-safety.
_credential_refresh_attempt = kwargs.pop(
'_credential_refresh_attempt', 0)
# Make a copy of the headers. They will be modified by the credentials
# and we want to pass the original headers if we recurse.
request_headers = headers.copy() if headers is not None else {}
async with self._refresh_lock:
await self._loop.run_in_executor(None, self.credentials.before_request,
self._auth_request, method, url, request_headers)
response = await super()._request(
method, url, headers=request_headers, **kwargs)
# If the response indicated that the credentials needed to be
# refreshed, then refresh the credentials and re-attempt the
# request.
# A stored token may expire between the time it is retrieved and
# the time the request is made, so we may need to try twice.
if (response.status in self._refresh_status_codes
and _credential_refresh_attempt < self._max_refresh_attempts):
self._logger.info(f'Refreshing credentials due to a {response.status} response. Attempt {credential_refresh_attempt + 1}/{self._max_refresh_attempts}.',)
auth_request_with_timeout = functools.partial(
self._auth_request, timeout=self._refresh_timeout)
async with self._refresh_lock:
await self._loop.run_in_executor(None, self.credentials.refresh, auth_request_with_timeout)
# Recurse. Pass in the original headers, not our modified set.
return await self._request(
method, url, headers=headers,
_credential_refresh_attempt=_credential_refresh_attempt + 1,
**kwargs)
return response
async def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description='Upvote service helper.')
# parser.add_argument('project_id', help='Your cloud project ID.')
parser.add_argument('endpoint', type=yarl.URL, help="Endpoint of upvote service. ex: https://upvote.appspot.com")
parser.add_argument('service_account_json', help='Path to service account json file.')
app_args = parser.parse_args()
credentials: service_account.Credentials = service_account.Credentials.from_service_account_file(app_args.service_account_json)
credentials: service_account.Credentials = credentials.with_scopes(_OAUTH_SCOPES)
async with AuthorizedSession(credentials) as session:
response = await session.get(app_args.endpoint / 'api/web/votes/query')
body = await response.read()
print(response)
if __name__ == '__main__':
asyncio.get_event_loop().run_until_complete(main())
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment