From 8a4cc723627625c8c67769cf1b21681beb5b3239 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Sun, 20 Oct 2024 21:37:19 +0200 Subject: [PATCH 01/35] Update README.rst --- README.rst | 76 ++++++++++++++++++++++++------------------------------ 1 file changed, 34 insertions(+), 42 deletions(-) diff --git a/README.rst b/README.rst index 17a7801c..94b5b88c 100644 --- a/README.rst +++ b/README.rst @@ -284,52 +284,44 @@ Example: .. code-block:: python - class AioHttpEntryTestCase(TestCase): - @mocketize - def test_http_session(self): - url = 'http://httpbin.org/ip' - body = "asd" * 100 - Entry.single_register(Entry.GET, url, body=body, status=404) - Entry.single_register(Entry.POST, url, body=body*2, status=201) + # `aiohttp` creates SSLContext instances at import-time + # that's why Mocket would get stuck when dealing with HTTP + # Importing the module while Mocket is in control (inside a + # decorated test function or using its context manager would + # be enough for making it work), the alternative is using a + # custom TCPConnector which always return a FakeSSLContext + # from Mocket like this example is showing. + import aiohttp + import pytest - async def main(l): - async with aiohttp.ClientSession( - loop=l, timeout=aiohttp.ClientTimeout(total=3) - ) as session: - async with session.get(url) as get_response: - assert get_response.status == 404 - assert await get_response.text() == body + from mocket import async_mocketize + from mocket.mockhttp import Entry + from mocket.plugins.aiohttp_connector import MocketTCPConnector - async with session.post(url, data=body * 6) as post_response: - assert post_response.status == 201 - assert await post_response.text() == body * 2 - loop = asyncio.new_event_loop() - loop.run_until_complete(main(loop)) + @pytest.mark.asyncio + @async_mocketize + async def test_aiohttp(): + """ + The alternative to using the custom `connector` would be importing + `aiohttp` when Mocket is already in control (inside the decorated test). + """ + + url = "https://bar.foo/" + data = {"message": "Hello"} + + Entry.single_register( + Entry.GET, + url, + body=json.dumps(data), + headers={"content-type": "application/json"}, + ) - # or again with a unittest.IsolatedAsyncioTestCase - from mocket.async_mocket import async_mocketize - - class AioHttpEntryTestCase(IsolatedAsyncioTestCase): - @async_mocketize - async def test_http_session(self): - url = 'http://httpbin.org/ip' - body = "asd" * 100 - Entry.single_register(Entry.GET, url, body=body, status=404) - Entry.single_register(Entry.POST, url, body=body * 2, status=201) - - async with aiohttp.ClientSession( - timeout=aiohttp.ClientTimeout(total=3) - ) as session: - async with session.get(url) as get_response: - assert get_response.status == 404 - assert await get_response.text() == body - - async with session.post(url, data=body * 6) as post_response: - assert post_response.status == 201 - assert await post_response.text() == body * 2 - assert Mocket.last_request().method == 'POST' - assert Mocket.last_request().body == body * 6 + async with aiohttp.ClientSession( + timeout=aiohttp.ClientTimeout(total=3), connector=MocketTCPConnector() + ) as session, session.get(url) as response: + response = await response.json() + assert response == data Works well with others From 643083756377aea15d63b7f99cfe421583153a70 Mon Sep 17 00:00:00 2001 From: Wilhelm Klopp Date: Sun, 3 Nov 2024 09:08:59 +0000 Subject: [PATCH 02/35] Build pure python wheel (#260) --- Makefile | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index 62c52dc7..7ba0210e 100644 --- a/Makefile +++ b/Makefile @@ -39,8 +39,8 @@ safetest: export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; make test publish: clean install-test-requirements - uv run python3 -m build --sdist . - uv run twine upload --repository mocket dist/*.tar.gz + uv run python3 -m build --sdist --wheel . + uv run twine upload --repository mocket dist/ clean: rm -rf .coverage *.egg-info dist/ requirements.txt uv.lock || true From e6c0b9ef66927287af452cdd426f0bbbd453c69a Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Sun, 10 Nov 2024 09:16:17 +0100 Subject: [PATCH 03/35] Update main.yml --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c4b8987f..c4481efc 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -33,7 +33,7 @@ jobs: cache: 'pip' cache-dependency-path: | pyproject.toml - - uses: hoverkraft-tech/compose-action@v2.0.0 + - uses: hoverkraft-tech/compose-action@v2.0.2 with: compose-file: "./docker-compose.yml" down-flags: "--remove-orphans" From 3bf9686ce9e55a4d43a3e6c38789a46dd04b1769 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Sun, 10 Nov 2024 21:47:00 +0100 Subject: [PATCH 04/35] Update README.rst --- README.rst | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.rst b/README.rst index 94b5b88c..e68cbfd3 100644 --- a/README.rst +++ b/README.rst @@ -24,7 +24,7 @@ A socket mock framework Outside GitHub ============== -Mocket packages are available for `Arch Linux`_, `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course you can **pip install** it from `PyPI`_. +Mocket packages are available for `Arch Linux`_, `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course from `PyPI`_. .. _`Arch Linux`: https://archlinux.org/packages/extra/any/python-mocket/ .. _`openSUSE`: https://software.opensuse.org/search?baseproject=ALL&q=mocket From 7224addd86a6c363c170520c12a705f3e5892329 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sat, 16 Nov 2024 19:49:50 +0100 Subject: [PATCH 05/35] refactor: make imports absolute --- mocket/__init__.py | 4 ++-- mocket/async_mocket.py | 4 ++-- mocket/mocket.py | 10 ++++++++-- mocket/mockhttp.py | 4 ++-- mocket/mockredis.py | 10 ++++++++-- mocket/utils.py | 6 +++--- 6 files changed, 25 insertions(+), 13 deletions(-) diff --git a/mocket/__init__.py b/mocket/__init__.py index fb0434e9..f72917a0 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,5 +1,5 @@ -from .async_mocket import async_mocketize -from .mocket import FakeSSLContext, Mocket, MocketEntry, Mocketizer, mocketize +from mocket.async_mocket import async_mocketize +from mocket.mocket import FakeSSLContext, Mocket, MocketEntry, Mocketizer, mocketize __all__ = ( "async_mocketize", diff --git a/mocket/async_mocket.py b/mocket/async_mocket.py index 2970e0f4..c0f77253 100644 --- a/mocket/async_mocket.py +++ b/mocket/async_mocket.py @@ -1,5 +1,5 @@ -from .mocket import Mocketizer -from .utils import get_mocketize +from mocket.mocket import Mocketizer +from mocket.utils import get_mocketize async def wrapper( diff --git a/mocket/mocket.py b/mocket/mocket.py index dcdab533..aa2e29ad 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -23,8 +23,14 @@ urllib3_wrap_socket = None -from .compat import basestring, byte_type, decode_from_bytes, encode_to_bytes, text_type -from .utils import ( +from mocket.compat import ( + basestring, + byte_type, + decode_from_bytes, + encode_to_bytes, + text_type, +) +from mocket.utils import ( MocketMode, MocketSocketCore, get_mocketize, diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py index 5058328d..beb312c0 100644 --- a/mocket/mockhttp.py +++ b/mocket/mockhttp.py @@ -7,8 +7,8 @@ from h11 import SERVER, Connection, Data from h11 import Request as H11Request -from .compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes -from .mocket import Mocket, MocketEntry +from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes +from mocket.mocket import Mocket, MocketEntry STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} CRLF = "\r\n" diff --git a/mocket/mockredis.py b/mocket/mockredis.py index 1a0c51e2..6ae4ef39 100644 --- a/mocket/mockredis.py +++ b/mocket/mockredis.py @@ -1,7 +1,13 @@ from itertools import chain -from .compat import byte_type, decode_from_bytes, encode_to_bytes, shsplit, text_type -from .mocket import Mocket, MocketEntry +from mocket.compat import ( + byte_type, + decode_from_bytes, + encode_to_bytes, + shsplit, + text_type, +) +from mocket.mocket import Mocket, MocketEntry class Request: diff --git a/mocket/utils.py b/mocket/utils.py index 9efd6ad9..35cfcea8 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -6,8 +6,8 @@ import ssl from typing import TYPE_CHECKING, Any, Callable, ClassVar -from .compat import decode_from_bytes, encode_to_bytes -from .exceptions import StrictMocketException +from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.exceptions import StrictMocketException if TYPE_CHECKING: # pragma: no cover from typing import NoReturn @@ -83,7 +83,7 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool: @staticmethod def raise_not_allowed() -> NoReturn: - from .mocket import Mocket + from mocket.mocket import Mocket current_entries = [ (location, "\n ".join(map(str, entries))) From 1cf09ec3a64c7e9b1952ef9fb9f788177dcf24b0 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 13:41:15 +0100 Subject: [PATCH 06/35] refactor: remove old compat text_type, byte_type, basestring --- mocket/compat.py | 12 ++++-------- mocket/mocket.py | 26 +++++++++----------------- mocket/mockredis.py | 12 +++++------- mocket/plugins/httpretty/__init__.py | 6 +++--- 4 files changed, 21 insertions(+), 35 deletions(-) diff --git a/mocket/compat.py b/mocket/compat.py index 276ae0f0..1ac2fc89 100644 --- a/mocket/compat.py +++ b/mocket/compat.py @@ -9,21 +9,17 @@ ENCODING: Final[str] = os.getenv("MOCKET_ENCODING", "utf-8") -text_type = str -byte_type = bytes -basestring = (str,) - def encode_to_bytes(s: str | bytes, encoding: str = ENCODING) -> bytes: - if isinstance(s, text_type): + if isinstance(s, str): s = s.encode(encoding) - return byte_type(s) + return bytes(s) def decode_from_bytes(s: str | bytes, encoding: str = ENCODING) -> str: - if isinstance(s, byte_type): + if isinstance(s, bytes): s = codecs.decode(s, encoding, "ignore") - return text_type(s) + return str(s) def shsplit(s: str | bytes) -> list[str]: diff --git a/mocket/mocket.py b/mocket/mocket.py index aa2e29ad..81a42bfb 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -23,13 +23,7 @@ urllib3_wrap_socket = None -from mocket.compat import ( - basestring, - byte_type, - decode_from_bytes, - encode_to_bytes, - text_type, -) +from mocket.compat import decode_from_bytes, encode_to_bytes from mocket.utils import ( MocketMode, MocketSocketCore, @@ -323,7 +317,7 @@ def true_sendall(self, data, *args, **kwargs): # make request unique again req_signature = _hash_request(hasher, req) # port should be always a string - port = text_type(self._port) + port = str(self._port) # prepare responses dictionary responses = {} @@ -433,7 +427,7 @@ class Mocket: _address = (None, None) _entries = collections.defaultdict(list) _requests = [] - _namespace = text_type(id(_entries)) + _namespace = str(id(_entries)) _truesocket_recording_dir = None @classmethod @@ -524,7 +518,7 @@ def enable(namespace=None, truesocket_recording_dir=None): socket.socketpair = socket.__dict__["socketpair"] = socketpair ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: byte_type( + socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( "\x7f\x00\x00\x01", "utf-8" ) urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( @@ -598,13 +592,13 @@ def assert_fail_if_entries_not_served(cls): class MocketEntry: - class Response(byte_type): + class Response(bytes): @property def data(self): return self response_index = 0 - request_cls = byte_type + request_cls = bytes response_cls = Response responses = None _served = None @@ -613,9 +607,7 @@ def __init__(self, location, responses): self._served = False self.location = location - if not isinstance(responses, collections_abc.Iterable) or isinstance( - responses, basestring - ): + if not isinstance(responses, collections_abc.Iterable): responses = [responses] if not responses: @@ -624,7 +616,7 @@ def __init__(self, location, responses): self.responses = [] for r in responses: if not isinstance(r, BaseException) and not getattr(r, "data", False): - if isinstance(r, text_type): + if isinstance(r, str): r = encode_to_bytes(r) r = self.response_cls(r) self.responses.append(r) @@ -664,7 +656,7 @@ def __init__( ): self.instance = instance self.truesocket_recording_dir = truesocket_recording_dir - self.namespace = namespace or text_type(id(self)) + self.namespace = namespace or str(id(self)) MocketMode().STRICT = strict_mode if strict_mode: MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] diff --git a/mocket/mockredis.py b/mocket/mockredis.py index 6ae4ef39..4ed69e1f 100644 --- a/mocket/mockredis.py +++ b/mocket/mockredis.py @@ -1,11 +1,9 @@ from itertools import chain from mocket.compat import ( - byte_type, decode_from_bytes, encode_to_bytes, shsplit, - text_type, ) from mocket.mocket import Mocket, MocketEntry @@ -20,7 +18,7 @@ def __init__(self, data=None): self.data = Redisizer.redisize(data or OK) -class Redisizer(byte_type): +class Redisizer(bytes): @staticmethod def tokens(iterable): iterable = [encode_to_bytes(x) for x in iterable] @@ -36,15 +34,15 @@ def get_conversion(t): Redisizer.tokens(list(chain(*tuple(x.items())))) ), int: lambda x: f":{x}".encode(), - text_type: lambda x: "${}\r\n{}".format( - len(x.encode("utf-8")), x - ).encode("utf-8"), + str: lambda x: "${}\r\n{}".format(len(x.encode("utf-8")), x).encode( + "utf-8" + ), list: lambda x: b"\r\n".join(Redisizer.tokens(x)), }[t] if isinstance(data, Redisizer): return data - if isinstance(data, byte_type): + if isinstance(data, bytes): data = decode_from_bytes(data) return Redisizer(get_conversion(data.__class__)(data) + b"\r\n") diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index 9d61ae2e..d5e41e30 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -1,6 +1,6 @@ from mocket import Mocket, mocketize from mocket.async_mocket import async_mocketize -from mocket.compat import ENCODING, byte_type, text_type +from mocket.compat import ENCODING from mocket.mockhttp import Entry as MocketHttpEntry from mocket.mockhttp import Request as MocketHttpRequest from mocket.mockhttp import Response as MocketHttpResponse @@ -129,6 +129,6 @@ def __getattr__(self, name): "HEAD", "PATCH", "register_uri", - "text_type", - "byte_type", + "str", + "bytes", ) From dccdd3bb30947c484ab9a0cda520a71186a9e453 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 15:36:11 +0100 Subject: [PATCH 07/35] refactor: move MocketMode from mocket.utils to mocket.mode --- mocket/mocket.py | 2 +- mocket/mode.py | 46 ++++++++++++++++++++++++++++++++++++++++++++++ mocket/utils.py | 46 +++------------------------------------------- 3 files changed, 50 insertions(+), 44 deletions(-) create mode 100644 mocket/mode.py diff --git a/mocket/mocket.py b/mocket/mocket.py index 81a42bfb..3918de14 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -24,8 +24,8 @@ from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.mode import MocketMode from mocket.utils import ( - MocketMode, MocketSocketCore, get_mocketize, hexdump, diff --git a/mocket/mode.py b/mocket/mode.py new file mode 100644 index 00000000..3c0638e5 --- /dev/null +++ b/mocket/mode.py @@ -0,0 +1,46 @@ +from __future__ import annotations + +from typing import TYPE_CHECKING, Any, ClassVar + +from mocket.exceptions import StrictMocketException + +if TYPE_CHECKING: # pragma: no cover + from typing import NoReturn + + +class MocketMode: + __shared_state: ClassVar[dict[str, Any]] = {} + STRICT: ClassVar = None + STRICT_ALLOWED: ClassVar = None + + def __init__(self) -> None: + self.__dict__ = self.__shared_state + + def is_allowed(self, location: str | tuple[str, int]) -> bool: + """ + Checks if (`host`, `port`) or at least `host` + are allowed locations to perform real `socket` calls + """ + if not self.STRICT: + return True + + host_allowed = False + if isinstance(location, tuple): + host_allowed = location[0] in self.STRICT_ALLOWED + return host_allowed or location in self.STRICT_ALLOWED + + @staticmethod + def raise_not_allowed() -> NoReturn: + from mocket.mocket import Mocket + + current_entries = [ + (location, "\n ".join(map(str, entries))) + for location, entries in Mocket._entries.items() + ] + formatted_entries = "\n".join( + [f" {location}:\n {entries}" for location, entries in current_entries] + ) + raise StrictMocketException( + "Mocket tried to use the real `socket` module while STRICT mode was active.\n" + f"Registered entries:\n{formatted_entries}" + ) diff --git a/mocket/utils.py b/mocket/utils.py index 35cfcea8..5f065bfa 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -4,14 +4,12 @@ import io import os import ssl -from typing import TYPE_CHECKING, Any, Callable, ClassVar +from typing import Callable from mocket.compat import decode_from_bytes, encode_to_bytes -from mocket.exceptions import StrictMocketException - -if TYPE_CHECKING: # pragma: no cover - from typing import NoReturn +# NOTE this is here for backwards-compat to keep old import-paths working +from mocket.mode import MocketMode as MocketMode SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 @@ -58,41 +56,3 @@ def get_mocketize(wrapper_: Callable) -> Callable: wrapper_, kwsyntax=True, ) - - -class MocketMode: - __shared_state: ClassVar[dict[str, Any]] = {} - STRICT: ClassVar = None - STRICT_ALLOWED: ClassVar = None - - def __init__(self) -> None: - self.__dict__ = self.__shared_state - - def is_allowed(self, location: str | tuple[str, int]) -> bool: - """ - Checks if (`host`, `port`) or at least `host` - are allowed locations to perform real `socket` calls - """ - if not self.STRICT: - return True - - host_allowed = False - if isinstance(location, tuple): - host_allowed = location[0] in self.STRICT_ALLOWED - return host_allowed or location in self.STRICT_ALLOWED - - @staticmethod - def raise_not_allowed() -> NoReturn: - from mocket.mocket import Mocket - - current_entries = [ - (location, "\n ".join(map(str, entries))) - for location, entries in Mocket._entries.items() - ] - formatted_entries = "\n".join( - [f" {location}:\n {entries}" for location, entries in current_entries] - ) - raise StrictMocketException( - "Mocket tried to use the real `socket` module while STRICT mode was active.\n" - f"Registered entries:\n{formatted_entries}" - ) From 2e9b640564f1db78caffad91b223a324d22bfd5d Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 15:46:13 +0100 Subject: [PATCH 08/35] refactor: move Mocketizer and mocketize from mocket.mocket to mocket.mocketizer --- mocket/__init__.py | 3 +- mocket/async_mocket.py | 2 +- mocket/mocket.py | 92 --------------------------------------- mocket/mocketizer.py | 97 ++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 100 insertions(+), 94 deletions(-) create mode 100644 mocket/mocketizer.py diff --git a/mocket/__init__.py b/mocket/__init__.py index f72917a0..b8f1e032 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,5 +1,6 @@ from mocket.async_mocket import async_mocketize -from mocket.mocket import FakeSSLContext, Mocket, MocketEntry, Mocketizer, mocketize +from mocket.mocket import FakeSSLContext, Mocket, MocketEntry +from mocket.mocketizer import Mocketizer, mocketize __all__ = ( "async_mocketize", diff --git a/mocket/async_mocket.py b/mocket/async_mocket.py index c0f77253..709d225f 100644 --- a/mocket/async_mocket.py +++ b/mocket/async_mocket.py @@ -1,4 +1,4 @@ -from mocket.mocket import Mocketizer +from mocket.mocketizer import Mocketizer from mocket.utils import get_mocketize diff --git a/mocket/mocket.py b/mocket/mocket.py index 3918de14..f420c27d 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -27,7 +27,6 @@ from mocket.mode import MocketMode from mocket.utils import ( MocketSocketCore, - get_mocketize, hexdump, hexload, ) @@ -643,94 +642,3 @@ def get_response(self): raise response return response.data - - -class Mocketizer: - def __init__( - self, - instance=None, - namespace=None, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - ): - self.instance = instance - self.truesocket_recording_dir = truesocket_recording_dir - self.namespace = namespace or str(id(self)) - MocketMode().STRICT = strict_mode - if strict_mode: - MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] - elif strict_mode_allowed: - raise ValueError( - "Allowed locations are only accepted when STRICT mode is active." - ) - - def enter(self): - Mocket.enable( - namespace=self.namespace, - truesocket_recording_dir=self.truesocket_recording_dir, - ) - if self.instance: - self.check_and_call("mocketize_setup") - - def __enter__(self): - self.enter() - return self - - def exit(self): - if self.instance: - self.check_and_call("mocketize_teardown") - Mocket.disable() - - def __exit__(self, type, value, tb): - self.exit() - - async def __aenter__(self, *args, **kwargs): - self.enter() - return self - - async def __aexit__(self, *args, **kwargs): - self.exit() - - def check_and_call(self, method_name): - method = getattr(self.instance, method_name, None) - if callable(method): - method() - - @staticmethod - def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): - instance = args[0] if args else None - namespace = None - if truesocket_recording_dir: - namespace = ".".join( - ( - instance.__class__.__module__, - instance.__class__.__name__, - test.__name__, - ) - ) - - return Mocketizer( - instance, - namespace=namespace, - truesocket_recording_dir=truesocket_recording_dir, - strict_mode=strict_mode, - strict_mode_allowed=strict_mode_allowed, - ) - - -def wrapper( - test, - truesocket_recording_dir=None, - strict_mode=False, - strict_mode_allowed=None, - *args, - **kwargs, -): - with Mocketizer.factory( - test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args - ): - return test(*args, **kwargs) - - -mocketize = get_mocketize(wrapper_=wrapper) diff --git a/mocket/mocketizer.py b/mocket/mocketizer.py new file mode 100644 index 00000000..5a988c77 --- /dev/null +++ b/mocket/mocketizer.py @@ -0,0 +1,97 @@ +from mocket.mode import MocketMode +from mocket.utils import get_mocketize + + +class Mocketizer: + def __init__( + self, + instance=None, + namespace=None, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + ): + self.instance = instance + self.truesocket_recording_dir = truesocket_recording_dir + self.namespace = namespace or str(id(self)) + MocketMode().STRICT = strict_mode + if strict_mode: + MocketMode().STRICT_ALLOWED = strict_mode_allowed or [] + elif strict_mode_allowed: + raise ValueError( + "Allowed locations are only accepted when STRICT mode is active." + ) + + def enter(self): + from mocket import Mocket + + Mocket.enable( + namespace=self.namespace, + truesocket_recording_dir=self.truesocket_recording_dir, + ) + if self.instance: + self.check_and_call("mocketize_setup") + + def __enter__(self): + self.enter() + return self + + def exit(self): + if self.instance: + self.check_and_call("mocketize_teardown") + from mocket import Mocket + + Mocket.disable() + + def __exit__(self, type, value, tb): + self.exit() + + async def __aenter__(self, *args, **kwargs): + self.enter() + return self + + async def __aexit__(self, *args, **kwargs): + self.exit() + + def check_and_call(self, method_name): + method = getattr(self.instance, method_name, None) + if callable(method): + method() + + @staticmethod + def factory(test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args): + instance = args[0] if args else None + namespace = None + if truesocket_recording_dir: + namespace = ".".join( + ( + instance.__class__.__module__, + instance.__class__.__name__, + test.__name__, + ) + ) + + return Mocketizer( + instance, + namespace=namespace, + truesocket_recording_dir=truesocket_recording_dir, + strict_mode=strict_mode, + strict_mode_allowed=strict_mode_allowed, + ) + + +def wrapper( + test, + truesocket_recording_dir=None, + strict_mode=False, + strict_mode_allowed=None, + *args, + **kwargs, +): + with Mocketizer.factory( + test, truesocket_recording_dir, strict_mode, strict_mode_allowed, args + ): + return test(*args, **kwargs) + + +mocketize = get_mocketize(wrapper_=wrapper) From 1df405cb0901197093e35d8d5e109e95459f3c8d Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 16:04:28 +0100 Subject: [PATCH 09/35] refactor: move MocketEntry from mocket.mocket to mocket.entry --- mocket/__init__.py | 3 ++- mocket/entry.py | 59 +++++++++++++++++++++++++++++++++++++++++++++ mocket/mocket.py | 55 ------------------------------------------ mocket/mockhttp.py | 3 ++- mocket/mockredis.py | 3 ++- 5 files changed, 65 insertions(+), 58 deletions(-) create mode 100644 mocket/entry.py diff --git a/mocket/__init__.py b/mocket/__init__.py index b8f1e032..30ec55a7 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,5 +1,6 @@ from mocket.async_mocket import async_mocketize -from mocket.mocket import FakeSSLContext, Mocket, MocketEntry +from mocket.entry import MocketEntry +from mocket.mocket import FakeSSLContext, Mocket from mocket.mocketizer import Mocketizer, mocketize __all__ = ( diff --git a/mocket/entry.py b/mocket/entry.py new file mode 100644 index 00000000..8fa28bc7 --- /dev/null +++ b/mocket/entry.py @@ -0,0 +1,59 @@ +import collections.abc + +from mocket.compat import encode_to_bytes + + +class MocketEntry: + class Response(bytes): + @property + def data(self): + return self + + response_index = 0 + request_cls = bytes + response_cls = Response + responses = None + _served = None + + def __init__(self, location, responses): + self._served = False + self.location = location + + if not isinstance(responses, collections.abc.Iterable): + responses = [responses] + + if not responses: + self.responses = [self.response_cls(encode_to_bytes(""))] + else: + self.responses = [] + for r in responses: + if not isinstance(r, BaseException) and not getattr(r, "data", False): + if isinstance(r, str): + r = encode_to_bytes(r) + r = self.response_cls(r) + self.responses.append(r) + + def __repr__(self): + return f"{self.__class__.__name__}(location={self.location})" + + @staticmethod + def can_handle(data): + return True + + def collect(self, data): + from mocket import Mocket + + req = self.request_cls(data) + Mocket.collect(req) + + def get_response(self): + response = self.responses[self.response_index] + if self.response_index < len(self.responses) - 1: + self.response_index += 1 + + self._served = True + + if isinstance(response, BaseException): + raise response + + return response.data diff --git a/mocket/mocket.py b/mocket/mocket.py index f420c27d..e9bb27e3 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,5 +1,4 @@ import collections -import collections.abc as collections_abc import contextlib import errno import hashlib @@ -588,57 +587,3 @@ def assert_fail_if_entries_not_served(cls): """Mocket checks that all entries have been served at least once.""" if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") - - -class MocketEntry: - class Response(bytes): - @property - def data(self): - return self - - response_index = 0 - request_cls = bytes - response_cls = Response - responses = None - _served = None - - def __init__(self, location, responses): - self._served = False - self.location = location - - if not isinstance(responses, collections_abc.Iterable): - responses = [responses] - - if not responses: - self.responses = [self.response_cls(encode_to_bytes(""))] - else: - self.responses = [] - for r in responses: - if not isinstance(r, BaseException) and not getattr(r, "data", False): - if isinstance(r, str): - r = encode_to_bytes(r) - r = self.response_cls(r) - self.responses.append(r) - - def __repr__(self): - return f"{self.__class__.__name__}(location={self.location})" - - @staticmethod - def can_handle(data): - return True - - def collect(self, data): - req = self.request_cls(data) - Mocket.collect(req) - - def get_response(self): - response = self.responses[self.response_index] - if self.response_index < len(self.responses) - 1: - self.response_index += 1 - - self._served = True - - if isinstance(response, BaseException): - raise response - - return response.data diff --git a/mocket/mockhttp.py b/mocket/mockhttp.py index beb312c0..245a11af 100644 --- a/mocket/mockhttp.py +++ b/mocket/mockhttp.py @@ -8,7 +8,8 @@ from h11 import Request as H11Request from mocket.compat import ENCODING, decode_from_bytes, do_the_magic, encode_to_bytes -from mocket.mocket import Mocket, MocketEntry +from mocket.entry import MocketEntry +from mocket.mocket import Mocket STATUS = {k: v[0] for k, v in BaseHTTPRequestHandler.responses.items()} CRLF = "\r\n" diff --git a/mocket/mockredis.py b/mocket/mockredis.py index 4ed69e1f..fc386e2d 100644 --- a/mocket/mockredis.py +++ b/mocket/mockredis.py @@ -5,7 +5,8 @@ encode_to_bytes, shsplit, ) -from mocket.mocket import Mocket, MocketEntry +from mocket.entry import MocketEntry +from mocket.mocket import Mocket class Request: From 207778acd16368012eaced3b65eb4324477f5b1b Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 16:12:03 +0100 Subject: [PATCH 10/35] refactor: move SocketMocketCore from mocket.utils to mocket.io --- mocket/io.py | 17 +++++++++++++++++ mocket/mocket.py | 7 ++----- mocket/utils.py | 20 +++----------------- 3 files changed, 22 insertions(+), 22 deletions(-) create mode 100644 mocket/io.py diff --git a/mocket/io.py b/mocket/io.py new file mode 100644 index 00000000..45bb8272 --- /dev/null +++ b/mocket/io.py @@ -0,0 +1,17 @@ +import io +import os + + +class MocketSocketCore(io.BytesIO): + def __init__(self, address) -> None: + self._address = address + super().__init__() + + def write(self, content): + from mocket import Mocket + + super().write(content) + + _, w_fd = Mocket.get_pair(self._address) + if w_fd: + os.write(w_fd, content) diff --git a/mocket/mocket.py b/mocket/mocket.py index e9bb27e3..fb7ec8a0 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -23,12 +23,9 @@ from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.io import MocketSocketCore from mocket.mode import MocketMode -from mocket.utils import ( - MocketSocketCore, - hexdump, - hexload, -) +from mocket.utils import hexdump, hexload xxh32 = None try: diff --git a/mocket/utils.py b/mocket/utils.py index 5f065bfa..f94b78f7 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,34 +1,20 @@ from __future__ import annotations import binascii -import io -import os import ssl from typing import Callable from mocket.compat import decode_from_bytes, encode_to_bytes +# NOTE this is here for backwards-compat to keep old import-paths working +from mocket.io import MocketSocketCore as MocketSocketCore + # NOTE this is here for backwards-compat to keep old import-paths working from mocket.mode import MocketMode as MocketMode SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 -class MocketSocketCore(io.BytesIO): - def __init__(self, address) -> None: - self._address = address - super().__init__() - - def write(self, content): - from mocket import Mocket - - super().write(content) - - _, w_fd = Mocket.get_pair(self._address) - if w_fd: - os.write(w_fd, content) - - def hexdump(binary_string: bytes) -> str: r""" >>> hexdump(b"bar foobar foo") == decode_from_bytes(encode_to_bytes("62 61 72 20 66 6F 6F 62 61 72 20 66 6F 6F")) From 012df1387282d74b7bf286c9de7583e9b1ff8e25 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 16:15:17 +0100 Subject: [PATCH 11/35] refactor: move FakeSSLContext from mocket.mocket to mocket.ssl --- mocket/__init__.py | 3 ++- mocket/mocket.py | 57 +--------------------------------------------- mocket/ssl.py | 56 +++++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 59 insertions(+), 57 deletions(-) create mode 100644 mocket/ssl.py diff --git a/mocket/__init__.py b/mocket/__init__.py index 30ec55a7..d64cb11d 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,7 +1,8 @@ from mocket.async_mocket import async_mocketize from mocket.entry import MocketEntry -from mocket.mocket import FakeSSLContext, Mocket +from mocket.mocket import Mocket from mocket.mocketizer import Mocketizer, mocketize +from mocket.ssl import FakeSSLContext __all__ = ( "async_mocketize", diff --git a/mocket/mocket.py b/mocket/mocket.py index fb7ec8a0..8f791ea3 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -25,6 +25,7 @@ from mocket.compat import decode_from_bytes, encode_to_bytes from mocket.io import MocketSocketCore from mocket.mode import MocketMode +from mocket.ssl import FakeSSLContext from mocket.utils import hexdump, hexload xxh32 = None @@ -59,62 +60,6 @@ true_urllib3_match_hostname = urllib3_match_hostname -class SuperFakeSSLContext: - """For Python 3.6 and newer.""" - - class FakeSetter(int): - def __set__(self, *args): - pass - - minimum_version = FakeSetter() - options = FakeSetter() - verify_mode = FakeSetter() - verify_flags = FakeSetter() - - -class FakeSSLContext(SuperFakeSSLContext): - DUMMY_METHODS = ( - "load_default_certs", - "load_verify_locations", - "set_alpn_protocols", - "set_ciphers", - "set_default_verify_paths", - ) - sock = None - post_handshake_auth = None - _check_hostname = False - - @property - def check_hostname(self): - return self._check_hostname - - @check_hostname.setter - def check_hostname(self, _): - self._check_hostname = False - - def __init__(self, *args, **kwargs): - self._set_dummy_methods() - - def _set_dummy_methods(self): - def dummy_method(*args, **kwargs): - pass - - for m in self.DUMMY_METHODS: - setattr(self, m, dummy_method) - - @staticmethod - def wrap_socket(sock, *args, **kwargs): - sock.kwargs = kwargs - sock._secure_socket = True - return sock - - @staticmethod - def wrap_bio(incoming, outcoming, *args, **kwargs): - ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] - return ssl_obj - - def create_connection(address, timeout=None, source_address=None): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: diff --git a/mocket/ssl.py b/mocket/ssl.py new file mode 100644 index 00000000..2e367f16 --- /dev/null +++ b/mocket/ssl.py @@ -0,0 +1,56 @@ +class SuperFakeSSLContext: + """For Python 3.6 and newer.""" + + class FakeSetter(int): + def __set__(self, *args): + pass + + minimum_version = FakeSetter() + options = FakeSetter() + verify_mode = FakeSetter() + verify_flags = FakeSetter() + + +class FakeSSLContext(SuperFakeSSLContext): + DUMMY_METHODS = ( + "load_default_certs", + "load_verify_locations", + "set_alpn_protocols", + "set_ciphers", + "set_default_verify_paths", + ) + sock = None + post_handshake_auth = None + _check_hostname = False + + @property + def check_hostname(self): + return self._check_hostname + + @check_hostname.setter + def check_hostname(self, _): + self._check_hostname = False + + def __init__(self, *args, **kwargs): + self._set_dummy_methods() + + def _set_dummy_methods(self): + def dummy_method(*args, **kwargs): + pass + + for m in self.DUMMY_METHODS: + setattr(self, m, dummy_method) + + @staticmethod + def wrap_socket(sock, *args, **kwargs): + sock.kwargs = kwargs + sock._secure_socket = True + return sock + + @staticmethod + def wrap_bio(incoming, outcoming, *args, **kwargs): + from mocket.mocket import MocketSocket + + ssl_obj = MocketSocket() + ssl_obj._host = kwargs["server_hostname"] + return ssl_obj From 14513def5d44e47c845618c5b4f39ba77fd1174a Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 16:25:15 +0100 Subject: [PATCH 12/35] refactor: move MocketSocket from mocket.mocket to mocket.socket --- mocket/mocket.py | 322 +------------------------------------------ mocket/socket.py | 346 +++++++++++++++++++++++++++++++++++++++++++++++ mocket/ssl.py | 2 +- 3 files changed, 348 insertions(+), 322 deletions(-) create mode 100644 mocket/socket.py diff --git a/mocket/mocket.py b/mocket/mocket.py index 8f791ea3..6bb0e566 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,15 +1,8 @@ import collections -import contextlib -import errno -import hashlib import itertools -import json import os -import select import socket import ssl -from datetime import datetime, timedelta -from json.decoder import JSONDecodeError from typing import Optional, Tuple import urllib3 @@ -22,19 +15,8 @@ urllib3_wrap_socket = None -from mocket.compat import decode_from_bytes, encode_to_bytes -from mocket.io import MocketSocketCore -from mocket.mode import MocketMode +from mocket.socket import MocketSocket, create_connection, socketpair from mocket.ssl import FakeSSLContext -from mocket.utils import hexdump, hexload - -xxh32 = None -try: - from xxhash import xxh32 -except ImportError: # pragma: no cover - with contextlib.suppress(ImportError): - from xxhash_cffi import xxh32 -hasher = xxh32 or hashlib.md5 try: # pragma: no cover from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 @@ -60,308 +42,6 @@ true_urllib3_match_hostname = urllib3_match_hostname -def create_connection(address, timeout=None, source_address=None): - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) - if timeout: - s.settimeout(timeout) - s.connect(address) - return s - - -def socketpair(*args, **kwargs): - """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" - import _socket - - return _socket.socketpair(*args, **kwargs) - - -def _hash_request(h, req): - return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() - - -class MocketSocket: - timeout = None - _fd = None - family = None - type = None - proto = None - _host = None - _port = None - _address = None - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION - _mode = None - _bufsize = None - _secure_socket = False - _did_handshake = False - _sent_non_empty_bytes = False - _io = None - - def __init__( - self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs - ): - self.true_socket = true_socket(family, type, proto) - self._buflen = 65536 - self._entry = None - self.family = int(family) - self.type = int(type) - self.proto = int(proto) - self._truesocket_recording_dir = None - self.kwargs = kwargs - - def __str__(self): - return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" - - def __enter__(self): - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.close() - - @property - def io(self): - if self._io is None: - self._io = MocketSocketCore((self._host, self._port)) - return self._io - - def fileno(self): - address = (self._host, self._port) - r_fd, _ = Mocket.get_pair(address) - if not r_fd: - r_fd, w_fd = os.pipe() - Mocket.set_pair(address, (r_fd, w_fd)) - return r_fd - - def gettimeout(self): - return self.timeout - - def setsockopt(self, family, type, proto): - self.family = family - self.type = type - self.proto = proto - - if self.true_socket: - self.true_socket.setsockopt(family, type, proto) - - def settimeout(self, timeout): - self.timeout = timeout - - @staticmethod - def getsockopt(level, optname, buflen=None): - return socket.SOCK_STREAM - - def do_handshake(self): - self._did_handshake = True - - def getpeername(self): - return self._address - - def setblocking(self, block): - self.settimeout(None) if block else self.settimeout(0.0) - - def getblocking(self): - return self.gettimeout() is None - - def getsockname(self): - return socket.gethostbyname(self._address[0]), self._address[1] - - def getpeercert(self, *args, **kwargs): - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self): - return self - - def write(self, data): - return self.send(encode_to_bytes(data)) - - def connect(self, address): - self._address = self._host, self._port = address - Mocket._address = address - - def makefile(self, mode="r", bufsize=-1): - self._mode = mode - self._bufsize = bufsize - return self.io - - def get_entry(self, data): - return Mocket.get_entry(self._host, self._port, data) - - def sendall(self, data, entry=None, *args, **kwargs): - if entry is None: - entry = self.get_entry(data) - - if entry: - consume_response = entry.collect(data) - response = entry.get_response() if consume_response is not False else None - else: - response = self.true_sendall(data, *args, **kwargs) - - if response is not None: - self.io.seek(0) - self.io.write(response) - self.io.truncate() - self.io.seek(0) - - def read(self, buffersize): - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - - def recv_into(self, buffer, buffersize=None, flags=None): - if hasattr(buffer, "write"): - return buffer.write(self.read(buffersize)) - # buffer is a memoryview - data = self.read(buffersize) - if data: - buffer[: len(data)] = data - return len(data) - - def recv(self, buffersize, flags=None): - r_fd, _ = Mocket.get_pair((self._host, self._port)) - if r_fd: - return os.read(r_fd, buffersize) - data = self.read(buffersize) - if data: - return data - # used by Redis mock - exc = BlockingIOError() - exc.errno = errno.EWOULDBLOCK - exc.args = (0,) - raise exc - - def true_sendall(self, data, *args, **kwargs): - if not MocketMode().is_allowed((self._host, self._port)): - MocketMode.raise_not_allowed() - - req = decode_from_bytes(data) - # make request unique again - req_signature = _hash_request(hasher, req) - # port should be always a string - port = str(self._port) - - # prepare responses dictionary - responses = {} - - if Mocket.get_truesocket_recording_dir(): - path = os.path.join( - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" - ) - # check if there's already a recorded session dumped to a JSON file - try: - with open(path) as f: - responses = json.load(f) - # if not, create a new dictionary - except (FileNotFoundError, JSONDecodeError): - pass - - try: - try: - response_dict = responses[self._host][port][req_signature] - except KeyError: - if hasher is not hashlib.md5: - # Fallback for backwards compatibility - req_signature = _hash_request(hashlib.md5, req) - response_dict = responses[self._host][port][req_signature] - else: - raise - except KeyError: - # preventing next KeyError exceptions - responses.setdefault(self._host, {}) - responses[self._host].setdefault(port, {}) - responses[self._host][port].setdefault(req_signature, {}) - response_dict = responses[self._host][port][req_signature] - - # try to get the response from the dictionary - try: - encoded_response = hexload(response_dict["response"]) - # if not available, call the real sendall - except KeyError: - host, port = self._host, self._port - host = true_gethostbyname(host) - - if isinstance(self.true_socket, true_socket) and self._secure_socket: - self.true_socket = true_urllib3_ssl_wrap_socket( - self.true_socket, - **self.kwargs, - ) - - with contextlib.suppress(OSError, ValueError): - # already connected - self.true_socket.connect((host, port)) - self.true_socket.sendall(data, *args, **kwargs) - encoded_response = b"" - # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 - while True: - more_to_read = select.select([self.true_socket], [], [], 0.1)[0] - if not more_to_read and encoded_response: - break - new_content = self.true_socket.recv(self._buflen) - if not new_content: - break - encoded_response += new_content - - # dump the resulting dictionary to a JSON file - if Mocket.get_truesocket_recording_dir(): - # update the dictionary with request and response lines - response_dict["request"] = req - response_dict["response"] = hexdump(encoded_response) - - with open(path, mode="w") as f: - f.write( - decode_from_bytes( - json.dumps(responses, indent=4, sort_keys=True) - ) - ) - - # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO - return encoded_response - - def send(self, data, *args, **kwargs): # pragma: no cover - entry = self.get_entry(data) - if not entry or (entry and self._entry != entry): - kwargs["entry"] = entry - self.sendall(data, *args, **kwargs) - else: - req = Mocket.last_request() - if hasattr(req, "add_data"): - req.add_data(data) - self._entry = entry - return len(data) - - def close(self): - if self.true_socket and not self.true_socket._closed: - self.true_socket.close() - self._fd = None - - def __getattr__(self, name): - """Do nothing catchall function, for methods like shutdown()""" - - def do_nothing(*args, **kwargs): - pass - - return do_nothing - - class Mocket: _socket_pairs = {} _address = (None, None) diff --git a/mocket/socket.py b/mocket/socket.py new file mode 100644 index 00000000..3a971af5 --- /dev/null +++ b/mocket/socket.py @@ -0,0 +1,346 @@ +import contextlib +import errno +import hashlib +import json +import os +import select +import socket +import ssl +from datetime import datetime, timedelta +from json.decoder import JSONDecodeError + +from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.io import MocketSocketCore +from mocket.mode import MocketMode +from mocket.utils import hexdump, hexload + +xxh32 = None +try: + from xxhash import xxh32 +except ImportError: # pragma: no cover + with contextlib.suppress(ImportError): + from xxhash_cffi import xxh32 +hasher = xxh32 or hashlib.md5 + + +def create_connection(address, timeout=None, source_address=None): + s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) + if timeout: + s.settimeout(timeout) + s.connect(address) + return s + + +def socketpair(*args, **kwargs): + """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" + import _socket + + return _socket.socketpair(*args, **kwargs) + + +def _hash_request(h, req): + return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() + + +class MocketSocket: + timeout = None + _fd = None + family = None + type = None + proto = None + _host = None + _port = None + _address = None + cipher = lambda s: ("ADH", "AES256", "SHA") + compression = lambda s: ssl.OP_NO_COMPRESSION + _mode = None + _bufsize = None + _secure_socket = False + _did_handshake = False + _sent_non_empty_bytes = False + _io = None + + def __init__( + self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs + ): + from mocket.mocket import true_socket + + self.true_socket = true_socket(family, type, proto) + self._buflen = 65536 + self._entry = None + self.family = int(family) + self.type = int(type) + self.proto = int(proto) + self._truesocket_recording_dir = None + self.kwargs = kwargs + + def __str__(self): + return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.close() + + @property + def io(self): + if self._io is None: + self._io = MocketSocketCore((self._host, self._port)) + return self._io + + def fileno(self): + from mocket.mocket import Mocket + + address = (self._host, self._port) + r_fd, _ = Mocket.get_pair(address) + if not r_fd: + r_fd, w_fd = os.pipe() + Mocket.set_pair(address, (r_fd, w_fd)) + return r_fd + + def gettimeout(self): + return self.timeout + + def setsockopt(self, family, type, proto): + self.family = family + self.type = type + self.proto = proto + + if self.true_socket: + self.true_socket.setsockopt(family, type, proto) + + def settimeout(self, timeout): + self.timeout = timeout + + @staticmethod + def getsockopt(level, optname, buflen=None): + return socket.SOCK_STREAM + + def do_handshake(self): + self._did_handshake = True + + def getpeername(self): + return self._address + + def setblocking(self, block): + self.settimeout(None) if block else self.settimeout(0.0) + + def getblocking(self): + return self.gettimeout() is None + + def getsockname(self): + return socket.gethostbyname(self._address[0]), self._address[1] + + def getpeercert(self, *args, **kwargs): + from mocket.mocket import Mocket + + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def unwrap(self): + return self + + def write(self, data): + return self.send(encode_to_bytes(data)) + + def connect(self, address): + from mocket.mocket import Mocket + + self._address = self._host, self._port = address + Mocket._address = address + + def makefile(self, mode="r", bufsize=-1): + self._mode = mode + self._bufsize = bufsize + return self.io + + def get_entry(self, data): + from mocket.mocket import Mocket + + return Mocket.get_entry(self._host, self._port, data) + + def sendall(self, data, entry=None, *args, **kwargs): + if entry is None: + entry = self.get_entry(data) + + if entry: + consume_response = entry.collect(data) + response = entry.get_response() if consume_response is not False else None + else: + response = self.true_sendall(data, *args, **kwargs) + + if response is not None: + self.io.seek(0) + self.io.write(response) + self.io.truncate() + self.io.seek(0) + + def read(self, buffersize): + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def recv_into(self, buffer, buffersize=None, flags=None): + if hasattr(buffer, "write"): + return buffer.write(self.read(buffersize)) + # buffer is a memoryview + data = self.read(buffersize) + if data: + buffer[: len(data)] = data + return len(data) + + def recv(self, buffersize, flags=None): + from mocket.mocket import Mocket + + r_fd, _ = Mocket.get_pair((self._host, self._port)) + if r_fd: + return os.read(r_fd, buffersize) + data = self.read(buffersize) + if data: + return data + # used by Redis mock + exc = BlockingIOError() + exc.errno = errno.EWOULDBLOCK + exc.args = (0,) + raise exc + + def true_sendall(self, data, *args, **kwargs): + from mocket.mocket import ( + Mocket, + true_gethostbyname, + true_socket, + true_urllib3_ssl_wrap_socket, + ) + + if not MocketMode().is_allowed((self._host, self._port)): + MocketMode.raise_not_allowed() + + req = decode_from_bytes(data) + # make request unique again + req_signature = _hash_request(hasher, req) + # port should be always a string + port = str(self._port) + + # prepare responses dictionary + responses = {} + + if Mocket.get_truesocket_recording_dir(): + path = os.path.join( + Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" + ) + # check if there's already a recorded session dumped to a JSON file + try: + with open(path) as f: + responses = json.load(f) + # if not, create a new dictionary + except (FileNotFoundError, JSONDecodeError): + pass + + try: + try: + response_dict = responses[self._host][port][req_signature] + except KeyError: + if hasher is not hashlib.md5: + # Fallback for backwards compatibility + req_signature = _hash_request(hashlib.md5, req) + response_dict = responses[self._host][port][req_signature] + else: + raise + except KeyError: + # preventing next KeyError exceptions + responses.setdefault(self._host, {}) + responses[self._host].setdefault(port, {}) + responses[self._host][port].setdefault(req_signature, {}) + response_dict = responses[self._host][port][req_signature] + + # try to get the response from the dictionary + try: + encoded_response = hexload(response_dict["response"]) + # if not available, call the real sendall + except KeyError: + host, port = self._host, self._port + host = true_gethostbyname(host) + + if isinstance(self.true_socket, true_socket) and self._secure_socket: + self.true_socket = true_urllib3_ssl_wrap_socket( + self.true_socket, + **self.kwargs, + ) + + with contextlib.suppress(OSError, ValueError): + # already connected + self.true_socket.connect((host, port)) + self.true_socket.sendall(data, *args, **kwargs) + encoded_response = b"" + # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 + while True: + more_to_read = select.select([self.true_socket], [], [], 0.1)[0] + if not more_to_read and encoded_response: + break + new_content = self.true_socket.recv(self._buflen) + if not new_content: + break + encoded_response += new_content + + # dump the resulting dictionary to a JSON file + if Mocket.get_truesocket_recording_dir(): + # update the dictionary with request and response lines + response_dict["request"] = req + response_dict["response"] = hexdump(encoded_response) + + with open(path, mode="w") as f: + f.write( + decode_from_bytes( + json.dumps(responses, indent=4, sort_keys=True) + ) + ) + + # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO + return encoded_response + + def send(self, data, *args, **kwargs): # pragma: no cover + from mocket.mocket import Mocket + + entry = self.get_entry(data) + if not entry or (entry and self._entry != entry): + kwargs["entry"] = entry + self.sendall(data, *args, **kwargs) + else: + req = Mocket.last_request() + if hasattr(req, "add_data"): + req.add_data(data) + self._entry = entry + return len(data) + + def close(self): + if self.true_socket and not self.true_socket._closed: + self.true_socket.close() + self._fd = None + + def __getattr__(self, name): + """Do nothing catchall function, for methods like shutdown()""" + + def do_nothing(*args, **kwargs): + pass + + return do_nothing diff --git a/mocket/ssl.py b/mocket/ssl.py index 2e367f16..e4ae44cf 100644 --- a/mocket/ssl.py +++ b/mocket/ssl.py @@ -49,7 +49,7 @@ def wrap_socket(sock, *args, **kwargs): @staticmethod def wrap_bio(incoming, outcoming, *args, **kwargs): - from mocket.mocket import MocketSocket + from mocket.socket import MocketSocket ssl_obj = MocketSocket() ssl_obj._host = kwargs["server_hostname"] From 89055e8a54b41a0ff8d21dcd2d1774e1e95f8667 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 20:34:04 +0100 Subject: [PATCH 13/35] Refactor: introduce state object (#264) * refactor: move enable- and disable-functions from mocket.mocket to mocket.inject * refactor: Mocket - add typing and get rid of cyclic import --- mocket/entry.py | 3 +- mocket/inject.py | 128 +++++++++++++++++++ mocket/io.py | 4 +- mocket/mocket.py | 182 ++++++--------------------- mocket/mocketizer.py | 4 +- mocket/mode.py | 3 +- mocket/plugins/httpretty/__init__.py | 3 +- mocket/socket.py | 30 ++--- mocket/types.py | 5 + tests/test_socket.py | 2 +- 10 files changed, 186 insertions(+), 178 deletions(-) create mode 100644 mocket/inject.py create mode 100644 mocket/types.py diff --git a/mocket/entry.py b/mocket/entry.py index 8fa28bc7..9dbbf442 100644 --- a/mocket/entry.py +++ b/mocket/entry.py @@ -1,6 +1,7 @@ import collections.abc from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket class MocketEntry: @@ -41,8 +42,6 @@ def can_handle(data): return True def collect(self, data): - from mocket import Mocket - req = self.request_cls(data) Mocket.collect(req) diff --git a/mocket/inject.py b/mocket/inject.py new file mode 100644 index 00000000..cba0b40b --- /dev/null +++ b/mocket/inject.py @@ -0,0 +1,128 @@ +from __future__ import annotations + +import os +import socket +import ssl + +import urllib3 +from urllib3.connection import match_hostname as urllib3_match_hostname +from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket + +try: + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket +except ImportError: + urllib3_wrap_socket = None + + +try: # pragma: no cover + from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 + + pyopenssl_override = True +except ImportError: + pyopenssl_override = False + +true_socket = socket.socket +true_create_connection = socket.create_connection +true_gethostbyname = socket.gethostbyname +true_gethostname = socket.gethostname +true_getaddrinfo = socket.getaddrinfo +true_socketpair = socket.socketpair +true_ssl_wrap_socket = getattr( + ssl, "wrap_socket", None +) # from Py3.12 it's only under SSLContext +true_ssl_socket = ssl.SSLSocket +true_ssl_context = ssl.SSLContext +true_inet_pton = socket.inet_pton +true_urllib3_wrap_socket = urllib3_wrap_socket +true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket +true_urllib3_match_hostname = urllib3_match_hostname + + +def enable( + namespace: str | None = None, + truesocket_recording_dir: str | None = None, +) -> None: + from mocket.mocket import Mocket + from mocket.socket import MocketSocket, create_connection, socketpair + from mocket.ssl import FakeSSLContext + + Mocket._namespace = namespace + Mocket._truesocket_recording_dir = truesocket_recording_dir + + if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): + # JSON dumps will be saved here + raise AssertionError + + socket.socket = socket.__dict__["socket"] = MocketSocket + socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket + socket.SocketType = socket.__dict__["SocketType"] = MocketSocket + socket.create_connection = socket.__dict__["create_connection"] = create_connection + socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" + socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1" + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( + lambda host, port, family=None, socktype=None, proto=None, flags=None: [ + (2, 1, 6, "", (host, port)) + ] + ) + socket.socketpair = socket.__dict__["socketpair"] = socketpair + ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket + ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext + socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( + "\x7f\x00\x00\x01", "utf-8" + ) + urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( + FakeSSLContext.wrap_socket + ) + urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ + "ssl_wrap_socket" + ] = FakeSSLContext.wrap_socket + urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( + FakeSSLContext.wrap_socket + ) + urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ + "ssl_wrap_socket" + ] = FakeSSLContext.wrap_socket + urllib3.connection.match_hostname = urllib3.connection.__dict__[ + "match_hostname" + ] = lambda *args: None + if pyopenssl_override: # pragma: no cover + # Take out the pyopenssl version - use the default implementation + extract_from_urllib3() + + +def disable() -> None: + from mocket.mocket import Mocket + + socket.socket = socket.__dict__["socket"] = true_socket + socket._socketobject = socket.__dict__["_socketobject"] = true_socket + socket.SocketType = socket.__dict__["SocketType"] = true_socket + socket.create_connection = socket.__dict__["create_connection"] = ( + true_create_connection + ) + socket.gethostname = socket.__dict__["gethostname"] = true_gethostname + socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo + socket.socketpair = socket.__dict__["socketpair"] = true_socketpair + if true_ssl_wrap_socket: + ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket + ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context + socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton + urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( + true_urllib3_wrap_socket + ) + urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ + "ssl_wrap_socket" + ] = true_urllib3_ssl_wrap_socket + urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( + true_urllib3_ssl_wrap_socket + ) + urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ + "ssl_wrap_socket" + ] = true_urllib3_ssl_wrap_socket + urllib3.connection.match_hostname = urllib3.connection.__dict__[ + "match_hostname" + ] = true_urllib3_match_hostname + Mocket.reset() + if pyopenssl_override: # pragma: no cover + # Put the pyopenssl version back in place + inject_into_urllib3() diff --git a/mocket/io.py b/mocket/io.py index 45bb8272..648b16dd 100644 --- a/mocket/io.py +++ b/mocket/io.py @@ -1,6 +1,8 @@ import io import os +from mocket.mocket import Mocket + class MocketSocketCore(io.BytesIO): def __init__(self, address) -> None: @@ -8,8 +10,6 @@ def __init__(self, address) -> None: super().__init__() def write(self, content): - from mocket import Mocket - super().write(content) _, w_fd = Mocket.get_pair(self._address) diff --git a/mocket/mocket.py b/mocket/mocket.py index 6bb0e566..3476902d 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -1,57 +1,33 @@ +from __future__ import annotations + import collections import itertools import os -import socket -import ssl -from typing import Optional, Tuple - -import urllib3 -from urllib3.connection import match_hostname as urllib3_match_hostname -from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket - -try: - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket -except ImportError: - urllib3_wrap_socket = None - - -from mocket.socket import MocketSocket, create_connection, socketpair -from mocket.ssl import FakeSSLContext - -try: # pragma: no cover - from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 - - pyopenssl_override = True -except ImportError: - pyopenssl_override = False - -true_socket = socket.socket -true_create_connection = socket.create_connection -true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_getaddrinfo = socket.getaddrinfo -true_socketpair = socket.socketpair -true_ssl_wrap_socket = getattr( - ssl, "wrap_socket", None -) # from Py3.12 it's only under SSLContext -true_ssl_socket = ssl.SSLSocket -true_ssl_context = ssl.SSLContext -true_inet_pton = socket.inet_pton -true_urllib3_wrap_socket = urllib3_wrap_socket -true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket -true_urllib3_match_hostname = urllib3_match_hostname +from typing import TYPE_CHECKING, ClassVar + +import mocket.inject + +# NOTE this is here for backwards-compat to keep old import-paths working +# from mocket.socket import MocketSocket as MocketSocket + +if TYPE_CHECKING: + from mocket.entry import MocketEntry + from mocket.types import Address class Mocket: - _socket_pairs = {} - _address = (None, None) - _entries = collections.defaultdict(list) - _requests = [] - _namespace = str(id(_entries)) - _truesocket_recording_dir = None + _socket_pairs: ClassVar[dict[Address, tuple[int, int]]] = {} + _address: ClassVar[Address] = (None, None) + _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) + _requests: ClassVar[list] = [] + _namespace: ClassVar[str] = str(id(_entries)) + _truesocket_recording_dir: ClassVar[str | None] = None + + enable = mocket.inject.enable + disable = mocket.inject.disable @classmethod - def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: + def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: """ Given the id() of the caller, return a pair of file descriptors as a tuple of two integers: (, ) @@ -59,7 +35,7 @@ def get_pair(cls, address: tuple) -> Tuple[Optional[int], Optional[int]]: return cls._socket_pairs.get(address, (None, None)) @classmethod - def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: + def set_pair(cls, address: Address, pair: tuple[int, int]) -> None: """ Store a pair of file descriptors under the key `id_` as a tuple of two integers: (, ) @@ -67,25 +43,26 @@ def set_pair(cls, address: tuple, pair: Tuple[int, int]) -> None: cls._socket_pairs[address] = pair @classmethod - def register(cls, *entries): + def register(cls, *entries: MocketEntry) -> None: for entry in entries: cls._entries[entry.location].append(entry) @classmethod - def get_entry(cls, host, port, data): - host = host or Mocket._address[0] - port = port or Mocket._address[1] + def get_entry(cls, host: str, port: int, data) -> MocketEntry | None: + host = host or cls._address[0] + port = port or cls._address[1] entries = cls._entries.get((host, port), []) for entry in entries: if entry.can_handle(data): return entry + return None @classmethod - def collect(cls, data): - cls.request_list().append(data) + def collect(cls, data) -> None: + cls._requests.append(data) @classmethod - def reset(cls): + def reset(cls) -> None: for r_fd, w_fd in cls._socket_pairs.values(): os.close(r_fd) os.close(w_fd) @@ -96,116 +73,31 @@ def reset(cls): @classmethod def last_request(cls): if cls.has_requests(): - return cls.request_list()[-1] + return cls._requests[-1] @classmethod def request_list(cls): return cls._requests @classmethod - def remove_last_request(cls): + def remove_last_request(cls) -> None: if cls.has_requests(): del cls._requests[-1] @classmethod - def has_requests(cls): + def has_requests(cls) -> bool: return bool(cls.request_list()) - @staticmethod - def enable(namespace=None, truesocket_recording_dir=None): - Mocket._namespace = namespace - Mocket._truesocket_recording_dir = truesocket_recording_dir - - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): - # JSON dumps will be saved here - raise AssertionError - - socket.socket = socket.__dict__["socket"] = MocketSocket - socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket - socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = ( - create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" - socket.gethostbyname = socket.__dict__["gethostbyname"] = ( - lambda host: "127.0.0.1" - ) - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( - lambda host, port, family=None, socktype=None, proto=None, flags=None: [ - (2, 1, 6, "", (host, port)) - ] - ) - socket.socketpair = socket.__dict__["socketpair"] = socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( - "\x7f\x00\x00\x01", "utf-8" - ) - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - FakeSSLContext.wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = lambda *args: None - if pyopenssl_override: # pragma: no cover - # Take out the pyopenssl version - use the default implementation - extract_from_urllib3() - - @staticmethod - def disable(): - socket.socket = socket.__dict__["socket"] = true_socket - socket._socketobject = socket.__dict__["_socketobject"] = true_socket - socket.SocketType = socket.__dict__["SocketType"] = true_socket - socket.create_connection = socket.__dict__["create_connection"] = ( - true_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = true_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = true_socketpair - if true_ssl_wrap_socket: - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context - socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - true_urllib3_wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - true_urllib3_ssl_wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = true_urllib3_match_hostname - Mocket.reset() - if pyopenssl_override: # pragma: no cover - # Put the pyopenssl version back in place - inject_into_urllib3() - @classmethod - def get_namespace(cls): + def get_namespace(cls) -> str: return cls._namespace @classmethod - def get_truesocket_recording_dir(cls): + def get_truesocket_recording_dir(cls) -> str | None: return cls._truesocket_recording_dir @classmethod - def assert_fail_if_entries_not_served(cls): + def assert_fail_if_entries_not_served(cls) -> None: """Mocket checks that all entries have been served at least once.""" if not all(entry._served for entry in itertools.chain(*cls._entries.values())): raise AssertionError("Some Mocket entries have not been served") diff --git a/mocket/mocketizer.py b/mocket/mocketizer.py index 5a988c77..2bf2b9cd 100644 --- a/mocket/mocketizer.py +++ b/mocket/mocketizer.py @@ -1,3 +1,4 @@ +from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import get_mocketize @@ -23,8 +24,6 @@ def __init__( ) def enter(self): - from mocket import Mocket - Mocket.enable( namespace=self.namespace, truesocket_recording_dir=self.truesocket_recording_dir, @@ -39,7 +38,6 @@ def __enter__(self): def exit(self): if self.instance: self.check_and_call("mocketize_teardown") - from mocket import Mocket Mocket.disable() diff --git a/mocket/mode.py b/mocket/mode.py index 3c0638e5..e1da7955 100644 --- a/mocket/mode.py +++ b/mocket/mode.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, Any, ClassVar from mocket.exceptions import StrictMocketException +from mocket.mocket import Mocket if TYPE_CHECKING: # pragma: no cover from typing import NoReturn @@ -31,8 +32,6 @@ def is_allowed(self, location: str | tuple[str, int]) -> bool: @staticmethod def raise_not_allowed() -> NoReturn: - from mocket.mocket import Mocket - current_entries = [ (location, "\n ".join(map(str, entries))) for location, entries in Mocket._entries.items() diff --git a/mocket/plugins/httpretty/__init__.py b/mocket/plugins/httpretty/__init__.py index d5e41e30..fac61840 100644 --- a/mocket/plugins/httpretty/__init__.py +++ b/mocket/plugins/httpretty/__init__.py @@ -1,6 +1,7 @@ -from mocket import Mocket, mocketize +from mocket import mocketize from mocket.async_mocket import async_mocketize from mocket.compat import ENCODING +from mocket.mocket import Mocket from mocket.mockhttp import Entry as MocketHttpEntry from mocket.mockhttp import Request as MocketHttpRequest from mocket.mockhttp import Response as MocketHttpResponse diff --git a/mocket/socket.py b/mocket/socket.py index 3a971af5..e4be00b6 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -10,7 +10,13 @@ from json.decoder import JSONDecodeError from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.inject import ( + true_gethostbyname, + true_socket, + true_urllib3_ssl_wrap_socket, +) from mocket.io import MocketSocketCore +from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import hexdump, hexload @@ -63,8 +69,6 @@ class MocketSocket: def __init__( self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs ): - from mocket.mocket import true_socket - self.true_socket = true_socket(family, type, proto) self._buflen = 65536 self._entry = None @@ -90,8 +94,6 @@ def io(self): return self._io def fileno(self): - from mocket.mocket import Mocket - address = (self._host, self._port) r_fd, _ = Mocket.get_pair(address) if not r_fd: @@ -133,8 +135,6 @@ def getsockname(self): return socket.gethostbyname(self._address[0]), self._address[1] def getpeercert(self, *args, **kwargs): - from mocket.mocket import Mocket - if not (self._host and self._port): self._address = self._host, self._port = Mocket._address @@ -161,8 +161,6 @@ def write(self, data): return self.send(encode_to_bytes(data)) def connect(self, address): - from mocket.mocket import Mocket - self._address = self._host, self._port = address Mocket._address = address @@ -172,8 +170,6 @@ def makefile(self, mode="r", bufsize=-1): return self.io def get_entry(self, data): - from mocket.mocket import Mocket - return Mocket.get_entry(self._host, self._port, data) def sendall(self, data, entry=None, *args, **kwargs): @@ -210,8 +206,6 @@ def recv_into(self, buffer, buffersize=None, flags=None): return len(data) def recv(self, buffersize, flags=None): - from mocket.mocket import Mocket - r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) @@ -225,13 +219,6 @@ def recv(self, buffersize, flags=None): raise exc def true_sendall(self, data, *args, **kwargs): - from mocket.mocket import ( - Mocket, - true_gethostbyname, - true_socket, - true_urllib3_ssl_wrap_socket, - ) - if not MocketMode().is_allowed((self._host, self._port)): MocketMode.raise_not_allowed() @@ -246,7 +233,8 @@ def true_sendall(self, data, *args, **kwargs): if Mocket.get_truesocket_recording_dir(): path = os.path.join( - Mocket.get_truesocket_recording_dir(), Mocket.get_namespace() + ".json" + Mocket.get_truesocket_recording_dir(), + Mocket.get_namespace() + ".json", ) # check if there's already a recorded session dumped to a JSON file try: @@ -319,8 +307,6 @@ def true_sendall(self, data, *args, **kwargs): return encoded_response def send(self, data, *args, **kwargs): # pragma: no cover - from mocket.mocket import Mocket - entry = self.get_entry(data) if not entry or (entry and self._entry != entry): kwargs["entry"] = entry diff --git a/mocket/types.py b/mocket/types.py new file mode 100644 index 00000000..61b7a4d5 --- /dev/null +++ b/mocket/types.py @@ -0,0 +1,5 @@ +from __future__ import annotations + +from typing import Tuple + +Address = Tuple[str, int] diff --git a/tests/test_socket.py b/tests/test_socket.py index 8a6e65ad..112a9089 100644 --- a/tests/test_socket.py +++ b/tests/test_socket.py @@ -2,7 +2,7 @@ import pytest -from mocket.mocket import MocketSocket +from mocket.socket import MocketSocket @pytest.mark.parametrize("blocking", (False, True)) From 4dc38cafa50b499bf783998e55a87e8446b2dce5 Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 10:16:24 +0100 Subject: [PATCH 14/35] refactor: type SuperFakeSSLContext and FakeSSLContext --- mocket/ssl.py | 32 +++++++++++++++++++++----------- 1 file changed, 21 insertions(+), 11 deletions(-) diff --git a/mocket/ssl.py b/mocket/ssl.py index e4ae44cf..9d9d5d3b 100644 --- a/mocket/ssl.py +++ b/mocket/ssl.py @@ -1,8 +1,15 @@ +from __future__ import annotations + +from typing import Any + +from mocket.socket import MocketSocket + + class SuperFakeSSLContext: """For Python 3.6 and newer.""" class FakeSetter(int): - def __set__(self, *args): + def __set__(self, *args: Any) -> None: pass minimum_version = FakeSetter() @@ -24,33 +31,36 @@ class FakeSSLContext(SuperFakeSSLContext): _check_hostname = False @property - def check_hostname(self): + def check_hostname(self) -> bool: return self._check_hostname @check_hostname.setter - def check_hostname(self, _): + def check_hostname(self, _: bool) -> None: self._check_hostname = False - def __init__(self, *args, **kwargs): + def __init__(self, *args: Any, **kwargs: Any) -> None: self._set_dummy_methods() - def _set_dummy_methods(self): - def dummy_method(*args, **kwargs): + def _set_dummy_methods(self) -> None: + def dummy_method(*args: Any, **kwargs: Any) -> Any: pass for m in self.DUMMY_METHODS: setattr(self, m, dummy_method) @staticmethod - def wrap_socket(sock, *args, **kwargs): + def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: sock.kwargs = kwargs sock._secure_socket = True return sock @staticmethod - def wrap_bio(incoming, outcoming, *args, **kwargs): - from mocket.socket import MocketSocket - + def wrap_bio( + incoming: Any, # _ssl.MemoryBIO + outgoing: Any, # _ssl.MemoryBIO + server_side: bool = False, + server_hostname: str | bytes | None = None, + ) -> MocketSocket: ssl_obj = MocketSocket() - ssl_obj._host = kwargs["server_hostname"] + ssl_obj._host = server_hostname return ssl_obj From ba68b9cd4ac1941b3c333a79d32936e9f2193aef Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 17:40:19 +0100 Subject: [PATCH 15/35] refactor: move FakeSSLContext from mocket.ssl to mocket.ssl.context --- mocket/__init__.py | 2 +- mocket/inject.py | 2 +- mocket/ssl/__init__.py | 0 mocket/{ssl.py => ssl/context.py} | 0 4 files changed, 2 insertions(+), 2 deletions(-) create mode 100644 mocket/ssl/__init__.py rename mocket/{ssl.py => ssl/context.py} (100%) diff --git a/mocket/__init__.py b/mocket/__init__.py index d64cb11d..58993a24 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -2,7 +2,7 @@ from mocket.entry import MocketEntry from mocket.mocket import Mocket from mocket.mocketizer import Mocketizer, mocketize -from mocket.ssl import FakeSSLContext +from mocket.ssl.context import FakeSSLContext __all__ = ( "async_mocketize", diff --git a/mocket/inject.py b/mocket/inject.py index cba0b40b..b39503ed 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -44,7 +44,7 @@ def enable( ) -> None: from mocket.mocket import Mocket from mocket.socket import MocketSocket, create_connection, socketpair - from mocket.ssl import FakeSSLContext + from mocket.ssl.context import FakeSSLContext Mocket._namespace = namespace Mocket._truesocket_recording_dir = truesocket_recording_dir diff --git a/mocket/ssl/__init__.py b/mocket/ssl/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/ssl.py b/mocket/ssl/context.py similarity index 100% rename from mocket/ssl.py rename to mocket/ssl/context.py From 942c33f379a1e0fc19122ecc9424ceeb6d270fef Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 18:15:48 +0100 Subject: [PATCH 16/35] refactor: move true_* from mocket.inject to mocket.socket and mocket.ssl.context --- mocket/inject.py | 71 +++++++++++++++++++++---------------------- mocket/socket.py | 69 ++++++++++++++++++++++++++++++++++++----- mocket/ssl/context.py | 3 ++ 3 files changed, 98 insertions(+), 45 deletions(-) diff --git a/mocket/inject.py b/mocket/inject.py index b39503ed..5909cb93 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -5,14 +5,6 @@ import ssl import urllib3 -from urllib3.connection import match_hostname as urllib3_match_hostname -from urllib3.util.ssl_ import ssl_wrap_socket as urllib3_ssl_wrap_socket - -try: - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket -except ImportError: - urllib3_wrap_socket = None - try: # pragma: no cover from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 @@ -21,29 +13,22 @@ except ImportError: pyopenssl_override = False -true_socket = socket.socket -true_create_connection = socket.create_connection -true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_getaddrinfo = socket.getaddrinfo -true_socketpair = socket.socketpair -true_ssl_wrap_socket = getattr( - ssl, "wrap_socket", None -) # from Py3.12 it's only under SSLContext -true_ssl_socket = ssl.SSLSocket -true_ssl_context = ssl.SSLContext -true_inet_pton = socket.inet_pton -true_urllib3_wrap_socket = urllib3_wrap_socket -true_urllib3_ssl_wrap_socket = urllib3_ssl_wrap_socket -true_urllib3_match_hostname = urllib3_match_hostname - def enable( namespace: str | None = None, truesocket_recording_dir: str | None = None, ) -> None: from mocket.mocket import Mocket - from mocket.socket import MocketSocket, create_connection, socketpair + from mocket.socket import ( + MocketSocket, + mock_create_connection, + mock_getaddrinfo, + mock_gethostbyname, + mock_gethostname, + mock_inet_pton, + mock_socketpair, + mock_urllib3_match_hostname, + ) from mocket.ssl.context import FakeSSLContext Mocket._namespace = namespace @@ -56,20 +41,16 @@ def enable( socket.socket = socket.__dict__["socket"] = MocketSocket socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = create_connection - socket.gethostname = socket.__dict__["gethostname"] = lambda: "localhost" - socket.gethostbyname = socket.__dict__["gethostbyname"] = lambda host: "127.0.0.1" - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = ( - lambda host, port, family=None, socktype=None, proto=None, flags=None: [ - (2, 1, 6, "", (host, port)) - ] + socket.create_connection = socket.__dict__["create_connection"] = ( + mock_create_connection ) - socket.socketpair = socket.__dict__["socketpair"] = socketpair + socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname + socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname + socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo + socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = lambda family, ip: bytes( - "\x7f\x00\x00\x01", "utf-8" - ) + socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( FakeSSLContext.wrap_socket ) @@ -84,7 +65,7 @@ def enable( ] = FakeSSLContext.wrap_socket urllib3.connection.match_hostname = urllib3.connection.__dict__[ "match_hostname" - ] = lambda *args: None + ] = mock_urllib3_match_hostname if pyopenssl_override: # pragma: no cover # Take out the pyopenssl version - use the default implementation extract_from_urllib3() @@ -92,6 +73,22 @@ def enable( def disable() -> None: from mocket.mocket import Mocket + from mocket.socket import ( + true_create_connection, + true_getaddrinfo, + true_gethostbyname, + true_gethostname, + true_inet_pton, + true_socket, + true_socketpair, + true_ssl_wrap_socket, + true_urllib3_match_hostname, + true_urllib3_ssl_wrap_socket, + true_urllib3_wrap_socket, + ) + from mocket.ssl.context import ( + true_ssl_context, + ) socket.socket = socket.__dict__["socket"] = true_socket socket._socketobject = socket.__dict__["_socketobject"] = true_socket diff --git a/mocket/socket.py b/mocket/socket.py index e4be00b6..ab711f06 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -1,3 +1,5 @@ +from __future__ import annotations + import contextlib import errno import hashlib @@ -8,18 +10,42 @@ import ssl from datetime import datetime, timedelta from json.decoder import JSONDecodeError +from typing import Any + +import urllib3.connection +import urllib3.util.ssl_ from mocket.compat import decode_from_bytes, encode_to_bytes -from mocket.inject import ( - true_gethostbyname, - true_socket, - true_urllib3_ssl_wrap_socket, -) from mocket.io import MocketSocketCore from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.utils import hexdump, hexload +true_create_connection = socket.create_connection +true_getaddrinfo = socket.getaddrinfo +true_gethostbyname = socket.gethostbyname +true_gethostname = socket.gethostname +true_inet_pton = socket.inet_pton +true_socket = socket.socket +true_socketpair = socket.socketpair +true_ssl_wrap_socket = None + +true_urllib3_match_hostname = urllib3.connection.match_hostname +true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket +true_urllib3_wrap_socket = None + +with contextlib.suppress(ImportError): + # from Py3.12 it's only under SSLContext + from ssl import wrap_socket as ssl_wrap_socket + + true_ssl_wrap_socket = ssl_wrap_socket + +with contextlib.suppress(ImportError): + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket + + true_urllib3_wrap_socket = urllib3_wrap_socket + + xxh32 = None try: from xxhash import xxh32 @@ -29,7 +55,7 @@ hasher = xxh32 or hashlib.md5 -def create_connection(address, timeout=None, source_address=None): +def mock_create_connection(address, timeout=None, source_address=None): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: s.settimeout(timeout) @@ -37,13 +63,40 @@ def create_connection(address, timeout=None, source_address=None): return s -def socketpair(*args, **kwargs): +def mock_getaddrinfo( + host: str, + port: int, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, +) -> list[tuple[int, int, int, str, tuple[str, int]]]: + return [(2, 1, 6, "", (host, port))] + + +def mock_gethostbyname(hostname: str) -> str: + return "127.0.0.1" + + +def mock_gethostname() -> str: + return "localhost" + + +def mock_inet_pton(address_family: int, ip_string: str) -> bytes: + return bytes("\x7f\x00\x00\x01", "utf-8") + + +def mock_socketpair(*args, **kwargs): """Returns a real socketpair() used by asyncio loop for supporting calls made by fastapi and similar services.""" import _socket return _socket.socketpair(*args, **kwargs) +def mock_urllib3_match_hostname(*args: Any) -> None: + return None + + def _hash_request(h, req): return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() @@ -132,7 +185,7 @@ def getblocking(self): return self.gettimeout() is None def getsockname(self): - return socket.gethostbyname(self._address[0]), self._address[1] + return true_gethostbyname(self._address[0]), self._address[1] def getpeercert(self, *args, **kwargs): if not (self._host and self._port): diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 9d9d5d3b..a327fbef 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,9 +1,12 @@ from __future__ import annotations +import ssl from typing import Any from mocket.socket import MocketSocket +true_ssl_context = ssl.SSLContext + class SuperFakeSSLContext: """For Python 3.6 and newer.""" From cfcd85c642cfa3847a7af1b5b81c9052846aa146 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 18:59:58 +0100 Subject: [PATCH 17/35] refactor: type MocketSocket --- mocket/socket.py | 93 ++++++++++++++++++++++++++++++++---------------- mocket/types.py | 17 ++++++++- 2 files changed, 78 insertions(+), 32 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index ab711f06..3743e6f2 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -10,15 +10,25 @@ import ssl from datetime import datetime, timedelta from json.decoder import JSONDecodeError -from typing import Any +from types import TracebackType +from typing import Any, Type import urllib3.connection import urllib3.util.ssl_ +from typing_extensions import Self from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.entry import MocketEntry from mocket.io import MocketSocketCore from mocket.mocket import Mocket from mocket.mode import MocketMode +from mocket.types import ( + Address, + ReadableBuffer, + WriteableBuffer, + _PeerCertRetDictType, + _RetAddress, +) from mocket.utils import hexdump, hexload true_create_connection = socket.create_connection @@ -120,8 +130,13 @@ class MocketSocket: _io = None def __init__( - self, family=socket.AF_INET, type=socket.SOCK_STREAM, proto=0, **kwargs - ): + self, + family: socket.AddressFamily | int = socket.AF_INET, + type: socket.SocketKind | int = socket.SOCK_STREAM, + proto: int = 0, + fileno: int | None = None, + **kwargs: Any, + ) -> None: self.true_socket = true_socket(family, type, proto) self._buflen = 65536 self._entry = None @@ -131,22 +146,27 @@ def __init__( self._truesocket_recording_dir = None self.kwargs = kwargs - def __str__(self): + def __str__(self) -> str: return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" - def __enter__(self): + def __enter__(self) -> Self: return self - def __exit__(self, exc_type, exc_val, exc_tb): + def __exit__( + self, + type_: Type[BaseException] | None, # noqa: UP006 + value: BaseException | None, + traceback: TracebackType | None, + ) -> None: self.close() @property - def io(self): + def io(self) -> MocketSocketCore: if self._io is None: self._io = MocketSocketCore((self._host, self._port)) return self._io - def fileno(self): + def fileno(self) -> int: address = (self._host, self._port) r_fd, _ = Mocket.get_pair(address) if not r_fd: @@ -154,10 +174,11 @@ def fileno(self): Mocket.set_pair(address, (r_fd, w_fd)) return r_fd - def gettimeout(self): + def gettimeout(self) -> float | None: return self.timeout - def setsockopt(self, family, type, proto): + # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` + def setsockopt(self, family: int, type: int, proto: int) -> None: self.family = family self.type = type self.proto = proto @@ -165,29 +186,29 @@ def setsockopt(self, family, type, proto): if self.true_socket: self.true_socket.setsockopt(family, type, proto) - def settimeout(self, timeout): + def settimeout(self, timeout: float | None) -> None: self.timeout = timeout @staticmethod - def getsockopt(level, optname, buflen=None): + def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: return socket.SOCK_STREAM - def do_handshake(self): + def do_handshake(self) -> None: self._did_handshake = True - def getpeername(self): + def getpeername(self) -> _RetAddress: return self._address - def setblocking(self, block): + def setblocking(self, block: bool) -> None: self.settimeout(None) if block else self.settimeout(0.0) - def getblocking(self): + def getblocking(self) -> bool: return self.gettimeout() is None - def getsockname(self): + def getsockname(self) -> _RetAddress: return true_gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, *args, **kwargs): + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: if not (self._host and self._port): self._address = self._host, self._port = Mocket._address @@ -207,22 +228,22 @@ def getpeercert(self, *args, **kwargs): ), } - def unwrap(self): + def unwrap(self) -> MocketSocket: return self - def write(self, data): + def write(self, data: bytes) -> int | None: return self.send(encode_to_bytes(data)) - def connect(self, address): + def connect(self, address: Address) -> None: self._address = self._host, self._port = address Mocket._address = address - def makefile(self, mode="r", bufsize=-1): + def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore: self._mode = mode self._bufsize = bufsize return self.io - def get_entry(self, data): + def get_entry(self, data: bytes) -> MocketEntry | None: return Mocket.get_entry(self._host, self._port, data) def sendall(self, data, entry=None, *args, **kwargs): @@ -241,7 +262,7 @@ def sendall(self, data, entry=None, *args, **kwargs): self.io.truncate() self.io.seek(0) - def read(self, buffersize): + def read(self, buffersize: int | None = None) -> bytes: rv = self.io.read(buffersize) if rv: self._sent_non_empty_bytes = True @@ -249,7 +270,12 @@ def read(self, buffersize): raise ssl.SSLWantReadError("The operation did not complete (read)") return rv - def recv_into(self, buffer, buffersize=None, flags=None): + def recv_into( + self, + buffer: WriteableBuffer, + buffersize: int | None = None, + flags: int | None = None, + ) -> int: if hasattr(buffer, "write"): return buffer.write(self.read(buffersize)) # buffer is a memoryview @@ -258,7 +284,7 @@ def recv_into(self, buffer, buffersize=None, flags=None): buffer[: len(data)] = data return len(data) - def recv(self, buffersize, flags=None): + def recv(self, buffersize: int, flags: int | None = None) -> bytes: r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) @@ -271,7 +297,7 @@ def recv(self, buffersize, flags=None): exc.args = (0,) raise exc - def true_sendall(self, data, *args, **kwargs): + def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: if not MocketMode().is_allowed((self._host, self._port)): MocketMode.raise_not_allowed() @@ -359,7 +385,12 @@ def true_sendall(self, data, *args, **kwargs): # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO return encoded_response - def send(self, data, *args, **kwargs): # pragma: no cover + def send( + self, + data: ReadableBuffer, + *args: Any, + **kwargs: Any, + ) -> int: # pragma: no cover entry = self.get_entry(data) if not entry or (entry and self._entry != entry): kwargs["entry"] = entry @@ -371,15 +402,15 @@ def send(self, data, *args, **kwargs): # pragma: no cover self._entry = entry return len(data) - def close(self): + def close(self) -> None: if self.true_socket and not self.true_socket._closed: self.true_socket.close() self._fd = None - def __getattr__(self, name): + def __getattr__(self, name: str) -> Any: """Do nothing catchall function, for methods like shutdown()""" - def do_nothing(*args, **kwargs): + def do_nothing(*args: Any, **kwargs: Any) -> Any: pass return do_nothing diff --git a/mocket/types.py b/mocket/types.py index 61b7a4d5..562648c7 100644 --- a/mocket/types.py +++ b/mocket/types.py @@ -1,5 +1,20 @@ from __future__ import annotations -from typing import Tuple +from typing import Any, Dict, Tuple, Union + +from typing_extensions import Buffer, TypeAlias Address = Tuple[str, int] + +# adapted from typeshed/stdlib/_typeshed/__init__.pyi +WriteableBuffer: TypeAlias = Buffer +ReadableBuffer: TypeAlias = Buffer + +# from typeshed/stdlib/_socket.pyi +_Address: TypeAlias = Union[Tuple[Any, ...], str, ReadableBuffer] +_RetAddress: TypeAlias = Any + +# from typeshed/stdlib/ssl.pyi +_PCTRTT: TypeAlias = Tuple[Tuple[str, str], ...] +_PCTRTTT: TypeAlias = Tuple[_PCTRTT, ...] +_PeerCertRetDictType: TypeAlias = Dict[str, Union[str, _PCTRTTT, _PCTRTT]] From 9050127e34dcd121086e68ae657d05e51d414425 Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 19:02:08 +0100 Subject: [PATCH 18/35] refactor: remove unused instance-variables from MocketSocket --- mocket/socket.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index 3743e6f2..c4b6a9a8 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -113,7 +113,6 @@ def _hash_request(h, req): class MocketSocket: timeout = None - _fd = None family = None type = None proto = None @@ -122,8 +121,6 @@ class MocketSocket: _address = None cipher = lambda s: ("ADH", "AES256", "SHA") compression = lambda s: ssl.OP_NO_COMPRESSION - _mode = None - _bufsize = None _secure_socket = False _did_handshake = False _sent_non_empty_bytes = False @@ -239,8 +236,6 @@ def connect(self, address: Address) -> None: Mocket._address = address def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore: - self._mode = mode - self._bufsize = bufsize return self.io def get_entry(self, data: bytes) -> MocketEntry | None: @@ -405,7 +400,6 @@ def send( def close(self) -> None: if self.true_socket and not self.true_socket._closed: self.true_socket.close() - self._fd = None def __getattr__(self, name: str) -> Any: """Do nothing catchall function, for methods like shutdown()""" From 1eb61cf55ea7a0445cfd7eee33a87b8fc936858c Mon Sep 17 00:00:00 2001 From: betaboon Date: Sun, 17 Nov 2024 19:13:12 +0100 Subject: [PATCH 19/35] refactor: MocketSocket - make instance-variables private and move into constructor --- mocket/socket.py | 81 +++++++++++++++++++++++++------------------ mocket/ssl/context.py | 2 +- 2 files changed, 48 insertions(+), 35 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index c4b6a9a8..0b345572 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -112,19 +112,8 @@ def _hash_request(h, req): class MocketSocket: - timeout = None - family = None - type = None - proto = None - _host = None - _port = None - _address = None cipher = lambda s: ("ADH", "AES256", "SHA") compression = lambda s: ssl.OP_NO_COMPRESSION - _secure_socket = False - _did_handshake = False - _sent_non_empty_bytes = False - _io = None def __init__( self, @@ -134,14 +123,26 @@ def __init__( fileno: int | None = None, **kwargs: Any, ) -> None: - self.true_socket = true_socket(family, type, proto) + self._family = family + self._type = type + self._proto = proto + + self._kwargs = kwargs + self._true_socket = true_socket(family, type, proto) + self._buflen = 65536 + self._timeout: float | None = None + + self._secure_socket = False + self._did_handshake = False + self._sent_non_empty_bytes = False + + self._host = None + self._port = None + self._address = None + + self._io = None self._entry = None - self.family = int(family) - self.type = int(type) - self.proto = int(proto) - self._truesocket_recording_dir = None - self.kwargs = kwargs def __str__(self) -> str: return f"({self.__class__.__name__})(family={self.family} type={self.type} protocol={self.proto})" @@ -157,6 +158,18 @@ def __exit__( ) -> None: self.close() + @property + def family(self) -> int: + return self._family + + @property + def type(self) -> int: + return self._type + + @property + def proto(self) -> int: + return self._proto + @property def io(self) -> MocketSocketCore: if self._io is None: @@ -172,19 +185,19 @@ def fileno(self) -> int: return r_fd def gettimeout(self) -> float | None: - return self.timeout + return self._timeout # FIXME the arguments here seem wrong. they should be `level: int, optname: int, value: int | ReadableBuffer | None` def setsockopt(self, family: int, type: int, proto: int) -> None: - self.family = family - self.type = type - self.proto = proto + self._family = family + self._type = type + self._proto = proto - if self.true_socket: - self.true_socket.setsockopt(family, type, proto) + if self._true_socket: + self._true_socket.setsockopt(family, type, proto) def settimeout(self, timeout: float | None) -> None: - self.timeout = timeout + self._timeout = timeout @staticmethod def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: @@ -343,23 +356,23 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host, port = self._host, self._port host = true_gethostbyname(host) - if isinstance(self.true_socket, true_socket) and self._secure_socket: - self.true_socket = true_urllib3_ssl_wrap_socket( - self.true_socket, - **self.kwargs, + if isinstance(self._true_socket, true_socket) and self._secure_socket: + self._true_socket = true_urllib3_ssl_wrap_socket( + self._true_socket, + **self._kwargs, ) with contextlib.suppress(OSError, ValueError): # already connected - self.true_socket.connect((host, port)) - self.true_socket.sendall(data, *args, **kwargs) + self._true_socket.connect((host, port)) + self._true_socket.sendall(data, *args, **kwargs) encoded_response = b"" # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 while True: - more_to_read = select.select([self.true_socket], [], [], 0.1)[0] + more_to_read = select.select([self._true_socket], [], [], 0.1)[0] if not more_to_read and encoded_response: break - new_content = self.true_socket.recv(self._buflen) + new_content = self._true_socket.recv(self._buflen) if not new_content: break encoded_response += new_content @@ -398,8 +411,8 @@ def send( return len(data) def close(self) -> None: - if self.true_socket and not self.true_socket._closed: - self.true_socket.close() + if self._true_socket and not self._true_socket._closed: + self._true_socket.close() def __getattr__(self, name: str) -> Any: """Do nothing catchall function, for methods like shutdown()""" diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index a327fbef..a830c1e7 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -53,7 +53,7 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: @staticmethod def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: - sock.kwargs = kwargs + sock._kwargs = kwargs sock._secure_socket = True return sock From 0eff8f1ec935124b0d6097ecc366d8e758220eda Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 09:29:23 +0100 Subject: [PATCH 20/35] refactor: move true-ssl-methods to mocket.ssl.context --- mocket/inject.py | 6 +++--- mocket/socket.py | 18 ++---------------- mocket/ssl/context.py | 18 ++++++++++++++++++ 3 files changed, 23 insertions(+), 19 deletions(-) diff --git a/mocket/inject.py b/mocket/inject.py index 5909cb93..b733dd3c 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -81,13 +81,13 @@ def disable() -> None: true_inet_pton, true_socket, true_socketpair, - true_ssl_wrap_socket, true_urllib3_match_hostname, - true_urllib3_ssl_wrap_socket, - true_urllib3_wrap_socket, ) from mocket.ssl.context import ( true_ssl_context, + true_ssl_wrap_socket, + true_urllib3_ssl_wrap_socket, + true_urllib3_wrap_socket, ) socket.socket = socket.__dict__["socket"] = true_socket diff --git a/mocket/socket.py b/mocket/socket.py index 0b345572..c3bed15f 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -14,7 +14,6 @@ from typing import Any, Type import urllib3.connection -import urllib3.util.ssl_ from typing_extensions import Self from mocket.compat import decode_from_bytes, encode_to_bytes @@ -38,22 +37,7 @@ true_inet_pton = socket.inet_pton true_socket = socket.socket true_socketpair = socket.socketpair -true_ssl_wrap_socket = None - true_urllib3_match_hostname = urllib3.connection.match_hostname -true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket -true_urllib3_wrap_socket = None - -with contextlib.suppress(ImportError): - # from Py3.12 it's only under SSLContext - from ssl import wrap_socket as ssl_wrap_socket - - true_ssl_wrap_socket = ssl_wrap_socket - -with contextlib.suppress(ImportError): - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket - - true_urllib3_wrap_socket = urllib3_wrap_socket xxh32 = None @@ -357,6 +341,8 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host = true_gethostbyname(host) if isinstance(self._true_socket, true_socket) and self._secure_socket: + from mocket.ssl.context import true_urllib3_ssl_wrap_socket + self._true_socket = true_urllib3_ssl_wrap_socket( self._true_socket, **self._kwargs, diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index a830c1e7..fccf5db4 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,12 +1,30 @@ from __future__ import annotations +import contextlib import ssl from typing import Any +import urllib3.util.ssl_ + from mocket.socket import MocketSocket true_ssl_context = ssl.SSLContext +true_ssl_wrap_socket = None +true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket +true_urllib3_wrap_socket = None + +with contextlib.suppress(ImportError): + # from Py3.12 it's only under SSLContext + from ssl import wrap_socket as ssl_wrap_socket + + true_ssl_wrap_socket = ssl_wrap_socket + +with contextlib.suppress(ImportError): + from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket + + true_urllib3_wrap_socket = urllib3_wrap_socket + class SuperFakeSSLContext: """For Python 3.6 and newer.""" From 90eb5db6929f12793413ac3894b53fc175b269c2 Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 09:38:27 +0100 Subject: [PATCH 21/35] refactor: prepare for removal of read and write from MocketSocket --- mocket/socket.py | 10 +++++++--- tests/test_http.py | 12 ++++++------ 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/mocket/socket.py b/mocket/socket.py index c3bed15f..3cd68fe5 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -269,9 +269,13 @@ def recv_into( flags: int | None = None, ) -> int: if hasattr(buffer, "write"): - return buffer.write(self.read(buffersize)) + return buffer.write(self.recv(buffersize)) + # buffer is a memoryview - data = self.read(buffersize) + if buffersize is None: + buffersize = len(buffer) + + data = self.recv(buffersize) if data: buffer[: len(data)] = data return len(data) @@ -280,7 +284,7 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes: r_fd, _ = Mocket.get_pair((self._host, self._port)) if r_fd: return os.read(r_fd, buffersize) - data = self.read(buffersize) + data = self.io.read(buffersize) if data: return data # used by Redis mock diff --git a/tests/test_http.py b/tests/test_http.py index d516068b..afa31185 100644 --- a/tests/test_http.py +++ b/tests/test_http.py @@ -359,12 +359,12 @@ def test_sockets(self): sock = socket.socket(address[0], address[1], address[2]) sock.connect(address[-1]) - sock.write(f"{method} {path} HTTP/1.0\r\n") - sock.write(f"Host: {host}\r\n") - sock.write("Content-Type: application/json\r\n") - sock.write("Content-Length: %d\r\n" % len(data)) - sock.write("Connection: close\r\n\r\n") - sock.write(data) + sock.send(f"{method} {path} HTTP/1.0\r\n".encode()) + sock.send(f"Host: {host}\r\n".encode()) + sock.send(b"Content-Type: application/json\r\n") + sock.send(b"Content-Length: %d\r\n" % len(data)) + sock.send(b"Connection: close\r\n\r\n") + sock.send(data.encode()) sock.close() # Proof that worked. From 636951f2f9ea47139539b346c0e3bbc9067e86f0 Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 18 Nov 2024 10:10:41 +0100 Subject: [PATCH 22/35] refactor: split ssl-functionality of MocketSocket into MocketSSLSocket --- mocket/socket.py | 52 ------------------------------------ mocket/ssl/context.py | 29 +++++++++++++++----- mocket/ssl/socket.py | 62 +++++++++++++++++++++++++++++++++++++++++++ 3 files changed, 85 insertions(+), 58 deletions(-) create mode 100644 mocket/ssl/socket.py diff --git a/mocket/socket.py b/mocket/socket.py index 3cd68fe5..2ce74c09 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -96,9 +96,6 @@ def _hash_request(h, req): class MocketSocket: - cipher = lambda s: ("ADH", "AES256", "SHA") - compression = lambda s: ssl.OP_NO_COMPRESSION - def __init__( self, family: socket.AddressFamily | int = socket.AF_INET, @@ -117,10 +114,6 @@ def __init__( self._buflen = 65536 self._timeout: float | None = None - self._secure_socket = False - self._did_handshake = False - self._sent_non_empty_bytes = False - self._host = None self._port = None self._address = None @@ -187,9 +180,6 @@ def settimeout(self, timeout: float | None) -> None: def getsockopt(level: int, optname: int, buflen: int | None = None) -> int: return socket.SOCK_STREAM - def do_handshake(self) -> None: - self._did_handshake = True - def getpeername(self) -> _RetAddress: return self._address @@ -202,32 +192,6 @@ def getblocking(self) -> bool: def getsockname(self) -> _RetAddress: return true_gethostbyname(self._address[0]), self._address[1] - def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: - if not (self._host and self._port): - self._address = self._host, self._port = Mocket._address - - now = datetime.now() - shift = now + timedelta(days=30 * 12) - return { - "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), - "subjectAltName": ( - ("DNS", f"*.{self._host}"), - ("DNS", self._host), - ("DNS", "*"), - ), - "subject": ( - (("organizationName", f"*.{self._host}"),), - (("organizationalUnitName", "Domain Control Validated"),), - (("commonName", f"*.{self._host}"),), - ), - } - - def unwrap(self) -> MocketSocket: - return self - - def write(self, data: bytes) -> int | None: - return self.send(encode_to_bytes(data)) - def connect(self, address: Address) -> None: self._address = self._host, self._port = address Mocket._address = address @@ -254,14 +218,6 @@ def sendall(self, data, entry=None, *args, **kwargs): self.io.truncate() self.io.seek(0) - def read(self, buffersize: int | None = None) -> bytes: - rv = self.io.read(buffersize) - if rv: - self._sent_non_empty_bytes = True - if self._did_handshake and not self._sent_non_empty_bytes: - raise ssl.SSLWantReadError("The operation did not complete (read)") - return rv - def recv_into( self, buffer: WriteableBuffer, @@ -344,14 +300,6 @@ def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: host, port = self._host, self._port host = true_gethostbyname(host) - if isinstance(self._true_socket, true_socket) and self._secure_socket: - from mocket.ssl.context import true_urllib3_ssl_wrap_socket - - self._true_socket = true_urllib3_ssl_wrap_socket( - self._true_socket, - **self._kwargs, - ) - with contextlib.suppress(OSError, ValueError): # already connected self._true_socket.connect((host, port)) diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index fccf5db4..e5f60c0a 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -7,6 +7,7 @@ import urllib3.util.ssl_ from mocket.socket import MocketSocket +from mocket.ssl.socket import MocketSSLSocket true_ssl_context = ssl.SSLContext @@ -70,10 +71,26 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: setattr(self, m, dummy_method) @staticmethod - def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSocket: - sock._kwargs = kwargs - sock._secure_socket = True - return sock + def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: + ssl_socket = MocketSSLSocket() + ssl_socket._original_socket = sock + + ssl_socket._true_socket = true_urllib3_ssl_wrap_socket( + sock._true_socket, + **kwargs, + ) + ssl_socket._kwargs = kwargs + + ssl_socket._timeout = sock._timeout + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._io = sock._io + ssl_socket._entry = sock._entry + + return ssl_socket @staticmethod def wrap_bio( @@ -81,7 +98,7 @@ def wrap_bio( outgoing: Any, # _ssl.MemoryBIO server_side: bool = False, server_hostname: str | bytes | None = None, - ) -> MocketSocket: - ssl_obj = MocketSocket() + ) -> MocketSSLSocket: + ssl_obj = MocketSSLSocket() ssl_obj._host = server_hostname return ssl_obj diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py new file mode 100644 index 00000000..e50b7320 --- /dev/null +++ b/mocket/ssl/socket.py @@ -0,0 +1,62 @@ +from __future__ import annotations + +import ssl +from datetime import datetime, timedelta +from typing import Any + +from mocket.compat import encode_to_bytes +from mocket.mocket import Mocket +from mocket.socket import MocketSocket +from mocket.types import _PeerCertRetDictType + + +class MocketSSLSocket(MocketSocket): + def __init__(self, *args: Any, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + + self._did_handshake = False + self._sent_non_empty_bytes = False + self._original_socket: MocketSocket = self + + def read(self, buffersize: int | None = None) -> bytes: + rv = self.io.read(buffersize) + if rv: + self._sent_non_empty_bytes = True + if self._did_handshake and not self._sent_non_empty_bytes: + raise ssl.SSLWantReadError("The operation did not complete (read)") + return rv + + def write(self, data: bytes) -> int | None: + return self.send(encode_to_bytes(data)) + + def do_handshake(self) -> None: + self._did_handshake = True + + def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: + if not (self._host and self._port): + self._address = self._host, self._port = Mocket._address + + now = datetime.now() + shift = now + timedelta(days=30 * 12) + return { + "notAfter": shift.strftime("%b %d %H:%M:%S GMT"), + "subjectAltName": ( + ("DNS", f"*.{self._host}"), + ("DNS", self._host), + ("DNS", "*"), + ), + "subject": ( + (("organizationName", f"*.{self._host}"),), + (("organizationalUnitName", "Domain Control Validated"),), + (("commonName", f"*.{self._host}"),), + ), + } + + def ciper(self) -> tuple[str, str, str]: + return ("ADH", "AES256", "SHA") + + def compression(self) -> str | None: + return ssl.OP_NO_COMPRESSION + + def unwrap(self) -> MocketSocket: + return self._original_socket From 14478af8c638a3947bfc6eb9c5f41a74af23c95c Mon Sep 17 00:00:00 2001 From: betaboon Date: Wed, 20 Nov 2024 10:53:36 +0100 Subject: [PATCH 23/35] Refactor rename ssl classes (#266) * refactor: rename MocketSocketCore to MocketSocketIO * refactor: rename FakeSSLContext to MocketSSLContext --- mocket/__init__.py | 6 +++++- mocket/inject.py | 14 +++++++------- mocket/io.py | 2 +- mocket/plugins/aiohttp_connector.py | 6 +++--- mocket/socket.py | 8 ++++---- mocket/ssl/context.py | 4 ++-- mocket/utils.py | 2 +- 7 files changed, 23 insertions(+), 19 deletions(-) diff --git a/mocket/__init__.py b/mocket/__init__.py index 58993a24..53064434 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -2,7 +2,10 @@ from mocket.entry import MocketEntry from mocket.mocket import Mocket from mocket.mocketizer import Mocketizer, mocketize -from mocket.ssl.context import FakeSSLContext +from mocket.ssl.context import MocketSSLContext + +# NOTE this is here for backwards-compat to keep old import-paths working +from mocket.ssl.context import MocketSSLContext as FakeSSLContext __all__ = ( "async_mocketize", @@ -10,6 +13,7 @@ "Mocket", "MocketEntry", "Mocketizer", + "MocketSSLContext", "FakeSSLContext", ) diff --git a/mocket/inject.py b/mocket/inject.py index b733dd3c..35e9da01 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -29,7 +29,7 @@ def enable( mock_socketpair, mock_urllib3_match_hostname, ) - from mocket.ssl.context import FakeSSLContext + from mocket.ssl.context import MocketSSLContext Mocket._namespace = namespace Mocket._truesocket_recording_dir = truesocket_recording_dir @@ -48,21 +48,21 @@ def enable( socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = FakeSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = FakeSSLContext + ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket + ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - FakeSSLContext.wrap_socket + MocketSSLContext.wrap_socket ) urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket + ] = MocketSSLContext.wrap_socket urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - FakeSSLContext.wrap_socket + MocketSSLContext.wrap_socket ) urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ "ssl_wrap_socket" - ] = FakeSSLContext.wrap_socket + ] = MocketSSLContext.wrap_socket urllib3.connection.match_hostname = urllib3.connection.__dict__[ "match_hostname" ] = mock_urllib3_match_hostname diff --git a/mocket/io.py b/mocket/io.py index 648b16dd..0334410b 100644 --- a/mocket/io.py +++ b/mocket/io.py @@ -4,7 +4,7 @@ from mocket.mocket import Mocket -class MocketSocketCore(io.BytesIO): +class MocketSocketIO(io.BytesIO): def __init__(self, address) -> None: self._address = address super().__init__() diff --git a/mocket/plugins/aiohttp_connector.py b/mocket/plugins/aiohttp_connector.py index 353c3af7..cde5019a 100644 --- a/mocket/plugins/aiohttp_connector.py +++ b/mocket/plugins/aiohttp_connector.py @@ -1,6 +1,6 @@ import contextlib -from mocket import FakeSSLContext +from mocket import MocketSSLContext with contextlib.suppress(ModuleNotFoundError): from aiohttp import ClientRequest @@ -14,5 +14,5 @@ class MocketTCPConnector(TCPConnector): slightly patching the `ClientSession` while testing. """ - def _get_ssl_context(self, req: ClientRequest) -> FakeSSLContext: - return FakeSSLContext() + def _get_ssl_context(self, req: ClientRequest) -> MocketSSLContext: + return MocketSSLContext() diff --git a/mocket/socket.py b/mocket/socket.py index 2ce74c09..e79c86c8 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -18,7 +18,7 @@ from mocket.compat import decode_from_bytes, encode_to_bytes from mocket.entry import MocketEntry -from mocket.io import MocketSocketCore +from mocket.io import MocketSocketIO from mocket.mocket import Mocket from mocket.mode import MocketMode from mocket.types import ( @@ -148,9 +148,9 @@ def proto(self) -> int: return self._proto @property - def io(self) -> MocketSocketCore: + def io(self) -> MocketSocketIO: if self._io is None: - self._io = MocketSocketCore((self._host, self._port)) + self._io = MocketSocketIO((self._host, self._port)) return self._io def fileno(self) -> int: @@ -196,7 +196,7 @@ def connect(self, address: Address) -> None: self._address = self._host, self._port = address Mocket._address = address - def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketCore: + def makefile(self, mode: str = "r", bufsize: int = -1) -> MocketSocketIO: return self.io def get_entry(self, data: bytes) -> MocketEntry | None: diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index e5f60c0a..438faa10 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -27,7 +27,7 @@ true_urllib3_wrap_socket = urllib3_wrap_socket -class SuperFakeSSLContext: +class _MocketSSLContext: """For Python 3.6 and newer.""" class FakeSetter(int): @@ -40,7 +40,7 @@ def __set__(self, *args: Any) -> None: verify_flags = FakeSetter() -class FakeSSLContext(SuperFakeSSLContext): +class MocketSSLContext(_MocketSSLContext): DUMMY_METHODS = ( "load_default_certs", "load_verify_locations", diff --git a/mocket/utils.py b/mocket/utils.py index f94b78f7..52777687 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -7,7 +7,7 @@ from mocket.compat import decode_from_bytes, encode_to_bytes # NOTE this is here for backwards-compat to keep old import-paths working -from mocket.io import MocketSocketCore as MocketSocketCore +from mocket.io import MocketSocketIO as MocketSocketCore # NOTE this is here for backwards-compat to keep old import-paths working from mocket.mode import MocketMode as MocketMode From 0da27224ad800297c4d120e740e1ba263da0327a Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Wed, 20 Nov 2024 11:02:33 +0100 Subject: [PATCH 24/35] Changes from `ruff`. (#267) --- .pre-commit-config.yaml | 4 ++-- mocket/socket.py | 3 --- mocket/utils.py | 12 +++++++++++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index eae15d12..74e4cdae 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -5,7 +5,7 @@ repos: - id: forbid-crlf - id: remove-crlf - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.6.0 + rev: v5.0.0 hooks: - id: trailing-whitespace - id: end-of-file-fixer @@ -15,7 +15,7 @@ repos: exclude: helm/ args: [ --unsafe ] - repo: https://github.com/charliermarsh/ruff-pre-commit - rev: "v0.4.4" + rev: "v0.7.4" hooks: - id: ruff args: [--fix, --exit-non-zero-on-fix] diff --git a/mocket/socket.py b/mocket/socket.py index e79c86c8..03c6f7e5 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -7,8 +7,6 @@ import os import select import socket -import ssl -from datetime import datetime, timedelta from json.decoder import JSONDecodeError from types import TracebackType from typing import Any, Type @@ -25,7 +23,6 @@ Address, ReadableBuffer, WriteableBuffer, - _PeerCertRetDictType, _RetAddress, ) from mocket.utils import hexdump, hexload diff --git a/mocket/utils.py b/mocket/utils.py index 52777687..b9e2c259 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -10,7 +10,7 @@ from mocket.io import MocketSocketIO as MocketSocketCore # NOTE this is here for backwards-compat to keep old import-paths working -from mocket.mode import MocketMode as MocketMode +from mocket.mode import MocketMode SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 @@ -42,3 +42,13 @@ def get_mocketize(wrapper_: Callable) -> Callable: wrapper_, kwsyntax=True, ) + + +__all__ = ( + "MocketSocketCore", + "MocketMode", + "SSL_PROTOCOL", + "hexdump", + "hexload", + "get_mocketize", +) From a5b5e34b8f981e033bf9694f4fec53481d933ade Mon Sep 17 00:00:00 2001 From: betaboon Date: Mon, 25 Nov 2024 11:57:24 +0100 Subject: [PATCH 25/35] improve injection code, make backwards compat explicit, make ssl-api explicit (#268) * refactor: make injection code more readable and make backwards-compat explicit * refactor: move ssl socket-wrapping code to ssl/socket.py * refactor: convert MocketSSLContext.wrap_socket and wrap_bio to instance-methods * refactor: MocketSSLSocket use proper ssl-context instead of urllib3 --- mocket/inject.py | 149 +++++++++++++++++------------------------- mocket/socket.py | 11 ---- mocket/ssl/context.py | 60 +++++------------ mocket/ssl/socket.py | 32 +++++++++ mocket/urllib3.py | 20 ++++++ mocket/utils.py | 4 +- tests/test_mode.py | 2 +- 7 files changed, 132 insertions(+), 146 deletions(-) create mode 100644 mocket/urllib3.py diff --git a/mocket/inject.py b/mocket/inject.py index 35e9da01..469ab30b 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -1,24 +1,32 @@ from __future__ import annotations +import contextlib import os import socket import ssl +from types import ModuleType +from typing import Any import urllib3 -try: # pragma: no cover - from urllib3.contrib.pyopenssl import extract_from_urllib3, inject_into_urllib3 +_patches_restore: dict[tuple[ModuleType, str], Any] = {} - pyopenssl_override = True -except ImportError: - pyopenssl_override = False + +def _patch(module: ModuleType, name: str, patched_value: Any) -> None: + with contextlib.suppress(KeyError): + original_value, module.__dict__[name] = module.__dict__[name], patched_value + _patches_restore[(module, name)] = original_value + + +def _restore(module: ModuleType, name: str) -> None: + if original_value := _patches_restore.pop((module, name)): + module.__dict__[name] = original_value def enable( namespace: str | None = None, truesocket_recording_dir: str | None = None, ) -> None: - from mocket.mocket import Mocket from mocket.socket import ( MocketSocket, mock_create_connection, @@ -27,99 +35,62 @@ def enable( mock_gethostname, mock_inet_pton, mock_socketpair, - mock_urllib3_match_hostname, ) - from mocket.ssl.context import MocketSSLContext + from mocket.ssl.context import MocketSSLContext, mock_wrap_socket + from mocket.urllib3 import ( + mock_match_hostname as mock_urllib3_match_hostname, + ) + from mocket.urllib3 import ( + mock_ssl_wrap_socket as mock_urllib3_ssl_wrap_socket, + ) + + patches = { + # stdlib: socket + (socket, "socket"): MocketSocket, + (socket, "create_connection"): mock_create_connection, + (socket, "getaddrinfo"): mock_getaddrinfo, + (socket, "gethostbyname"): mock_gethostbyname, + (socket, "gethostname"): mock_gethostname, + (socket, "inet_pton"): mock_inet_pton, + (socket, "SocketType"): MocketSocket, + (socket, "socketpair"): mock_socketpair, + # stdlib: ssl + (ssl, "SSLContext"): MocketSSLContext, + (ssl, "wrap_socket"): mock_wrap_socket, # python < 3.12.0 + # urllib3 + (urllib3.connection, "match_hostname"): mock_urllib3_match_hostname, + (urllib3.connection, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "ssl_wrap_socket"): mock_urllib3_ssl_wrap_socket, + (urllib3.util.ssl_, "wrap_socket"): mock_urllib3_ssl_wrap_socket, # urllib3 < 2 + } + + for (module, name), new_value in patches.items(): + _patch(module, name, new_value) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import extract_from_urllib3 + + extract_from_urllib3() + + from mocket.mocket import Mocket Mocket._namespace = namespace Mocket._truesocket_recording_dir = truesocket_recording_dir - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): # JSON dumps will be saved here raise AssertionError - socket.socket = socket.__dict__["socket"] = MocketSocket - socket._socketobject = socket.__dict__["_socketobject"] = MocketSocket - socket.SocketType = socket.__dict__["SocketType"] = MocketSocket - socket.create_connection = socket.__dict__["create_connection"] = ( - mock_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = mock_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = mock_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = mock_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = mock_socketpair - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = MocketSSLContext.wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = MocketSSLContext - socket.inet_pton = socket.__dict__["inet_pton"] = mock_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - MocketSSLContext.wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = MocketSSLContext.wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - MocketSSLContext.wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = MocketSSLContext.wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = mock_urllib3_match_hostname - if pyopenssl_override: # pragma: no cover - # Take out the pyopenssl version - use the default implementation - extract_from_urllib3() - def disable() -> None: + for module, name in list(_patches_restore.keys()): + _restore(module, name) + + with contextlib.suppress(ImportError): + from urllib3.contrib.pyopenssl import inject_into_urllib3 + + inject_into_urllib3() + from mocket.mocket import Mocket - from mocket.socket import ( - true_create_connection, - true_getaddrinfo, - true_gethostbyname, - true_gethostname, - true_inet_pton, - true_socket, - true_socketpair, - true_urllib3_match_hostname, - ) - from mocket.ssl.context import ( - true_ssl_context, - true_ssl_wrap_socket, - true_urllib3_ssl_wrap_socket, - true_urllib3_wrap_socket, - ) - socket.socket = socket.__dict__["socket"] = true_socket - socket._socketobject = socket.__dict__["_socketobject"] = true_socket - socket.SocketType = socket.__dict__["SocketType"] = true_socket - socket.create_connection = socket.__dict__["create_connection"] = ( - true_create_connection - ) - socket.gethostname = socket.__dict__["gethostname"] = true_gethostname - socket.gethostbyname = socket.__dict__["gethostbyname"] = true_gethostbyname - socket.getaddrinfo = socket.__dict__["getaddrinfo"] = true_getaddrinfo - socket.socketpair = socket.__dict__["socketpair"] = true_socketpair - if true_ssl_wrap_socket: - ssl.wrap_socket = ssl.__dict__["wrap_socket"] = true_ssl_wrap_socket - ssl.SSLContext = ssl.__dict__["SSLContext"] = true_ssl_context - socket.inet_pton = socket.__dict__["inet_pton"] = true_inet_pton - urllib3.util.ssl_.wrap_socket = urllib3.util.ssl_.__dict__["wrap_socket"] = ( - true_urllib3_wrap_socket - ) - urllib3.util.ssl_.ssl_wrap_socket = urllib3.util.ssl_.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.util.ssl_wrap_socket = urllib3.util.__dict__["ssl_wrap_socket"] = ( - true_urllib3_ssl_wrap_socket - ) - urllib3.connection.ssl_wrap_socket = urllib3.connection.__dict__[ - "ssl_wrap_socket" - ] = true_urllib3_ssl_wrap_socket - urllib3.connection.match_hostname = urllib3.connection.__dict__[ - "match_hostname" - ] = true_urllib3_match_hostname Mocket.reset() - if pyopenssl_override: # pragma: no cover - # Put the pyopenssl version back in place - inject_into_urllib3() diff --git a/mocket/socket.py b/mocket/socket.py index 03c6f7e5..9480d365 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -11,7 +11,6 @@ from types import TracebackType from typing import Any, Type -import urllib3.connection from typing_extensions import Self from mocket.compat import decode_from_bytes, encode_to_bytes @@ -27,14 +26,8 @@ ) from mocket.utils import hexdump, hexload -true_create_connection = socket.create_connection -true_getaddrinfo = socket.getaddrinfo true_gethostbyname = socket.gethostbyname -true_gethostname = socket.gethostname -true_inet_pton = socket.inet_pton true_socket = socket.socket -true_socketpair = socket.socketpair -true_urllib3_match_hostname = urllib3.connection.match_hostname xxh32 = None @@ -84,10 +77,6 @@ def mock_socketpair(*args, **kwargs): return _socket.socketpair(*args, **kwargs) -def mock_urllib3_match_hostname(*args: Any) -> None: - return None - - def _hash_request(h, req): return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() diff --git a/mocket/ssl/context.py b/mocket/ssl/context.py index 438faa10..6d5e7307 100644 --- a/mocket/ssl/context.py +++ b/mocket/ssl/context.py @@ -1,31 +1,10 @@ from __future__ import annotations -import contextlib -import ssl from typing import Any -import urllib3.util.ssl_ - from mocket.socket import MocketSocket from mocket.ssl.socket import MocketSSLSocket -true_ssl_context = ssl.SSLContext - -true_ssl_wrap_socket = None -true_urllib3_ssl_wrap_socket = urllib3.util.ssl_.ssl_wrap_socket -true_urllib3_wrap_socket = None - -with contextlib.suppress(ImportError): - # from Py3.12 it's only under SSLContext - from ssl import wrap_socket as ssl_wrap_socket - - true_ssl_wrap_socket = ssl_wrap_socket - -with contextlib.suppress(ImportError): - from urllib3.util.ssl_ import wrap_socket as urllib3_wrap_socket - - true_urllib3_wrap_socket = urllib3_wrap_socket - class _MocketSSLContext: """For Python 3.6 and newer.""" @@ -70,30 +49,16 @@ def dummy_method(*args: Any, **kwargs: Any) -> Any: for m in self.DUMMY_METHODS: setattr(self, m, dummy_method) - @staticmethod - def wrap_socket(sock: MocketSocket, *args: Any, **kwargs: Any) -> MocketSSLSocket: - ssl_socket = MocketSSLSocket() - ssl_socket._original_socket = sock - - ssl_socket._true_socket = true_urllib3_ssl_wrap_socket( - sock._true_socket, - **kwargs, - ) - ssl_socket._kwargs = kwargs - - ssl_socket._timeout = sock._timeout - - ssl_socket._host = sock._host - ssl_socket._port = sock._port - ssl_socket._address = sock._address - - ssl_socket._io = sock._io - ssl_socket._entry = sock._entry - - return ssl_socket + def wrap_socket( + self, + sock: MocketSocket, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: + return MocketSSLSocket._create(sock, *args, **kwargs) - @staticmethod def wrap_bio( + self, incoming: Any, # _ssl.MemoryBIO outgoing: Any, # _ssl.MemoryBIO server_side: bool = False, @@ -102,3 +67,12 @@ def wrap_bio( ssl_obj = MocketSSLSocket() ssl_obj._host = server_hostname return ssl_obj + + +def mock_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py index e50b7320..f7f41761 100644 --- a/mocket/ssl/socket.py +++ b/mocket/ssl/socket.py @@ -60,3 +60,35 @@ def compression(self) -> str | None: def unwrap(self) -> MocketSocket: return self._original_socket + + @classmethod + def _create( + cls, + sock: MocketSocket, + ssl_context: ssl.SSLContext | None = None, + server_hostname: str | None = None, + *args: Any, + **kwargs: Any, + ) -> MocketSSLSocket: + ssl_socket = MocketSSLSocket() + ssl_socket._original_socket = sock + ssl_socket._true_socket = sock._true_socket + + if ssl_context: + ssl_socket._true_socket = ssl_context.wrap_socket( + sock=ssl_socket._true_socket, + server_hostname=server_hostname, + ) + + ssl_socket._kwargs = kwargs + + ssl_socket._timeout = sock._timeout + + ssl_socket._host = sock._host + ssl_socket._port = sock._port + ssl_socket._address = sock._address + + ssl_socket._io = sock._io + ssl_socket._entry = sock._entry + + return ssl_socket diff --git a/mocket/urllib3.py b/mocket/urllib3.py new file mode 100644 index 00000000..e89bc7b5 --- /dev/null +++ b/mocket/urllib3.py @@ -0,0 +1,20 @@ +from __future__ import annotations + +from typing import Any + +from mocket.socket import MocketSocket +from mocket.ssl.context import MocketSSLContext +from mocket.ssl.socket import MocketSSLSocket + + +def mock_match_hostname(*args: Any) -> None: + return None + + +def mock_ssl_wrap_socket( + sock: MocketSocket, + *args: Any, + **kwargs: Any, +) -> MocketSSLSocket: + context = MocketSSLContext() + return context.wrap_socket(sock, *args, **kwargs) diff --git a/mocket/utils.py b/mocket/utils.py index b9e2c259..59403954 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -45,10 +45,10 @@ def get_mocketize(wrapper_: Callable) -> Callable: __all__ = ( - "MocketSocketCore", "MocketMode", + "MocketSocketCore", "SSL_PROTOCOL", + "get_mocketize", "hexdump", "hexload", - "get_mocketize", ) diff --git a/tests/test_mode.py b/tests/test_mode.py index 2a764949..ea5905b0 100644 --- a/tests/test_mode.py +++ b/tests/test_mode.py @@ -4,7 +4,7 @@ from mocket import Mocketizer, mocketize from mocket.exceptions import StrictMocketException from mocket.mockhttp import Entry, Response -from mocket.utils import MocketMode +from mocket.mode import MocketMode @mocketize(strict_mode=True) From e529319e1502b367843ab6364795c91aa096be2e Mon Sep 17 00:00:00 2001 From: betaboon Date: Tue, 26 Nov 2024 11:09:19 +0100 Subject: [PATCH 26/35] Refactor introduce recording storage (#274) * refactor: separate injection and enable/disable logic * refactor: add class that handles request records --- .github/workflows/main.yml | 2 +- mocket/inject.py | 18 +---- mocket/mocket.py | 46 ++++++++++-- mocket/recording.py | 147 +++++++++++++++++++++++++++++++++++++ mocket/socket.py | 134 ++++++++++----------------------- mocket/utils.py | 13 +--- 6 files changed, 230 insertions(+), 130 deletions(-) create mode 100644 mocket/recording.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index c4481efc..cdb55fe0 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -22,7 +22,7 @@ jobs: runs-on: ubuntu-20.04 strategy: matrix: - python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', ' 3.13', 'pypy3.10'] + python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10'] steps: - uses: actions/checkout@v4 diff --git a/mocket/inject.py b/mocket/inject.py index 469ab30b..866ee563 100644 --- a/mocket/inject.py +++ b/mocket/inject.py @@ -1,7 +1,6 @@ from __future__ import annotations import contextlib -import os import socket import ssl from types import ModuleType @@ -23,10 +22,7 @@ def _restore(module: ModuleType, name: str) -> None: module.__dict__[name] = original_value -def enable( - namespace: str | None = None, - truesocket_recording_dir: str | None = None, -) -> None: +def enable() -> None: from mocket.socket import ( MocketSocket, mock_create_connection, @@ -73,14 +69,6 @@ def enable( extract_from_urllib3() - from mocket.mocket import Mocket - - Mocket._namespace = namespace - Mocket._truesocket_recording_dir = truesocket_recording_dir - if truesocket_recording_dir and not os.path.isdir(truesocket_recording_dir): - # JSON dumps will be saved here - raise AssertionError - def disable() -> None: for module, name in list(_patches_restore.keys()): @@ -90,7 +78,3 @@ def disable() -> None: from urllib3.contrib.pyopenssl import inject_into_urllib3 inject_into_urllib3() - - from mocket.mocket import Mocket - - Mocket.reset() diff --git a/mocket/mocket.py b/mocket/mocket.py index 3476902d..a01a7b46 100644 --- a/mocket/mocket.py +++ b/mocket/mocket.py @@ -3,9 +3,11 @@ import collections import itertools import os +from pathlib import Path from typing import TYPE_CHECKING, ClassVar import mocket.inject +from mocket.recording import MocketRecordStorage # NOTE this is here for backwards-compat to keep old import-paths working # from mocket.socket import MocketSocket as MocketSocket @@ -20,11 +22,36 @@ class Mocket: _address: ClassVar[Address] = (None, None) _entries: ClassVar[dict[Address, list[MocketEntry]]] = collections.defaultdict(list) _requests: ClassVar[list] = [] - _namespace: ClassVar[str] = str(id(_entries)) - _truesocket_recording_dir: ClassVar[str | None] = None + _record_storage: ClassVar[MocketRecordStorage | None] = None - enable = mocket.inject.enable - disable = mocket.inject.disable + @classmethod + def enable( + cls, + namespace: str | None = None, + truesocket_recording_dir: str | None = None, + ) -> None: + if namespace is None: + namespace = str(id(cls._entries)) + + if truesocket_recording_dir is not None: + recording_dir = Path(truesocket_recording_dir) + + if not recording_dir.is_dir(): + # JSON dumps will be saved here + raise AssertionError + + cls._record_storage = MocketRecordStorage( + directory=recording_dir, + namespace=namespace, + ) + + mocket.inject.enable() + + @classmethod + def disable(cls) -> None: + cls.reset() + + mocket.inject.disable() @classmethod def get_pair(cls, address: Address) -> tuple[int, int] | tuple[None, None]: @@ -69,6 +96,7 @@ def reset(cls) -> None: cls._socket_pairs = {} cls._entries = collections.defaultdict(list) cls._requests = [] + cls._record_storage = None @classmethod def last_request(cls): @@ -89,12 +117,16 @@ def has_requests(cls) -> bool: return bool(cls.request_list()) @classmethod - def get_namespace(cls) -> str: - return cls._namespace + def get_namespace(cls) -> str | None: + if not cls._record_storage: + return None + return cls._record_storage.namespace @classmethod def get_truesocket_recording_dir(cls) -> str | None: - return cls._truesocket_recording_dir + if not cls._record_storage: + return None + return str(cls._record_storage.directory) @classmethod def assert_fail_if_entries_not_served(cls) -> None: diff --git a/mocket/recording.py b/mocket/recording.py new file mode 100644 index 00000000..97d2adbe --- /dev/null +++ b/mocket/recording.py @@ -0,0 +1,147 @@ +from __future__ import annotations + +import contextlib +import hashlib +import json +from collections import defaultdict +from dataclasses import dataclass +from pathlib import Path + +from mocket.compat import decode_from_bytes, encode_to_bytes +from mocket.types import Address +from mocket.utils import hexdump, hexload + +hash_function = hashlib.md5 + +with contextlib.suppress(ImportError): + from xxhash_cffi import xxh32 as xxhash_cffi_xxh32 + + hash_function = xxhash_cffi_xxh32 + +with contextlib.suppress(ImportError): + from xxhash import xxh32 as xxhash_xxh32 + + hash_function = xxhash_xxh32 + + +def _hash_prepare_request(data: bytes) -> bytes: + _data = decode_from_bytes(data) + return encode_to_bytes("".join(sorted(_data.split("\r\n")))) + + +def _hash_request(data: bytes) -> str: + _data = _hash_prepare_request(data) + return hash_function(_data).hexdigest() + + +def _hash_request_fallback(data: bytes) -> str: + _data = _hash_prepare_request(data) + return hashlib.md5(_data).hexdigest() + + +@dataclass +class MocketRecord: + host: str + port: int + request: bytes + response: bytes + + +class MocketRecordStorage: + def __init__(self, directory: Path, namespace: str) -> None: + self._directory = directory + self._namespace = namespace + self._records: defaultdict[Address, defaultdict[str, MocketRecord]] = ( + defaultdict(defaultdict) + ) + + self._load() + + @property + def directory(self) -> Path: + return self._directory + + @property + def namespace(self) -> str: + return self._namespace + + @property + def file(self) -> Path: + return self._directory / f"{self._namespace}.json" + + def _load(self) -> None: + if not self.file.exists(): + return + + json_data = self.file.read_text() + records = json.loads(json_data) + for host, port_signature_record in records.items(): + for port, signature_record in port_signature_record.items(): + for signature, record in signature_record.items(): + # NOTE backward-compat + try: + request_data = hexload(record["request"]) + except ValueError: + request_data = record["request"] + + self._records[(host, int(port))][signature] = MocketRecord( + host=host, + port=port, + request=request_data, + response=hexload(record["response"]), + ) + + def _save(self) -> None: + data: dict[str, dict[str, dict[str, dict[str, str]]]] = defaultdict( + lambda: defaultdict(defaultdict) + ) + for address, signature_record in self._records.items(): + host, port = address + for signature, record in signature_record.items(): + data[host][str(port)][signature] = dict( + request=decode_from_bytes(record.request), + response=hexdump(record.response), + ) + + json_data = json.dumps(data, indent=4, sort_keys=True) + self.file.parent.mkdir(exist_ok=True) + self.file.write_text(json_data) + + def get_records(self, address: Address) -> list[MocketRecord]: + return list(self._records[address].values()) + + def get_record(self, address: Address, request: bytes) -> MocketRecord | None: + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + return self._records[address].get(request_signature_fallback) + + request_signature = _hash_request(request) + if request_signature in self._records[address]: + return self._records[address][request_signature] + + return None + + def put_record( + self, + address: Address, + request: bytes, + response: bytes, + ) -> None: + host, port = address + record = MocketRecord( + host=host, + port=port, + request=request, + response=response, + ) + + # NOTE for backward-compat + request_signature_fallback = _hash_request_fallback(request) + if request_signature_fallback in self._records[address]: + self._records[address][request_signature_fallback] = record + return + + request_signature = _hash_request(request) + self._records[address][request_signature] = record + self._save() diff --git a/mocket/socket.py b/mocket/socket.py index 9480d365..3b1862e2 100644 --- a/mocket/socket.py +++ b/mocket/socket.py @@ -2,18 +2,14 @@ import contextlib import errno -import hashlib -import json import os import select import socket -from json.decoder import JSONDecodeError from types import TracebackType from typing import Any, Type from typing_extensions import Self -from mocket.compat import decode_from_bytes, encode_to_bytes from mocket.entry import MocketEntry from mocket.io import MocketSocketIO from mocket.mocket import Mocket @@ -24,21 +20,11 @@ WriteableBuffer, _RetAddress, ) -from mocket.utils import hexdump, hexload true_gethostbyname = socket.gethostbyname true_socket = socket.socket -xxh32 = None -try: - from xxhash import xxh32 -except ImportError: # pragma: no cover - with contextlib.suppress(ImportError): - from xxhash_cffi import xxh32 -hasher = xxh32 or hashlib.md5 - - def mock_create_connection(address, timeout=None, source_address=None): s = socket.socket(socket.AF_INET, socket.SOCK_STREAM, socket.IPPROTO_TCP) if timeout: @@ -77,10 +63,6 @@ def mock_socketpair(*args, **kwargs): return _socket.socketpair(*args, **kwargs) -def _hash_request(h, req): - return h(encode_to_bytes("".join(sorted(req.split("\r\n"))))).hexdigest() - - class MocketSocket: def __init__( self, @@ -235,87 +217,47 @@ def recv(self, buffersize: int, flags: int | None = None) -> bytes: exc.args = (0,) raise exc - def true_sendall(self, data: ReadableBuffer, *args: Any, **kwargs: Any) -> int: - if not MocketMode().is_allowed((self._host, self._port)): + def true_sendall(self, data: bytes, *args: Any, **kwargs: Any) -> bytes: + if not MocketMode().is_allowed(self._address): MocketMode.raise_not_allowed() - req = decode_from_bytes(data) - # make request unique again - req_signature = _hash_request(hasher, req) - # port should be always a string - port = str(self._port) - - # prepare responses dictionary - responses = {} - - if Mocket.get_truesocket_recording_dir(): - path = os.path.join( - Mocket.get_truesocket_recording_dir(), - Mocket.get_namespace() + ".json", + # try to get the response from recordings + if Mocket._record_storage: + record = Mocket._record_storage.get_record( + address=self._address, + request=data, ) - # check if there's already a recorded session dumped to a JSON file - try: - with open(path) as f: - responses = json.load(f) - # if not, create a new dictionary - except (FileNotFoundError, JSONDecodeError): - pass - - try: - try: - response_dict = responses[self._host][port][req_signature] - except KeyError: - if hasher is not hashlib.md5: - # Fallback for backwards compatibility - req_signature = _hash_request(hashlib.md5, req) - response_dict = responses[self._host][port][req_signature] - else: - raise - except KeyError: - # preventing next KeyError exceptions - responses.setdefault(self._host, {}) - responses[self._host].setdefault(port, {}) - responses[self._host][port].setdefault(req_signature, {}) - response_dict = responses[self._host][port][req_signature] - - # try to get the response from the dictionary - try: - encoded_response = hexload(response_dict["response"]) - # if not available, call the real sendall - except KeyError: - host, port = self._host, self._port - host = true_gethostbyname(host) - - with contextlib.suppress(OSError, ValueError): - # already connected - self._true_socket.connect((host, port)) - self._true_socket.sendall(data, *args, **kwargs) - encoded_response = b"" - # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 - while True: - more_to_read = select.select([self._true_socket], [], [], 0.1)[0] - if not more_to_read and encoded_response: - break - new_content = self._true_socket.recv(self._buflen) - if not new_content: - break - encoded_response += new_content - - # dump the resulting dictionary to a JSON file - if Mocket.get_truesocket_recording_dir(): - # update the dictionary with request and response lines - response_dict["request"] = req - response_dict["response"] = hexdump(encoded_response) - - with open(path, mode="w") as f: - f.write( - decode_from_bytes( - json.dumps(responses, indent=4, sort_keys=True) - ) - ) - - # response back to .sendall() which writes it to the Mocket socket and flush the BytesIO - return encoded_response + if record is not None: + return record.response + + host, port = self._address + host = true_gethostbyname(host) + + with contextlib.suppress(OSError, ValueError): + # already connected + self._true_socket.connect((host, port)) + + self._true_socket.sendall(data, *args, **kwargs) + response = b"" + # https://github.com/kennethreitz/requests/blob/master/tests/testserver/server.py#L12 + while True: + more_to_read = select.select([self._true_socket], [], [], 0.1)[0] + if not more_to_read and response: + break + new_content = self._true_socket.recv(self._buflen) + if not new_content: + break + response += new_content + + # store request+response in recordings + if Mocket._record_storage: + Mocket._record_storage.put_record( + address=self._address, + request=data, + response=response, + ) + + return response def send( self, diff --git a/mocket/utils.py b/mocket/utils.py index 59403954..31557a58 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -6,12 +6,6 @@ from mocket.compat import decode_from_bytes, encode_to_bytes -# NOTE this is here for backwards-compat to keep old import-paths working -from mocket.io import MocketSocketIO as MocketSocketCore - -# NOTE this is here for backwards-compat to keep old import-paths working -from mocket.mode import MocketMode - SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 @@ -30,7 +24,10 @@ def hexload(string: str) -> bytes: True """ string_no_spaces = "".join(string.split()) - return encode_to_bytes(binascii.unhexlify(string_no_spaces)) + try: + return encode_to_bytes(binascii.unhexlify(string_no_spaces)) + except binascii.Error as e: + raise ValueError from e def get_mocketize(wrapper_: Callable) -> Callable: @@ -45,8 +42,6 @@ def get_mocketize(wrapper_: Callable) -> Callable: __all__ = ( - "MocketMode", - "MocketSocketCore", "SSL_PROTOCOL", "get_mocketize", "hexdump", From 9cfad4cdb4fcec787114f0a9cb36afac274722fd Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Tue, 26 Nov 2024 12:20:12 +0100 Subject: [PATCH 27/35] Small cleanup (#275) * Small cleanup. --- mocket/ssl/socket.py | 5 +++-- mocket/utils.py | 4 ---- 2 files changed, 3 insertions(+), 6 deletions(-) diff --git a/mocket/ssl/socket.py b/mocket/ssl/socket.py index f7f41761..6dcd7817 100644 --- a/mocket/ssl/socket.py +++ b/mocket/ssl/socket.py @@ -2,6 +2,7 @@ import ssl from datetime import datetime, timedelta +from ssl import Options from typing import Any from mocket.compat import encode_to_bytes @@ -53,9 +54,9 @@ def getpeercert(self, binary_form: bool = False) -> _PeerCertRetDictType: } def ciper(self) -> tuple[str, str, str]: - return ("ADH", "AES256", "SHA") + return "ADH", "AES256", "SHA" - def compression(self) -> str | None: + def compression(self) -> Options: return ssl.OP_NO_COMPRESSION def unwrap(self) -> MocketSocket: diff --git a/mocket/utils.py b/mocket/utils.py index 31557a58..ab293776 100644 --- a/mocket/utils.py +++ b/mocket/utils.py @@ -1,13 +1,10 @@ from __future__ import annotations import binascii -import ssl from typing import Callable from mocket.compat import decode_from_bytes, encode_to_bytes -SSL_PROTOCOL = ssl.PROTOCOL_TLSv1_2 - def hexdump(binary_string: bytes) -> str: r""" @@ -42,7 +39,6 @@ def get_mocketize(wrapper_: Callable) -> Callable: __all__ = ( - "SSL_PROTOCOL", "get_mocketize", "hexdump", "hexload", From 895e299f174710bf1ea054ace4421e1dbd0aa3c1 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Sat, 21 Dec 2024 19:21:30 +0100 Subject: [PATCH 28/35] Target `make safetest` got broken (#273) * Fix `Makefile`. * Only running tests that don't use real sockets got broken, fixed by adding CIENT SETINFO commands --- Makefile | 4 ++-- tests/test_mocket.py | 1 + tests/test_redis.py | 4 +++- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 7ba0210e..3ab33d87 100644 --- a/Makefile +++ b/Makefile @@ -31,12 +31,12 @@ types: test: types @echo "Running Python tests" uv pip uninstall pook || true - export VIRTUAL_ENV=.venv; .venv/bin/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- .venv/bin/pytest + .venv/bin/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- .venv/bin/pytest uv pip install pook && .venv/bin/pytest tests/test_pook.py && uv pip uninstall pook @echo "" safetest: - export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; make test + export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; .venv/bin/pytest publish: clean install-test-requirements uv run python3 -m build --sdist --wheel . diff --git a/tests/test_mocket.py b/tests/test_mocket.py index 8d09f170..8810a5b9 100644 --- a/tests/test_mocket.py +++ b/tests/test_mocket.py @@ -222,6 +222,7 @@ def test_patch( @pytest.mark.skipif(not psutil.POSIX, reason="Uses a POSIX-only API to test") +@pytest.mark.skipif('os.getenv("SKIP_TRUE_HTTP", False)') @pytest.mark.asyncio async def test_no_dangling_fds(): url = "http://httpbin.local/ip" diff --git a/tests/test_redis.py b/tests/test_redis.py index 50b9beac..fb6ec355 100644 --- a/tests/test_redis.py +++ b/tests/test_redis.py @@ -158,9 +158,11 @@ def setUp(self): self.rclient = redis.StrictRedis() def mocketize_setup(self): + Entry.register_response("CLIENT SETINFO LIB-NAME redis-py", OK) + Entry.register_response(f"CLIENT SETINFO LIB-VER {redis.__version__}", OK) Entry.register_response("FLUSHDB", OK) self.rclient.flushdb() - self.assertEqual(len(Mocket.request_list()), 1) + self.assertEqual(len(Mocket.request_list()), 3) Mocket.reset() @mocketize From 87f1dec8d513e6c47f842e3efd4802cfb2f6e4e2 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Sat, 21 Dec 2024 23:56:29 +0100 Subject: [PATCH 29/35] Small refactor for `Makefile`. --- Makefile | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/Makefile b/Makefile index 3ab33d87..25af4e93 100644 --- a/Makefile +++ b/Makefile @@ -1,5 +1,7 @@ #!/usr/bin/make -f +VENV_PATH = .venv/bin + install-dev-requirements: curl -LsSf https://astral.sh/uv/install.sh | sh uv venv && uv pip install hatch @@ -25,18 +27,18 @@ develop: install-dev-requirements install-test-requirements types: @echo "Type checking Python files" - .venv/bin/mypy --pretty + $(VENV_PATH)/mypy --pretty @echo "" test: types @echo "Running Python tests" uv pip uninstall pook || true - .venv/bin/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- .venv/bin/pytest - uv pip install pook && .venv/bin/pytest tests/test_pook.py && uv pip uninstall pook + $(VENV_PATH)/wait-for-it --service httpbin.local:443 --service localhost:6379 --timeout 5 -- $(VENV_PATH)/pytest + uv pip install pook && $(VENV_PATH)/pytest tests/test_pook.py && uv pip uninstall pook @echo "" safetest: - export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; .venv/bin/pytest + export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; $(VENV_PATH)/pytest publish: clean install-test-requirements uv run python3 -m build --sdist --wheel . From 815a20ff1f7a0d5dae771d81f7856b0fdf9ff7df Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Tue, 24 Dec 2024 10:03:27 +0100 Subject: [PATCH 30/35] Releasing `beta` version after all the refactors. --- mocket/__init__.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mocket/__init__.py b/mocket/__init__.py index 53064434..faac03e3 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -17,4 +17,4 @@ "FakeSSLContext", ) -__version__ = "3.13.2" +__version__ = "3.13.3b1" From 9461bd035e2fa10b3d9732a81057bff419ddf6a4 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Tue, 24 Dec 2024 12:08:00 +0100 Subject: [PATCH 31/35] Switching to build and publish through `uv`. --- Makefile | 4 ++-- pyproject.toml | 1 - 2 files changed, 2 insertions(+), 3 deletions(-) diff --git a/Makefile b/Makefile index 25af4e93..07d5d459 100644 --- a/Makefile +++ b/Makefile @@ -41,8 +41,8 @@ safetest: export SKIP_TRUE_REDIS=1; export SKIP_TRUE_HTTP=1; $(VENV_PATH)/pytest publish: clean install-test-requirements - uv run python3 -m build --sdist --wheel . - uv run twine upload --repository mocket dist/ + uv build --package mocket --sdist --wheel + uv publish clean: rm -rf .coverage *.egg-info dist/ requirements.txt uv.lock || true diff --git a/pyproject.toml b/pyproject.toml index 77d1f5d4..5418304e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -51,7 +51,6 @@ test = [ "httpx", "pipfile", "build", - "twine", "fastapi", "aiohttp", "wait-for-it", From db33579c850a1bc01de2ca7540fab227c08b3926 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Thu, 13 Feb 2025 20:27:45 +0100 Subject: [PATCH 32/35] Update README.rst --- README.rst | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/README.rst b/README.rst index e68cbfd3..82bf6c37 100644 --- a/README.rst +++ b/README.rst @@ -24,9 +24,8 @@ A socket mock framework Outside GitHub ============== -Mocket packages are available for `Arch Linux`_, `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course from `PyPI`_. +Mocket packages are available for `openSUSE`_, `NixOS`_, `ALT Linux`_, `NetBSD`_, and of course from `PyPI`_. -.. _`Arch Linux`: https://archlinux.org/packages/extra/any/python-mocket/ .. _`openSUSE`: https://software.opensuse.org/search?baseproject=ALL&q=mocket .. _`NixOS`: https://search.nixos.org/packages?query=mocket .. _`ALT Linux`: https://packages.altlinux.org/en/sisyphus/srpms/python3-module-mocket/ From e847dd6c7c52b7f4b05e2b51b3c9afe9494b6a44 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Thu, 6 Mar 2025 12:27:35 +0100 Subject: [PATCH 33/35] Bump Ubuntu version to latest LTS --- .github/workflows/main.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index cdb55fe0..bb976ac9 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -19,7 +19,7 @@ concurrency: jobs: build: - runs-on: ubuntu-20.04 + runs-on: ubuntu-24.04 strategy: matrix: python-version: ['3.8', '3.9', '3.10', '3.11', '3.12', '3.13', 'pypy3.10'] From 71e47163c9390db0d74e24f32e4520fc82c91bb7 Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Sat, 22 Mar 2025 08:46:46 +0100 Subject: [PATCH 34/35] Moving ready-made mocks under `mocket.mocks` and decorators under `mocket.decorators`. (#278) --- mocket/__init__.py | 17 +++++++++++++++-- mocket/decorators/__init__.py | 0 mocket/{ => decorators}/async_mocket.py | 2 +- mocket/{ => decorators}/mocketizer.py | 0 mocket/mocks/__init__.py | 0 mocket/{ => mocks}/mockhttp.py | 0 mocket/{ => mocks}/mockredis.py | 0 7 files changed, 16 insertions(+), 3 deletions(-) create mode 100644 mocket/decorators/__init__.py rename mocket/{ => decorators}/async_mocket.py (89%) rename mocket/{ => decorators}/mocketizer.py (100%) create mode 100644 mocket/mocks/__init__.py rename mocket/{ => mocks}/mockhttp.py (100%) rename mocket/{ => mocks}/mockredis.py (100%) diff --git a/mocket/__init__.py b/mocket/__init__.py index faac03e3..2279bf19 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -1,12 +1,25 @@ -from mocket.async_mocket import async_mocketize +import importlib +import sys + +from mocket.decorators.async_mocket import async_mocketize +from mocket.decorators.mocketizer import Mocketizer, mocketize from mocket.entry import MocketEntry from mocket.mocket import Mocket -from mocket.mocketizer import Mocketizer, mocketize from mocket.ssl.context import MocketSSLContext # NOTE this is here for backwards-compat to keep old import-paths working from mocket.ssl.context import MocketSSLContext as FakeSSLContext +sys.modules["mocket.mockhttp"] = importlib.import_module("mocket.mocks.mockhttp") +sys.modules["mocket.mockredis"] = importlib.import_module("mocket.mocks.mockredis") +sys.modules["mocket.async_mocket"] = importlib.import_module( + "mocket.decorators.async_mocket" +) +sys.modules["mocket.mocketizer"] = importlib.import_module( + "mocket.decorators.mocketizer" +) + + __all__ = ( "async_mocketize", "mocketize", diff --git a/mocket/decorators/__init__.py b/mocket/decorators/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/async_mocket.py b/mocket/decorators/async_mocket.py similarity index 89% rename from mocket/async_mocket.py rename to mocket/decorators/async_mocket.py index 709d225f..40b763ae 100644 --- a/mocket/async_mocket.py +++ b/mocket/decorators/async_mocket.py @@ -1,4 +1,4 @@ -from mocket.mocketizer import Mocketizer +from mocket.decorators.mocketizer import Mocketizer from mocket.utils import get_mocketize diff --git a/mocket/mocketizer.py b/mocket/decorators/mocketizer.py similarity index 100% rename from mocket/mocketizer.py rename to mocket/decorators/mocketizer.py diff --git a/mocket/mocks/__init__.py b/mocket/mocks/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/mocket/mockhttp.py b/mocket/mocks/mockhttp.py similarity index 100% rename from mocket/mockhttp.py rename to mocket/mocks/mockhttp.py diff --git a/mocket/mockredis.py b/mocket/mocks/mockredis.py similarity index 100% rename from mocket/mockredis.py rename to mocket/mocks/mockredis.py From a0743f4497aff8b223e04d1ae8a7cb9293d0165e Mon Sep 17 00:00:00 2001 From: Giorgio Salluzzo Date: Sat, 22 Mar 2025 08:56:56 +0100 Subject: [PATCH 35/35] Bump version. --- mocket/__init__.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/mocket/__init__.py b/mocket/__init__.py index 2279bf19..c785bba5 100644 --- a/mocket/__init__.py +++ b/mocket/__init__.py @@ -7,7 +7,8 @@ from mocket.mocket import Mocket from mocket.ssl.context import MocketSSLContext -# NOTE this is here for backwards-compat to keep old import-paths working +# NOTE the following lines are here for backwards-compatibility, +# to keep old import-paths working from mocket.ssl.context import MocketSSLContext as FakeSSLContext sys.modules["mocket.mockhttp"] = importlib.import_module("mocket.mocks.mockhttp") @@ -30,4 +31,4 @@ "FakeSSLContext", ) -__version__ = "3.13.3b1" +__version__ = "3.13.3" pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy