diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c4b8987f..bb976ac9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,10 +19,10 @@ concurrency: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', ' 3.13', 'pypy3.10'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10'] steps: - uses: actions/checkout@v4 @@ -33,7 +33,7 @@ jobs: cache: 'pip' cache-dependency-path: | pyproject.toml - - uses: hoverkraft-tech/compose-action@v2.0.0 + - uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: "./docker-compose.yml" down-flags: "--remove-orphans" diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eae15d12..74e4cdae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: forbid-crlf - id: remove-crlf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -15,7 +15,7 @@ repos: exclude: helm/ args: [ --unsafe ] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.4.4" + rev: "v0.7.4" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/Makefile b/Makefile index 62c52dc7..07d5d459 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ #!/usr/bin/make -f +VENV_PATH = .venv/bin + install-dev-requirements: curl -LsSf https://astral.sh/uv/install.sh | sh uv venv && uv pip install hatch @@ -25,22 +27,22 @@ develop: install-dev-requirements install-test-requirements types: @echo "Type checking Python files" - .venv/bin/mypy --pretty + $(VENV_PATH)/mypy --pretty @echo "" test: types @echo "Running Python tests" uv pip uninstall pook || true - export VIRTUAL_ENV=.venv; .venv/bin/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- .venv/bin/pytest - uv pip install pook && .venv/bin/pytest tests/test_pook.py && uv pip uninstall pook + $(VENV_PATH)/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- $(VENV_PATH)/pytest + uv pip install pook && $(VENV_PATH)/pytest tests/test_pook.py && uv pip uninstall pook @echo "" safetest: - export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; make test + export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; $(VENV_PATH)/pytest publish: clean install-test-requirements - uv run python3 -m build --sdist . - uv run twine upload --repository mocket dist/*.tar.gz + uv build --package mocket --sdist --wheel + uv publish clean: rm -rf .coverage *.egg-info dist/ requirements.txt uv.lock || true diff --git a/README.rst b/README.rst index 17a7801c..82bf6c37 100644 --- a/README.rst +++ b/README.rst @@ -24,9 +24,8 @@ A socket mock framework Outside GitHub ============== -Mocket packages are available for `Arch Linux`_, `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course you can **pip install** it from `PyPI`_. +Mocket packages are available for `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course from `PyPI`_. -.. _`Arch Linux`: https://archlinux.org/packages/extra/any/python-mocket/ .. _`openSUSE`: https://software.opensuse.org/search?baseproject=ALL&q=mocket .. _`NixOS`: https://search.nixos.org/packages?query=mocket .. _`ALT Linux`: https://packages.altlinux.org/en/sisyphus/srpms/python3-module-mocket/ @@ -284,52 +283,44 @@ Example: .. code-block:: python - class AioHttpEntryTestCase(TestCase): - @mocketize - def test_http_session(self): - url = 'http://httpbin.org/ip' - body = "asd" * 100 - Entry.single_register(Entry.GET, url, body=body, status=404) - Entry.single_register(Entry.POST, url, body=body*2, status=201) + # `aiohttp` creates SSLContext instances at import-time + # that's why Mocket would get stuck when dealing with HTTP + # Importing the module while Mocket is in control (inside a + # decorated test function or using its context manager would + # be enough for making it work), the alternative is using a + # custom TCPConnector which always return a FakeSSLContext + # from Mocket like this example is showing. + import aiohttp + import pytest - async def main(l): - async with aiohttp.ClientSession( - loop=l, timeout=aiohttp.ClientTimeout(total=3) - ) as session: - async with session.get(url) as get_response: - assert get_response.status == 404 - assert await get_response.text() == body + from mocket import async_mocketize + from mocket.mockhttp import Entry + from mocket.plugins.aiohttp_connector import MocketTCPConnector - async with session.post(url, data=body * 6) as post_response: - assert post_response.status == 201 - assert await post_response.text() == body * 2 - loop = asyncio.new_event_loop() - loop.run_until_complete(main(loop)) + @pytest.mark.asyncio + @async_mocketize + async def test_aiohttp(): + """ + The alternative to using the custom `connector` would be importing + `aiohttp` when Mocket is already in control (inside the decorated test). + """ + + url = "https://bar.foo/" + data = {"message": "Hello"} + + Entry.single_register( + Entry.GET, + url, + body=json.dumps(data), + headers={"content-type": "application/json"}, + ) - # or again with a unittest.IsolatedAsyncioTestCase - from mocket.async_mocket import async_mocketize - - class AioHttpEntryTestCase(IsolatedAsyncioTestCase): - @async_mocketize - async def test_http_session(self): - url = 'http://httpbin.org/ip' - body = "asd" * 100 - Entry.single_register(Entry.GET, url, body=body, status=404) - Entry.single_register(Entry.POST, url, body=body * 2, status=201) - - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=3) - ) as session: - async with session.get(url) as get_response: - assert get_response.status == 404 - assert await get_response.text() == body - - async with session.post(url, data=body * 6) as post_response: - assert post_response.status == 201 - assert await post_response.text() == body * 2 - assert Mocket.last_request().method == 'POST' - assert Mocket.last_request().body == body * 6 + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=3), connector=MocketTCPConnector() + ) as session, session.get(url) as response: + response = await response.json() + assert response == data Works well with others diff --git a/mocket/__init__.py b/mocket/__init__.py index fb0434e9..c785bba5 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,5 +1,25 @@ -from .async_mocket import async_mocketize -from .mocket import FakeSSLContext, Mocket, MocketEntry, Mocketizer, mocketize +import importlib +import sys + +from mocket.decorators.async_mocket import async_mocketize +from mocket.decorators.mocketizer import Mocketizer, mocketize +from mocket.entry import MocketEntry +from mocket.mocket import Mocket +from mocket.ssl.context import MocketSSLContext + +# NOTE the following lines are here for backwards-compatibility, +# to keep old import-paths working +from mocket.ssl.context import MocketSSLContext as FakeSSLContext + +sys.modules["mocket.mockhttp"] = importlib.import_module("mocket.mocks.mockhttp") +sys.modules["mocket.mockredis"] = importlib.import_module("mocket.mocks.mockredis") +sys.modules["mocket.async_mocket"] = importlib.import_module( + "mocket.decorators.async_mocket" +) +sys.modules["mocket.mocketizer"] = importlib.import_module( + "mocket.decorators.mocketizer" +) + __all__ = ( "async_mocketize", @@ -7,7 +27,8 @@ "Mocket", "MocketEntry", "Mocketizer", + "MocketSSLContext", "FakeSSLContext", ) -__version__ = "3.13.2" +__version__ = "3.13.3" diff --git a/mocket/compat.py b/mocket/compat.py index 276ae0f0..1ac2fc89 100644 --- a/mocket/compat.py +++ b/mocket/compat.py @@ -9,21 +9,17 @@ ENCODING: Final[str] = os.getenv("MOCKET_ENCODING", "utf-8") -text_type = str -byte_type = bytes -basestring = (str,) - def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes: - if isinstance(s, text_type): + if isinstance(s, str): s = s.encode(encoding) - return byte_type(s) + return bytes(s) def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str: - if isinstance(s, byte_type): + if isinstance(s, bytes): s = codecs.decode(s, encoding, "ignore") - return text_type(s) + return str(s) def shsplit(s: str | bytes) -> list[str]: diff --git a/mocket/decorators/__init__.py b/mocket/decorators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/async_mocket.py b/mocket/decorators/async_mocket.py similarity index 81% rename from mocket/async_mocket.py rename to mocket/decorators/async_mocket.py index 2970e0f4..40b763ae 100644 --- a/mocket/async_mocket.py +++ b/mocket/decorators/async_mocket.py @@ -1,5 +1,5 @@ -from .mocket import Mocketizer -from .utils import get_mocketize +from mocket.decorators.mocketizer import Mocketizer +from mocket.utils import get_mocketize async def wrapper( diff --git a/mocket/decorators/mocketizer.py b/mocket/decorators/mocketizer.py new file mode 100644 index 00000000..2bf2b9cd --- /dev/null +++ b/mocket/decorators/mocketizer.py @@ -0,0 +1,95 @@ +from mocket.mocket import Mocket +from mocket.mode import MocketMode +from mocket.utils import get_mocketize + + +class Mocketizer: + def __init__( + self, + instance=None, + namespace=None, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + ): + self.instance = instance + self.truesocket_recording_dir = truesocket_recording_dir + self.namespace = namespace or str(id(self)) + MocketMode().STRICT = strict_mode + if strict_mode: + MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] + elif strict_mode_allowed: + raise ValueError( + "Allowed locations are only accepted when STRICT mode is active." + ) + + def enter(self): + Mocket.enable( + namespace=self.namespace, + truesocket_recording_dir=self.truesocket_recording_dir, + ) + if self.instance: + self.check_and_call("mocketize_setup") + + def __enter__(self): + self.enter() + return self + + def exit(self): + if self.instance: + self.check_and_call("mocketize_teardown") + + Mocket.disable() + + def __exit__(self, type, value, tb): + self.exit() + + async def __aenter__(self, *args, **kwargs): + self.enter() + return self + + async def __aexit__(self, *args, **kwargs): + self.exit() + + def check_and_call(self, method_name): + method = getattr(self.instance, method_name, None) + if callable(method): + method() + + @staticmethod + def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): + instance = args[0] if args else None + namespace = None + if truesocket_recording_dir: + namespace = ".".join( + ( + instance.__class__.__module__, + instance.__class__.__name__, + test.__name__, + ) + ) + + return Mocketizer( + instance, + namespace=namespace, + truesocket_recording_dir=truesocket_recording_dir, + strict_mode=strict_mode, + strict_mode_allowed=strict_mode_allowed, + ) + + +def wrapper( + test, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + *args, + **kwargs, +): + with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): + return test(*args, **kwargs) + + +mocketize = get_mocketize(wrapper_=wrapper) diff --git a/mocket/entry.py b/mocket/entry.py new file mode 100644 index 00000000..9dbbf442 --- /dev/null +++ b/mocket/entry.py @@ -0,0 +1,58 @@ +import collections.abc + +from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket + + +class MocketEntry: + class Response(bytes): + @property + def data(self): + return self + + response_index = 0 + request_cls = bytes + response_cls = Response + responses = None + _served = None + + def __init__(self, location, responses): + self._served = False + self.location = location + + if not isinstance(responses, collections.abc.Iterable): + responses = [responses] + + if not responses: + self.responses = [self.response_cls(encode_to_bytes(""))] + else: + self.responses = [] + for r in responses: + if not isinstance(r, BaseException) and not getattr(r, "data", False): + if isinstance(r, str): + r = encode_to_bytes(r) + r = self.response_cls(r) + self.responses.append(r) + + def __repr__(self): + return f"{self.__class__.__name__}(location={self.location})" + + @staticmethod + def can_handle(data): + return True + + def collect(self, data): + req = self.request_cls(data) + Mocket.collect(req) + + def get_response(self): + response = self.responses[self.response_index] + if self.response_index < len(self.responses) - 1: + self.response_index += 1 + + self._served = True + + if isinstance(response, BaseException): + raise response + + return response.data diff --git a/mocket/inject.py b/mocket/inject.py new file mode 100644 index 00000000..866ee563 --- /dev/null +++ b/mocket/inject.py @@ -0,0 +1,80 @@ +from __future__ import annotations + +import contextlib +import socket +import ssl +from types import ModuleType +from typing import Any + +import urllib3 + +_patches_restore: dict[tuple[ModuleType, str], Any] = {} + + +def _patch(module: ModuleType, name: str, patched_value: Any) -> None: + with contextlib.suppress(KeyError): + original_value, module.__dict__[name] = module.__dict__[name], patched_value + _patches_restore[(module, name)] = original_value + + +def _restore(module: ModuleType, name: str) -> None: + if original_value := _patches_restore.pop((module, name)): + module.__dict__[name] = original_value + + +def enable() -> None: + from mocket.socket import ( + MocketSocket, + mock_create_connection, + mock_getaddrinfo, + mock_gethostbyname, + mock_gethostname, + mock_inet_pton, + mock_socketpair, + ) + from mocket.ssl.context import MocketSSLContext, mock_wrap_socket + from mocket.urllib3 import ( + mock_match_hostname as mock_urllib3_match_hostname, + ) + from mocket.urllib3 import ( + mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket, + ) + + patches = { + # stdlib: socket + (socket, "socket"): MocketSocket, + (socket, "create_connection"): mock_create_connection, + (socket, "getaddrinfo"): mock_getaddrinfo, + (socket, "gethostbyname"): mock_gethostbyname, + (socket, "gethostname"): mock_gethostname, + (socket, "inet_pton"): mock_inet_pton, + (socket, "SocketType"): MocketSocket, + (socket, "socketpair"): mock_socketpair, + # stdlib: ssl + (ssl, "SSLContext"): MocketSSLContext, + (ssl, "wrap_socket"): mock_wrap_socket, # python < 3.12.0 + # urllib3 + (urllib3.connection, "match_hostname"): mock_urllib3_match_hostname, + (urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "wrap_socket"): mock_urllib3_ssl_wrap_socket, # urllib3 < 2 + } + + for (module, name), new_value in patches.items(): + _patch(module, name, new_value) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import extract_from_urllib3 + + extract_from_urllib3() + + +def disable() -> None: + for module, name in list(_patches_restore.keys()): + _restore(module, name) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import inject_into_urllib3 + + inject_into_urllib3() diff --git a/mocket/io.py b/mocket/io.py new file mode 100644 index 00000000..0334410b --- /dev/null +++ b/mocket/io.py @@ -0,0 +1,17 @@ +import io +import os + +from mocket.mocket import Mocket + + +class MocketSocketIO(io.BytesIO): + def __init__(self, address) -> None: + self._address = address + super().__init__() + + def write(self, content): + super().write(content) + + _, w_fd = Mocket.get_pair(self._address) + if w_fd: + os.write(w_fd, content) diff --git a/mocket/mocket.py b/mocket/mocket.py index dcdab533..a01a7b46 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,437 +1,60 @@ +from __future__ import annotations + import collections -import collections.abc as collections_abc -import contextlib -import errno -import hashlib import itertools -import json import os -import select -import socket -import ssl -from datetime import datetime, timedelta -from json.decoder import JSONDecodeError -from typing import Optional, Tuple - -import urllib3 -from urllib3.connection import match_hostname as urllib3_match_hostname -from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket - -try: - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket -except ImportError: - urllib3_wrap_socket = None - - -from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type -from .utils import ( - MocketMode, - MocketSocketCore, - get_mocketize, - hexdump, - hexload, -) - -xxh32 = None -try: - from xxhash import xxh32 -except ImportError: # pragma: no cover - with contextlib.suppress(ImportError): - from xxhash_cffi import xxh32 -hasher = xxh32 or hashlib.md5 - -try: # pragma: no cover - from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 - - pyopenssl_override = True -except ImportError: - pyopenssl_override = False - -true_socket = socket.socket -true_create_connection = socket.create_connection -true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_getaddrinfo = socket.getaddrinfo -true_socketpair = socket.socketpair -true_ssl_wrap_socket = getattr( - ssl, "wrap_socket", None -) # from Py3.12 it's only under SSLContext -true_ssl_socket = ssl.SSLSocket -true_ssl_context = ssl.SSLContext -true_inet_pton = socket.inet_pton -true_urllib3_wrap_socket = urllib3_wrap_socket -true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket -true_urllib3_match_hostname = urllib3_match_hostname - - -class SuperFakeSSLContext: - """For Python 3.6 and newer.""" - - class FakeSetter(int): - def __set__(self, *args): - pass - - minimum_version = FakeSetter() - options = FakeSetter() - verify_mode = FakeSetter() - verify_flags = FakeSetter() - - -class FakeSSLContext(SuperFakeSSLContext): - DUMMY_METHODS = ( - "load_default_certs", - "load_verify_locations", - "set_alpn_protocols", - "set_ciphers", - "set_default_verify_paths", - ) - sock = None - post_handshake_auth = None - _check_hostname = False - - @property - def check_hostname(self): - return self._check_hostname - - @check_hostname.setter - def check_hostname(self, _): - self._check_hostname = False - - def __init__(self, *args, **kwargs): - self._set_dummy_methods() - - def _set_dummy_methods(self): - def dummy_method(*args, **kwargs): - pass - - for m in self.DUMMY_METHODS: - setattr(self, m, dummy_method) - - @staticmethod - def wrap_socket(sock, *args, **kwargs): - sock.kwargs = kwargs - sock._secure_socket = True - return sock - - @staticmethod - def wrap_bio(incoming, outcoming, *args, **kwargs): - ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] - return ssl_obj - - -def create_connection(address, timeout=None, source_address=None): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) - if timeout: - s.settimeout(timeout) - s.connect(address) - return s - - -def socketpair(*args, **kwargs): - """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" - import _socket - - return _socket.socketpair(*args, **kwargs) - - -def _hash_request(h, req): - return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() - - -class MocketSocket: - timeout = None - _fd = None - family = None - type = None - proto = None - _host = None - _port = None - _address = None - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION - _mode = None - _bufsize = None - _secure_socket = False - _did_handshake = False - _sent_non_empty_bytes = False - _io = None - - def __init__( - self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs - ): - self.true_socket = true_socket(family, type, proto) - self._buflen = 65536 - self._entry = None - self.family = int(family) - self.type = int(type) - self.proto = int(proto) - self._truesocket_recording_dir = None - self.kwargs = kwargs - - def __str__(self): - return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - @property - def io(self): - if self._io is None: - self._io = MocketSocketCore((self._host, self._port)) - return self._io - - def fileno(self): - address = (self._host, self._port) - r_fd, _ = Mocket.get_pair(address) - if not r_fd: - r_fd, w_fd = os.pipe() - Mocket.set_pair(address, (r_fd, w_fd)) - return r_fd - - def gettimeout(self): - return self.timeout - - def setsockopt(self, family, type, proto): - self.family = family - self.type = type - self.proto = proto - - if self.true_socket: - self.true_socket.setsockopt(family, type, proto) - - def settimeout(self, timeout): - self.timeout = timeout - - @staticmethod - def getsockopt(level, optname, buflen=None): - return socket.SOCK_STREAM - - def do_handshake(self): - self._did_handshake = True - - def getpeername(self): - return self._address - - def setblocking(self, block): - self.settimeout(None) if block else self.settimeout(0.0) - - def getblocking(self): - return self.gettimeout() is None - - def getsockname(self): - return socket.gethostbyname(self._address[0]), self._address[1] - - def getpeercert(self, *args, **kwargs): - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self): - return self - - def write(self, data): - return self.send(encode_to_bytes(data)) - - def connect(self, address): - self._address = self._host, self._port = address - Mocket._address = address - - def makefile(self, mode="r", bufsize=-1): - self._mode = mode - self._bufsize = bufsize - return self.io - - def get_entry(self, data): - return Mocket.get_entry(self._host, self._port, data) +from pathlib import Path +from typing import TYPE_CHECKING, ClassVar - def sendall(self, data, entry=None, *args, **kwargs): - if entry is None: - entry = self.get_entry(data) +import mocket.inject +from mocket.recording import MocketRecordStorage - if entry: - consume_response = entry.collect(data) - response = entry.get_response() if consume_response is not False else None - else: - response = self.true_sendall(data, *args, **kwargs) +# NOTE this is here for backwards-compat to keep old import-paths working +# from mocket.socket import MocketSocket as MocketSocket - if response is not None: - self.io.seek(0) - self.io.write(response) - self.io.truncate() - self.io.seek(0) +if TYPE_CHECKING: + from mocket.entry import MocketEntry + from mocket.types import Address - def read(self, buffersize): - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - def recv_into(self, buffer, buffersize=None, flags=None): - if hasattr(buffer, "write"): - return buffer.write(self.read(buffersize)) - # buffer is a memoryview - data = self.read(buffersize) - if data: - buffer[: len(data)] = data - return len(data) - - def recv(self, buffersize, flags=None): - r_fd, _ = Mocket.get_pair((self._host, self._port)) - if r_fd: - return os.read(r_fd, buffersize) - data = self.read(buffersize) - if data: - return data - # used by Redis mock - exc = BlockingIOError() - exc.errno = errno.EWOULDBLOCK - exc.args = (0,) - raise exc - - def true_sendall(self, data, *args, **kwargs): - if not MocketMode().is_allowed((self._host, self._port)): - MocketMode.raise_not_allowed() - - req = decode_from_bytes(data) - # make request unique again - req_signature = _hash_request(hasher, req) - # port should be always a string - port = text_type(self._port) - - # prepare responses dictionary - responses = {} +class Mocket: + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} + _address: ClassVar[Address] = (None, None) + _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) + _requests: ClassVar[list] = [] + _record_storage: ClassVar[MocketRecordStorage | None] = None - if Mocket.get_truesocket_recording_dir(): - path = os.path.join( - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" + @classmethod + def enable( + cls, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + ) -> None: + if namespace is None: + namespace = str(id(cls._entries)) + + if truesocket_recording_dir is not None: + recording_dir = Path(truesocket_recording_dir) + + if not recording_dir.is_dir(): + # JSON dumps will be saved here + raise AssertionError + + cls._record_storage = MocketRecordStorage( + directory=recording_dir, + namespace=namespace, ) - # check if there's already a recorded session dumped to a JSON file - try: - with open(path) as f: - responses = json.load(f) - # if not, create a new dictionary - except (FileNotFoundError, JSONDecodeError): - pass - - try: - try: - response_dict = responses[self._host][port][req_signature] - except KeyError: - if hasher is not hashlib.md5: - # Fallback for backwards compatibility - req_signature = _hash_request(hashlib.md5, req) - response_dict = responses[self._host][port][req_signature] - else: - raise - except KeyError: - # preventing next KeyError exceptions - responses.setdefault(self._host, {}) - responses[self._host].setdefault(port, {}) - responses[self._host][port].setdefault(req_signature, {}) - response_dict = responses[self._host][port][req_signature] - - # try to get the response from the dictionary - try: - encoded_response = hexload(response_dict["response"]) - # if not available, call the real sendall - except KeyError: - host, port = self._host, self._port - host = true_gethostbyname(host) - - if isinstance(self.true_socket, true_socket) and self._secure_socket: - self.true_socket = true_urllib3_ssl_wrap_socket( - self.true_socket, - **self.kwargs, - ) - - with contextlib.suppress(OSError, ValueError): - # already connected - self.true_socket.connect((host, port)) - self.true_socket.sendall(data, *args, **kwargs) - encoded_response = b"" - # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 - while True: - more_to_read = select.select([self.true_socket], [], [], 0.1)[0] - if not more_to_read and encoded_response: - break - new_content = self.true_socket.recv(self._buflen) - if not new_content: - break - encoded_response += new_content - - # dump the resulting dictionary to a JSON file - if Mocket.get_truesocket_recording_dir(): - # update the dictionary with request and response lines - response_dict["request"] = req - response_dict["response"] = hexdump(encoded_response) - - with open(path, mode="w") as f: - f.write( - decode_from_bytes( - json.dumps(responses, indent=4, sort_keys=True) - ) - ) - - # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO - return encoded_response - - def send(self, data, *args, **kwargs): # pragma: no cover - entry = self.get_entry(data) - if not entry or (entry and self._entry != entry): - kwargs["entry"] = entry - self.sendall(data, *args, **kwargs) - else: - req = Mocket.last_request() - if hasattr(req, "add_data"): - req.add_data(data) - self._entry = entry - return len(data) - def close(self): - if self.true_socket and not self.true_socket._closed: - self.true_socket.close() - self._fd = None + mocket.inject.enable() - def __getattr__(self, name): - """Do nothing catchall function, for methods like shutdown()""" - - def do_nothing(*args, **kwargs): - pass - - return do_nothing + @classmethod + def disable(cls) -> None: + cls.reset() - -class Mocket: - _socket_pairs = {} - _address = (None, None) - _entries = collections.defaultdict(list) - _requests = [] - _namespace = text_type(id(_entries)) - _truesocket_recording_dir = None + mocket.inject.disable() @classmethod - def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: + def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: """ Given the id() of the caller, return a pair of file descriptors as a tuple of two integers: (, ) @@ -439,7 +62,7 @@ def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: return cls._socket_pairs.get(address, (None, None)) @classmethod - def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: + def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: """ Store a pair of file descriptors under the key `id_` as a tuple of two integers: (, ) @@ -447,292 +70,66 @@ def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: cls._socket_pairs[address] = pair @classmethod - def register(cls, *entries): + def register(cls, *entries: MocketEntry) -> None: for entry in entries: cls._entries[entry.location].append(entry) @classmethod - def get_entry(cls, host, port, data): - host = host or Mocket._address[0] - port = port or Mocket._address[1] + def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: + host = host or cls._address[0] + port = port or cls._address[1] entries = cls._entries.get((host, port), []) for entry in entries: if entry.can_handle(data): return entry + return None @classmethod - def collect(cls, data): - cls.request_list().append(data) + def collect(cls, data) -> None: + cls._requests.append(data) @classmethod - def reset(cls): + def reset(cls) -> None: for r_fd, w_fd in cls._socket_pairs.values(): os.close(r_fd) os.close(w_fd) cls._socket_pairs = {} cls._entries = collections.defaultdict(list) cls._requests = [] + cls._record_storage = None @classmethod def last_request(cls): if cls.has_requests(): - return cls.request_list()[-1] + return cls._requests[-1] @classmethod def request_list(cls): return cls._requests @classmethod - def remove_last_request(cls): + def remove_last_request(cls) -> None: if cls.has_requests(): del cls._requests[-1] @classmethod - def has_requests(cls): + def has_requests(cls) -> bool: return bool(cls.request_list()) - @staticmethod - def enable(namespace=None, truesocket_recording_dir=None): - Mocket._namespace = namespace - Mocket._truesocket_recording_dir = truesocket_recording_dir - - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): - # JSON dumps will be saved here - raise AssertionError - - socket.socket = socket.__dict__["socket"] = MocketSocket - socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket - socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = ( - create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" - socket.gethostbyname = socket.__dict__["gethostbyname"] = ( - lambda host: "127.0.0.1" - ) - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( - lambda host, port, family=None, socktype=None, proto=None, flags=None: [ - (2, 1, 6, "", (host, port)) - ] - ) - socket.socketpair = socket.__dict__["socketpair"] = socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type( - "\x7f\x00\x00\x01", "utf-8" - ) - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = lambda *args: None - if pyopenssl_override: # pragma: no cover - # Take out the pyopenssl version - use the default implementation - extract_from_urllib3() - - @staticmethod - def disable(): - socket.socket = socket.__dict__["socket"] = true_socket - socket._socketobject = socket.__dict__["_socketobject"] = true_socket - socket.SocketType = socket.__dict__["SocketType"] = true_socket - socket.create_connection = socket.__dict__["create_connection"] = ( - true_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = true_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = true_socketpair - if true_ssl_wrap_socket: - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context - socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - true_urllib3_wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - true_urllib3_ssl_wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = true_urllib3_match_hostname - Mocket.reset() - if pyopenssl_override: # pragma: no cover - # Put the pyopenssl version back in place - inject_into_urllib3() - @classmethod - def get_namespace(cls): - return cls._namespace + def get_namespace(cls) -> str | None: + if not cls._record_storage: + return None + return cls._record_storage.namespace @classmethod - def get_truesocket_recording_dir(cls): - return cls._truesocket_recording_dir + def get_truesocket_recording_dir(cls) -> str | None: + if not cls._record_storage: + return None + return str(cls._record_storage.directory) @classmethod - def assert_fail_if_entries_not_served(cls): + def assert_fail_if_entries_not_served(cls) -> None: """Mocket checks that all entries have been served at least once.""" if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") - - -class MocketEntry: - class Response(byte_type): - @property - def data(self): - return self - - response_index = 0 - request_cls = byte_type - response_cls = Response - responses = None - _served = None - - def __init__(self, location, responses): - self._served = False - self.location = location - - if not isinstance(responses, collections_abc.Iterable) or isinstance( - responses, basestring - ): - responses = [responses] - - if not responses: - self.responses = [self.response_cls(encode_to_bytes(""))] - else: - self.responses = [] - for r in responses: - if not isinstance(r, BaseException) and not getattr(r, "data", False): - if isinstance(r, text_type): - r = encode_to_bytes(r) - r = self.response_cls(r) - self.responses.append(r) - - def __repr__(self): - return f"{self.__class__.__name__}(location={self.location})" - - @staticmethod - def can_handle(data): - return True - - def collect(self, data): - req = self.request_cls(data) - Mocket.collect(req) - - def get_response(self): - response = self.responses[self.response_index] - if self.response_index < len(self.responses) - 1: - self.response_index += 1 - - self._served = True - - if isinstance(response, BaseException): - raise response - - return response.data - - -class Mocketizer: - def __init__( - self, - instance=None, - namespace=None, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - ): - self.instance = instance - self.truesocket_recording_dir = truesocket_recording_dir - self.namespace = namespace or text_type(id(self)) - MocketMode().STRICT = strict_mode - if strict_mode: - MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] - elif strict_mode_allowed: - raise ValueError( - "Allowed locations are only accepted when STRICT mode is active." - ) - - def enter(self): - Mocket.enable( - namespace=self.namespace, - truesocket_recording_dir=self.truesocket_recording_dir, - ) - if self.instance: - self.check_and_call("mocketize_setup") - - def __enter__(self): - self.enter() - return self - - def exit(self): - if self.instance: - self.check_and_call("mocketize_teardown") - Mocket.disable() - - def __exit__(self, type, value, tb): - self.exit() - - async def __aenter__(self, *args, **kwargs): - self.enter() - return self - - async def __aexit__(self, *args, **kwargs): - self.exit() - - def check_and_call(self, method_name): - method = getattr(self.instance, method_name, None) - if callable(method): - method() - - @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): - instance = args[0] if args else None - namespace = None - if truesocket_recording_dir: - namespace = ".".join( - ( - instance.__class__.__module__, - instance.__class__.__name__, - test.__name__, - ) - ) - - return Mocketizer( - instance, - namespace=namespace, - truesocket_recording_dir=truesocket_recording_dir, - strict_mode=strict_mode, - strict_mode_allowed=strict_mode_allowed, - ) - - -def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): - with Mocketizer.factory( - test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args - ): - return test(*args, **kwargs) - - -mocketize = get_mocketize(wrapper_=wrapper) diff --git a/mocket/mocks/__init__.py b/mocket/mocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/mockhttp.py b/mocket/mocks/mockhttp.py similarity index 98% rename from mocket/mockhttp.py rename to mocket/mocks/mockhttp.py index 5058328d..245a11af 100644 --- a/mocket/mockhttp.py +++ b/mocket/mocks/mockhttp.py @@ -7,8 +7,9 @@ from h11 import SERVER, Connection, Data from h11 import Request as H11Request -from .compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes -from .mocket import Mocket, MocketEntry +from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes +from mocket.entry import MocketEntry +from mocket.mocket import Mocket STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} CRLF = "\r\n" diff --git a/mocket/mockredis.py b/mocket/mocks/mockredis.py similarity index 86% rename from mocket/mockredis.py rename to mocket/mocks/mockredis.py index 1a0c51e2..fc386e2d 100644 --- a/mocket/mockredis.py +++ b/mocket/mocks/mockredis.py @@ -1,7 +1,12 @@ from itertools import chain -from .compat import byte_type, decode_from_bytes, encode_to_bytes, shsplit, text_type -from .mocket import Mocket, MocketEntry +from mocket.compat import ( + decode_from_bytes, + encode_to_bytes, + shsplit, +) +from mocket.entry import MocketEntry +from mocket.mocket import Mocket class Request: @@ -14,7 +19,7 @@ def __init__(self, data=None): self.data = Redisizer.redisize(data or OK) -class Redisizer(byte_type): +class Redisizer(bytes): @staticmethod def tokens(iterable): iterable = [encode_to_bytes(x) for x in iterable] @@ -30,15 +35,15 @@ def get_conversion(t): Redisizer.tokens(list(chain(*tuple(x.items())))) ), int: lambda x: f":{x}".encode(), - text_type: lambda x: "${}\r\n{}".format( - len(x.encode("utf-8")), x - ).encode("utf-8"), + str: lambda x: "${}\r\n{}".format(len(x.encode("utf-8")), x).encode( + "utf-8" + ), list: lambda x: b"\r\n".join(Redisizer.tokens(x)), }[t] if isinstance(data, Redisizer): return data - if isinstance(data, byte_type): + if isinstance(data, bytes): data = decode_from_bytes(data) return Redisizer(get_conversion(data.__class__)(data) + b"\r\n") diff --git a/mocket/mode.py b/mocket/mode.py new file mode 100644 index 00000000..e1da7955 --- /dev/null +++ b/mocket/mode.py @@ -0,0 +1,45 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +from mocket.exceptions import StrictMocketException +from mocket.mocket import Mocket + +if TYPE_CHECKING: # pragma: no cover + from typing import NoReturn + + +class MocketMode: + __shared_state: ClassVar[dict[str, Any]] = {} + STRICT: ClassVar = None + STRICT_ALLOWED: ClassVar = None + + def __init__(self) -> None: + self.__dict__ = self.__shared_state + + def is_allowed(self, location: str | tuple[str, int]) -> bool: + """ + Checks if (`host`, `port`) or at least `host` + are allowed locations to perform real `socket` calls + """ + if not self.STRICT: + return True + + host_allowed = False + if isinstance(location, tuple): + host_allowed = location[0] in self.STRICT_ALLOWED + return host_allowed or location in self.STRICT_ALLOWED + + @staticmethod + def raise_not_allowed() -> NoReturn: + current_entries = [ + (location, "\n ".join(map(str, entries))) + for location, entries in Mocket._entries.items() + ] + formatted_entries = "\n".join( + [f" {location}:\n {entries}" for location, entries in current_entries] + ) + raise StrictMocketException( + "Mocket tried to use the real `socket` module while STRICT mode was active.\n" + f"Registered entries:\n{formatted_entries}" + ) diff --git a/mocket/plugins/aiohttp_connector.py b/mocket/plugins/aiohttp_connector.py index 353c3af7..cde5019a 100644 --- a/mocket/plugins/aiohttp_connector.py +++ b/mocket/plugins/aiohttp_connector.py @@ -1,6 +1,6 @@ import contextlib -from mocket import FakeSSLContext +from mocket import MocketSSLContext with contextlib.suppress(ModuleNotFoundError): from aiohttp import ClientRequest @@ -14,5 +14,5 @@ class MocketTCPConnector(TCPConnector): slightly patching the `ClientSession` while testing. """ - def _get_ssl_context(self, req: ClientRequest) -> FakeSSLContext: - return FakeSSLContext() + def _get_ssl_context(self, req: ClientRequest) -> MocketSSLContext: + return MocketSSLContext() diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index 9d61ae2e..fac61840 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -1,6 +1,7 @@ -from mocket import Mocket, mocketize +from mocket import mocketize from mocket.async_mocket import async_mocketize -from mocket.compat import ENCODING, byte_type, text_type +from mocket.compat import ENCODING +from mocket.mocket import Mocket from mocket.mockhttp import Entry as MocketHttpEntry from mocket.mockhttp import Request as MocketHttpRequest from mocket.mockhttp import Response as MocketHttpResponse @@ -129,6 +130,6 @@ def __getattr__(self, name): "HEAD", "PATCH", "register_uri", - "text_type", - "byte_type", + "str", + "bytes", ) diff --git a/mocket/recording.py b/mocket/recording.py new file mode 100644 index 00000000..97d2adbe --- /dev/null +++ b/mocket/recording.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import contextlib +import hashlib +import json +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path + +from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.types import Address +from mocket.utils import hexdump, hexload + +hash_function = hashlib.md5 + +with contextlib.suppress(ImportError): + from xxhash_cffi import xxh32 as xxhash_cffi_xxh32 + + hash_function = xxhash_cffi_xxh32 + +with contextlib.suppress(ImportError): + from xxhash import xxh32 as xxhash_xxh32 + + hash_function = xxhash_xxh32 + + +def _hash_prepare_request(data: bytes) -> bytes: + _data = decode_from_bytes(data) + return encode_to_bytes("".join(sorted(_data.split("\r\n")))) + + +def _hash_request(data: bytes) -> str: + _data = _hash_prepare_request(data) + return hash_function(_data).hexdigest() + + +def _hash_request_fallback(data: bytes) -> str: + _data = _hash_prepare_request(data) + return hashlib.md5(_data).hexdigest() + + +@dataclass +class MocketRecord: + host: str + port: int + request: bytes + response: bytes + + +class MocketRecordStorage: + def __init__(self, directory: Path, namespace: str) -> None: + self._directory = directory + self._namespace = namespace + self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = ( + defaultdict(defaultdict) + ) + + self._load() + + @property + def directory(self) -> Path: + return self._directory + + @property + def namespace(self) -> str: + return self._namespace + + @property + def file(self) -> Path: + return self._directory / f"{self._namespace}.json" + + def _load(self) -> None: + if not self.file.exists(): + return + + json_data = self.file.read_text() + records = json.loads(json_data) + for host, port_signature_record in records.items(): + for port, signature_record in port_signature_record.items(): + for signature, record in signature_record.items(): + # NOTE backward-compat + try: + request_data = hexload(record["request"]) + except ValueError: + request_data = record["request"] + + self._records[(host, int(port))][signature] = MocketRecord( + host=host, + port=port, + request=request_data, + response=hexload(record["response"]), + ) + + def _save(self) -> None: + data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict( + lambda: defaultdict(defaultdict) + ) + for address, signature_record in self._records.items(): + host, port = address + for signature, record in signature_record.items(): + data[host][str(port)][signature] = dict( + request=decode_from_bytes(record.request), + response=hexdump(record.response), + ) + + json_data = json.dumps(data, indent=4, sort_keys=True) + self.file.parent.mkdir(exist_ok=True) + self.file.write_text(json_data) + + def get_records(self, address: Address) -> list[MocketRecord]: + return list(self._records[address].values()) + + def get_record(self, address: Address, request: bytes) -> MocketRecord | None: + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + return self._records[address].get(request_signature_fallback) + + request_signature = _hash_request(request) + if request_signature in self._records[address]: + return self._records[address][request_signature] + + return None + + def put_record( + self, + address: Address, + request: bytes, + response: bytes, + ) -> None: + host, port = address + record = MocketRecord( + host=host, + port=port, + request=request, + response=response, + ) + + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + self._records[address][request_signature_fallback] = record + return + + request_signature = _hash_request(request) + self._records[address][request_signature] = record + self._save() diff --git a/mocket/socket.py b/mocket/socket.py new file mode 100644 index 00000000..3b1862e2 --- /dev/null +++ b/mocket/socket.py @@ -0,0 +1,289 @@ +from __future__ import annotations + +import contextlib +import errno +import os +import select +import socket +from types import TracebackType +from typing import Any, Type + +from typing_extensions import Self + +from mocket.entry import MocketEntry +from mocket.io import MocketSocketIO +from mocket.mocket import Mocket +from mocket.mode import MocketMode +from mocket.types import ( + Address, + ReadableBuffer, + WriteableBuffer, + _RetAddress, +) + +true_gethostbyname = socket.gethostbyname +true_socket = socket.socket + + +def mock_create_connection(address, timeout=None, source_address=None): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) + if timeout: + s.settimeout(timeout) + s.connect(address) + return s + + +def mock_getaddrinfo( + host: str, + port: int, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[tuple[int, int, int, str, tuple[str, int]]]: + return [(2, 1, 6, "", (host, port))] + + +def mock_gethostbyname(hostname: str) -> str: + return "127.0.0.1" + + +def mock_gethostname() -> str: + return "localhost" + + +def mock_inet_pton(address_family: int, ip_string: str) -> bytes: + return bytes("\x7f\x00\x00\x01", "utf-8") + + +def mock_socketpair(*args, **kwargs): + """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" + import _socket + + return _socket.socketpair(*args, **kwargs) + + +class MocketSocket: + def __init__( + self, + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, + **kwargs: Any, + ) -> None: + self._family = family + self._type = type + self._proto = proto + + self._kwargs = kwargs + self._true_socket = true_socket(family, type, proto) + + self._buflen = 65536 + self._timeout: float | None = None + + self._host = None + self._port = None + self._address = None + + self._io = None + self._entry = None + + def __str__(self) -> str: + return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" + + def __enter__(self) -> Self: + return self + + def __exit__( + self, + type_: Type[BaseException] | None, # noqa: UP006 + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: + self.close() + + @property + def family(self) -> int: + return self._family + + @property + def type(self) -> int: + return self._type + + @property + def proto(self) -> int: + return self._proto + + @property + def io(self) -> MocketSocketIO: + if self._io is None: + self._io = MocketSocketIO((self._host, self._port)) + return self._io + + def fileno(self) -> int: + address = (self._host, self._port) + r_fd, _ = Mocket.get_pair(address) + if not r_fd: + r_fd, w_fd = os.pipe() + Mocket.set_pair(address, (r_fd, w_fd)) + return r_fd + + def gettimeout(self) -> float | None: + return self._timeout + + # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` + def setsockopt(self, family: int, type: int, proto: int) -> None: + self._family = family + self._type = type + self._proto = proto + + if self._true_socket: + self._true_socket.setsockopt(family, type, proto) + + def settimeout(self, timeout: float | None) -> None: + self._timeout = timeout + + @staticmethod + def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: + return socket.SOCK_STREAM + + def getpeername(self) -> _RetAddress: + return self._address + + def setblocking(self, block: bool) -> None: + self.settimeout(None) if block else self.settimeout(0.0) + + def getblocking(self) -> bool: + return self.gettimeout() is None + + def getsockname(self) -> _RetAddress: + return true_gethostbyname(self._address[0]), self._address[1] + + def connect(self, address: Address) -> None: + self._address = self._host, self._port = address + Mocket._address = address + + def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO: + return self.io + + def get_entry(self, data: bytes) -> MocketEntry | None: + return Mocket.get_entry(self._host, self._port, data) + + def sendall(self, data, entry=None, *args, **kwargs): + if entry is None: + entry = self.get_entry(data) + + if entry: + consume_response = entry.collect(data) + response = entry.get_response() if consume_response is not False else None + else: + response = self.true_sendall(data, *args, **kwargs) + + if response is not None: + self.io.seek(0) + self.io.write(response) + self.io.truncate() + self.io.seek(0) + + def recv_into( + self, + buffer: WriteableBuffer, + buffersize: int | None = None, + flags: int | None = None, + ) -> int: + if hasattr(buffer, "write"): + return buffer.write(self.recv(buffersize)) + + # buffer is a memoryview + if buffersize is None: + buffersize = len(buffer) + + data = self.recv(buffersize) + if data: + buffer[: len(data)] = data + return len(data) + + def recv(self, buffersize: int, flags: int | None = None) -> bytes: + r_fd, _ = Mocket.get_pair((self._host, self._port)) + if r_fd: + return os.read(r_fd, buffersize) + data = self.io.read(buffersize) + if data: + return data + # used by Redis mock + exc = BlockingIOError() + exc.errno = errno.EWOULDBLOCK + exc.args = (0,) + raise exc + + def true_sendall(self, data: bytes, *args: Any, **kwargs: Any) -> bytes: + if not MocketMode().is_allowed(self._address): + MocketMode.raise_not_allowed() + + # try to get the response from recordings + if Mocket._record_storage: + record = Mocket._record_storage.get_record( + address=self._address, + request=data, + ) + if record is not None: + return record.response + + host, port = self._address + host = true_gethostbyname(host) + + with contextlib.suppress(OSError, ValueError): + # already connected + self._true_socket.connect((host, port)) + + self._true_socket.sendall(data, *args, **kwargs) + response = b"" + # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 + while True: + more_to_read = select.select([self._true_socket], [], [], 0.1)[0] + if not more_to_read and response: + break + new_content = self._true_socket.recv(self._buflen) + if not new_content: + break + response += new_content + + # store request+response in recordings + if Mocket._record_storage: + Mocket._record_storage.put_record( + address=self._address, + request=data, + response=response, + ) + + return response + + def send( + self, + data: ReadableBuffer, + *args: Any, + **kwargs: Any, + ) -> int: # pragma: no cover + entry = self.get_entry(data) + if not entry or (entry and self._entry != entry): + kwargs["entry"] = entry + self.sendall(data, *args, **kwargs) + else: + req = Mocket.last_request() + if hasattr(req, "add_data"): + req.add_data(data) + self._entry = entry + return len(data) + + def close(self) -> None: + if self._true_socket and not self._true_socket._closed: + self._true_socket.close() + + def __getattr__(self, name: str) -> Any: + """Do nothing catchall function, for methods like shutdown()""" + + def do_nothing(*args: Any, **kwargs: Any) -> Any: + pass + + return do_nothing diff --git a/mocket/ssl/__init__.py b/mocket/ssl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py new file mode 100644 index 00000000..6d5e7307 --- /dev/null +++ b/mocket/ssl/context.py @@ -0,0 +1,78 @@ +from __future__ import annotations + +from typing import Any + +from mocket.socket import MocketSocket +from mocket.ssl.socket import MocketSSLSocket + + +class _MocketSSLContext: + """For Python 3.6 and newer.""" + + class FakeSetter(int): + def __set__(self, *args: Any) -> None: + pass + + minimum_version = FakeSetter() + options = FakeSetter() + verify_mode = FakeSetter() + verify_flags = FakeSetter() + + +class MocketSSLContext(_MocketSSLContext): + DUMMY_METHODS = ( + "load_default_certs", + "load_verify_locations", + "set_alpn_protocols", + "set_ciphers", + "set_default_verify_paths", + ) + sock = None + post_handshake_auth = None + _check_hostname = False + + @property + def check_hostname(self) -> bool: + return self._check_hostname + + @check_hostname.setter + def check_hostname(self, _: bool) -> None: + self._check_hostname = False + + def __init__(self, *args: Any, **kwargs: Any) -> None: + self._set_dummy_methods() + + def _set_dummy_methods(self) -> None: + def dummy_method(*args: Any, **kwargs: Any) -> Any: + pass + + for m in self.DUMMY_METHODS: + setattr(self, m, dummy_method) + + def wrap_socket( + self, + sock: MocketSocket, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: + return MocketSSLSocket._create(sock, *args, **kwargs) + + def wrap_bio( + self, + incoming: Any, # _ssl.MemoryBIO + outgoing: Any, # _ssl.MemoryBIO + server_side: bool = False, + server_hostname: str | bytes | None = None, + ) -> MocketSSLSocket: + ssl_obj = MocketSSLSocket() + ssl_obj._host = server_hostname + return ssl_obj + + +def mock_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py new file mode 100644 index 00000000..6dcd7817 --- /dev/null +++ b/mocket/ssl/socket.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import ssl +from datetime import datetime, timedelta +from ssl import Options +from typing import Any + +from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket +from mocket.socket import MocketSocket +from mocket.types import _PeerCertRetDictType + + +class MocketSSLSocket(MocketSocket): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self._did_handshake = False + self._sent_non_empty_bytes = False + self._original_socket: MocketSocket = self + + def read(self, buffersize: int | None = None) -> bytes: + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def write(self, data: bytes) -> int | None: + return self.send(encode_to_bytes(data)) + + def do_handshake(self) -> None: + self._did_handshake = True + + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def ciper(self) -> tuple[str, str, str]: + return "ADH", "AES256", "SHA" + + def compression(self) -> Options: + return ssl.OP_NO_COMPRESSION + + def unwrap(self) -> MocketSocket: + return self._original_socket + + @classmethod + def _create( + cls, + sock: MocketSocket, + ssl_context: ssl.SSLContext | None = None, + server_hostname: str | None = None, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: + ssl_socket = MocketSSLSocket() + ssl_socket._original_socket = sock + ssl_socket._true_socket = sock._true_socket + + if ssl_context: + ssl_socket._true_socket = ssl_context.wrap_socket( + sock=ssl_socket._true_socket, + server_hostname=server_hostname, + ) + + ssl_socket._kwargs = kwargs + + ssl_socket._timeout = sock._timeout + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._io = sock._io + ssl_socket._entry = sock._entry + + return ssl_socket diff --git a/mocket/types.py b/mocket/types.py new file mode 100644 index 00000000..562648c7 --- /dev/null +++ b/mocket/types.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import Any, Dict, Tuple, Union + +from typing_extensions import Buffer, TypeAlias + +Address = Tuple[str, int] + +# adapted from typeshed/stdlib/_typeshed/__init__.pyi +WriteableBuffer: TypeAlias = Buffer +ReadableBuffer: TypeAlias = Buffer + +# from typeshed/stdlib/_socket.pyi +_Address: TypeAlias = Union[Tuple[Any, ...], str, ReadableBuffer] +_RetAddress: TypeAlias = Any + +# from typeshed/stdlib/ssl.pyi +_PCTRTT: TypeAlias = Tuple[Tuple[str, str], ...] +_PCTRTTT: TypeAlias = Tuple[_PCTRTT, ...] +_PeerCertRetDictType: TypeAlias = Dict[str, Union[str, _PCTRTTT, _PCTRTT]] diff --git a/mocket/urllib3.py b/mocket/urllib3.py new file mode 100644 index 00000000..e89bc7b5 --- /dev/null +++ b/mocket/urllib3.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import Any + +from mocket.socket import MocketSocket +from mocket.ssl.context import MocketSSLContext +from mocket.ssl.socket import MocketSSLSocket + + +def mock_match_hostname(*args: Any) -> None: + return None + + +def mock_ssl_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/utils.py b/mocket/utils.py index 9efd6ad9..ab293776 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,34 +1,9 @@ from __future__ import annotations import binascii -import io -import os -import ssl -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import Callable -from .compat import decode_from_bytes, encode_to_bytes -from .exceptions import StrictMocketException - -if TYPE_CHECKING: # pragma: no cover - from typing import NoReturn - - -SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 - - -class MocketSocketCore(io.BytesIO): - def __init__(self, address) -> None: - self._address = address - super().__init__() - - def write(self, content): - from mocket import Mocket - - super().write(content) - - _, w_fd = Mocket.get_pair(self._address) - if w_fd: - os.write(w_fd, content) +from mocket.compat import decode_from_bytes, encode_to_bytes def hexdump(binary_string: bytes) -> str: @@ -46,7 +21,10 @@ def hexload(string: str) -> bytes: True """ string_no_spaces = "".join(string.split()) - return encode_to_bytes(binascii.unhexlify(string_no_spaces)) + try: + return encode_to_bytes(binascii.unhexlify(string_no_spaces)) + except binascii.Error as e: + raise ValueError from e def get_mocketize(wrapper_: Callable) -> Callable: @@ -60,39 +38,8 @@ def get_mocketize(wrapper_: Callable) -> Callable: ) -class MocketMode: - __shared_state: ClassVar[dict[str, Any]] = {} - STRICT: ClassVar = None - STRICT_ALLOWED: ClassVar = None - - def __init__(self) -> None: - self.__dict__ = self.__shared_state - - def is_allowed(self, location: str | tuple[str, int]) -> bool: - """ - Checks if (`host`, `port`) or at least `host` - are allowed locations to perform real `socket` calls - """ - if not self.STRICT: - return True - - host_allowed = False - if isinstance(location, tuple): - host_allowed = location[0] in self.STRICT_ALLOWED - return host_allowed or location in self.STRICT_ALLOWED - - @staticmethod - def raise_not_allowed() -> NoReturn: - from .mocket import Mocket - - current_entries = [ - (location, "\n ".join(map(str, entries))) - for location, entries in Mocket._entries.items() - ] - formatted_entries = "\n".join( - [f" {location}:\n {entries}" for location, entries in current_entries] - ) - raise StrictMocketException( - "Mocket tried to use the real `socket` module while STRICT mode was active.\n" - f"Registered entries:\n{formatted_entries}" - ) +__all__ = ( + "get_mocketize", + "hexdump", + "hexload", +) diff --git a/pyproject.toml b/pyproject.toml index 77d1f5d4..5418304e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ test = [ "httpx", "pipfile", "build", - "twine", "fastapi", "aiohttp", "wait-for-it", diff --git a/tests/test_http.py b/tests/test_http.py index d516068b..afa31185 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -359,12 +359,12 @@ def test_sockets(self): sock = socket.socket(address[0], address[1], address[2]) sock.connect(address[-1]) - sock.write(f"{method} {path} HTTP/1.0\r\n") - sock.write(f"Host: {host}\r\n") - sock.write("Content-Type: application/json\r\n") - sock.write("Content-Length: %d\r\n" % len(data)) - sock.write("Connection: close\r\n\r\n") - sock.write(data) + sock.send(f"{method} {path} HTTP/1.0\r\n".encode()) + sock.send(f"Host: {host}\r\n".encode()) + sock.send(b"Content-Type: application/json\r\n") + sock.send(b"Content-Length: %d\r\n" % len(data)) + sock.send(b"Connection: close\r\n\r\n") + sock.send(data.encode()) sock.close() # Proof that worked. diff --git a/tests/test_mocket.py b/tests/test_mocket.py index 8d09f170..8810a5b9 100644 --- a/tests/test_mocket.py +++ b/tests/test_mocket.py @@ -222,6 +222,7 @@ def test_patch( @pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test") +@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') @pytest.mark.asyncio async def test_no_dangling_fds(): url = "http://httpbin.local/ip" diff --git a/tests/test_mode.py b/tests/test_mode.py index 2a764949..ea5905b0 100644 --- a/tests/test_mode.py +++ b/tests/test_mode.py @@ -4,7 +4,7 @@ from mocket import Mocketizer, mocketize from mocket.exceptions import StrictMocketException from mocket.mockhttp import Entry, Response -from mocket.utils import MocketMode +from mocket.mode import MocketMode @mocketize(strict_mode=True) diff --git a/tests/test_redis.py b/tests/test_redis.py index 50b9beac..fb6ec355 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -158,9 +158,11 @@ def setUp(self): self.rclient = redis.StrictRedis() def mocketize_setup(self): + Entry.register_response("CLIENT SETINFO LIB-NAME redis-py", OK) + Entry.register_response(f"CLIENT SETINFO LIB-VER {redis.__version__}", OK) Entry.register_response("FLUSHDB", OK) self.rclient.flushdb() - self.assertEqual(len(Mocket.request_list()), 1) + self.assertEqual(len(Mocket.request_list()), 3) Mocket.reset() @mocketize diff --git a/tests/test_socket.py b/tests/test_socket.py index 8a6e65ad..112a9089 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -2,7 +2,7 @@ import pytest -from mocket.mocket import MocketSocket +from mocket.socket import MocketSocket @pytest.mark.parametrize("blocking", (False, True)) pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy