diff --git a/docs/topics/keepalive.rst b/docs/topics/keepalive.rst index e63c2f8f5..fd8300183 100644 --- a/docs/topics/keepalive.rst +++ b/docs/topics/keepalive.rst @@ -136,8 +136,8 @@ measured during the last exchange of Ping and Pong frames:: Alternatively, you can measure the latency at any time by calling :attr:`~asyncio.connection.Connection.ping` and awaiting its result:: - pong_waiter = await websocket.ping() - latency = await pong_waiter + pong_received = await websocket.ping() + latency = await pong_received Latency between a client and a server may increase for two reasons: diff --git a/src/websockets/asyncio/connection.py b/src/websockets/asyncio/connection.py index 61c300d63..e7af71fc5 100644 --- a/src/websockets/asyncio/connection.py +++ b/src/websockets/asyncio/connection.py @@ -101,10 +101,10 @@ def __init__( self.close_deadline: float | None = None # Protect sending fragmented messages. - self.fragmented_send_waiter: asyncio.Future[None] | None = None + self.send_in_progress: asyncio.Future[None] | None = None # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[asyncio.Future[float], float]] = {} + self.pending_pings: dict[bytes, tuple[asyncio.Future[float], float]] = {} self.latency: float = 0 """ @@ -468,8 +468,8 @@ async def send( """ # While sending a fragmented message, prevent sending other messages # until all fragments are sent. - while self.fragmented_send_waiter is not None: - await asyncio.shield(self.fragmented_send_waiter) + while self.send_in_progress is not None: + await asyncio.shield(self.send_in_progress) # Unfragmented message -- this case must be handled first because # strings and bytes-like objects are iterable. @@ -502,8 +502,8 @@ async def send( except StopIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -549,8 +549,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None # Fragmented message -- async iterator. @@ -561,8 +561,8 @@ async def send( except StopAsyncIteration: return - assert self.fragmented_send_waiter is None - self.fragmented_send_waiter = self.loop.create_future() + assert self.send_in_progress is None + self.send_in_progress = self.loop.create_future() try: # First fragment. if isinstance(chunk, str): @@ -610,8 +610,8 @@ async def send( raise finally: - self.fragmented_send_waiter.set_result(None) - self.fragmented_send_waiter = None + self.send_in_progress.set_result(None) + self.send_in_progress = None else: raise TypeError("data must be str, bytes, iterable, or async iterable") @@ -635,7 +635,7 @@ async def close(self, code: int = 1000, reason: str = "") -> None: # The context manager takes care of waiting for the TCP connection # to terminate after calling a method that sends a close frame. async with self.send_context(): - if self.fragmented_send_waiter is not None: + if self.send_in_progress is not None: self.protocol.fail( CloseCode.INTERNAL_ERROR, "close during fragmented message", @@ -677,9 +677,9 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: :: - pong_waiter = await ws.ping() + pong_received = await ws.ping() # only if you want to wait for the corresponding pong - latency = await pong_waiter + latency = await pong_received Raises: ConnectionClosed: When the connection is closed. @@ -696,19 +696,19 @@ async def ping(self, data: Data | None = None) -> Awaitable[float]: async with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = self.loop.create_future() + pong_received = self.loop.create_future() # The event loop's default clock is time.monotonic(). Its resolution # is a bit low on Windows (~16ms). This is improved in Python 3.13. - self.pong_waiters[data] = (pong_waiter, self.loop.time()) + self.pending_pings[data] = (pong_received, self.loop.time()) self.protocol.send_ping(data) - return pong_waiter + return pong_received async def pong(self, data: Data = b"") -> None: """ @@ -757,7 +757,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = self.loop.time() @@ -766,20 +766,20 @@ def acknowledge_pings(self, data: bytes) -> None: # Acknowledge all previous pings too in that case. ping_id = None ping_ids = [] - for ping_id, (pong_waiter, ping_timestamp) in self.pong_waiters.items(): + for ping_id, (pong_received, ping_timestamp) in self.pending_pings.items(): ping_ids.append(ping_id) latency = pong_timestamp - ping_timestamp - if not pong_waiter.done(): - pong_waiter.set_result(latency) + if not pong_received.done(): + pong_received.set_result(latency) if ping_id == data: self.latency = latency break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] def abort_pings(self) -> None: """ @@ -791,16 +791,16 @@ def abort_pings(self) -> None: assert self.protocol.state is CLOSED exc = self.protocol.close_exc - for pong_waiter, _ping_timestamp in self.pong_waiters.values(): - if not pong_waiter.done(): - pong_waiter.set_exception(exc) + for pong_received, _ping_timestamp in self.pending_pings.values(): + if not pong_received.done(): + pong_received.set_exception(exc) # If the exception is never retrieved, it will be logged when ping # is garbage-collected. This is confusing for users. # Given that ping is done (with an exception), canceling it does # nothing, but it prevents logging the exception. - pong_waiter.cancel() + pong_received.cancel() - self.pong_waiters.clear() + self.pending_pings.clear() async def keepalive(self) -> None: """ @@ -821,7 +821,7 @@ async def keepalive(self) -> None: # connection to be closed before raising ConnectionClosed. # However, connection_lost() cancels keepalive_task before # it gets a chance to resume excuting. - pong_waiter = await self.ping() + pong_received = await self.ping() if self.debug: self.logger.debug("% sent keepalive ping") @@ -830,9 +830,9 @@ async def keepalive(self) -> None: async with asyncio_timeout(self.ping_timeout): # connection_lost cancels keepalive immediately # after setting a ConnectionClosed exception on - # pong_waiter. A CancelledError is raised here, + # pong_received. A CancelledError is raised here, # not a ConnectionClosed exception. - latency = await pong_waiter + latency = await pong_received self.logger.debug("% received keepalive pong") except asyncio.TimeoutError: if self.debug: @@ -1201,7 +1201,7 @@ def broadcast( if connection.protocol.state is not OPEN: continue - if connection.fragmented_send_waiter is not None: + if connection.send_in_progress is not None: if raise_exceptions: exception = ConcurrencyError("sending a fragmented message") exceptions.append(exception) diff --git a/src/websockets/asyncio/messages.py b/src/websockets/asyncio/messages.py index 1fd41811c..1fe33f709 100644 --- a/src/websockets/asyncio/messages.py +++ b/src/websockets/asyncio/messages.py @@ -81,8 +81,7 @@ class Assembler: """ - # coverage reports incorrectly: "line NN didn't jump to the function exit" - def __init__( # pragma: no cover + def __init__( self, high: int | None = None, low: int | None = None, diff --git a/src/websockets/sync/connection.py b/src/websockets/sync/connection.py index bedbf4def..04f65f3f6 100644 --- a/src/websockets/sync/connection.py +++ b/src/websockets/sync/connection.py @@ -104,7 +104,7 @@ def __init__( self.send_in_progress = False # Mapping of ping IDs to pong waiters, in chronological order. - self.pong_waiters: dict[bytes, tuple[threading.Event, float, bool]] = {} + self.pending_pings: dict[bytes, tuple[threading.Event, float, bool]] = {} self.latency: float = 0 """ @@ -629,8 +629,9 @@ def ping( :: - pong_event = ws.ping() - pong_event.wait() # only if you want to wait for the pong + pong_received = ws.ping() + # only if you want to wait for the corresponding pong + pong_received.wait() Raises: ConnectionClosed: When the connection is closed. @@ -647,17 +648,17 @@ def ping( with self.send_context(): # Protect against duplicates if a payload is explicitly set. - if data in self.pong_waiters: + if data in self.pending_pings: raise ConcurrencyError("already waiting for a pong with the same data") # Generate a unique random payload otherwise. - while data is None or data in self.pong_waiters: + while data is None or data in self.pending_pings: data = struct.pack("!I", random.getrandbits(32)) - pong_waiter = threading.Event() - self.pong_waiters[data] = (pong_waiter, time.monotonic(), ack_on_close) + pong_received = threading.Event() + self.pending_pings[data] = (pong_received, time.monotonic(), ack_on_close) self.protocol.send_ping(data) - return pong_waiter + return pong_received def pong(self, data: Data = b"") -> None: """ @@ -707,7 +708,7 @@ def acknowledge_pings(self, data: bytes) -> None: """ with self.protocol_mutex: # Ignore unsolicited pong. - if data not in self.pong_waiters: + if data not in self.pending_pings: return pong_timestamp = time.monotonic() @@ -717,21 +718,21 @@ def acknowledge_pings(self, data: bytes) -> None: ping_id = None ping_ids = [] for ping_id, ( - pong_waiter, + pong_received, ping_timestamp, _ack_on_close, - ) in self.pong_waiters.items(): + ) in self.pending_pings.items(): ping_ids.append(ping_id) - pong_waiter.set() + pong_received.set() if ping_id == data: self.latency = pong_timestamp - ping_timestamp break else: raise AssertionError("solicited pong not found in pings") - # Remove acknowledged pings from self.pong_waiters. + # Remove acknowledged pings from self.pending_pings. for ping_id in ping_ids: - del self.pong_waiters[ping_id] + del self.pending_pings[ping_id] def acknowledge_pending_pings(self) -> None: """ @@ -740,11 +741,11 @@ def acknowledge_pending_pings(self) -> None: """ assert self.protocol.state is CLOSED - for pong_waiter, _ping_timestamp, ack_on_close in self.pong_waiters.values(): + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): if ack_on_close: - pong_waiter.set() + pong_received.set() - self.pong_waiters.clear() + self.pending_pings.clear() def keepalive(self) -> None: """ @@ -762,15 +763,14 @@ def keepalive(self) -> None: break try: - pong_waiter = self.ping(ack_on_close=True) + pong_received = self.ping(ack_on_close=True) except ConnectionClosed: break if self.debug: self.logger.debug("% sent keepalive ping") if self.ping_timeout is not None: - # - if pong_waiter.wait(self.ping_timeout): + if pong_received.wait(self.ping_timeout): if self.debug: self.logger.debug("% received keepalive pong") else: @@ -804,7 +804,7 @@ def recv_events(self) -> None: Run this method in a thread as long as the connection is alive. - ``recv_events()`` exits immediately when the ``self.socket`` is closed. + ``recv_events()`` exits immediately when ``self.socket`` is closed. """ try: @@ -979,6 +979,7 @@ def send_context( # Minor layering violation: we assume that the connection # will be closing soon if it isn't in the expected state. wait_for_close = True + # TODO: calculate close deadline if not set? raise_close_exc = True # To avoid a deadlock, release the connection lock by exiting the diff --git a/src/websockets/trio/__init__.py b/src/websockets/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/src/websockets/trio/connection.py b/src/websockets/trio/connection.py new file mode 100644 index 000000000..2a77749b5 --- /dev/null +++ b/src/websockets/trio/connection.py @@ -0,0 +1,1114 @@ +from __future__ import annotations + +import contextlib +import logging +import random +import struct +import uuid +from collections.abc import AsyncIterable, AsyncIterator, Iterable, Mapping +from types import TracebackType +from typing import Any, Literal, overload + +import trio + +from ..asyncio.compatibility import ( + TimeoutError, + aiter, + anext, +) +from ..exceptions import ( + ConcurrencyError, + ConnectionClosed, + ConnectionClosedOK, + ProtocolError, +) +from ..frames import DATA_OPCODES, BytesLike, CloseCode, Frame, Opcode +from ..http11 import Request, Response +from ..protocol import CLOSED, OPEN, Event, Protocol, State +from ..typing import Data, LoggerLike, Subprotocol +from .messages import Assembler + + +__all__ = ["Connection"] + + +class Connection: + """ + :mod:`trio` implementation of a WebSocket connection. + + :class:`Connection` provides APIs shared between WebSocket servers and + clients. + + You shouldn't use it directly. Instead, use + :class:`~websockets.trio.client.ClientConnection` or + :class:`~websockets.trio.server.ServerConnection`. + + """ + + def __init__( + self, + nursery: trio.Nursery, + stream: trio.abc.Stream, + protocol: Protocol, + *, + ping_interval: float | None = 20, + ping_timeout: float | None = 20, + close_timeout: float | None = 10, + max_queue: int | None | tuple[int | None, int | None] = 16, + ) -> None: + self.nursery = nursery + self.stream = stream + self.protocol = protocol + self.ping_interval = ping_interval + self.ping_timeout = ping_timeout + self.close_timeout = close_timeout + self.max_queue: tuple[int | None, int | None] + if isinstance(max_queue, int) or max_queue is None: + self.max_queue = (max_queue, None) + else: + self.max_queue = max_queue + + # Inject reference to this instance in the protocol's logger. + self.protocol.logger = logging.LoggerAdapter( + self.protocol.logger, + {"websocket": self}, + ) + + # Copy attributes from the protocol for convenience. + self.id: uuid.UUID = self.protocol.id + """Unique identifier of the connection. Useful in logs.""" + self.logger: LoggerLike = self.protocol.logger + """Logger for this connection.""" + self.debug = self.protocol.debug + + # HTTP handshake request and response. + self.request: Request | None = None + """Opening handshake request.""" + self.response: Response | None = None + """Opening handshake response.""" + + # Lock stopping reads when the assembler buffer is full. + self.recv_flow_control = trio.Lock() + + # Assembler turning frames into messages and serializing reads. + self.recv_messages = Assembler( + *self.max_queue, + pause=self.recv_flow_control.acquire_nowait, + resume=self.recv_flow_control.release, + ) + + # Deadline for the closing handshake. + self.close_deadline: float | None = None + + # Protect sending fragmented messages. + self.send_in_progress: trio.Event | None = None + + # Mapping of ping IDs to pong waiters, in chronological order. + self.pending_pings: dict[bytes, tuple[trio.Event, float, bool]] = {} + + self.latency: float = 0 + """ + Latency of the connection, in seconds. + + Latency is defined as the round-trip time of the connection. It is + measured by sending a Ping frame and waiting for a matching Pong frame. + Before the first measurement, :attr:`latency` is ``0``. + + By default, websockets enables a :ref:`keepalive ` mechanism + that sends Ping frames automatically at regular intervals. You can also + send Ping frames and measure latency with :meth:`ping`. + """ + + # Exception raised while reading from the connection, to be chained to + # ConnectionClosed in order to show why the TCP connection dropped. + self.recv_exc: BaseException | None = None + + # Start recv_events only after all attributes are initialized. + self.nursery.start_soon(self.recv_events) + + # Completed when the TCP connection is closed and the WebSocket + # connection state becomes CLOSED. + self.stream_closed: trio.Event = trio.Event() + + # Public attributes + + @property + def local_address(self) -> Any: + """ + Local address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getsockname`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getsockname() + else: # pragma: no cover + raise NotImplementedError + + @property + def remote_address(self) -> Any: + """ + Remote address of the connection. + + For IPv4 connections, this is a ``(host, port)`` tuple. + + The format of the address depends on the address family. + See :meth:`~socket.socket.getpeername`. + + """ + if isinstance(self.stream, trio.SSLStream): # pragma: no cover + stream = self.stream.transport_stream + else: + stream = self.stream + if isinstance(stream, trio.SocketStream): + return stream.socket.getpeername() + else: # pragma: no cover + raise NotImplementedError + + @property + def state(self) -> State: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should call :meth:`~recv` or + :meth:`send` and handle :exc:`~websockets.exceptions.ConnectionClosed` + exceptions. + + """ + return self.protocol.state + + @property + def subprotocol(self) -> Subprotocol | None: + """ + Subprotocol negotiated during the opening handshake. + + :obj:`None` if no subprotocol was negotiated. + + """ + return self.protocol.subprotocol + + @property + def close_code(self) -> int | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_code + + @property + def close_reason(self) -> str | None: + """ + State of the WebSocket connection, defined in :rfc:`6455`. + + This attribute is provided for completeness. Typical applications + shouldn't check its value. Instead, they should inspect attributes + of :exc:`~websockets.exceptions.ConnectionClosed` exceptions. + + """ + return self.protocol.close_reason + + # Public methods + + async def __aenter__(self) -> Connection: + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + if exc_type is None: + await self.close() + else: + await self.close(CloseCode.INTERNAL_ERROR) + + async def __aiter__(self) -> AsyncIterator[Data]: + """ + Iterate on incoming messages. + + The iterator calls :meth:`recv` and yields messages asynchronously in an + infinite loop. + + It exits when the connection is closed normally. It raises a + :exc:`~websockets.exceptions.ConnectionClosedError` exception after a + protocol error or a network failure. + + """ + try: + while True: + yield await self.recv() + except ConnectionClosedOK: + return + + @overload + async def recv(self, decode: Literal[True]) -> str: ... + + @overload + async def recv(self, decode: Literal[False]) -> bytes: ... + + @overload + async def recv(self, decode: bool | None = None) -> Data: ... + + async def recv(self, decode: bool | None = None) -> Data: + """ + Receive the next message. + + When the connection is closed, :meth:`recv` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it raises + :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal closure + and :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. This is how you detect the end of the + message stream. + + Canceling :meth:`recv` is safe. There's no risk of losing data. The next + invocation of :meth:`recv` will return the next message. + + This makes it possible to enforce a timeout by wrapping :meth:`recv` in + :func:`~trio.move_on_after` or :func:`~trio.fail_after`. + + When the message is fragmented, :meth:`recv` waits until all fragments + are received, reassembles them, and returns the whole message. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + A string (:class:`str`) for a Text_ frame or a bytestring + (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames and + return a bytestring (:class:`bytes`). This improves performance + when decoding isn't needed, for example if the message contains + JSON and you're using a JSON library that expects a bytestring. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return a string (:class:`str`). This may be useful for + servers that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + return await self.recv_messages.get(decode) + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + @overload + def recv_streaming(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def recv_streaming(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def recv_streaming(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Receive the next message frame by frame. + + This method is designed for receiving fragmented messages. It returns an + asynchronous iterator that yields each fragment as it is received. This + iterator must be fully consumed. Else, future calls to :meth:`recv` or + :meth:`recv_streaming` will raise + :exc:`~websockets.exceptions.ConcurrencyError`, making the connection + unusable. + + :meth:`recv_streaming` raises the same exceptions as :meth:`recv`. + + Canceling :meth:`recv_streaming` before receiving the first frame is + safe. Canceling it after receiving one or more frames leaves the + iterator in a partially consumed state, making the connection unusable. + Instead, you should close the connection with :meth:`close`. + + Args: + decode: Set this flag to override the default behavior of returning + :class:`str` or :class:`bytes`. See below for details. + + Returns: + An iterator of strings (:class:`str`) for a Text_ frame or + bytestrings (:class:`bytes`) for a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``decode`` argument: + + * Set ``decode=False`` to disable UTF-8 decoding of Text_ frames + and return bytestrings (:class:`bytes`). This may be useful to + optimize performance when decoding isn't needed. + * Set ``decode=True`` to force UTF-8 decoding of Binary_ frames + and return strings (:class:`str`). This is useful for servers + that send binary frames instead of text frames. + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If two coroutines call :meth:`recv` or + :meth:`recv_streaming` concurrently. + + """ + try: + async for frame in self.recv_messages.get_iter(decode): + yield frame + return + except EOFError: + pass + # fallthrough + except ConcurrencyError: + raise ConcurrencyError( + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming" + ) from None + except UnicodeDecodeError as exc: + async with self.send_context(): + self.protocol.fail( + CloseCode.INVALID_DATA, + f"{exc.reason} at position {exc.start}", + ) + # fallthrough + + # Wait for the protocol state to be CLOSED before accessing close_exc. + await self.stream_closed.wait() + raise self.protocol.close_exc from self.recv_exc + + async def send( + self, + message: Data | Iterable[Data] | AsyncIterable[Data], + text: bool | None = None, + ) -> None: + """ + Send a message. + + A string (:class:`str`) is sent as a Text_ frame. A bytestring or + bytes-like object (:class:`bytes`, :class:`bytearray`, or + :class:`memoryview`) is sent as a Binary_ frame. + + .. _Text: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + .. _Binary: https://datatracker.ietf.org/doc/html/rfc6455#section-5.6 + + You may override this behavior with the ``text`` argument: + + * Set ``text=True`` to send a bytestring or bytes-like object + (:class:`bytes`, :class:`bytearray`, or :class:`memoryview`) as a + Text_ frame. This improves performance when the message is already + UTF-8 encoded, for example if the message contains JSON and you're + using a JSON library that produces a bytestring. + * Set ``text=False`` to send a string (:class:`str`) in a Binary_ + frame. This may be useful for servers that expect binary frames + instead of text frames. + + :meth:`send` also accepts an iterable or an asynchronous iterable of + strings, bytestrings, or bytes-like objects to enable fragmentation_. + Each item is treated as a message fragment and sent in its own frame. + All items must be of the same type, or else :meth:`send` will raise a + :exc:`TypeError` and the connection will be closed. + + .. _fragmentation: https://datatracker.ietf.org/doc/html/rfc6455#section-5.4 + + :meth:`send` rejects dict-like objects because this is often an error. + (If you really want to send the keys of a dict-like object as fragments, + call its :meth:`~dict.keys` method and pass the result to :meth:`send`.) + + Canceling :meth:`send` is discouraged. Instead, you should close the + connection with :meth:`close`. Indeed, there are only two situations + where :meth:`send` may yield control to the event loop and then get + canceled; in both cases, :meth:`close` has the same effect and is + more clear: + + 1. The write buffer is full. If you don't want to wait until enough + data is sent, your only alternative is to close the connection. + :meth:`close` will likely time out then abort the TCP connection. + 2. ``message`` is an asynchronous iterator that yields control. + Stopping in the middle of a fragmented message will cause a + protocol error and the connection will be closed. + + When the connection is closed, :meth:`send` raises + :exc:`~websockets.exceptions.ConnectionClosed`. Specifically, it + raises :exc:`~websockets.exceptions.ConnectionClosedOK` after a normal + connection closure and + :exc:`~websockets.exceptions.ConnectionClosedError` after a protocol + error or a network failure. + + Args: + message: Message to send. + + Raises: + ConnectionClosed: When the connection is closed. + TypeError: If ``message`` doesn't have a supported type. + + """ + # While sending a fragmented message, prevent sending other messages + # until all fragments are sent. + while self.send_in_progress is not None: + await self.send_in_progress.wait() + + # Unfragmented message -- this case must be handled first because + # strings and bytes-like objects are iterable. + + if isinstance(message, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(message.encode()) + else: + self.protocol.send_text(message.encode()) + + elif isinstance(message, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(message) + else: + self.protocol.send_binary(message) + + # Catch a common mistake -- passing a dict to send(). + + elif isinstance(message, Mapping): + raise TypeError("data is a dict-like object") + + # Fragmented message -- regular iterator. + + elif isinstance(message, Iterable): + chunks = iter(message) + try: + chunk = next(chunks) + except StopIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + async with self.send_context(): + if text is False: + self.protocol.send_binary(chunk.encode(), fin=False) + else: + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + async with self.send_context(): + if text is True: + self.protocol.send_text(chunk, fin=False) + else: + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("iterable must contain bytes or str") + + # Other fragments + for chunk in chunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + # Fragmented message -- async iterator. + + elif isinstance(message, AsyncIterable): + achunks = aiter(message) + try: + chunk = await anext(achunks) + except StopAsyncIteration: + return + + assert self.send_in_progress is None + self.send_in_progress = trio.Event() + try: + # First fragment. + if isinstance(chunk, str): + if text is False: + async with self.send_context(): + self.protocol.send_binary(chunk.encode(), fin=False) + else: + async with self.send_context(): + self.protocol.send_text(chunk.encode(), fin=False) + encode = True + elif isinstance(chunk, BytesLike): + if text is True: + async with self.send_context(): + self.protocol.send_text(chunk, fin=False) + else: + async with self.send_context(): + self.protocol.send_binary(chunk, fin=False) + encode = False + else: + raise TypeError("async iterable must contain bytes or str") + + # Other fragments + async for chunk in achunks: + if isinstance(chunk, str) and encode: + async with self.send_context(): + self.protocol.send_continuation(chunk.encode(), fin=False) + elif isinstance(chunk, BytesLike) and not encode: + async with self.send_context(): + self.protocol.send_continuation(chunk, fin=False) + else: + raise TypeError("async iterable must contain uniform types") + + # Final fragment. + async with self.send_context(): + self.protocol.send_continuation(b"", fin=True) + + except Exception: + # We're half-way through a fragmented message and we can't + # complete it. This makes the connection unusable. + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "error in fragmented message", + ) + raise + + finally: + self.send_in_progress.set() + self.send_in_progress = None + + else: + raise TypeError("data must be str, bytes, iterable, or async iterable") + + async def close(self, code: int = 1000, reason: str = "") -> None: + """ + Perform the closing handshake. + + :meth:`close` waits for the other end to complete the handshake and + for the TCP connection to terminate. + + :meth:`close` is idempotent: it doesn't do anything once the + connection is closed. + + Args: + code: WebSocket close code. + reason: WebSocket close reason. + + """ + try: + # The context manager takes care of waiting for the TCP connection + # to terminate after calling a method that sends a close frame. + async with self.send_context(): + if self.send_in_progress is not None: + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "close during fragmented message", + ) + else: + self.protocol.send_close(code, reason) + except ConnectionClosed: + # Ignore ConnectionClosed exceptions raised from send_context(). + # They mean that the connection is closed, which was the goal. + pass + + async def wait_closed(self) -> None: + """ + Wait until the connection is closed. + + :meth:`wait_closed` waits for the closing handshake to complete and for + the TCP connection to terminate. + + """ + await self.stream_closed.wait() + + async def ping( + self, data: Data | None = None, ack_on_close: bool = False + ) -> trio.Event: + """ + Send a Ping_. + + .. _Ping: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.2 + + A ping may serve as a keepalive or as a check that the remote endpoint + received all messages up to this point + + Args: + data: Payload of the ping. A :class:`str` will be encoded to UTF-8. + If ``data`` is :obj:`None`, the payload is four random bytes. + ack_on_close: when this option is :obj:`True`, the event will also + be set when the connection is closed. While this avoids getting + stuck waiting for a pong that will never arrive, it requires + checking that the state of the connection is still ``OPEN`` to + confirm that a pong was received, rather than the connection + being closed. + + Returns: + An event that will be set when the corresponding pong is received. + You can ignore it if you don't intend to wait. + + :: + + pong_received = await ws.ping() + # only if you want to wait for the corresponding pong + await pong_received.wait() + + Raises: + ConnectionClosed: When the connection is closed. + ConcurrencyError: If another ping was sent with the same data and + the corresponding pong wasn't received yet. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + elif data is not None: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + # Protect against duplicates if a payload is explicitly set. + if data in self.pending_pings: + raise ConcurrencyError("already waiting for a pong with the same data") + + # Generate a unique random payload otherwise. + while data is None or data in self.pending_pings: + data = struct.pack("!I", random.getrandbits(32)) + + pong_received = trio.Event() + self.pending_pings[data] = ( + pong_received, + trio.current_time(), + ack_on_close, + ) + self.protocol.send_ping(data) + return pong_received + + async def pong(self, data: Data = b"") -> None: + """ + Send a Pong_. + + .. _Pong: https://datatracker.ietf.org/doc/html/rfc6455#section-5.5.3 + + An unsolicited pong may serve as a unidirectional heartbeat. + + Args: + data: Payload of the pong. A :class:`str` will be encoded to UTF-8. + + Raises: + ConnectionClosed: When the connection is closed. + + """ + if isinstance(data, BytesLike): + data = bytes(data) + elif isinstance(data, str): + data = data.encode() + else: + raise TypeError("data must be str or bytes-like") + + async with self.send_context(): + self.protocol.send_pong(data) + + # Private methods + + def process_event(self, event: Event) -> None: + """ + Process one incoming event. + + This method is overridden in subclasses to handle the handshake. + + """ + assert isinstance(event, Frame) + if event.opcode in DATA_OPCODES: + self.recv_messages.put(event) + + if event.opcode is Opcode.PONG: + self.acknowledge_pings(bytes(event.data)) + + def acknowledge_pings(self, data: bytes) -> None: + """ + Acknowledge pings when receiving a pong. + + """ + # Ignore unsolicited pong. + if data not in self.pending_pings: + return + + pong_timestamp = trio.current_time() + + # Sending a pong for only the most recent ping is legal. + # Acknowledge all previous pings too in that case. + ping_id = None + ping_ids = [] + for ping_id, ( + pong_received, + ping_timestamp, + _ack_on_close, + ) in self.pending_pings.items(): + ping_ids.append(ping_id) + pong_received.set() + if ping_id == data: + self.latency = pong_timestamp - ping_timestamp + break + else: + raise AssertionError("solicited pong not found in pings") + + # Remove acknowledged pings from self.pending_pings. + for ping_id in ping_ids: + del self.pending_pings[ping_id] + + def acknowledge_pending_pings(self) -> None: + """ + Acknowledge pending pings when the connection is closed. + + """ + assert self.protocol.state is CLOSED + + for pong_received, _ping_timestamp, ack_on_close in self.pending_pings.values(): + if ack_on_close: + pong_received.set() + + self.pending_pings.clear() + + async def keepalive(self) -> None: + """ + Send a Ping frame and wait for a Pong frame at regular intervals. + + """ + assert self.ping_interval is not None + latency = 0.0 + try: + while True: + # If self.ping_timeout > latency > self.ping_interval, + # pings will be sent immediately after receiving pongs. + # The period will be longer than self.ping_interval. + with trio.move_on_after(self.ping_interval - latency): + await self.stream_closed.wait() + break + + try: + pong_received = await self.ping(ack_on_close=True) + except ConnectionClosed: + break + if self.debug: + self.logger.debug("% sent keepalive ping") + + if self.ping_timeout is not None: + with trio.move_on_after(self.ping_timeout) as cancel_scope: + await pong_received.wait() + self.logger.debug("% received keepalive pong") + if cancel_scope.cancelled_caught: + if self.debug: + self.logger.debug("- timed out waiting for keepalive pong") + async with self.send_context(): + self.protocol.fail( + CloseCode.INTERNAL_ERROR, + "keepalive ping timeout", + ) + break + except Exception: + self.logger.error("keepalive ping failed", exc_info=True) + + def start_keepalive(self) -> None: + """ + Run :meth:`keepalive` in a task, unless keepalive is disabled. + + """ + if self.ping_interval is not None: + self.nursery.start_soon(self.keepalive) + + async def recv_events(self) -> None: + """ + Read incoming data from the stream and process events. + + Run this method in a task as long as the connection is alive. + + ``recv_events()`` exits immediately when ``self.stream`` is closed. + + """ + try: + while True: + try: + data = await self.stream.receive_some() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while receiving data", + exc_info=True, + ) + # When the closing handshake is initiated by our side, + # recv() may block until send_context() closes the stream. + # In that case, send_context() already set recv_exc. + # Calling set_recv_exc() avoids overwriting it. + self.set_recv_exc(exc) + break + + if data == b"": + break + + # Feed incoming data to the protocol. + self.protocol.receive_data(data) + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # Write outgoing data to the socket. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug( + "! error while sending data", + exc_info=True, + ) + # Similarly to the above, avoid overriding an exception + # set by send_context(), in case of a race condition + # i.e. send_context() closes the transport after recv() + # returns above but before send_data() calls send(). + self.set_recv_exc(exc) + break + + if self.protocol.close_expected(): + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = ( + trio.current_time() + self.close_timeout + ) + + # If self.send_data raised an exception, then events are lost. + # Given that automatic responses write small amounts of data, + # this should be uncommon, so we don't handle the edge case. + + for event in events: + # This isn't expected to raise an exception. + self.process_event(event) + + # Breaking out of the while True: ... loop means that we believe + # that the socket doesn't work anymore. + # Feed the end of the data stream to the protocol. + self.protocol.receive_eof() + + # This isn't expected to raise an exception. + events = self.protocol.events_received() + + # There is no error handling because send_data() can only write + # the end of the data stream here and it handles errors itself. + await self.send_data() + + # This code path is triggered when receiving an HTTP response + # without a Content-Length header. This is the only case where + # reading until EOF generates an event; all other events have + # a known length. Ignore for coverage measurement because tests + # are in test_client.py rather than test_connection.py. + for event in events: # pragma: no cover + # This isn't expected to raise an exception. + self.process_event(event) + + except Exception as exc: + # This branch should never run. It's a safety net in case of bugs. + self.logger.error("unexpected internal error", exc_info=True) + self.set_recv_exc(exc) + finally: + # This isn't expected to raise an exception. + await self.close_stream() + + @contextlib.asynccontextmanager + async def send_context( + self, + *, + expected_state: State = OPEN, # CONNECTING during the opening handshake + ) -> AsyncIterator[None]: + """ + Create a context for writing to the connection from user code. + + On entry, :meth:`send_context` checks that the connection is open; on + exit, it writes outgoing data to the socket:: + + async with self.send_context(): + self.protocol.send_text(message.encode()) + + When the connection isn't open on entry, when the connection is expected + to close on exit, or when an unexpected error happens, terminating the + connection, :meth:`send_context` waits until the connection is closed + then raises :exc:`~websockets.exceptions.ConnectionClosed`. + + """ + # Should we wait until the connection is closed? + wait_for_close = False + # Should we close the transport and raise ConnectionClosed? + raise_close_exc = False + # What exception should we chain ConnectionClosed to? + original_exc: BaseException | None = None + + if self.protocol.state is expected_state: + # Let the caller interact with the protocol. + try: + yield + except (ProtocolError, ConcurrencyError): + # The protocol state wasn't changed. Exit immediately. + raise + except Exception as exc: + self.logger.error("unexpected internal error", exc_info=True) + # This branch should never run. It's a safety net in case of + # bugs. Since we don't know what happened, we will close the + # connection and raise the exception to the caller. + wait_for_close = False + raise_close_exc = True + original_exc = exc + else: + # Check if the connection is expected to close soon. + if self.protocol.close_expected(): + wait_for_close = True + # If the connection is expected to close soon, set the + # close deadline based on the close timeout. + # Since we tested earlier that protocol.state was OPEN + # (or CONNECTING), self.close_deadline is still None. + if self.close_timeout is not None: + assert self.close_deadline is None + self.close_deadline = trio.current_time() + self.close_timeout + # Write outgoing data to the socket and enforce flow control. + try: + await self.send_data() + except Exception as exc: + if self.debug: + self.logger.debug("! error while sending data", exc_info=True) + # While the only expected exception here is OSError, + # other exceptions would be treated identically. + wait_for_close = False + raise_close_exc = True + original_exc = exc + + else: # self.protocol.state is not expected_state + # Minor layering violation: we assume that the connection + # will be closing soon if it isn't in the expected state. + wait_for_close = True + # Calculate close_deadline if it wasn't set yet. + if self.close_timeout is not None: + if self.close_deadline is None: + self.close_deadline = trio.current_time() + self.close_timeout + raise_close_exc = True + + # If the connection is expected to close soon and the close timeout + # elapses, close the socket to terminate the connection. + if wait_for_close: + if self.close_deadline is not None: + with trio.move_on_at(self.close_deadline) as cancel_scope: + await self.stream_closed.wait() + if cancel_scope.cancelled_caught: + # There's no risk to overwrite another error because + # original_exc is never set when wait_for_close is True. + assert original_exc is None + original_exc = TimeoutError("timed out while closing connection") + # Set recv_exc before closing the transport in order to get + # proper exception reporting. + raise_close_exc = True + self.set_recv_exc(original_exc) + else: + await self.stream_closed.wait() + + # If an error occurred, close the transport to terminate the connection and + # raise an exception. + if raise_close_exc: + await self.close_stream() + raise self.protocol.close_exc from original_exc + + async def send_data(self) -> None: + """ + Send outgoing data. + + """ + for data in self.protocol.data_to_send(): + if data: + await self.stream.send_all(data) + else: + # Half-close the TCP connection when possible i.e. no TLS. + if isinstance(self.stream, trio.abc.HalfCloseableStream): + if self.debug: + self.logger.debug("x half-closing TCP connection") + try: + await self.stream.send_eof() + except Exception: # pragma: no cover + pass + # Else, close the TCP connection. + else: # pragma: no cover + if self.debug: + self.logger.debug("x closing TCP connection") + await self.stream.aclose() + + def set_recv_exc(self, exc: BaseException | None) -> None: + """ + Set recv_exc, if not set yet. + + """ + if self.recv_exc is None: + self.recv_exc = exc + + async def close_stream(self) -> None: + """ + Shutdown and close stream. Close message assembler. + + Calling close_stream() guarantees that recv_events() terminates. Indeed, + recv_events() may block only on stream.recv() or on recv_messages.put(). + + """ + # Close the stream. + await self.stream.aclose() + + # Calling protocol.receive_eof() is safe because it's idempotent. + # This guarantees that the protocol state becomes CLOSED. + self.protocol.receive_eof() + assert self.protocol.state is CLOSED + + # Abort recv() with a ConnectionClosed exception. + self.recv_messages.close() + + # Acknowledge pings sent with the ack_on_close option. + self.acknowledge_pending_pings() + + # Unblock coroutines waiting on self.stream_closed. + self.stream_closed.set() diff --git a/src/websockets/trio/messages.py b/src/websockets/trio/messages.py new file mode 100644 index 000000000..65f9759ef --- /dev/null +++ b/src/websockets/trio/messages.py @@ -0,0 +1,285 @@ +from __future__ import annotations + +import codecs +import math +from collections.abc import AsyncIterator +from typing import Any, Callable, Literal, TypeVar, overload + +import trio + +from ..exceptions import ConcurrencyError +from ..frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from ..typing import Data + + +__all__ = ["Assembler"] + +UTF8Decoder = codecs.getincrementaldecoder("utf-8") + +T = TypeVar("T") + + +class Assembler: + """ + Assemble messages from frames. + + :class:`Assembler` expects only data frames. The stream of frames must + respect the protocol; if it doesn't, the behavior is undefined. + + Args: + pause: Called when the buffer of frames goes above the high water mark; + should pause reading from the network. + resume: Called when the buffer of frames goes below the low water mark; + should resume reading from the network. + + """ + + def __init__( + self, + high: int | None = None, + low: int | None = None, + pause: Callable[[], Any] = lambda: None, + resume: Callable[[], Any] = lambda: None, + ) -> None: + # Queue of incoming frames. + self.send_frames: trio.MemorySendChannel[Frame] + self.recv_frames: trio.MemoryReceiveChannel[Frame] + self.send_frames, self.recv_frames = trio.open_memory_channel(math.inf) + + # We cannot put a hard limit on the size of the queue because a single + # call to Protocol.data_received() could produce thousands of frames, + # which must be buffered. Instead, we pause reading when the buffer goes + # above the high limit and we resume when it goes under the low limit. + if high is not None and low is None: + low = high // 4 + if high is None and low is not None: + high = low * 4 + if high is not None and low is not None: + if low < 0: + raise ValueError("low must be positive or equal to zero") + if high < low: + raise ValueError("high must be greater than or equal to low") + self.high, self.low = high, low + self.pause = pause + self.resume = resume + self.paused = False + + # This flag prevents concurrent calls to get() by user code. + self.get_in_progress = False + + # This flag marks the end of the connection. + self.closed = False + + @overload + async def get(self, decode: Literal[True]) -> str: ... + + @overload + async def get(self, decode: Literal[False]) -> bytes: ... + + @overload + async def get(self, decode: bool | None = None) -> Data: ... + + async def get(self, decode: bool | None = None) -> Data: + """ + Read the next message. + + :meth:`get` returns a single :class:`str` or :class:`bytes`. + + If the message is fragmented, :meth:`get` waits until the last frame is + received, then it reassembles the message and returns it. To receive + messages frame by frame, use :meth:`get_iter` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get() fetches a complete message or is canceled. + + try: + # First frame + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + frames = [frame] + + # Following frames, for fragmented messages + while not frame.fin: + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + # Put frames already received back into the queue + # so that future calls to get() can return them. + assert not self.send_frames._state.receive_tasks, ( + "no task should be waiting on receive()" + ) + assert not self.send_frames._state.data, "queue should be empty" + for frame in frames: + self.send_frames.send_nowait(frame) + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + frames.append(frame) + + finally: + self.get_in_progress = False + + data = b"".join(frame.data for frame in frames) + if decode: + return data.decode() + else: + return data + + @overload + def get_iter(self, decode: Literal[True]) -> AsyncIterator[str]: ... + + @overload + def get_iter(self, decode: Literal[False]) -> AsyncIterator[bytes]: ... + + @overload + def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: ... + + async def get_iter(self, decode: bool | None = None) -> AsyncIterator[Data]: + """ + Stream the next message. + + Iterating the return value of :meth:`get_iter` asynchronously yields a + :class:`str` or :class:`bytes` for each frame in the message. + + The iterator must be fully consumed before calling :meth:`get_iter` or + :meth:`get` again. Else, :exc:`ConcurrencyError` is raised. + + This method only makes sense for fragmented messages. If messages aren't + fragmented, use :meth:`get` instead. + + Args: + decode: :obj:`False` disables UTF-8 decoding of text frames and + returns :class:`bytes`. :obj:`True` forces UTF-8 decoding of + binary frames and returns :class:`str`. + + Raises: + EOFError: If the stream of frames has ended. + UnicodeDecodeError: If a text frame contains invalid UTF-8. + ConcurrencyError: If two coroutines run :meth:`get` or + :meth:`get_iter` concurrently. + + """ + if self.get_in_progress: + raise ConcurrencyError("get() or get_iter() is already running") + self.get_in_progress = True + + # Locking with get_in_progress prevents concurrent execution + # until get_iter() fetches a complete message or is canceled. + + # If get_iter() raises an exception e.g. in decoder.decode(), + # get_in_progress remains set and the connection becomes unusable. + + # First frame + try: + frame = await self.recv_frames.receive() + except trio.Cancelled: + self.get_in_progress = False + raise + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_TEXT or frame.opcode is OP_BINARY + if decode is None: + decode = frame.opcode is OP_TEXT + if decode: + decoder = UTF8Decoder() + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + # Following frames, for fragmented messages + while not frame.fin: + # We cannot handle trio.Cancelled because we don't buffer + # previous fragments — we're streaming them. Canceling get_iter() + # here will leave the assembler in a stuck state. Future calls to + # get() or get_iter() will raise ConcurrencyError. + try: + frame = await self.recv_frames.receive() + except trio.EndOfChannel: + raise EOFError("stream of frames ended") + self.maybe_resume() + assert frame.opcode is OP_CONT + if decode: + yield decoder.decode(frame.data, frame.fin) + else: + yield frame.data + + self.get_in_progress = False + + def put(self, frame: Frame) -> None: + """ + Add ``frame`` to the next message. + + Raises: + EOFError: If the stream of frames has ended. + + """ + if self.closed: + raise EOFError("stream of frames ended") + + self.send_frames.send_nowait(frame) + self.maybe_pause() + + def maybe_pause(self) -> None: + """Pause the writer if queue is above the high water mark.""" + # Skip if flow control is disabled + if self.high is None: + return + + # Bypass the statistics() method for performance reasons. + # Check for "> high" to support high = 0 + if len(self.send_frames._state.data) > self.high and not self.paused: + self.paused = True + self.pause() + + def maybe_resume(self) -> None: + """Resume the writer if queue is below the low water mark.""" + # Skip if flow control is disabled + if self.low is None: + return + + # Bypass the statistics() method for performance reasons. + # Check for "<= low" to support low = 0 + if len(self.send_frames._state.data) <= self.low and self.paused: + self.paused = False + self.resume() + + def close(self) -> None: + """ + End the stream of frames. + + Calling :meth:`close` concurrently with :meth:`get`, :meth:`get_iter`, + or :meth:`put` is safe. They will raise :exc:`trio.EndOfChannel`. + + """ + if self.closed: + return + + self.closed = True + + # Unblock get() or get_iter(). + self.send_frames.close() diff --git a/tests/asyncio/connection.py b/tests/asyncio/connection.py index ad1c121bf..854b9bb99 100644 --- a/tests/asyncio/connection.py +++ b/tests/asyncio/connection.py @@ -21,7 +21,7 @@ def delay_frames_sent(self, delay): """ Add a delay before sending frames. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write is None @@ -36,7 +36,7 @@ def delay_eof_sent(self, delay): """ Add a delay before sending EOF. - This can result in out-of-order writes, which is unrealistic. + Misuse can result in out-of-order writes, which is unrealistic. """ assert self.transport.delay_write_eof is None @@ -83,9 +83,9 @@ class InterceptingTransport: This is coupled to the implementation, which relies on these two methods. - Since ``write()`` and ``write_eof()`` are not coroutines, this effect is - achieved by scheduling writes at a later time, after the methods return. - This can easily result in out-of-order writes, which is unrealistic. + Since ``write()`` and ``write_eof()`` are synchronous, we can only schedule + writes at a later time, after they return. This is unrealistic and can lead + to out-of-order writes if tests aren't written carefully. """ @@ -101,15 +101,13 @@ def __getattr__(self, name): return getattr(self.transport, name) def write(self, data): - if not self.drop_write: - if self.delay_write is not None: - self.loop.call_later(self.delay_write, self.transport.write, data) - else: - self.transport.write(data) + if self.delay_write is not None: + self.loop.call_later(self.delay_write, self.transport.write, data) + elif not self.drop_write: + self.transport.write(data) def write_eof(self): - if not self.drop_write_eof: - if self.delay_write_eof is not None: - self.loop.call_later(self.delay_write_eof, self.transport.write_eof) - else: - self.transport.write_eof() + if self.delay_write_eof is not None: + self.loop.call_later(self.delay_write_eof, self.transport.write_eof) + elif not self.drop_write_eof: + self.transport.write_eof() diff --git a/tests/asyncio/test_connection.py b/tests/asyncio/test_connection.py index 668f55cbd..29450a043 100644 --- a/tests/asyncio/test_connection.py +++ b/tests/asyncio/test_connection.py @@ -33,13 +33,13 @@ class ClientConnectionTests(AssertNoLogsMixin, unittest.IsolatedAsyncioTestCase) REMOTE = SERVER async def asyncSetUp(self): - loop = asyncio.get_running_loop() + self.loop = asyncio.get_running_loop() socket_, remote_socket = socket.socketpair() - self.transport, self.connection = await loop.create_connection( + self.transport, self.connection = await self.loop.create_connection( lambda: Connection(Protocol(self.LOCAL), close_timeout=2 * MS), sock=socket_, ) - self.remote_transport, self.remote_connection = await loop.create_connection( + _remote_transport, self.remote_connection = await self.loop.create_connection( lambda: InterceptingConnection(RecordingProtocol(self.REMOTE)), sock=remote_socket, ) @@ -125,41 +125,41 @@ async def test_exit_with_exception(self): async def test_aiter_text(self): """__aiter__ yields text messages.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + self.assertEqual(await anext(iterator), "😀") await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + self.assertEqual(await anext(iterator), "😀") async def test_aiter_binary(self): """__aiter__ yields binary messages.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_mixed(self): """__aiter__ yields a mix of text and binary messages.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.send("😀") - self.assertEqual(await anext(aiterator), "😀") + self.assertEqual(await anext(iterator), "😀") await self.remote_connection.send(b"\x01\x02\xfe\xff") - self.assertEqual(await anext(aiterator), b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") async def test_aiter_connection_closed_ok(self): """__aiter__ terminates after a normal closure.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.close() with self.assertRaises(StopAsyncIteration): - await anext(aiterator) + await anext(iterator) async def test_aiter_connection_closed_error(self): """__aiter__ raises ConnectionClosedError after an error.""" - aiterator = aiter(self.connection) + iterator = aiter(self.connection) await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) with self.assertRaises(ConnectionClosedError): - await anext(aiterator) + await anext(iterator) # Test recv. @@ -245,7 +245,7 @@ async def test_recv_during_recv_streaming(self): ) async def test_recv_cancellation_before_receiving(self): - """recv can be canceled before receiving a frame.""" + """recv can be canceled before receiving a message.""" recv_task = asyncio.create_task(self.connection.recv()) await asyncio.sleep(0) # let the event loop start recv_task @@ -257,11 +257,8 @@ async def test_recv_cancellation_before_receiving(self): self.assertEqual(await self.connection.recv(), "😀") async def test_recv_cancellation_while_receiving(self): - """recv cannot be canceled after receiving a frame.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(0) # let the event loop start recv_task - - gate = asyncio.get_running_loop().create_future() + """recv can be canceled while receiving a fragmented message.""" + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -269,13 +266,16 @@ async def fragments(): yield "⌛️" asyncio.create_task(self.remote_connection.send(fragments())) - await asyncio.sleep(MS) + + recv_task = asyncio.create_task(self.connection.recv()) + await asyncio.sleep(0) # let the event loop start recv_task recv_task.cancel() await asyncio.sleep(0) # let the event loop cancel recv_task - # Running recv again receives the complete message. gate.set_result(None) + + # Running recv again receives the complete message. self.assertEqual(await self.connection.recv(), "⏳⌛️") # Test recv_streaming. @@ -360,8 +360,7 @@ async def test_recv_streaming_during_recv(self): self.addCleanup(recv_task.cancel) with self.assertRaises(ConcurrencyError) as raised: - async for _ in self.connection.recv_streaming(): - self.fail("did not raise") + await alist(self.connection.recv_streaming()) self.assertEqual( str(raised.exception), "cannot call recv_streaming while another coroutine " @@ -377,8 +376,7 @@ async def test_recv_streaming_during_recv_streaming(self): self.addCleanup(recv_streaming_task.cancel) with self.assertRaises(ConcurrencyError) as raised: - async for _ in self.connection.recv_streaming(): - self.fail("did not raise") + await alist(self.connection.recv_streaming()) self.assertEqual( str(raised.exception), r"cannot call recv_streaming while another coroutine " @@ -409,7 +407,7 @@ async def test_recv_streaming_cancellation_while_receiving(self): ) await asyncio.sleep(0) # let the event loop start recv_streaming_task - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -423,6 +421,7 @@ async def fragments(): await asyncio.sleep(0) # let the event loop cancel recv_streaming_task gate.set_result(None) + # Running recv_streaming again fails. with self.assertRaises(ConcurrencyError): await alist(self.connection.recv_streaming()) @@ -555,7 +554,7 @@ async def test_send_connection_closed_error(self): async def test_send_while_send_blocked(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an Iterable. self.connection.pause_writing() asyncio.create_task(self.connection.send(["⏳", "⌛️"])) @@ -580,7 +579,7 @@ async def test_send_while_send_blocked(self): async def test_send_while_send_async_blocked(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. self.connection.pause_writing() @@ -610,9 +609,9 @@ async def fragments(): async def test_send_during_send_async(self): """send waits for a previous call to send to complete.""" - # This test fails if the guard with fragmented_send_waiter is removed + # This test fails if the guard with send_in_progress is removed # from send() in the case when message is an AsyncIterable. - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -709,8 +708,14 @@ async def test_close_explicit_code_reason(self): async def test_close_waits_for_close_frame(self): """close waits for a close frame (then EOF) before returning.""" + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -724,8 +729,14 @@ async def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -738,8 +749,14 @@ async def test_close_no_timeout_waits_for_close_frame(self): """close without timeout waits for a close frame (then EOF) before returning.""" self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_frames_rcvd(MS), self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -755,8 +772,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): self.connection.close_timeout = None + t0 = self.loop.time() async with self.delay_eof_rcvd(MS): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -767,8 +790,14 @@ async def test_close_no_timeout_waits_for_connection_closed(self): async def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = self.loop.time() async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() @@ -782,8 +811,14 @@ async def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = self.loop.time() async with self.drop_eof_rcvd(): await self.connection.close() + t1 = self.loop.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: await self.connection.recv() @@ -799,13 +834,9 @@ async def test_close_preserves_queued_messages(self): await self.connection.close() self.assertEqual(await self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): await self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - async def test_close_idempotency(self): """close does nothing if the connection is already closed.""" await self.connection.close() @@ -816,11 +847,15 @@ async def test_close_idempotency(self): async def test_close_during_recv(self): """close aborts recv when called concurrently with recv.""" - recv_task = asyncio.create_task(self.connection.recv()) - await asyncio.sleep(MS) - await self.connection.close() + + async def closer(): + await asyncio.sleep(MS) + await self.connection.close() + + asyncio.create_task(closer()) + with self.assertRaises(ConnectionClosedOK) as raised: - await recv_task + await self.connection.recv() exc = raised.exception self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") @@ -828,23 +863,24 @@ async def test_close_during_recv(self): async def test_close_during_send(self): """close fails the connection when called concurrently with send.""" - gate = asyncio.get_running_loop().create_future() + close_gate = self.loop.create_future() + exit_gate = self.loop.create_future() + + async def closer(): + await close_gate + await self.connection.close() + exit_gate.set_result(None) async def fragments(): yield "⏳" - await gate + close_gate.set_result(None) + await exit_gate yield "⌛️" - send_task = asyncio.create_task(self.connection.send(fragments())) - await asyncio.sleep(MS) - - asyncio.create_task(self.connection.close()) - await asyncio.sleep(MS) - - gate.set_result(None) + asyncio.create_task(closer()) with self.assertRaises(ConnectionClosedError) as raised: - await send_task + await self.connection.send(fragments()) exc = raised.exception self.assertEqual( @@ -886,54 +922,54 @@ async def test_ping_explicit_binary(self): async def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("this") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_canceled_ping(self): """ping is acknowledged by a pong with the same payload after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received.cancel() await self.remote_connection.pong("this") with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.remote_connection.pong("that") with self.assertRaises(TimeoutError): async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for a later ping.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") + pong_received = await self.connection.ping("this") await self.connection.ping("that") await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter + await pong_received async def test_acknowledge_previous_canceled_ping(self): """ping is acknowledged by a pong for a later ping after being canceled.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("this") - pong_waiter_2 = await self.connection.ping("that") - pong_waiter.cancel() + pong_received = await self.connection.ping("this") + pong_received_2 = await self.connection.ping("that") + pong_received.cancel() await self.remote_connection.pong("that") async with asyncio_timeout(MS): - await pong_waiter_2 + await pong_received_2 with self.assertRaises(asyncio.CancelledError): - await pong_waiter + await pong_received async def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" async with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = await self.connection.ping("idem") + pong_received = await self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: await self.connection.ping("idem") @@ -944,7 +980,7 @@ async def test_ping_duplicate_payload(self): await self.remote_connection.pong("idem") async with asyncio_timeout(MS): - await pong_waiter + await pong_received await self.connection.ping("idem") # doesn't raise an exception @@ -1034,6 +1070,7 @@ async def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertFalse(self.connection.keepalive_task.done()) await asyncio.sleep(MS) await self.connection.close() self.assertTrue(self.connection.keepalive_task.done()) @@ -1062,9 +1099,9 @@ async def test_keepalive_reports_errors(self): await asyncio.sleep(2 * MS) # Exiting the context manager sleeps for 1 ms. # 3 ms: inject a fault: raise an exception in the pending pong waiter. - pong_waiter = next(iter(self.connection.pong_waiters.values()))[0] + pong_received = next(iter(self.connection.pending_pings.values()))[0] with self.assertLogs("websockets", logging.ERROR) as logs: - pong_waiter.set_exception(Exception("BOOM")) + pong_received.set_exception(Exception("BOOM")) await asyncio.sleep(0) self.assertEqual( [record.getMessage() for record in logs.records], @@ -1079,20 +1116,28 @@ async def test_keepalive_reports_errors(self): async def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - connection = Connection(Protocol(self.LOCAL), close_timeout=42 * MS) + connection = Connection( + Protocol(self.LOCAL), + close_timeout=42 * MS, + ) self.assertEqual(connection.close_timeout, 42 * MS) async def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=4) - transport = Mock() - connection.connection_made(transport) + connection = Connection( + Protocol(self.LOCAL), + max_queue=4, + ) + connection.connection_made(Mock(spec=asyncio.Transport)) self.assertEqual(connection.recv_messages.high, 4) async def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - connection = Connection(Protocol(self.LOCAL), max_queue=None) - transport = Mock() + connection = Connection( + Protocol(self.LOCAL), + max_queue=None, + ) + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, None) self.assertEqual(connection.recv_messages.low, None) @@ -1103,7 +1148,7 @@ async def test_max_queue_tuple(self): Protocol(self.LOCAL), max_queue=(4, 2), ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) self.assertEqual(connection.recv_messages.high, 4) self.assertEqual(connection.recv_messages.low, 2) @@ -1114,7 +1159,7 @@ async def test_write_limit(self): Protocol(self.LOCAL), write_limit=4096, ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, None) @@ -1124,7 +1169,7 @@ async def test_write_limits(self): Protocol(self.LOCAL), write_limit=(4096, 2048), ) - transport = Mock() + transport = Mock(spec=asyncio.Transport) connection.connection_made(transport) transport.set_write_buffer_limits.assert_called_once_with(4096, 2048) @@ -1138,13 +1183,13 @@ async def test_logger(self): """Connection has a logger attribute.""" self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) - @patch("asyncio.BaseTransport.get_extra_info", return_value=("sock", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("sock", 1234)) async def test_local_address(self, get_extra_info): """Connection provides a local_address attribute.""" self.assertEqual(self.connection.local_address, ("sock", 1234)) get_extra_info.assert_called_with("sockname") - @patch("asyncio.BaseTransport.get_extra_info", return_value=("peer", 1234)) + @patch("asyncio.Transport.get_extra_info", return_value=("peer", 1234)) async def test_remote_address(self, get_extra_info): """Connection provides a remote_address attribute.""" self.assertEqual(self.connection.remote_address, ("peer", 1234)) @@ -1181,27 +1226,27 @@ async def test_writing_in_data_received_fails(self): # Inject a fault by shutting down the transport for writing — but not by # closing it because that would terminate the connection. self.transport.write_eof() + # Receive a ping. Responding with a pong will fail. await self.remote_connection.ping() # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + + self.assertIsInstance(raised.exception.__cause__, RuntimeError) async def test_writing_in_send_context_fails(self): """Error when sending outgoing frame is correctly reported.""" # Inject a fault by shutting down the transport for writing — but not by # closing it because that would terminate the connection. self.transport.write_eof() + # Sending a pong will fail. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: await self.connection.pong() - cause = raised.exception.__cause__ - self.assertEqual(str(cause), "Cannot call write() after write_eof()") - self.assertIsInstance(cause, RuntimeError) + + self.assertIsInstance(raised.exception.__cause__, RuntimeError) # Test safety nets — catching all exceptions in case of bugs. @@ -1216,9 +1261,7 @@ async def test_unexpected_failure_in_data_received(self, events_received): with self.assertRaises(ConnectionClosedError) as raised: await self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Inject a fault in a random call in send_context(). # This test is tightly coupled to the implementation. @@ -1230,9 +1273,7 @@ async def test_unexpected_failure_in_send_context(self, send_text): with self.assertRaises(ConnectionClosedError) as raised: await self.connection.send("😀") - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Test broadcast. @@ -1303,7 +1344,7 @@ async def test_broadcast_skips_closing_connection(self): async def test_broadcast_skips_connection_with_send_blocked(self): """broadcast logs a warning when a connection is blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" @@ -1330,7 +1371,7 @@ async def fragments(): ) async def test_broadcast_reports_connection_with_send_blocked(self): """broadcast raises exceptions for connections blocked in send.""" - gate = asyncio.get_running_loop().create_future() + gate = self.loop.create_future() async def fragments(): yield "⏳" diff --git a/tests/asyncio/test_messages.py b/tests/asyncio/test_messages.py index a90788d02..340aa00a8 100644 --- a/tests/asyncio/test_messages.py +++ b/tests/asyncio/test_messages.py @@ -267,6 +267,7 @@ async def test_get_iter_fragmented_text_message_not_received_yet(self): self.assertEqual(await anext(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(await anext(iterator), "é") + await iterator.aclose() async def test_get_iter_fragmented_binary_message_not_received_yet(self): """get_iter yields a fragmented binary message when it is received.""" @@ -277,6 +278,7 @@ async def test_get_iter_fragmented_binary_message_not_received_yet(self): self.assertEqual(await anext(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() async def test_get_iter_fragmented_text_message_being_received(self): """get_iter yields a fragmented text message that is partially received.""" @@ -287,6 +289,7 @@ async def test_get_iter_fragmented_text_message_being_received(self): self.assertEqual(await anext(iterator), "f") self.assembler.put(Frame(OP_CONT, b"\xa9")) self.assertEqual(await anext(iterator), "é") + await iterator.aclose() async def test_get_iter_fragmented_binary_message_being_received(self): """get_iter yields a fragmented binary message that is partially received.""" @@ -297,6 +300,7 @@ async def test_get_iter_fragmented_binary_message_being_received(self): self.assertEqual(await anext(iterator), b"e") self.assembler.put(Frame(OP_CONT, b"a")) self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() async def test_get_iter_encoded_text_message(self): """get_iter yields a text message without UTF-8 decoding.""" @@ -334,6 +338,8 @@ async def test_get_iter_resumes_reading(self): await anext(iterator) self.resume.assert_called_once_with() + await iterator.aclose() + async def test_get_iter_does_not_resume_reading(self): """get_iter does not resume reading when the low-water mark is unset.""" self.assembler.low = None @@ -345,6 +351,7 @@ async def test_get_iter_does_not_resume_reading(self): await anext(iterator) await anext(iterator) await anext(iterator) + await iterator.aclose() self.resume.assert_not_called() @@ -467,7 +474,7 @@ async def test_get_iter_queued_fragmented_message_after_close(self): self.assertEqual(fragments, [b"t", b"e", b"a"]) async def test_get_partially_queued_fragmented_message_after_close(self): - """get raises EOF on a partial fragmented message after close is called.""" + """get raises EOFError on a partial fragmented message after close is called.""" self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) self.assembler.put(Frame(OP_CONT, b"e", fin=False)) self.assembler.close() diff --git a/tests/sync/test_connection.py b/tests/sync/test_connection.py index a5aee35bb..157aa2056 100644 --- a/tests/sync/test_connection.py +++ b/tests/sync/test_connection.py @@ -6,7 +6,7 @@ import time import unittest import uuid -from unittest.mock import patch +from unittest.mock import Mock, patch from websockets.exceptions import ( ConcurrencyError, @@ -489,8 +489,14 @@ def test_close_explicit_code_reason(self): def test_close_waits_for_close_frame(self): """close waits for a close frame (then EOF) before returning.""" + t0 = time.time() with self.delay_frames_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -504,8 +510,14 @@ def test_close_waits_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.delay_eof_rcvd(MS): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -516,8 +528,14 @@ def test_close_waits_for_connection_closed(self): def test_close_timeout_waiting_for_close_frame(self): """close times out if no close frame is received.""" + t0 = time.time() with self.drop_frames_rcvd(), self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() @@ -531,8 +549,14 @@ def test_close_timeout_waiting_for_connection_closed(self): if self.LOCAL is SERVER: self.skipTest("only relevant on the client-side") + t0 = time.time() with self.drop_eof_rcvd(): self.connection.close() + t1 = time.time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) with self.assertRaises(ConnectionClosedOK) as raised: self.connection.recv() @@ -548,13 +572,9 @@ def test_close_preserves_queued_messages(self): self.connection.close() self.assertEqual(self.connection.recv(), "😀") - with self.assertRaises(ConnectionClosedOK) as raised: + with self.assertRaises(ConnectionClosedOK): self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") - self.assertIsNone(exc.__cause__) - def test_close_idempotency(self): """close does nothing if the connection is already closed.""" self.connection.close() @@ -622,10 +642,10 @@ def closer(): exit_gate.set() def fragments(): - yield "😀" + yield "⏳" close_gate.set() exit_gate.wait() - yield "😀" + yield "⌛️" close_thread = threading.Thread(target=closer) close_thread.start() @@ -665,38 +685,38 @@ def test_ping_explicit_binary(self): def test_acknowledge_ping(self): """ping is acknowledged by a pong with the same payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("this") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_non_matching_pong(self): """ping isn't acknowledged by a pong with a different payload.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.remote_connection.pong("that") - self.assertFalse(pong_waiter.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_acknowledge_previous_ping(self): """ping is acknowledged by a pong for as a later ping.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("this") + pong_received = self.connection.ping("this") self.connection.ping("that") self.remote_connection.pong("that") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) def test_acknowledge_ping_on_close(self): """ping with ack_on_close is acknowledged when the connection is closed.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter_ack_on_close = self.connection.ping("this", ack_on_close=True) - pong_waiter = self.connection.ping("that") + pong_received_ack_on_close = self.connection.ping("this", ack_on_close=True) + pong_received = self.connection.ping("that") self.connection.close() - self.assertTrue(pong_waiter_ack_on_close.wait(MS)) - self.assertFalse(pong_waiter.wait(MS)) + self.assertTrue(pong_received_ack_on_close.wait(MS)) + self.assertFalse(pong_received.wait(MS)) def test_ping_duplicate_payload(self): """ping rejects the same payload until receiving the pong.""" with self.drop_frames_rcvd(): # drop automatic response to ping - pong_waiter = self.connection.ping("idem") + pong_received = self.connection.ping("idem") with self.assertRaises(ConcurrencyError) as raised: self.connection.ping("idem") @@ -706,7 +726,7 @@ def test_ping_duplicate_payload(self): ) self.remote_connection.pong("idem") - self.assertTrue(pong_waiter.wait(MS)) + self.assertTrue(pong_received.wait(MS)) self.connection.ping("idem") # doesn't raise an exception @@ -742,7 +762,7 @@ def test_pong_unsupported_type(self): @patch("random.getrandbits", return_value=1918987876) def test_keepalive(self, getrandbits): """keepalive sends pings at ping_interval and measures latency.""" - self.connection.ping_interval = 4 * MS + self.connection.ping_interval = 3 * MS self.connection.start_keepalive() self.assertIsNotNone(self.connection.keepalive_thread) self.assertEqual(self.connection.latency, 0) @@ -796,6 +816,7 @@ def test_keepalive_terminates_while_sleeping(self): """keepalive task terminates while waiting to send a ping.""" self.connection.ping_interval = 3 * MS self.connection.start_keepalive() + self.assertTrue(self.connection.keepalive_thread.is_alive()) time.sleep(MS) self.connection.close() self.connection.keepalive_thread.join(MS) @@ -803,8 +824,9 @@ def test_keepalive_terminates_while_sleeping(self): def test_keepalive_terminates_when_sending_ping_fails(self): """keepalive task terminates when sending a ping fails.""" - self.connection.ping_interval = 1 * MS + self.connection.ping_interval = MS self.connection.start_keepalive() + self.assertTrue(self.connection.keepalive_thread.is_alive()) with self.drop_eof_rcvd(), self.drop_frames_rcvd(): self.connection.close() self.assertFalse(self.connection.keepalive_thread.is_alive()) @@ -827,14 +849,13 @@ def test_keepalive_terminates_while_waiting_for_pong(self): def test_keepalive_reports_errors(self): """keepalive reports unexpected errors in logs.""" self.connection.ping_interval = 2 * MS - with self.drop_frames_rcvd(): - self.connection.start_keepalive() - # 2 ms: keepalive() sends a ping frame. - # 2.x ms: a pong frame is dropped. - with self.assertLogs("websockets", logging.ERROR) as logs: - with patch("threading.Event.wait", side_effect=Exception("BOOM")): - time.sleep(3 * MS) - # Exiting the context manager sleeps for 1 ms. + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("threading.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + time.sleep(3 * MS) self.assertEqual( [record.getMessage() for record in logs.records], ["keepalive ping failed"], @@ -848,11 +869,8 @@ def test_keepalive_reports_errors(self): def test_close_timeout(self): """close_timeout parameter configures close timeout.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), close_timeout=42 * MS, ) @@ -860,11 +878,8 @@ def test_close_timeout(self): def test_max_queue(self): """max_queue configures high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=4, ) @@ -872,11 +887,8 @@ def test_max_queue(self): def test_max_queue_none(self): """max_queue disables high-water mark of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=None, ) @@ -885,11 +897,8 @@ def test_max_queue_none(self): def test_max_queue_tuple(self): """max_queue configures high-water and low-water marks of frames buffer.""" - socket_, remote_socket = socket.socketpair() - self.addCleanup(socket_.close) - self.addCleanup(remote_socket.close) connection = Connection( - socket_, + Mock(spec=socket.socket), Protocol(self.LOCAL), max_queue=(4, 2), ) @@ -960,11 +969,13 @@ def test_writing_in_recv_events_fails(self): # Inject a fault by shutting down the socket for writing — but not by # closing it because that would terminate the connection. self.connection.socket.shutdown(socket.SHUT_WR) + # Receive a ping. Responding with a pong will fail. self.remote_connection.ping() # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) def test_writing_in_send_context_fails(self): @@ -972,10 +983,12 @@ def test_writing_in_send_context_fails(self): # Inject a fault by shutting down the socket for writing — but not by # closing it because that would terminate the connection. self.connection.socket.shutdown(socket.SHUT_WR) + # Sending a pong will fail. # The connection closed exception reports the injected fault. with self.assertRaises(ConnectionClosedError) as raised: self.connection.pong() + self.assertIsInstance(raised.exception.__cause__, BrokenPipeError) # Test safety nets — catching all exceptions in case of bugs. @@ -991,9 +1004,7 @@ def test_unexpected_failure_in_recv_events(self, events_received): with self.assertRaises(ConnectionClosedError) as raised: self.connection.recv() - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) # Inject a fault in a random call in send_context(). # This test is tightly coupled to the implementation. @@ -1005,9 +1016,7 @@ def test_unexpected_failure_in_send_context(self, send_text): with self.assertRaises(ConnectionClosedError) as raised: self.connection.send("😀") - exc = raised.exception - self.assertEqual(str(exc), "no close frame received or sent") - self.assertIsInstance(exc.__cause__, AssertionError) + self.assertIsInstance(raised.exception.__cause__, AssertionError) class ServerConnectionTests(ClientConnectionTests): diff --git a/tests/trio/__init__.py b/tests/trio/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/trio/connection.py b/tests/trio/connection.py new file mode 100644 index 000000000..2a7f2aa07 --- /dev/null +++ b/tests/trio/connection.py @@ -0,0 +1,116 @@ +import contextlib + +import trio + +from websockets.trio.connection import Connection + + +class InterceptingConnection(Connection): + """ + Connection subclass that can intercept outgoing packets. + + By interfacing with this connection, we simulate network conditions + affecting what the component being tested receives during a test. + + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stream = InterceptingStream(self.stream) + + @contextlib.contextmanager + def delay_frames_sent(self, delay): + """ + Add a delay before sending frames. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_all is None + self.stream.delay_send_all = delay + try: + yield + finally: + self.stream.delay_send_all = None + + @contextlib.contextmanager + def delay_eof_sent(self, delay): + """ + Add a delay before sending EOF. + + Delays cumulate: they're added before every frame or before EOF. + + """ + assert self.stream.delay_send_eof is None + self.stream.delay_send_eof = delay + try: + yield + finally: + self.stream.delay_send_eof = None + + @contextlib.contextmanager + def drop_frames_sent(self): + """ + Prevent frames from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_all + self.stream.drop_send_all = True + try: + yield + finally: + self.stream.drop_send_all = False + + @contextlib.contextmanager + def drop_eof_sent(self): + """ + Prevent EOF from being sent. + + Since TCP is reliable, sending frames or EOF afterwards is unrealistic. + + """ + assert not self.stream.drop_send_eof + self.stream.drop_send_eof = True + try: + yield + finally: + self.stream.drop_send_eof = False + + +class InterceptingStream: + """ + Stream wrapper that intercepts calls to ``send_all()`` and ``send_eof()``. + + This is coupled to the implementation, which relies on these two methods. + + """ + + # We cannot delay EOF with trio's virtual streams because close_hook is + # synchronous. Adopt the same approach as in the other implementations. + + def __init__(self, stream): + self.stream = stream + self.delay_send_all = None + self.delay_send_eof = None + self.drop_send_all = False + self.drop_send_eof = False + + def __getattr__(self, name): + return getattr(self.stream, name) + + async def send_all(self, data): + if self.delay_send_all is not None: + await trio.sleep(self.delay_send_all) + if not self.drop_send_all: + await self.stream.send_all(data) + + async def send_eof(self): + if self.delay_send_eof is not None: + await trio.sleep(self.delay_send_eof) + if not self.drop_send_eof: + await self.stream.send_eof() + + +trio.abc.HalfCloseableStream.register(InterceptingStream) diff --git a/tests/trio/test_connection.py b/tests/trio/test_connection.py new file mode 100644 index 000000000..8c613c700 --- /dev/null +++ b/tests/trio/test_connection.py @@ -0,0 +1,1253 @@ +import contextlib +import logging +import uuid +from unittest.mock import patch + +import trio.testing + +from websockets.asyncio.compatibility import TimeoutError, aiter, anext +from websockets.exceptions import ( + ConcurrencyError, + ConnectionClosedError, + ConnectionClosedOK, +) +from websockets.frames import CloseCode, Frame, Opcode +from websockets.protocol import CLIENT, SERVER, Protocol, State +from websockets.trio.connection import * + +from ..asyncio.utils import alist +from ..protocol import RecordingProtocol +from ..utils import MS, AssertNoLogsMixin +from .connection import InterceptingConnection +from .utils import IsolatedTrioTestCase + + +# Connection implements symmetrical behavior between clients and servers. +# All tests run on the client side and the server side to validate this. + + +class ClientConnectionTests(AssertNoLogsMixin, IsolatedTrioTestCase): + LOCAL = CLIENT + REMOTE = SERVER + + async def asyncSetUp(self): + stream, remote_stream = trio.testing.memory_stream_pair() + protocol = Protocol(self.LOCAL) + remote_protocol = RecordingProtocol(self.REMOTE) + self.connection = Connection( + self.nursery, stream, protocol, close_timeout=2 * MS + ) + self.remote_connection = InterceptingConnection( + self.nursery, remote_stream, remote_protocol + ) + + async def asyncTearDown(self): + await self.remote_connection.close() + await self.connection.close() + + # Test helpers built upon RecordingProtocol and InterceptingConnection. + + async def assertFrameSent(self, frame): + """Check that a single frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), [frame]) + + async def assertFramesSent(self, frames): + """Check that several frames were sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), frames) + + async def assertNoFrameSent(self): + """Check that no frame was sent.""" + await trio.testing.wait_all_tasks_blocked() + self.assertEqual(self.remote_connection.protocol.get_frames_rcvd(), []) + + @contextlib.asynccontextmanager + async def delay_frames_rcvd(self, delay): + """Delay frames before they're received by the connection.""" + with self.remote_connection.delay_frames_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def delay_eof_rcvd(self, delay): + """Delay EOF before it's received by the connection.""" + with self.remote_connection.delay_eof_sent(delay): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_frames_rcvd(self): + """Drop frames before they're received by the connection.""" + with self.remote_connection.drop_frames_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + @contextlib.asynccontextmanager + async def drop_eof_rcvd(self): + """Drop EOF before it's received by the connection.""" + with self.remote_connection.drop_eof_sent(): + yield + await trio.testing.wait_all_tasks_blocked() + + # Test __aenter__ and __aexit__. + + async def test_aenter(self): + """__aenter__ returns the connection itself.""" + async with self.connection as connection: + self.assertIs(connection, self.connection) + + async def test_aexit(self): + """__aexit__ closes the connection with code 1000.""" + async with self.connection: + await self.assertNoFrameSent() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_exit_with_exception(self): + """__exit__ with an exception closes the connection with code 1011.""" + with self.assertRaises(RuntimeError): + async with self.connection: + raise RuntimeError + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xf3")) + + # Test __aiter__. + + async def test_aiter_text(self): + """__aiter__ yields text messages.""" + iterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await iterator.aclose() + + async def test_aiter_binary(self): + """__aiter__ yields binary messages.""" + iterator = aiter(self.connection) + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await iterator.aclose() + + async def test_aiter_mixed(self): + """__aiter__ yields a mix of text and binary messages.""" + iterator = aiter(self.connection) + await self.remote_connection.send("😀") + self.assertEqual(await anext(iterator), "😀") + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await anext(iterator), b"\x01\x02\xfe\xff") + await iterator.aclose() + + async def test_aiter_connection_closed_ok(self): + """__aiter__ terminates after a normal closure.""" + iterator = aiter(self.connection) + await self.remote_connection.close() + with self.assertRaises(StopAsyncIteration): + await anext(iterator) + await iterator.aclose() + + async def test_aiter_connection_closed_error(self): + """__aiter__ raises ConnectionClosedError after an error.""" + iterator = aiter(self.connection) + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await anext(iterator) + await iterator.aclose() + + # Test recv. + + async def test_recv_text(self): + """recv receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_binary(self): + """recv receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_text_as_bytes(self): + """recv receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(decode=False), "😀".encode()) + + async def test_recv_binary_as_text(self): + """recv receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual(await self.connection.recv(decode=True), "😀") + + async def test_recv_fragmented_text(self): + """recv receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual(await self.connection.recv(), "😀😀") + + async def test_recv_fragmented_binary(self): + """recv receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + self.assertEqual(await self.connection.recv(), b"\x01\x02\xfe\xff") + + async def test_recv_connection_closed_ok(self): + """recv raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_recv_connection_closed_error(self): + """recv raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + + async def test_recv_non_utf8_text(self): + """recv receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await self.connection.recv() + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + + async def test_recv_during_recv(self): + """recv raises ConcurrencyError when called concurrently.""" + self.nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_during_recv_streaming(self): + """recv raises ConcurrencyError when called concurrently with recv_streaming.""" + self.nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.recv() + self.assertEqual( + str(raised.exception), + "cannot call recv while another coroutine " + "is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_cancellation_before_receiving(self): + """recv can be canceled before receiving a message.""" + with trio.move_on_after(MS): + await self.connection.recv() + + # Running recv again receives the next message. + await self.remote_connection.send("😀") + self.assertEqual(await self.connection.recv(), "😀") + + async def test_recv_cancellation_while_receiving(self): + """recv can be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + async def send_fragments(): + await self.remote_connection.send(fragments()) + + self.nursery.start_soon(send_fragments) + await trio.testing.wait_all_tasks_blocked() + + with trio.move_on_after(MS): + await self.connection.recv() + + gate.set() + + # Running recv again receives the complete message. + self.assertEqual(await self.connection.recv(), "⏳⌛️") + + # Test recv_streaming. + + async def test_recv_streaming_text(self): + """recv_streaming receives a text message.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀"], + ) + + async def test_recv_streaming_binary(self): + """recv_streaming receives a binary message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff") + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02\xfe\xff"], + ) + + async def test_recv_streaming_text_as_bytes(self): + """recv_streaming receives a text message as bytes.""" + await self.remote_connection.send("😀") + self.assertEqual( + await alist(self.connection.recv_streaming(decode=False)), + ["😀".encode()], + ) + + async def test_recv_streaming_binary_as_str(self): + """recv_streaming receives a binary message as a str.""" + await self.remote_connection.send("😀".encode()) + self.assertEqual( + await alist(self.connection.recv_streaming(decode=True)), + ["😀"], + ) + + async def test_recv_streaming_fragmented_text(self): + """recv_streaming receives a fragmented text message.""" + await self.remote_connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_fragmented_binary(self): + """recv_streaming receives a fragmented binary message.""" + await self.remote_connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_recv_streaming_connection_closed_ok(self): + """recv_streaming raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_connection_closed_error(self): + """recv_streaming raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + async for _ in self.connection.recv_streaming(): + self.fail("did not raise") + + async def test_recv_streaming_non_utf8_text(self): + """recv_streaming receives a non-UTF-8 text message.""" + await self.remote_connection.send(b"\x01\x02\xfe\xff", text=True) + with self.assertRaises(ConnectionClosedError): + await alist(self.connection.recv_streaming()) + await self.assertFrameSent( + Frame(Opcode.CLOSE, b"\x03\xefinvalid start byte at position 2") + ) + + async def test_recv_streaming_during_recv(self): + """recv_streaming raises ConcurrencyError when called concurrently with recv.""" + self.nursery.start_soon(self.connection.recv) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await alist(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), + "cannot call recv_streaming while another coroutine " + "is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_streaming_during_recv_streaming(self): + """recv_streaming raises ConcurrencyError when called concurrently.""" + self.nursery.start_soon(alist, self.connection.recv_streaming()) + await trio.testing.wait_all_tasks_blocked() + + with self.assertRaises(ConcurrencyError) as raised: + await alist(self.connection.recv_streaming()) + self.assertEqual( + str(raised.exception), + r"cannot call recv_streaming while another coroutine " + r"is already running recv or recv_streaming", + ) + + await self.remote_connection.send("") + + async def test_recv_streaming_cancellation_before_receiving(self): + """recv_streaming can be canceled before receiving a message.""" + with trio.move_on_after(MS): + await alist(self.connection.recv_streaming()) + + # Running recv_streaming again receives the next message. + await self.remote_connection.send(["😀", "😀"]) + self.assertEqual( + await alist(self.connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_recv_streaming_cancellation_while_receiving(self): + """recv_streaming cannot be canceled while receiving a fragmented message.""" + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + async def send_fragments(): + iterator = fragments() + with self.assertRaises(ConnectionClosedError): + await self.remote_connection.send(iterator) + await iterator.aclose() + + self.nursery.start_soon(send_fragments) + await trio.testing.wait_all_tasks_blocked() + + with trio.move_on_after(MS): + await alist(self.connection.recv_streaming()) + + gate.set() + + # Running recv_streaming again fails. + with self.assertRaises(ConcurrencyError): + await alist(self.connection.recv_streaming()) + + # Test send. + + async def test_send_text(self): + """send sends a text message.""" + await self.connection.send("😀") + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_binary(self): + """send sends a binary message.""" + await self.connection.send(b"\x01\x02\xfe\xff") + self.assertEqual(await self.remote_connection.recv(), b"\x01\x02\xfe\xff") + + async def test_send_binary_from_str(self): + """send sends a binary message from a str.""" + await self.connection.send("😀", text=False) + self.assertEqual(await self.remote_connection.recv(), "😀".encode()) + + async def test_send_text_from_bytes(self): + """send sends a text message from bytes.""" + await self.connection.send("😀".encode(), text=True) + self.assertEqual(await self.remote_connection.recv(), "😀") + + async def test_send_fragmented_text(self): + """send sends a fragmented text message.""" + await self.connection.send(["😀", "😀"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_fragmented_binary(self): + """send sends a fragmented binary message.""" + await self.connection.send([b"\x01\x02", b"\xfe\xff"]) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str.""" + await self.connection.send(["😀", "😀"], text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes.""" + await self.connection.send(["😀".encode(), "😀".encode()], text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_text(self): + """send sends a fragmented text message asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_async_fragmented_binary(self): + """send sends a fragmented binary message asynchronously.""" + + async def fragments(): + yield b"\x01\x02" + yield b"\xfe\xff" + + await self.connection.send(fragments()) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + [b"\x01\x02", b"\xfe\xff", b""], + ) + + async def test_send_async_fragmented_binary_from_str(self): + """send sends a fragmented binary message from a str asynchronously.""" + + async def fragments(): + yield "😀" + yield "😀" + + await self.connection.send(fragments(), text=False) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀".encode(), "😀".encode(), b""], + ) + + async def test_send_async_fragmented_text_from_bytes(self): + """send sends a fragmented text message from bytes asynchronously.""" + + async def fragments(): + yield "😀".encode() + yield "😀".encode() + + await self.connection.send(fragments(), text=True) + # websockets sends an trailing empty fragment. That's an implementation detail. + self.assertEqual( + await alist(self.remote_connection.recv_streaming()), + ["😀", "😀", ""], + ) + + async def test_send_connection_closed_ok(self): + """send raises ConnectionClosedOK after a normal closure.""" + await self.remote_connection.close() + with self.assertRaises(ConnectionClosedOK): + await self.connection.send("😀") + + async def test_send_connection_closed_error(self): + """send raises ConnectionClosedError after an error.""" + await self.remote_connection.close(code=CloseCode.INTERNAL_ERROR) + with self.assertRaises(ConnectionClosedError): + await self.connection.send("😀") + + async def test_send_during_send_async(self): + """send waits for a previous call to send to complete.""" + # This test fails if the guard with send_in_progress is removed + # from send() in the case when message is an AsyncIterable. + gate = trio.Event() + + async def fragments(): + yield "⏳" + await gate.wait() + yield "⌛️" + + async def send_fragments(): + await self.connection.send(fragments()) + + self.nursery.start_soon(send_fragments) + await trio.testing.wait_all_tasks_blocked() + await self.assertFrameSent( + Frame(Opcode.TEXT, "⏳".encode(), fin=False), + ) + + self.nursery.start_soon(self.connection.send, "✅") + await trio.testing.wait_all_tasks_blocked() + await self.assertNoFrameSent() + + gate.set() + await trio.testing.wait_all_tasks_blocked() + await self.assertFramesSent( + [ + Frame(Opcode.CONT, "⌛️".encode(), fin=False), + Frame(Opcode.CONT, b"", fin=True), + Frame(Opcode.TEXT, "✅".encode()), + ] + ) + + async def test_send_empty_iterable(self): + """send does nothing when called with an empty iterable.""" + await self.connection.send([]) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + with self.assertRaises(TypeError): + await self.connection.send(["😀", b"\xfe\xff"]) + + async def test_send_unsupported_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send([None]) + + async def test_send_empty_async_iterable(self): + """send does nothing when called with an empty async iterable.""" + + async def fragments(): + return + yield # pragma: no cover + + await self.connection.send(fragments()) + await self.connection.close() + self.assertEqual(await alist(self.remote_connection), []) + + async def test_send_mixed_async_iterable(self): + """send raises TypeError when called with an iterable of inconsistent types.""" + + async def fragments(): + yield "😀" + yield b"\xfe\xff" + + iterator = fragments() + with self.assertRaises(TypeError): + await self.connection.send(iterator) + await iterator.aclose() + + async def test_send_unsupported_async_iterable(self): + """send raises TypeError when called with an iterable of unsupported type.""" + + async def fragments(): + yield None + + iterator = fragments() + with self.assertRaises(TypeError): + await self.connection.send(iterator) + await iterator.aclose() + + async def test_send_dict(self): + """send raises TypeError when called with a dict.""" + with self.assertRaises(TypeError): + await self.connection.send({"type": "object"}) + + async def test_send_unsupported_type(self): + """send raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.send(None) + + # Test close. + + async def test_close(self): + """close sends a close frame.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + async def test_close_explicit_code_reason(self): + """close sends a close frame with a given code and reason.""" + await self.connection.close(CloseCode.GOING_AWAY, "bye!") + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe9bye!")) + + async def test_close_waits_for_close_frame(self): + """close waits for a close frame (then EOF) before returning.""" + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_waits_for_connection_closed(self): + """close waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_close_frame(self): + """close without timeout waits for a close frame (then EOF) before returning.""" + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_frames_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_no_timeout_waits_for_connection_closed(self): + """close without timeout waits for EOF before returning.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + self.connection.close_timeout = None + + t0 = trio.current_time() + async with self.delay_eof_rcvd(MS): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_timeout_waiting_for_close_frame(self): + """close times out if no close frame is received.""" + t0 = trio.current_time() + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.ABNORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + exc = self.connection.protocol.close_exc + self.assertEqual(str(exc), "sent 1000 (OK); no close frame received") + # TODO + # self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_timeout_waiting_for_connection_closed(self): + """close times out if EOF isn't received.""" + if self.LOCAL is SERVER: + self.skipTest("only relevant on the client-side") + + t0 = trio.current_time() + async with self.drop_eof_rcvd(): + await self.connection.close() + t1 = trio.current_time() + + self.assertEqual(self.connection.state, State.CLOSED) + self.assertEqual(self.connection.close_code, CloseCode.NORMAL_CLOSURE) + self.assertGreater(t1 - t0, 2 * MS) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsInstance(exc.__cause__, TimeoutError) + + async def test_close_preserves_queued_messages(self): + """close preserves messages buffered in the assembler.""" + await self.remote_connection.send("😀") + await self.connection.close() + + self.assertEqual(await self.connection.recv(), "😀") + with self.assertRaises(ConnectionClosedOK): + await self.connection.recv() + + async def test_close_idempotency(self): + """close does nothing if the connection is already closed.""" + await self.connection.close() + await self.assertFrameSent(Frame(Opcode.CLOSE, b"\x03\xe8")) + + await self.connection.close() + await self.assertNoFrameSent() + + async def test_close_during_recv(self): + """close aborts recv when called concurrently with recv.""" + + async def closer(): + await trio.sleep(MS) + await self.connection.close() + + self.nursery.start_soon(closer) + + with self.assertRaises(ConnectionClosedOK) as raised: + await self.connection.recv() + + exc = raised.exception + self.assertEqual(str(exc), "sent 1000 (OK); then received 1000 (OK)") + self.assertIsNone(exc.__cause__) + + async def test_close_during_send(self): + """close fails the connection when called concurrently with send.""" + close_gate = trio.Event() + exit_gate = trio.Event() + + async def closer(): + await close_gate.wait() + await trio.testing.wait_all_tasks_blocked() + await self.connection.close() + exit_gate.set() + + async def fragments(): + yield "⏳" + close_gate.set() + await exit_gate.wait() + yield "⌛️" + + self.nursery.start_soon(closer) + + iterator = fragments() + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send(iterator) + await iterator.aclose() + + exc = raised.exception + self.assertEqual( + str(exc), + "sent 1011 (internal error) close during fragmented message; " + "no close frame received", + ) + self.assertIsNone(exc.__cause__) + + # Test wait_closed. + + async def test_wait_closed(self): + """wait_closed waits for the connection to close.""" + closed = trio.Event() + + async def closer(): + await self.connection.wait_closed() + closed.set() + + self.nursery.start_soon(closer) + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(closed.is_set()) + + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertTrue(closed.is_set()) + + # Test ping. + + @patch("random.getrandbits", return_value=1918987876) + async def test_ping(self, getrandbits): + """ping sends a ping frame with a random payload.""" + await self.connection.ping() + getrandbits.assert_called_once_with(32) + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + + async def test_ping_explicit_text(self): + """ping sends a ping frame with a payload provided as text.""" + await self.connection.ping("ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_ping_explicit_binary(self): + """ping sends a ping frame with a payload provided as binary.""" + await self.connection.ping(b"ping") + await self.assertFrameSent(Frame(Opcode.PING, b"ping")) + + async def test_acknowledge_ping(self): + """ping is acknowledged by a pong with the same payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("this") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_non_matching_pong(self): + """ping isn't acknowledged by a pong with a different payload.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.remote_connection.pong("that") + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_previous_ping(self): + """ping is acknowledged by a pong for a later ping.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("this") + await self.connection.ping("that") + await self.remote_connection.pong("that") + with trio.fail_after(MS): + await pong_received.wait() + + async def test_acknowledge_ping_on_close(self): + """ping with ack_on_close is acknowledged when the connection is closed.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received_ack_on_close = await self.connection.ping( + "this", ack_on_close=True + ) + pong_received = await self.connection.ping("that") + await self.connection.close() + with trio.fail_after(MS): + await pong_received_ack_on_close.wait() + with self.assertRaises(trio.TooSlowError): + with trio.fail_after(MS): + await pong_received.wait() + + async def test_ping_duplicate_payload(self): + """ping rejects the same payload until receiving the pong.""" + async with self.drop_frames_rcvd(): # drop automatic response to ping + pong_received = await self.connection.ping("idem") + + with self.assertRaises(ConcurrencyError) as raised: + await self.connection.ping("idem") + self.assertEqual( + str(raised.exception), + "already waiting for a pong with the same data", + ) + + await self.remote_connection.pong("idem") + with trio.fail_after(MS): + await pong_received.wait() + + await self.connection.ping("idem") # doesn't raise an exception + + async def test_ping_unsupported_type(self): + """ping raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.ping([]) + + # Test pong. + + async def test_pong(self): + """pong sends a pong frame.""" + await self.connection.pong() + await self.assertFrameSent(Frame(Opcode.PONG, b"")) + + async def test_pong_explicit_text(self): + """pong sends a pong frame with a payload provided as text.""" + await self.connection.pong("pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_explicit_binary(self): + """pong sends a pong frame with a payload provided as binary.""" + await self.connection.pong(b"pong") + await self.assertFrameSent(Frame(Opcode.PONG, b"pong")) + + async def test_pong_unsupported_type(self): + """pong raises TypeError when called with an unsupported type.""" + with self.assertRaises(TypeError): + await self.connection.pong([]) + + # Test keepalive. + + def keepalive_task_is_running(self): + return any( + task.name == "websockets.trio.connection.Connection.keepalive" + for task in self.nursery.child_tasks + ) + + @patch("random.getrandbits", return_value=1918987876) + async def test_keepalive(self, getrandbits): + """keepalive sends pings at ping_interval and measures latency.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + self.assertEqual(self.connection.latency, 0) + # 3 ms: keepalive() sends a ping frame. + # 3.x ms: a pong frame is received. + await trio.sleep(4 * MS) + # 4 ms: check that the ping frame was sent. + await self.assertFrameSent(Frame(Opcode.PING, b"rand")) + self.assertGreater(self.connection.latency, 0) + self.assertLess(self.connection.latency, MS) + + async def test_disable_keepalive(self): + """keepalive is disabled when ping_interval is None.""" + self.connection.ping_interval = None + self.connection.start_keepalive() + self.assertFalse(self.keepalive_task_is_running()) + + @patch("random.getrandbits", return_value=1918987876) + async def test_keepalive_times_out(self, getrandbits): + """keepalive closes the connection if ping_timeout elapses.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = 2 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection is closed. + await trio.sleep(2 * MS) + # 7 ms: check that the connection is closed. + self.assertEqual(self.connection.state, State.CLOSED) + + @patch("random.getrandbits", return_value=1918987876) + async def test_keepalive_ignores_timeout(self, getrandbits): + """keepalive ignores timeouts if ping_timeout isn't set.""" + self.connection.ping_interval = 4 * MS + self.connection.ping_timeout = None + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 4 ms: keepalive() sends a ping frame. + # 4.x ms: a pong frame is dropped. + await trio.sleep(5 * MS) + # 6 ms: no pong frame is received; the connection remains open. + await trio.sleep(2 * MS) + # 7 ms: check that the connection is still open. + self.assertEqual(self.connection.state, State.OPEN) + + async def test_keepalive_terminates_while_sleeping(self): + """keepalive task terminates while waiting to send a ping.""" + self.connection.ping_interval = 3 * MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + await trio.testing.wait_all_tasks_blocked() + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_when_sending_ping_fails(self): + """keepalive task terminates when sending a ping fails.""" + self.connection.ping_interval = MS + self.connection.start_keepalive() + self.assertTrue(self.keepalive_task_is_running()) + async with self.drop_eof_rcvd(), self.drop_frames_rcvd(): + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_terminates_while_waiting_for_pong(self): + """keepalive task terminates while waiting to receive a pong.""" + self.connection.ping_interval = MS + self.connection.ping_timeout = 3 * MS + async with self.drop_frames_rcvd(): + self.connection.start_keepalive() + # 1 ms: keepalive() sends a ping frame. + # 1.x ms: a pong frame is dropped. + await trio.sleep(2 * MS) + # 2 ms: close the connection before ping_timeout elapses. + await self.connection.close() + await trio.testing.wait_all_tasks_blocked() + self.assertFalse(self.keepalive_task_is_running()) + + async def test_keepalive_reports_errors(self): + """keepalive reports unexpected errors in logs.""" + self.connection.ping_interval = 2 * MS + self.connection.start_keepalive() + # Inject a fault when waiting to receive a pong. + with self.assertLogs("websockets", logging.ERROR) as logs: + with patch("trio.Event.wait", side_effect=Exception("BOOM")): + # 2 ms: keepalive() sends a ping frame. + # 2.x ms: a pong frame is dropped. + await trio.sleep(3 * MS) + self.assertEqual( + [record.getMessage() for record in logs.records], + ["keepalive ping failed"], + ) + self.assertEqual( + [str(record.exc_info[1]) for record in logs.records], + ["BOOM"], + ) + + # Test parameters. + + async def test_close_timeout(self): + """close_timeout parameter configures close timeout.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + close_timeout=42, + ) + self.assertEqual(connection.close_timeout, 42) + await remote_stream.aclose() + + async def test_max_queue(self): + """max_queue configures high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=4, + ) + self.assertEqual(connection.recv_messages.high, 4) + await remote_stream.aclose() + + async def test_max_queue_none(self): + """max_queue disables high-water mark of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=None, + ) + self.assertEqual(connection.recv_messages.high, None) + self.assertEqual(connection.recv_messages.low, None) + await remote_stream.aclose() + + async def test_max_queue_tuple(self): + """max_queue configures high-water and low-water marks of frames buffer.""" + stream, remote_stream = trio.testing.memory_stream_pair() + connection = Connection( + self.nursery, + stream, + Protocol(self.LOCAL), + max_queue=(4, 2), + ) + self.assertEqual(connection.recv_messages.high, 4) + self.assertEqual(connection.recv_messages.low, 2) + await remote_stream.aclose() + + # Test attributes. + + async def test_id(self): + """Connection has an id attribute.""" + self.assertIsInstance(self.connection.id, uuid.UUID) + + async def test_logger(self): + """Connection has a logger attribute.""" + self.assertIsInstance(self.connection.logger, logging.LoggerAdapter) + + @contextlib.asynccontextmanager + async def get_server_and_client_streams(self): + listeners = await trio.open_tcp_listeners(0, host="127.0.0.1") + assert len(listeners) == 1 + listener = listeners[0] + client_stream = await trio.testing.open_stream_to_socket_listener(listener) + client_port = client_stream.socket.getsockname()[1] + server_stream = await listener.accept() + server_port = listener.socket.getsockname()[1] + try: + yield client_stream, server_stream, client_port, server_port + finally: + await server_stream.aclose() + await client_stream.aclose() + await listener.aclose() + + async def test_local_address(self): + """Connection provides a local_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + port = {CLIENT: client_port, SERVER: server_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.local_address, ("127.0.0.1", port)) + + async def test_remote_address(self): + """Connection provides a remote_address attribute.""" + async with self.get_server_and_client_streams() as ( + client_stream, + server_stream, + client_port, + server_port, + ): + stream = {CLIENT: client_stream, SERVER: server_stream}[self.LOCAL] + remote_port = {CLIENT: server_port, SERVER: client_port}[self.LOCAL] + connection = Connection(self.nursery, stream, Protocol(self.LOCAL)) + self.assertEqual(connection.remote_address, ("127.0.0.1", remote_port)) + + async def test_state(self): + """Connection has a state attribute.""" + self.assertIs(self.connection.state, State.OPEN) + + async def test_request(self): + """Connection has a request attribute.""" + self.assertIsNone(self.connection.request) + + async def test_response(self): + """Connection has a response attribute.""" + self.assertIsNone(self.connection.response) + + async def test_subprotocol(self): + """Connection has a subprotocol attribute.""" + self.assertIsNone(self.connection.subprotocol) + + async def test_close_code(self): + """Connection has a close_code attribute.""" + self.assertIsNone(self.connection.close_code) + + async def test_close_reason(self): + """Connection has a close_reason attribute.""" + self.assertIsNone(self.connection.close_reason) + + # Test reporting of network errors. + + async def test_writing_in_recv_events_fails(self): + """Error when responding to incoming frames is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + + # Receive a ping. Responding with a pong will fail. + await self.remote_connection.ping() + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + async def test_writing_in_send_context_fails(self): + """Error when sending outgoing frame is correctly reported.""" + # Inject a fault by shutting down the stream for writing — but not the + # stream for reading because that would terminate the connection. + self.connection.stream.send_stream.close() + + # Sending a pong will fail. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.pong() + + self.assertIsInstance(raised.exception.__cause__, trio.ClosedResourceError) + + # Test safety nets — catching all exceptions in case of bugs. + + # Inject a fault in a random call in recv_events(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.events_received", side_effect=AssertionError) + async def test_unexpected_failure_in_recv_events(self, events_received): + """Unexpected internal error in recv_events() is correctly reported.""" + # Receive a message to trigger the fault. + await self.remote_connection.send("😀") + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.recv() + + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + # Inject a fault in a random call in send_context(). + # This test is tightly coupled to the implementation. + @patch("websockets.protocol.Protocol.send_text", side_effect=AssertionError) + async def test_unexpected_failure_in_send_context(self, send_text): + """Unexpected internal error in send_context() is correctly reported.""" + # Send a message to trigger the fault. + # The connection closed exception reports the injected fault. + with self.assertRaises(ConnectionClosedError) as raised: + await self.connection.send("😀") + + self.assertIsInstance(raised.exception.__cause__, AssertionError) + + +class ServerConnectionTests(ClientConnectionTests): + LOCAL = SERVER + REMOTE = CLIENT diff --git a/tests/trio/test_messages.py b/tests/trio/test_messages.py new file mode 100644 index 000000000..838b52bdb --- /dev/null +++ b/tests/trio/test_messages.py @@ -0,0 +1,633 @@ +import unittest +import unittest.mock + +import trio.testing + +from websockets.asyncio.compatibility import aiter, anext +from websockets.exceptions import ConcurrencyError +from websockets.frames import OP_BINARY, OP_CONT, OP_TEXT, Frame +from websockets.trio.messages import * + +from ..asyncio.utils import alist +from ..utils import MS +from .utils import IsolatedTrioTestCase + + +class AssemblerTests(IsolatedTrioTestCase): + def setUp(self): + self.pause = unittest.mock.Mock() + self.resume = unittest.mock.Mock() + self.assembler = Assembler(high=2, low=1, pause=self.pause, resume=self.resume) + + # Test get + + async def test_get_text_message_already_received(self): + """get returns a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_binary_message_already_received(self): + """get returns a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_text_message_not_received_yet(self): + """get returns a text message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(message, "café") + + async def test_get_binary_message_not_received_yet(self): + """get returns a binary message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_already_received(self): + """get reassembles a fragmented a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_already_received(self): + """get reassembles a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_not_received_yet(self): + """get reassembles a fragmented text message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_not_received_yet(self): + """get reassembles a fragmented binary message when it is received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + async def test_get_fragmented_text_message_being_received(self): + """get reassembles a fragmented text message that is partially received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + self.assertEqual(message, "café") + + async def test_get_fragmented_binary_message_being_received(self): + """get reassembles a fragmented binary message that is partially received.""" + message = None + + async def get_task(): + nonlocal message + message = await self.assembler.get() + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.assertEqual(message, b"tea") + + async def test_get_encoded_text_message(self): + """get returns a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + message = await self.assembler.get(decode=False) + self.assertEqual(message, b"caf\xc3\xa9") + + async def test_get_decoded_binary_message(self): + """get returns a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + message = await self.assembler.get(decode=True) + self.assertEqual(message, "tea") + + async def test_get_resumes_reading(self): + """get resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + + # queue is above the low-water mark + await self.assembler.get() + self.resume.assert_not_called() + + # queue is at the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await self.assembler.get() + self.resume.assert_called_once_with() + + async def test_get_does_not_resume_reading(self): + """get does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"more caf\xc3\xa9")) + self.assembler.put(Frame(OP_TEXT, b"water")) + await self.assembler.get() + await self.assembler.get() + await self.assembler.get() + + self.resume.assert_not_called() + + async def test_cancel_get_before_first_frame(self): + """get can be canceled safely before reading the first frame.""" + + async def get_task(): + await self.assembler.get() + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_cancel_get_after_first_frame(self): + """get can be canceled safely after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + async def get_task(): + await self.assembler.get() + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + message = await self.assembler.get() + self.assertEqual(message, "café") + + # Test get_iter + + async def test_get_iter_text_message_already_received(self): + """get_iter yields a text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_already_received(self): + """get_iter yields a binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"tea")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_text_message_not_received_yet(self): + """get_iter yields a text message when it is received.""" + fragments = None + + async def get_iter_task(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + self.assertEqual(fragments, ["café"]) + + async def test_get_iter_binary_message_not_received_yet(self): + """get_iter yields a binary message when it is received.""" + fragments = None + + async def get_iter_task(): + nonlocal fragments + fragments = await alist(self.assembler.get_iter()) + + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + await trio.testing.wait_all_tasks_blocked() + self.assembler.put(Frame(OP_BINARY, b"tea")) + + self.assertEqual(fragments, [b"tea"]) + + async def test_get_iter_fragmented_text_message_already_received(self): + """get_iter yields a fragmented text message that is already received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["ca", "f", "é"]) + + async def test_get_iter_fragmented_binary_message_already_received(self): + """get_iter yields a fragmented binary message that is already received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_iter_fragmented_text_message_not_received_yet(self): + """get_iter yields a fragmented text message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + await iterator.aclose() + + async def test_get_iter_fragmented_binary_message_not_received_yet(self): + """get_iter yields a fragmented binary message when it is received.""" + iterator = aiter(self.assembler.get_iter()) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() + + async def test_get_iter_fragmented_text_message_being_received(self): + """get_iter yields a fragmented text message that is partially received.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), "ca") + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assertEqual(await anext(iterator), "f") + self.assembler.put(Frame(OP_CONT, b"\xa9")) + self.assertEqual(await anext(iterator), "é") + await iterator.aclose() + + async def test_get_iter_fragmented_binary_message_being_received(self): + """get_iter yields a fragmented binary message that is partially received.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + iterator = aiter(self.assembler.get_iter()) + self.assertEqual(await anext(iterator), b"t") + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assertEqual(await anext(iterator), b"e") + self.assembler.put(Frame(OP_CONT, b"a")) + self.assertEqual(await anext(iterator), b"a") + await iterator.aclose() + + async def test_get_iter_encoded_text_message(self): + """get_iter yields a text message without UTF-8 decoding.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + fragments = await alist(self.assembler.get_iter(decode=False)) + self.assertEqual(fragments, [b"ca", b"f\xc3", b"\xa9"]) + + async def test_get_iter_decoded_binary_message(self): + """get_iter yields a binary message with UTF-8 decoding.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + fragments = await alist(self.assembler.get_iter(decode=True)) + self.assertEqual(fragments, ["t", "e", "a"]) + + async def test_get_iter_resumes_reading(self): + """get_iter resumes reading when queue goes below the low-water mark.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + iterator = aiter(self.assembler.get_iter()) + + # queue is above the low-water mark + await anext(iterator) + self.resume.assert_not_called() + + # queue is at the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + # queue is below the low-water mark + await anext(iterator) + self.resume.assert_called_once_with() + + await iterator.aclose() + + async def test_get_iter_does_not_resume_reading(self): + """get_iter does not resume reading when the low-water mark is unset.""" + self.assembler.low = None + + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + iterator = aiter(self.assembler.get_iter()) + await anext(iterator) + await anext(iterator) + await anext(iterator) + await iterator.aclose() + + self.resume.assert_not_called() + + async def test_cancel_get_iter_before_first_frame(self): + """get_iter can be canceled safely before reading the first frame.""" + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_cancel_get_iter_after_first_frame(self): + """get_iter cannot be canceled after reading the first frame.""" + self.assembler.put(Frame(OP_TEXT, b"ca", fin=False)) + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.move_on_after(MS) as cancel_scope: + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + self.assertTrue(cancel_scope.cancelled_caught) + + self.assembler.put(Frame(OP_CONT, b"f\xc3", fin=False)) + self.assembler.put(Frame(OP_CONT, b"\xa9")) + + with self.assertRaises(ConcurrencyError): + await alist(self.assembler.get_iter()) + + # Test put + + async def test_put_pauses_reading(self): + """put pauses reading when queue goes above the high-water mark.""" + # queue is below the high-water mark + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.pause.assert_not_called() + + # queue is at the high-water mark + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.pause.assert_called_once_with() + + # queue is above the high-water mark + self.assembler.put(Frame(OP_CONT, b"a")) + self.pause.assert_called_once_with() + + async def test_put_does_not_pause_reading(self): + """put does not pause reading when the high-water mark is unset.""" + self.assembler.high = None + + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + + self.pause.assert_not_called() + + # Test termination + + async def test_get_fails_when_interrupted_by_close(self): + """get raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + async with trio.open_nursery() as nursery: + nursery.start_soon(closer) + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_when_interrupted_by_close(self): + """get_iter raises EOFError when close is called.""" + + async def closer(): + self.assembler.close() + + async with trio.open_nursery() as nursery: + nursery.start_soon(closer) + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_fails_after_close(self): + """get raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_fails_after_close(self): + """get_iter raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + async for _ in self.assembler.get_iter(): + self.fail("no fragment expected") + + async def test_get_queued_message_after_close(self): + """get returns a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, "café") + + async def test_get_iter_queued_message_after_close(self): + """get_iter yields a message after close is called.""" + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, ["café"]) + + async def test_get_queued_fragmented_message_after_close(self): + """get reassembles a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + self.assembler.close() + message = await self.assembler.get() + self.assertEqual(message, b"tea") + + async def test_get_iter_queued_fragmented_message_after_close(self): + """get_iter yields a fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.put(Frame(OP_CONT, b"a")) + self.assembler.close() + fragments = await alist(self.assembler.get_iter()) + self.assertEqual(fragments, [b"t", b"e", b"a"]) + + async def test_get_partially_queued_fragmented_message_after_close(self): + """get raises EOFError on a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + with self.assertRaises(EOFError): + await self.assembler.get() + + async def test_get_iter_partially_queued_fragmented_message_after_close(self): + """get_iter yields a partial fragmented message after close is called.""" + self.assembler.put(Frame(OP_BINARY, b"t", fin=False)) + self.assembler.put(Frame(OP_CONT, b"e", fin=False)) + self.assembler.close() + fragments = [] + with self.assertRaises(EOFError): + async for fragment in self.assembler.get_iter(): + fragments.append(fragment) + self.assertEqual(fragments, [b"t", b"e"]) + + async def test_put_fails_after_close(self): + """put raises EOFError after close is called.""" + self.assembler.close() + with self.assertRaises(EOFError): + self.assembler.put(Frame(OP_TEXT, b"caf\xc3\xa9")) + + async def test_close_is_idempotent(self): + """close can be called multiple times safely.""" + self.assembler.close() + self.assembler.close() + + # Test (non-)concurrency + + async def test_get_fails_when_get_is_running(self): + """get cannot be called concurrently.""" + + async def get_task(): + await self.assembler.get() + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + nursery.start_soon(get_task) + + async def test_get_fails_when_get_iter_is_running(self): + """get cannot be called concurrently with get_iter.""" + + async def get_task(): + await alist(self.assembler.get_iter()) + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + nursery.start_soon(get_task) + + async def test_get_iter_fails_when_get_is_running(self): + """get_iter cannot be called concurrently with get.""" + + async def get_task(): + await alist(self.assembler.get_iter()) + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_task) + nursery.start_soon(get_iter_task) + + async def test_get_iter_fails_when_get_iter_is_running(self): + """get_iter cannot be called concurrently.""" + + async def get_iter_task(): + await alist(self.assembler.get_iter()) + + with trio.testing.RaisesGroup(ConcurrencyError): + async with trio.open_nursery() as nursery: + nursery.start_soon(get_iter_task) + nursery.start_soon(get_iter_task) + + # Test setting limits + + async def test_set_high_water_mark(self): + """high sets the high-water and low-water marks.""" + assembler = Assembler(high=10) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 2) + + async def test_set_low_water_mark(self): + """low sets the low-water and high-water marks.""" + assembler = Assembler(low=5) + self.assertEqual(assembler.low, 5) + self.assertEqual(assembler.high, 20) + + async def test_set_high_and_low_water_marks(self): + """high and low set the high-water and low-water marks.""" + assembler = Assembler(high=10, low=5) + self.assertEqual(assembler.high, 10) + self.assertEqual(assembler.low, 5) + + async def test_unset_high_and_low_water_marks(self): + """High-water and low-water marks are unset.""" + assembler = Assembler() + self.assertEqual(assembler.high, None) + self.assertEqual(assembler.low, None) + + async def test_set_invalid_high_water_mark(self): + """high must be a non-negative integer.""" + with self.assertRaises(ValueError): + Assembler(high=-1) + + async def test_set_invalid_low_water_mark(self): + """low must be higher than high.""" + with self.assertRaises(ValueError): + Assembler(low=10, high=5) diff --git a/tests/trio/utils.py b/tests/trio/utils.py new file mode 100644 index 000000000..4686a74e6 --- /dev/null +++ b/tests/trio/utils.py @@ -0,0 +1,58 @@ +import asyncio +import functools +import sys +import unittest + +import trio.testing + + +if sys.version_info[:2] < (3, 11): # pragma: no cover + from exceptiongroup import ExceptionGroup + + +class IsolatedTrioTestCase(unittest.TestCase): + """ + Wrap test coroutines with :func:`trio.testing.trio_test` automatically. + + Also initializes a nursery for each test and adds :meth:`asyncSetUp` and + :meth:`asyncTearDown`, similar to :class:`unittest.IsolatedAsyncioTestCase`. + + """ + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + for name in unittest.defaultTestLoader.getTestCaseNames(cls): + test = getattr(cls, name) + if getattr(test, "converted_to_trio", False): + return + assert asyncio.iscoroutinefunction(test) + setattr(cls, name, cls.convert_to_trio(test)) + + @staticmethod + def convert_to_trio(test): + @trio.testing.trio_test + @functools.wraps(test) + async def new_test(self, *args, **kwargs): + try: + # Provide a nursery so it's easy to start tasks. + async with trio.open_nursery() as self.nursery: + await self.asyncSetUp() + try: + return await test(self, *args, **kwargs) + finally: + await self.asyncTearDown() + except ExceptionGroup as exc_group: + # Unwrap exceptions like unittest.SkipTest. + if len(exc_group.exceptions) == 1: + raise exc_group.exceptions[0] + else: # pragma: no cover + raise + + new_test.converted_to_trio = True + return new_test + + async def asyncSetUp(self): + pass + + async def asyncTearDown(self): + pass diff --git a/tox.ini b/tox.ini index 9450e9714..7f7d4e101 100644 --- a/tox.ini +++ b/tox.ini @@ -17,6 +17,7 @@ pass_env = deps = py311,py312,py313,coverage,maxi_cov: mitmproxy py311,py312,py313,coverage,maxi_cov: python-socks[asyncio] + trio werkzeug [testenv:coverage] @@ -48,4 +49,5 @@ commands = deps = mypy python-socks + trio werkzeug pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy