diff --git a/google/auth/transport/aiohttp_requests.py b/google/auth/transport/aiohttp_requests.py index 46816ea5e..29ecb3ad5 100644 --- a/google/auth/transport/aiohttp_requests.py +++ b/google/auth/transport/aiohttp_requests.py @@ -18,6 +18,7 @@ import asyncio import functools +import zlib import aiohttp import six @@ -31,6 +32,57 @@ _DEFAULT_TIMEOUT = 180 # in seconds +class _CombinedResponse(transport.Response): + """ + In order to more closely resemble the `requests` interface, where a raw + and deflated content could be accessed at once, this class lazily reads the + stream in `transport.Response` so both return forms can be used. + + The gzip and deflate transfer-encodings are automatically decoded for you + because the default parameter for autodecompress into the ClientSession is set + to False, and therefore we add this class to act as a wrapper for a user to be + able to access both the raw and decoded response bodies - mirroring the sync + implementation. + """ + + def __init__(self, response): + self._response = response + self._raw_content = None + + def _is_compressed(self): + headers = self._client_response.headers + return "Content-Encoding" in headers and ( + headers["Content-Encoding"] == "gzip" + or headers["Content-Encoding"] == "deflate" + ) + + @property + def status(self): + return self._response.status + + @property + def headers(self): + return self._response.headers + + @property + def data(self): + return self._response.content + + async def raw_content(self): + if self._raw_content is None: + self._raw_content = await self._response.content.read() + return self._raw_content + + async def content(self): + if self._raw_content is None: + self._raw_content = await self._response.content.read() + if self._is_compressed: + d = zlib.decompressobj(zlib.MAX_WBITS | 32) + decompressed = d.decompress(self._raw_content) + return decompressed + return self._raw_content + + class _Response(transport.Response): """ Requests transport response adapter. @@ -79,7 +131,6 @@ class Request(transport.Request): """ def __init__(self, session=None): - self.session = None async def __call__( @@ -89,7 +140,7 @@ async def __call__( body=None, headers=None, timeout=_DEFAULT_TIMEOUT, - **kwargs + **kwargs, ): """ Make an HTTP request using aiohttp. @@ -115,12 +166,14 @@ async def __call__( try: if self.session is None: # pragma: NO COVER - self.session = aiohttp.ClientSession() # pragma: NO COVER + self.session = aiohttp.ClientSession( + auto_decompress=False + ) # pragma: NO COVER requests._LOGGER.debug("Making request: %s %s", method, url) response = await self.session.request( method, url, data=body, headers=headers, timeout=timeout, **kwargs ) - return _Response(response) + return _CombinedResponse(response) except aiohttp.ClientError as caught_exc: new_exc = exceptions.TransportError(caught_exc) @@ -175,6 +228,7 @@ def __init__( max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, refresh_timeout=None, auth_request=None, + auto_decompress=False, ): super(AuthorizedSession, self).__init__() self.credentials = credentials @@ -186,6 +240,7 @@ def __init__( self._auth_request_session = None self._loop = asyncio.get_event_loop() self._refresh_lock = asyncio.Lock() + self._auto_decompress = auto_decompress async def request( self, @@ -195,7 +250,8 @@ async def request( headers=None, max_allowed_time=None, timeout=_DEFAULT_TIMEOUT, - **kwargs + auto_decompress=False, + **kwargs, ): """Implementation of Authorized Session aiohttp request. @@ -230,8 +286,17 @@ async def request( transmitted. The timout error will be raised after such request completes. """ - - async with aiohttp.ClientSession() as self._auth_request_session: + # Headers come in as bytes which isn't expected behavior, the resumable + # media libraries in some cases expect a str type for the header values, + # but sometimes the operations return these in bytes types. + if headers: + for key in headers.keys(): + if type(headers[key]) is bytes: + headers[key] = headers[key].decode("utf-8") + + async with aiohttp.ClientSession( + auto_decompress=self._auto_decompress + ) as self._auth_request_session: auth_request = Request(self._auth_request_session) self._auth_request = auth_request @@ -264,7 +329,7 @@ async def request( data=data, headers=request_headers, timeout=timeout, - **kwargs + **kwargs, ) remaining_time = guard.remaining_timeout @@ -307,7 +372,7 @@ async def request( max_allowed_time=remaining_time, timeout=timeout, _credential_refresh_attempt=_credential_refresh_attempt + 1, - **kwargs + **kwargs, ) return response diff --git a/google/auth/transport/mtls.py b/google/auth/transport/mtls.py index 5b742306b..b40bfbedf 100644 --- a/google/auth/transport/mtls.py +++ b/google/auth/transport/mtls.py @@ -86,9 +86,12 @@ def default_client_encrypted_cert_source(cert_path, key_path): def callback(): try: - _, cert_bytes, key_bytes, passphrase_bytes = _mtls_helper.get_client_ssl_credentials( - generate_encrypted_key=True - ) + ( + _, + cert_bytes, + key_bytes, + passphrase_bytes, + ) = _mtls_helper.get_client_ssl_credentials(generate_encrypted_key=True) with open(cert_path, "wb") as cert_file: cert_file.write(cert_bytes) with open(key_path, "wb") as key_file: diff --git a/google/oauth2/_client_async.py b/google/oauth2/_client_async.py index a6cc3b292..4817ea40e 100644 --- a/google/oauth2/_client_async.py +++ b/google/oauth2/_client_async.py @@ -104,7 +104,8 @@ async def _token_endpoint_request(request, token_uri, body): method="POST", url=token_uri, headers=headers, body=body ) - response_body1 = await response.data.read() + # Using data.read() resulted in zlib decompression errors. This may require future investigation. + response_body1 = await response.content() response_body = ( response_body1.decode("utf-8") diff --git a/google/oauth2/credentials_async.py b/google/oauth2/credentials_async.py index 2081a0be2..199b7eb56 100644 --- a/google/oauth2/credentials_async.py +++ b/google/oauth2/credentials_async.py @@ -61,7 +61,12 @@ async def refresh(self, request): "token_uri, client_id, and client_secret." ) - access_token, refresh_token, expiry, grant_response = await _client.refresh_grant( + ( + access_token, + refresh_token, + expiry, + grant_response, + ) = await _client.refresh_grant( request, self._token_uri, self._refresh_token, diff --git a/tests_async/oauth2/test__client_async.py b/tests_async/oauth2/test__client_async.py index c32a183a6..87982807f 100644 --- a/tests_async/oauth2/test__client_async.py +++ b/tests_async/oauth2/test__client_async.py @@ -63,6 +63,7 @@ def make_request(response_data, status=http_client.OK): data = json.dumps(response_data).encode("utf-8") response.data = mock.AsyncMock(spec=["__call__", "read"]) response.data.read = mock.AsyncMock(spec=["__call__"], return_value=data) + response.content = mock.AsyncMock(spec=["__call__"], return_value=data) request = mock.AsyncMock(spec=["transport.Request"]) request.return_value = response return request
Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.
Alternative Proxies: