diff --git a/firebase_admin/_retry.py b/firebase_admin/_retry.py new file mode 100644 index 00000000..ef330cbd --- /dev/null +++ b/firebase_admin/_retry.py @@ -0,0 +1,224 @@ +# Copyright 2025 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Internal retry logic module + +This module provides utilities for adding retry logic to HTTPX requests +""" + +from __future__ import annotations +import copy +import email.utils +import random +import re +import time +from types import CoroutineType +from typing import Any, Callable, List, Optional, Tuple +import logging +import asyncio +import httpx + +logger = logging.getLogger(__name__) + + +class HttpxRetry: + """HTTPX based retry config""" + # Status codes to be used for respecting `Retry-After` header + RETRY_AFTER_STATUS_CODES = frozenset([413, 429, 503]) + + # Default maximum backoff time. + DEFAULT_BACKOFF_MAX = 120 + + def __init__( + self, + max_retries: int = 10, + status_forcelist: Optional[List[int]] = None, + backoff_factor: float = 0, + backoff_max: float = DEFAULT_BACKOFF_MAX, + backoff_jitter: float = 0, + history: Optional[List[Tuple[ + httpx.Request, + Optional[httpx.Response], + Optional[Exception] + ]]] = None, + respect_retry_after_header: bool = False, + ) -> None: + self.retries_left = max_retries + self.status_forcelist = status_forcelist + self.backoff_factor = backoff_factor + self.backoff_max = backoff_max + self.backoff_jitter = backoff_jitter + if history: + self.history = history + else: + self.history = [] + self.respect_retry_after_header = respect_retry_after_header + + def copy(self) -> HttpxRetry: + """Creates a deep copy of this instance.""" + return copy.deepcopy(self) + + def is_retryable_response(self, response: httpx.Response) -> bool: + """Determine if a response implies that the request should be retried if possible.""" + if self.status_forcelist and response.status_code in self.status_forcelist: + return True + + has_retry_after = bool(response.headers.get("Retry-After")) + if ( + self.respect_retry_after_header + and has_retry_after + and response.status_code in self.RETRY_AFTER_STATUS_CODES + ): + return True + + return False + + def is_exhausted(self) -> bool: + """Determine if there are anymore more retires.""" + # retries_left is negative + return self.retries_left < 0 + + # Identical implementation of `urllib3.Retry.parse_retry_after()` + def _parse_retry_after(self, retry_after_header: str) -> float | None: + """Parses Retry-After string into a float with unit seconds.""" + seconds: float + # Whitespace: https://tools.ietf.org/html/rfc7230#section-3.2.4 + if re.match(r"^\s*[0-9]+\s*$", retry_after_header): + seconds = int(retry_after_header) + else: + retry_date_tuple = email.utils.parsedate_tz(retry_after_header) + if retry_date_tuple is None: + raise httpx.RemoteProtocolError(f"Invalid Retry-After header: {retry_after_header}") + + retry_date = email.utils.mktime_tz(retry_date_tuple) + seconds = retry_date - time.time() + + seconds = max(seconds, 0) + + return seconds + + def get_retry_after(self, response: httpx.Response) -> float | None: + """Determine the Retry-After time needed before sending the next request.""" + retry_after_header = response.headers.get('Retry-After', None) + if retry_after_header: + # Convert retry header to a float in seconds + return self._parse_retry_after(retry_after_header) + return None + + def get_backoff_time(self): + """Determine the backoff time needed before sending the next request.""" + # attempt_count is the number of previous request attempts + attempt_count = len(self.history) + # Backoff should be set to 0 until after first retry. + if attempt_count <= 1: + return 0 + backoff = self.backoff_factor * (2 ** (attempt_count-1)) + if self.backoff_jitter: + backoff += random.random() * self.backoff_jitter + return float(max(0, min(self.backoff_max, backoff))) + + async def sleep_for_backoff(self) -> None: + """Determine and wait the backoff time needed before sending the next request.""" + backoff = self.get_backoff_time() + logger.debug('Sleeping for backoff of %f seconds following failed request', backoff) + await asyncio.sleep(backoff) + + async def sleep(self, response: httpx.Response) -> None: + """Determine and wait the time needed before sending the next request.""" + if self.respect_retry_after_header: + retry_after = self.get_retry_after(response) + if retry_after: + logger.debug( + 'Sleeping for Retry-After header of %f seconds following failed request', + retry_after + ) + await asyncio.sleep(retry_after) + return + await self.sleep_for_backoff() + + def increment( + self, + request: httpx.Request, + response: Optional[httpx.Response] = None, + error: Optional[Exception] = None + ) -> None: + """Update the retry state based on request attempt.""" + self.retries_left -= 1 + self.history.append((request, response, error)) + + +class HttpxRetryTransport(httpx.AsyncBaseTransport): + """HTTPX transport with retry logic.""" + + DEFAULT_RETRY = HttpxRetry(max_retries=4, status_forcelist=[500, 503], backoff_factor=0.5) + + def __init__(self, retry: HttpxRetry = DEFAULT_RETRY, **kwargs) -> None: + self._retry = retry + + transport_kwargs = kwargs.copy() + transport_kwargs.update({'retries': 0, 'http2': True}) + # We use a full AsyncHTTPTransport under the hood that is already + # set up to handle requests. We also insure that that transport's internal + # retries are not allowed. + self._wrapped_transport = httpx.AsyncHTTPTransport(**transport_kwargs) + + async def handle_async_request(self, request: httpx.Request) -> httpx.Response: + return await self._dispatch_with_retry( + request, self._wrapped_transport.handle_async_request) + + async def _dispatch_with_retry( + self, + request: httpx.Request, + dispatch_method: Callable[[httpx.Request], CoroutineType[Any, Any, httpx.Response]] + ) -> httpx.Response: + """Sends a request with retry logic using a provided dispatch method.""" + # This request config is used across all requests that use this transport and therefore + # needs to be copied to be used for just this request and it's retries. + retry = self._retry.copy() + # First request + response, error = None, None + + while not retry.is_exhausted(): + + # First retry + if response: + await retry.sleep(response) + + # Need to reset here so only last attempt's error or response is saved. + response, error = None, None + + try: + logger.debug('Sending request in _dispatch_with_retry(): %r', request) + response = await dispatch_method(request) + logger.debug('Received response: %r', response) + except httpx.HTTPError as err: + logger.debug('Received error: %r', err) + error = err + + if response and not retry.is_retryable_response(response): + return response + + if error: + raise error + + retry.increment(request, response, error) + + if response: + return response + if error: + raise error + raise AssertionError('_dispatch_with_retry() ended with no response or exception') + + async def aclose(self) -> None: + await self._wrapped_transport.aclose() diff --git a/firebase_admin/_utils.py b/firebase_admin/_utils.py index b6e29254..765d1158 100644 --- a/firebase_admin/_utils.py +++ b/firebase_admin/_utils.py @@ -16,9 +16,11 @@ import json from platform import python_version +from typing import Callable, Optional import google.auth import requests +import httpx import firebase_admin from firebase_admin import exceptions @@ -128,6 +130,36 @@ def handle_platform_error_from_requests(error, handle_func=None): return exc if exc else _handle_func_requests(error, message, error_dict) +def handle_platform_error_from_httpx( + error: httpx.HTTPError, + handle_func: Optional[Callable[..., Optional[exceptions.FirebaseError]]] = None +) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given httpx error. + + This can be used to handle errors returned by Google Cloud Platform (GCP) APIs. + + Args: + error: An error raised by the httpx module while making an HTTP call to a GCP API. + handle_func: A function that can be used to handle platform errors in a custom way. When + specified, this function will be called with three arguments. It has the same + signature as ```_handle_func_httpx``, but may return ``None``. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + + if isinstance(error, httpx.HTTPStatusError): + response = error.response + content = response.content.decode() + status_code = response.status_code + error_dict, message = _parse_platform_error(content, status_code) + exc = None + if handle_func: + exc = handle_func(error, message, error_dict) + + return exc if exc else _handle_func_httpx(error, message, error_dict) + return handle_httpx_error(error) + def handle_operation_error(error): """Constructs a ``FirebaseError`` from the given operation error. @@ -204,6 +236,60 @@ def handle_requests_error(error, message=None, code=None): err_type = _error_code_to_exception_type(code) return err_type(message=message, cause=error, http_response=error.response) +def _handle_func_httpx(error: httpx.HTTPError, message, error_dict) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given GCP error. + + Args: + error: An error raised by the httpx module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError``. + error_dict: Parsed GCP error response. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code or None. + """ + code = error_dict.get('status') + return handle_httpx_error(error, message, code) + + +def handle_httpx_error(error: httpx.HTTPError, message=None, code=None) -> exceptions.FirebaseError: + """Constructs a ``FirebaseError`` from the given httpx error. + + This method is agnostic of the remote service that produced the error, whether it is a GCP + service or otherwise. Therefore, this method does not attempt to parse the error response in + any way. + + Args: + error: An error raised by the httpx module while making an HTTP call. + message: A message to be included in the resulting ``FirebaseError`` (optional). If not + specified the string representation of the ``error`` argument is used as the message. + code: A GCP error code that will be used to determine the resulting error type (optional). + If not specified the HTTP status code on the error response is used to determine a + suitable error code. + + Returns: + FirebaseError: A ``FirebaseError`` that can be raised to the user code. + """ + if isinstance(error, httpx.TimeoutException): + return exceptions.DeadlineExceededError( + message='Timed out while making an API call: {0}'.format(error), + cause=error) + if isinstance(error, httpx.ConnectError): + return exceptions.UnavailableError( + message='Failed to establish a connection: {0}'.format(error), + cause=error) + if isinstance(error, httpx.HTTPStatusError): + print("printing status error", error) + if not code: + code = _http_status_to_error_code(error.response.status_code) + if not message: + message = str(error) + + err_type = _error_code_to_exception_type(code) + return err_type(message=message, cause=error, http_response=error.response) + + return exceptions.UnknownError( + message='Unknown error while making a remote service call: {0}'.format(error), + cause=error) def _http_status_to_error_code(status): """Maps an HTTP status to a platform error code.""" diff --git a/firebase_admin/messaging.py b/firebase_admin/messaging.py index d2ad04a0..7f747db1 100644 --- a/firebase_admin/messaging.py +++ b/firebase_admin/messaging.py @@ -14,22 +14,34 @@ """Firebase Cloud Messaging module.""" +from __future__ import annotations +from typing import Callable, List, Optional import concurrent.futures import json import warnings +import asyncio +import logging import requests +import httpx +from google.auth import credentials +from google.auth.transport import requests as auth_requests from googleapiclient import http from googleapiclient import _auth import firebase_admin -from firebase_admin import _http_client -from firebase_admin import _messaging_encoder -from firebase_admin import _messaging_utils -from firebase_admin import _gapic_utils -from firebase_admin import _utils -from firebase_admin import exceptions - +from firebase_admin import ( + _http_client, + _messaging_encoder, + _messaging_utils, + _gapic_utils, + _utils, + exceptions, + App +) +from firebase_admin._retry import HttpxRetryTransport + +logger = logging.getLogger(__name__) _MESSAGING_ATTRIBUTE = '_messaging' @@ -66,6 +78,7 @@ 'send_all', 'send_multicast', 'send_each', + 'send_each_async', 'send_each_for_multicast', 'subscribe_to_topic', 'unsubscribe_from_topic', @@ -97,10 +110,10 @@ UnregisteredError = _messaging_utils.UnregisteredError -def _get_messaging_service(app): +def _get_messaging_service(app: Optional[App]) -> _MessagingService: return _utils.get_app_service(app, _MESSAGING_ATTRIBUTE, _MessagingService) -def send(message, dry_run=False, app=None): +def send(message, dry_run=False, app: Optional[App] = None): """Sends the given message via Firebase Cloud Messaging (FCM). If the ``dry_run`` mode is enabled, the message will not be actually delivered to the @@ -140,6 +153,9 @@ def send_each(messages, dry_run=False, app=None): """ return _get_messaging_service(app).send_each(messages, dry_run) +async def send_each_async(messages, dry_run=True, app: Optional[App] = None) -> BatchResponse: + return await _get_messaging_service(app).send_each_async(messages, dry_run) + def send_each_for_multicast(multicast_message, dry_run=False, app=None): """Sends the given mutlicast message to each token via Firebase Cloud Messaging (FCM). @@ -321,21 +337,21 @@ def errors(self): class BatchResponse: """The response received from a batch request to the FCM API.""" - def __init__(self, responses): + def __init__(self, responses: List[SendResponse]) -> None: self._responses = responses self._success_count = len([resp for resp in responses if resp.success]) @property - def responses(self): + def responses(self) -> List[SendResponse]: """A list of ``messaging.SendResponse`` objects (possibly empty).""" return self._responses @property - def success_count(self): + def success_count(self) -> int: return self._success_count @property - def failure_count(self): + def failure_count(self) -> int: return len(self.responses) - self.success_count @@ -363,6 +379,53 @@ def exception(self): """A ``FirebaseError`` if an error occurs while sending the message to the FCM service.""" return self._exception +class GoogleAuthCredentialFlow(httpx.Auth): + """Google Auth Credential Auth Flow""" + def __init__(self, credential: credentials.Credentials): + self._credential = credential + self._max_refresh_attempts = 2 + self._refresh_status_codes = (401,) + + def apply_auth_headers(self, request: httpx.Request): + # Build request used to refresh credentials if needed + auth_request = auth_requests.Request() + # This refreshes the credentials if needed and mutates the request headers to + # contain access token and any other google auth headers + self._credential.before_request(auth_request, request.method, request.url, request.headers) + + + def auth_flow(self, request: httpx.Request): + # Keep original headers since `credentials.before_request` mutates the passed headers and we + # want to keep the original in cause we need an auth retry. + _original_headers = request.headers.copy() + + _credential_refresh_attempt = 0 + while _credential_refresh_attempt <= self._max_refresh_attempts: + # copy original headers + request.headers = _original_headers.copy() + # mutates request headers + logger.debug( + 'Refreshing credentials for request attempt %d', + _credential_refresh_attempt + 1) + self.apply_auth_headers(request) + + # Continue to perform the request + # yield here dispatches the request and returns with the response + response: httpx.Response = yield request + + # We can check the result of the response and determine in we need to retry + # on refreshable status codes. Current transport.requests.AuthorizedSession() + # only does this on 401 errors. We should do the same. + if response.status_code in self._refresh_status_codes: + logger.debug( + 'Request attempt %d failed due to unauthorized credentials', + _credential_refresh_attempt + 1) + _credential_refresh_attempt += 1 + else: + break + # Last yielded response is auto returned. + + class _MessagingService: """Service class that implements Firebase Cloud Messaging (FCM) functionality.""" @@ -381,7 +444,7 @@ class _MessagingService: 'UNREGISTERED': UnregisteredError, } - def __init__(self, app): + def __init__(self, app) -> None: project_id = app.project_id if not project_id: raise ValueError( @@ -396,6 +459,12 @@ def __init__(self, app): timeout = app.options.get('httpTimeout', _http_client.DEFAULT_TIMEOUT_SECONDS) self._credential = app.credential.get_credential() self._client = _http_client.JsonHttpClient(credential=self._credential, timeout=timeout) + self._async_client = httpx.AsyncClient( + http2=True, + auth=GoogleAuthCredentialFlow(self._credential), + timeout=timeout, + transport=HttpxRetryTransport() + ) self._build_transport = _auth.authorized_http @classmethod @@ -448,6 +517,40 @@ def send_data(data): message='Unknown error while making remote service calls: {0}'.format(error), cause=error) + async def send_each_async(self, messages: List[Message], dry_run: bool = True) -> BatchResponse: + """Sends the given messages to FCM via the FCM v1 API.""" + if not isinstance(messages, list): + raise ValueError('messages must be a list of messaging.Message instances.') + if len(messages) > 500: + raise ValueError('messages must not contain more than 500 elements.') + + async def send_data(data): + try: + resp = await self._async_client.request( + 'post', + url=self._fcm_url, + headers=self._fcm_headers, + json=data) + # HTTP/2 check + if resp.http_version != 'HTTP/2': + raise Exception('This messages was not sent with HTTP/2') + resp.raise_for_status() + # except httpx.HTTPStatusError as exception: + except httpx.HTTPError as exception: + return SendResponse(resp=None, exception=self._handle_fcm_httpx_error(exception)) + else: + return SendResponse(resp.json(), exception=None) + + message_data = [self._message_data(message, dry_run) for message in messages] + try: + responses = await asyncio.gather(*[send_data(message) for message in message_data]) + return BatchResponse(responses) + except Exception as error: + raise exceptions.UnknownError( + message='Unknown error while making remote service calls: {0}'.format(error), + cause=error) + + def send_all(self, messages, dry_run=False): """Sends the given messages to FCM via the batch API.""" if not isinstance(messages, list): @@ -533,6 +636,11 @@ def _handle_fcm_error(self, error): return _utils.handle_platform_error_from_requests( error, _MessagingService._build_fcm_error_requests) + def _handle_fcm_httpx_error(self, error: httpx.HTTPError) -> exceptions.FirebaseError: + """Handles errors received from the FCM API.""" + return _utils.handle_platform_error_from_httpx( + error, _MessagingService._build_fcm_error_httpx) + def _handle_iid_error(self, error): """Handles errors received from the Instance ID API.""" if error.response is None: @@ -562,6 +670,9 @@ def _handle_batch_error(self, error): return _gapic_utils.handle_platform_error_from_googleapiclient( error, _MessagingService._build_fcm_error_googleapiclient) + def close(self) -> None: + asyncio.run(self._async_client.aclose()) + @classmethod def _build_fcm_error_requests(cls, error, message, error_dict): """Parses an error response from the FCM API and creates a FCM-specific exception if @@ -569,6 +680,19 @@ def _build_fcm_error_requests(cls, error, message, error_dict): exc_type = cls._build_fcm_error(error_dict) return exc_type(message, cause=error, http_response=error.response) if exc_type else None + @classmethod + def _build_fcm_error_httpx( + cls, error: httpx.HTTPError, message, error_dict + ) -> Optional[exceptions.FirebaseError]: + """Parses a httpx error response from the FCM API and creates a FCM-specific exception if + appropriate.""" + exc_type = cls._build_fcm_error(error_dict) + if isinstance(error, httpx.HTTPStatusError): + return exc_type( + message, cause=error, http_response=error.response) if exc_type else None + return exc_type(message, cause=error) if exc_type else None + + @classmethod def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_response): """Parses an error response from the FCM API and creates a FCM-specific exception if @@ -577,7 +701,7 @@ def _build_fcm_error_googleapiclient(cls, error, message, error_dict, http_respo return exc_type(message, cause=error, http_response=http_response) if exc_type else None @classmethod - def _build_fcm_error(cls, error_dict): + def _build_fcm_error(cls, error_dict) -> Optional[Callable[..., exceptions.FirebaseError]]: if not error_dict: return None fcm_code = None @@ -585,4 +709,4 @@ def _build_fcm_error(cls, error_dict): if detail.get('@type') == 'type.googleapis.com/google.firebase.fcm.v1.FcmError': fcm_code = detail.get('errorCode') break - return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) + return _MessagingService.FCM_ERROR_TYPES.get(fcm_code) if fcm_code else None diff --git a/integration/conftest.py b/integration/conftest.py index 71f53f61..bdecca40 100644 --- a/integration/conftest.py +++ b/integration/conftest.py @@ -15,8 +15,9 @@ """pytest configuration and global fixtures for integration tests.""" import json -import asyncio +# import asyncio import pytest +from pytest_asyncio import is_async_test import firebase_admin from firebase_admin import credentials @@ -72,11 +73,17 @@ def api_key(request): with open(path) as keyfile: return keyfile.read().strip() -@pytest.fixture(scope="session") -def event_loop(): - """Create an instance of the default event loop for test session. - This avoids early eventloop closure. - """ - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() +# @pytest.fixture(scope="session") +# def event_loop(): +# """Create an instance of the default event loop for test session. +# This avoids early eventloop closure. +# """ +# loop = asyncio.get_event_loop_policy().new_event_loop() +# yield loop +# loop.close() + +def pytest_collection_modifyitems(items): + pytest_asyncio_tests = (item for item in items if is_async_test(item)) + session_scope_marker = pytest.mark.asyncio(loop_scope="session") + for async_test in pytest_asyncio_tests: + async_test.add_marker(session_scope_marker, append=False) diff --git a/integration/test_messaging.py b/integration/test_messaging.py index 4c1d7d0d..af35ce01 100644 --- a/integration/test_messaging.py +++ b/integration/test_messaging.py @@ -221,3 +221,85 @@ def test_subscribe(): def test_unsubscribe(): resp = messaging.unsubscribe_from_topic(_REGISTRATION_TOKEN, 'mock-topic') assert resp.success_count + resp.failure_count == 1 + +@pytest.mark.asyncio +async def test_send_each_async(): + messages = [ + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + topic='foo-bar', notification=messaging.Notification('Title', 'Body')), + messaging.Message( + token='not-a-token', notification=messaging.Notification('Title', 'Body')), + ] + + batch_response = await messaging.send_each_async(messages, dry_run=True) + + assert batch_response.success_count == 2 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 3 + + response = batch_response.responses[0] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[1] + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) + + response = batch_response.responses[2] + assert response.success is False + assert isinstance(response.exception, exceptions.InvalidArgumentError) + assert response.message_id is None + + +# @pytest.mark.asyncio +# async def test_send_each_async_error(): +# messages = [ +# messaging.Message( +# topic='foo-bar', notification=messaging.Notification('Title', 'Body')), +# messaging.Message( +# topic='foo-bar', notification=messaging.Notification('Title', 'Body')), +# messaging.Message( +# token='not-a-token', notification=messaging.Notification('Title', 'Body')), +# ] + +# batch_response = await messaging.send_each_async(messages, dry_run=True) + +# assert batch_response.success_count == 2 +# assert batch_response.failure_count == 1 +# assert len(batch_response.responses) == 3 + +# response = batch_response.responses[0] +# assert response.success is True +# assert response.exception is None +# assert re.match('^projects/.*/messages/.*$', response.message_id) + +# response = batch_response.responses[1] +# assert response.success is True +# assert response.exception is None +# assert re.match('^projects/.*/messages/.*$', response.message_id) + +# response = batch_response.responses[2] +# assert response.success is False +# assert isinstance(response.exception, exceptions.InvalidArgumentError) +# assert response.message_id is None + +@pytest.mark.asyncio +async def test_send_each_async_500(): + messages = [] + for msg_number in range(500): + topic = 'foo-bar-{0}'.format(msg_number % 10) + messages.append(messaging.Message(topic=topic)) + + batch_response = await messaging.send_each_async(messages, dry_run=True) + + assert batch_response.success_count == 500 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 500 + for response in batch_response.responses: + assert response.success is True + assert response.exception is None + assert re.match('^projects/.*/messages/.*$', response.message_id) diff --git a/requirements.txt b/requirements.txt index fd5b0b39..ba6f2f94 100644 --- a/requirements.txt +++ b/requirements.txt @@ -5,10 +5,12 @@ pytest-cov >= 2.4.0 pytest-localserver >= 0.4.1 pytest-asyncio >= 0.16.0 pytest-mock >= 3.6.1 +respx == 0.22.0 cachecontrol >= 0.12.14 google-api-core[grpc] >= 1.22.1, < 3.0.0dev; platform.python_implementation != 'PyPy' google-api-python-client >= 1.7.8 google-cloud-firestore >= 2.19.0; platform.python_implementation != 'PyPy' google-cloud-storage >= 1.37.1 -pyjwt[crypto] >= 2.5.0 \ No newline at end of file +pyjwt[crypto] >= 2.5.0 +httpx[http2] == 0.28.1 \ No newline at end of file diff --git a/setup.py b/setup.py index 23be6d48..e92d207a 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ 'google-cloud-firestore>=2.19.0; platform.python_implementation != "PyPy"', 'google-cloud-storage>=1.37.1', 'pyjwt[crypto] >= 2.5.0', + 'httpx[http2] == 0.28.1', ] setup( diff --git a/tests/test_messaging.py b/tests/test_messaging.py index b7b5c69b..0200c11e 100644 --- a/tests/test_messaging.py +++ b/tests/test_messaging.py @@ -14,8 +14,11 @@ """Test cases for the firebase_admin.messaging module.""" import datetime +from itertools import chain, repeat import json import numbers +import httpx +import respx from googleapiclient import http from googleapiclient import _helpers @@ -1924,6 +1927,201 @@ def test_send_each(self): assert all([r.success for r in batch_response.responses]) assert not any([r.exception for r in batch_response.responses]) + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async(self): + responses = [ + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id2'}), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id3'}), + ] + msg1 = messaging.Message(topic='foo1') + msg2 = messaging.Message(topic='foo2') + msg3 = messaging.Message(topic='foo3') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + + batch_response = await messaging.send_each_async([msg1, msg2, msg3], dry_run=True) + + # try: + # batch_response = await messaging.send_each_async([msg1, msg2], dry_run=True) + # except Exception as error: + # if isinstance(error.cause.__cause__, StopIteration): + # raise Exception('Received more requests than mocks') + + assert batch_response.success_count == 3 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 3 + assert [r.message_id for r in batch_response.responses] \ + == ['message-id1', 'message-id2', 'message-id3'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + assert route.call_count == 3 + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_401_fail_auth_retry(self): + payload = json.dumps({ + 'error': { + 'status': 'UNAUTHENTICATED', + 'message': 'test unauthenticated error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + + responses = repeat(respx.MockResponse(401, http_version='HTTP/2', content=payload)) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 3 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.UnauthenticatedError) + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_401_pass_on_auth_retry(self): + payload = json.dumps({ + 'error': { + 'status': 'UNAUTHENTICATED', + 'message': 'test unauthenticated error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + responses = [ + respx.MockResponse(401, http_version='HTTP/2', content=payload), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + ] + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 2 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 1 + assert [r.message_id for r in batch_response.responses] == ['message-id1'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_500_fail_retry_config(self): + payload = json.dumps({ + 'error': { + 'status': 'INTERNAL', + 'message': 'test INTERNAL error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + + responses = repeat(respx.MockResponse(500, http_version='HTTP/2', content=payload)) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 5 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.InternalError) + + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_error_500_pass_on_retry_config(self): + payload = json.dumps({ + 'error': { + 'status': 'INTERNAL', + 'message': 'test INTERNAL error', + 'details': [ + { + '@type': 'type.googleapis.com/google.firebase.fcm.v1.FcmError', + 'errorCode': 'SOME_UNKNOWN_CODE', + }, + ], + } + }) + responses = chain( + [ + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(500, http_version='HTTP/2', content=payload), + respx.MockResponse(200, http_version='HTTP/2', json={'name': 'message-id1'}), + ], + ) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 5 + assert batch_response.success_count == 1 + assert batch_response.failure_count == 0 + assert len(batch_response.responses) == 1 + assert [r.message_id for r in batch_response.responses] == ['message-id1'] + assert all([r.success for r in batch_response.responses]) + assert not any([r.exception for r in batch_response.responses]) + + @respx.mock + @pytest.mark.asyncio + async def test_send_each_async_request_error(self): + responses = httpx.ConnectError("Test request error", request=httpx.Request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send')) + + msg1 = messaging.Message(topic='foo1') + route = respx.request( + 'POST', + 'https://fcm.googleapis.com/v1/projects/explicit-project-id/messages:send' + ).mock(side_effect=responses) + batch_response = await messaging.send_each_async([msg1], dry_run=True) + + assert route.call_count == 1 + assert batch_response.success_count == 0 + assert batch_response.failure_count == 1 + assert len(batch_response.responses) == 1 + exception = batch_response.responses[0].exception + assert isinstance(exception, exceptions.UnavailableError) + @pytest.mark.parametrize('status', HTTP_ERROR_CODES) def test_send_each_detailed_error(self, status): success_payload = json.dumps({'name': 'message-id'}) diff --git a/tests/test_retry.py b/tests/test_retry.py new file mode 100644 index 00000000..751fdea7 --- /dev/null +++ b/tests/test_retry.py @@ -0,0 +1,454 @@ +# Copyright 2025 Google Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test cases for the firebase_admin._retry module.""" + +import time +import email.utils +from itertools import repeat +from unittest.mock import call +import pytest +import httpx +from pytest_mock import MockerFixture +import respx + +from firebase_admin._retry import HttpxRetry, HttpxRetryTransport + +_TEST_URL = 'http://firebase.test.url/' + +@pytest.fixture +def base_url() -> str: + """Provides a consistent base URL for tests.""" + return _TEST_URL + +class TestHttpxRetryTransport(): + @pytest.mark.asyncio + @respx.mock + async def test_no_retry_on_success(self, base_url: str, mocker: MockerFixture): + """Test that a successful response doesn't trigger retries.""" + retry_config = HttpxRetry(max_retries=3, status_forcelist=[500]) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(return_value=httpx.Response(200, text="Success")) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert response.text == "Success" + assert route.call_count == 1 + mock_sleep.assert_not_called() + + @pytest.mark.asyncio + @respx.mock + async def test_no_retry_on_non_retryable_status(self, base_url: str, mocker: MockerFixture): + """Test that a non-retryable error status doesn't trigger retries.""" + retry_config = HttpxRetry(max_retries=3, status_forcelist=[500, 503]) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(return_value=httpx.Response(404, text="Not Found")) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 404 + assert response.text == "Not Found" + assert route.call_count == 1 + mock_sleep.assert_not_called() + + @pytest.mark.asyncio + @respx.mock + async def test_retry_on_status_code_success_on_last_retry( + self, base_url: str, mocker: MockerFixture + ): + """Test retry on status code from status_forcelist, succeeding on the last attempt.""" + retry_config = HttpxRetry(max_retries=2, status_forcelist=[503, 500], backoff_factor=0.5) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(503, text="Attempt 1 Failed"), + httpx.Response(500, text="Attempt 2 Failed"), + httpx.Response(200, text="Attempt 3 Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert response.text == "Attempt 3 Success" + assert route.call_count == 3 + assert mock_sleep.call_count == 2 + # Check sleep calls (backoff_factor is 0.5) + mock_sleep.assert_has_calls([call(0.0), call(1.0)]) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_exhausted_returns_last_response( + self, base_url: str, mocker: MockerFixture + ): + """Test that the last response is returned when retries are exhausted.""" + retry_config = HttpxRetry(max_retries=1, status_forcelist=[500], backoff_factor=0) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Attempt 1 Failed"), + httpx.Response(500, text="Attempt 2 Failed (Final)"), + # Should stop after previous response + httpx.Response(200, text="This should not be reached"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 500 + assert response.text == "Attempt 2 Failed (Final)" + assert route.call_count == 2 # Initial call + 1 retry + assert mock_sleep.call_count == 1 # Slept before the single retry + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_header_seconds(self, base_url: str, mocker: MockerFixture): + """Test respecting Retry-After header with seconds value.""" + retry_config = HttpxRetry( + max_retries=1, respect_retry_after_header=True, backoff_factor=100) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '10'}), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 2 + assert mock_sleep.call_count == 1 + # Assert sleep was called with the value from Retry-After header + mock_sleep.assert_called_once_with(10.0) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_header_http_date(self, base_url: str, mocker: MockerFixture): + """Test respecting Retry-After header with an HTTP-date value.""" + retry_config = HttpxRetry( + max_retries=1, respect_retry_after_header=True, backoff_factor=100) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + # Calculate a future time and format as HTTP-date + retry_delay_seconds = 60 + time_at_request = time.time() + retry_time = time_at_request + retry_delay_seconds + http_date = email.utils.formatdate(retry_time) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(503, text="Maintenance", headers={'Retry-After': http_date}), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + # Patch time.time() within the test context to control the baseline for date calculation + # Set the mock time to be *just before* the Retry-After time + mocker.patch('time.time', return_value=time_at_request) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 2 + assert mock_sleep.call_count == 1 + # Check that sleep was called with approximately the correct delay + # Allow for small floating point inaccuracies + mock_sleep.assert_called_once() + args, _ = mock_sleep.call_args + assert args[0] == pytest.approx(retry_delay_seconds, abs=2) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_ignored_when_disabled(self, base_url: str, mocker: MockerFixture): + """Test Retry-After header is ignored if `respect_retry_after_header` is `False`.""" + retry_config = HttpxRetry( + max_retries=3, respect_retry_after_header=False, status_forcelist=[429], + backoff_factor=0.5, backoff_max=10) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '60'}), + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '60'}), + httpx.Response(429, text="Too Many Requests", headers={'Retry-After': '60'}), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Assert sleep was called with the calculated backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.5 * (2**(2-1)) = 0.5 * 2 = 1.0 + # After retry 2 (attempt 3): delay = 0.5 * (2**(3-1)) = 0.5 * 4 = 2.0 + expected_sleeps = [call(0), call(1.0), call(2.0)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_retry_after_header_missing_backoff_fallback( + self, base_url: str, mocker: MockerFixture + ): + """Test Retry-After header is ignored if `respect_retry_after_header`is `True` but header is + not set.""" + retry_config = HttpxRetry( + max_retries=3, respect_retry_after_header=True, status_forcelist=[429], + backoff_factor=0.5, backoff_max=10) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(429, text="Too Many Requests"), + httpx.Response(429, text="Too Many Requests"), + httpx.Response(429, text="Too Many Requests"), + httpx.Response(200, text="OK"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Assert sleep was called with the calculated backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.5 * (2**(2-1)) = 0.5 * 2 = 1.0 + # After retry 2 (attempt 3): delay = 0.5 * (2**(3-1)) = 0.5 * 4 = 2.0 + expected_sleeps = [call(0), call(1.0), call(2.0)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_exponential_backoff(self, base_url: str, mocker: MockerFixture): + """Test that sleep time increases exponentially with `backoff_factor`.""" + # status=3 allows 3 retries (attempts 2, 3, 4) + retry_config = HttpxRetry( + max_retries=3, status_forcelist=[500], backoff_factor=0.1, backoff_max=10.0) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Fail 1"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 3"), + httpx.Response(200, text="Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Check expected backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.1 * (2**(2-1)) = 0.1 * 2 = 0.2 + # After retry 2 (attempt 3): delay = 0.1 * (2**(3-1)) = 0.1 * 4 = 0.4 + expected_sleeps = [call(0), call(0.2), call(0.4)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_backoff_max(self, base_url: str, mocker: MockerFixture): + """Test that backoff time respects `backoff_max`.""" + # status=4 allows 4 retries. backoff_factor=1 causes rapid increase. + retry_config = HttpxRetry( + max_retries=4, status_forcelist=[500], backoff_factor=1, backoff_max=3.0) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Fail 1"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 4"), + httpx.Response(200, text="Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 5 + assert mock_sleep.call_count == 4 + + # Check expected backoff times: + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 1*(2**(2-1)) = 2. Clamped by max(0, min(3.0, 2)) = 2.0 + # After retry 2 (attempt 3): delay = 1*(2**(3-1)) = 4. Clamped by max(0, min(3.0, 4)) = 3.0 + # After retry 3 (attempt 4): delay = 1*(2**(4-1)) = 8. Clamped by max(0, min(3.0, 8)) = 3.0 + expected_sleeps = [call(0.0), call(2.0), call(3.0), call(3.0)] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_backoff_jitter(self, base_url: str, mocker: MockerFixture): + """Test that `backoff_jitter` adds randomness within bounds.""" + retry_config = HttpxRetry( + max_retries=3, status_forcelist=[500], backoff_factor=0.2, backoff_jitter=0.1) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + route = respx.post(base_url).mock(side_effect=[ + httpx.Response(500, text="Fail 1"), + httpx.Response(500, text="Fail 2"), + httpx.Response(500, text="Fail 3"), + httpx.Response(200, text="Success"), + ]) + + mock_sleep = mocker.patch('asyncio.sleep', return_value=None) + response = await client.post(base_url) + + assert response.status_code == 200 + assert route.call_count == 4 + assert mock_sleep.call_count == 3 + + # Check expected backoff times are within the expected range [base - jitter, base + jitter] + # After first attempt: delay = 0 + # After retry 1 (attempt 2): delay = 0.2 * (2**(2-1)) = 0.2 * 2 = 0.4 +/- 0.1 + # After retry 2 (attempt 3): delay = 0.2 * (2**(3-1)) = 0.2 * 4 = 0.8 +/- 0.1 + expected_sleeps = [ + call(pytest.approx(0.0, abs=0.1)), + call(pytest.approx(0.4, abs=0.1)), + call(pytest.approx(0.8, abs=0.1)) + ] + mock_sleep.assert_has_calls(expected_sleeps) + + @pytest.mark.asyncio + @respx.mock + async def test_error_not_retryable(self, base_url): + """Test that non-HTTP errors are raised immediately if not retryable.""" + retry_config = HttpxRetry(max_retries=3) + transport = HttpxRetryTransport(retry=retry_config) + client = httpx.AsyncClient(transport=transport) + + # Mock a connection error + route = respx.post(base_url).mock( + side_effect=repeat(httpx.ConnectError("Connection failed"))) + + with pytest.raises(httpx.ConnectError, match="Connection failed"): + await client.post(base_url) + + assert route.call_count == 1 + + +class TestHttpxRetry(): + _TEST_REQUEST = httpx.Request('POST', _TEST_URL) + + def test_httpx_retry_copy(self, base_url): + """Test that `HttpxRetry.copy()` creates a deep copy.""" + original = HttpxRetry(max_retries=5, status_forcelist=[500, 503], backoff_factor=0.5) + original.history.append((base_url, None, None)) # Add something mutable + + copied = original.copy() + + # Assert they are different objects + assert original is not copied + assert original.history is not copied.history + + # Assert values are the same initially + assert copied.retries_left == original.retries_left + assert copied.status_forcelist == original.status_forcelist + assert copied.backoff_factor == original.backoff_factor + assert len(copied.history) == 1 + + # Modify the copy and check original is unchanged + copied.retries_left = 1 + copied.status_forcelist = [404] + copied.history.append((base_url, None, None)) + + assert original.retries_left == 5 + assert original.status_forcelist == [500, 503] + assert len(original.history) == 1 + + def test_parse_retry_after_seconds(self): + retry = HttpxRetry() + assert retry._parse_retry_after('10') == 10.0 + assert retry._parse_retry_after(' 30 ') == 30.0 + + + def test_parse_retry_after_http_date(self, mocker: MockerFixture): + mocker.patch('time.time', return_value=1000.0) + retry = HttpxRetry() + # Date string representing 1015 seconds since epoch + http_date = email.utils.formatdate(1015.0) + # time.time() is mocked to 1000.0, so delay should be 15s + assert retry._parse_retry_after(http_date) == pytest.approx(15.0) + + def test_parse_retry_after_past_http_date(self, mocker: MockerFixture): + """Test that a past date results in 0 seconds.""" + mocker.patch('time.time', return_value=1000.0) + retry = HttpxRetry() + http_date = email.utils.formatdate(990.0) # 10s in the past + assert retry._parse_retry_after(http_date) == 0.0 + + def test_parse_retry_after_invalid_date(self): + retry = HttpxRetry() + with pytest.raises(httpx.RemoteProtocolError, match='Invalid Retry-After header'): + retry._parse_retry_after('Invalid Date Format') + + def test_get_backoff_time_calculation(self): + retry = HttpxRetry( + max_retries=6, status_forcelist=[503], backoff_factor=0.5, backoff_max=10.0) + response = httpx.Response(503) + # No history -> attempt 1 -> no backoff before first request + # Note: get_backoff_time() is typically called *before* the *next* request, + # so history length reflects completed attempts. + assert retry.get_backoff_time() == 0.0 + + # Simulate attempt 1 completed + retry.increment(self._TEST_REQUEST, response) + # History len 1, attempt 2 -> base case 0 + assert retry.get_backoff_time() == pytest.approx(0) + + # Simulate attempt 2 completed + retry.increment(self._TEST_REQUEST, response) + # History len 2, attempt 3 -> 0.5*(2^1) = 1.0 + assert retry.get_backoff_time() == pytest.approx(1.0) + + # Simulate attempt 3 completed + retry.increment(self._TEST_REQUEST, response) + # History len 3, attempt 4 -> 0.5*(2^2) = 2.0 + assert retry.get_backoff_time() == pytest.approx(2.0) + + # Simulate attempt 4 completed + retry.increment(self._TEST_REQUEST, response) + # History len 4, attempt 5 -> 0.5*(2^3) = 4.0 + assert retry.get_backoff_time() == pytest.approx(4.0) + + # Simulate attempt 5 completed + retry.increment(self._TEST_REQUEST, response) + # History len 5, attempt 6 -> 0.5*(2^4) = 8.0 + assert retry.get_backoff_time() == pytest.approx(8.0) + + # Simulate attempt 6 completed + retry.increment(self._TEST_REQUEST, response) + # History len 6, attempt 7 -> 0.5*(2^5) = 16.0 Clamped to 10 + assert retry.get_backoff_time() == pytest.approx(10.0) diff --git a/tests/test_token_gen.py b/tests/test_token_gen.py index 536a5ec9..916a19f5 100644 --- a/tests/test_token_gen.py +++ b/tests/test_token_gen.py @@ -19,6 +19,7 @@ import json import os import time +from unittest import mock from google.auth import crypt from google.auth import jwt @@ -562,17 +563,34 @@ def test_expired_token(self, user_mgt_app): def test_expired_token_with_tolerance(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - id_token = self.invalid_tokens['ExpiredTokenShort'] + id_token_encoded = self.invalid_tokens['ExpiredTokenShort'] if _is_emulated(): - self._assert_valid_token(id_token, user_mgt_app) + # Emulator mode doesn't perform the same time checks, skip advanced mocking + self._assert_valid_token(id_token_encoded, user_mgt_app) return - claims = auth.verify_id_token(id_token, app=user_mgt_app, - clock_skew_seconds=60) + + # Decode the token to get its actual 'exp' timestamp + # Ensure 'google.auth.jwt' is available for jwt.decode + # This might require `from google.auth import jwt` if not already present + decoded_token = jwt.decode(id_token_encoded, verify=False) + exp_timestamp = decoded_token['exp'] + + # Valid case: mock utcnow to be exactly at exp + clock_skew (boundary of validity) + # The token should be considered valid here. + mock_now_valid = datetime.datetime.utcfromtimestamp(exp_timestamp + 60) + with mock.patch('google.auth._helpers.utcnow', return_value=mock_now_valid): + claims = auth.verify_id_token(id_token_encoded, app=user_mgt_app, + clock_skew_seconds=60) assert claims['admin'] is True assert claims['uid'] == claims['sub'] - with pytest.raises(auth.ExpiredIdTokenError): - auth.verify_id_token(id_token, app=user_mgt_app, - clock_skew_seconds=20) + + # Expired case: mock utcnow to be 1 second after exp + clock_skew (just expired) + # The token should be considered expired here. + mock_now_expired = datetime.datetime.utcfromtimestamp(exp_timestamp + 20 + 1) + with mock.patch('google.auth._helpers.utcnow', return_value=mock_now_expired): + with pytest.raises(auth.ExpiredIdTokenError): + auth.verify_id_token(id_token_encoded, app=user_mgt_app, + clock_skew_seconds=20) def test_project_id_option(self): app = firebase_admin.initialize_app( @@ -741,17 +759,34 @@ def test_expired_cookie(self, user_mgt_app): def test_expired_cookie_with_tolerance(self, user_mgt_app): _overwrite_cert_request(user_mgt_app, MOCK_REQUEST) - cookie = self.invalid_cookies['ExpiredCookieShort'] + cookie_encoded = self.invalid_cookies['ExpiredCookieShort'] if _is_emulated(): - self._assert_valid_cookie(cookie, user_mgt_app) + # Emulator mode doesn't perform the same time checks, skip advanced mocking + self._assert_valid_cookie(cookie_encoded, user_mgt_app) return - claims = auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=False, - clock_skew_seconds=59) + + # Decode the token to get its actual 'exp' timestamp + decoded_cookie = jwt.decode(cookie_encoded, verify=False) + exp_timestamp = decoded_cookie['exp'] + + # Valid case: mock utcnow to be exactly at exp + clock_skew (boundary of validity) + # The cookie should be considered valid here. + mock_now_valid = datetime.datetime.utcfromtimestamp(exp_timestamp + 59) + with mock.patch('google.auth._helpers.utcnow', return_value=mock_now_valid): + claims = auth.verify_session_cookie( + cookie_encoded, app=user_mgt_app, check_revoked=False, + clock_skew_seconds=59) # This clock_skew is used by google.auth.jwt assert claims['admin'] is True assert claims['uid'] == claims['sub'] - with pytest.raises(auth.ExpiredSessionCookieError): - auth.verify_session_cookie(cookie, app=user_mgt_app, check_revoked=False, - clock_skew_seconds=29) + + # Expired case: mock utcnow to be 1 second after exp + clock_skew (just expired) + # The cookie should be considered expired here. + mock_now_expired = datetime.datetime.utcfromtimestamp(exp_timestamp + 29 + 1) + with mock.patch('google.auth._helpers.utcnow', return_value=mock_now_expired): + with pytest.raises(auth.ExpiredSessionCookieError): + auth.verify_session_cookie( + cookie_encoded, app=user_mgt_app, check_revoked=False, + clock_skew_seconds=29) # This clock_skew is used by google.auth.jwt def test_project_id_option(self): app = firebase_admin.initialize_app( pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy