Skip to content

Commit 38bc13d

Browse files
blink1073NoahStapp
andauthored
PYTHON-5212 [v4.12] Do not hold Topology lock while resetting pool (#2307)
Co-authored-by: Noah Stapp <noah.stapp@mongodb.com>
1 parent c6671e2 commit 38bc13d

File tree

10 files changed

+236
-25
lines changed

10 files changed

+236
-25
lines changed

doc/changelog.rst

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ Version 4.12.1 is a bug fix release.
99
- Fixed a bug that could raise ``UnboundLocalError`` when creating asynchronous connections over SSL.
1010
- Fixed a bug causing SRV hostname validation to fail when resolver and resolved hostnames are identical with three domain levels.
1111
- Fixed a bug that caused direct use of ``pymongo.uri_parser`` to raise an ``AttributeError``.
12+
- Fixed a bug where clients created with connect=False and a "mongodb+srv://" connection string
13+
could cause public ``pymongo.MongoClient`` and ``pymongo.AsyncMongoClient`` attributes (topology_description,
14+
nodes, address, primary, secondaries, arbiters) to incorrectly return a Database, leading to type
15+
errors such as: "NotImplementedError: Database objects do not implement truth value testing or bool()".
16+
- Fixed a bug where MongoDB cluster topology changes could cause asynchronous operations to take much longer to complete
17+
due to holding the Topology lock while closing stale connections.
1218

1319
Issues Resolved
1420
...............

pymongo/asynchronous/pool.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import collections
1819
import contextlib
1920
import logging
@@ -860,8 +861,14 @@ async def _reset(
860861
# PoolClosedEvent but that reset() SHOULD close sockets *after*
861862
# publishing the PoolClearedEvent.
862863
if close:
863-
for conn in sockets:
864-
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
864+
if not _IS_SYNC:
865+
await asyncio.gather(
866+
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
867+
return_exceptions=True,
868+
)
869+
else:
870+
for conn in sockets:
871+
await conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
865872
if self.enabled_for_cmap:
866873
assert listeners is not None
867874
listeners.publish_pool_closed(self.address)
@@ -891,8 +898,14 @@ async def _reset(
891898
serverPort=self.address[1],
892899
serviceId=service_id,
893900
)
894-
for conn in sockets:
895-
await conn.close_conn(ConnectionClosedReason.STALE)
901+
if not _IS_SYNC:
902+
await asyncio.gather(
903+
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
904+
return_exceptions=True,
905+
)
906+
else:
907+
for conn in sockets:
908+
await conn.close_conn(ConnectionClosedReason.STALE)
896909

897910
async def update_is_writable(self, is_writable: Optional[bool]) -> None:
898911
"""Updates the is_writable attribute on all sockets currently in the
@@ -938,8 +951,14 @@ async def remove_stale_sockets(self, reference_generation: int) -> None:
938951
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
939952
):
940953
close_conns.append(self.conns.pop())
941-
for conn in close_conns:
942-
await conn.close_conn(ConnectionClosedReason.IDLE)
954+
if not _IS_SYNC:
955+
await asyncio.gather(
956+
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
957+
return_exceptions=True,
958+
)
959+
else:
960+
for conn in close_conns:
961+
await conn.close_conn(ConnectionClosedReason.IDLE)
943962

944963
while True:
945964
async with self.size_cond:

pymongo/asynchronous/topology.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,6 @@ async def _process_change(
529529
if not _IS_SYNC:
530530
self._monitor_tasks.append(self._srv_monitor)
531531

532-
# Clear the pool from a failed heartbeat.
533-
if reset_pool:
534-
server = self._servers.get(server_description.address)
535-
if server:
536-
await server.pool.reset(interrupt_connections=interrupt_connections)
537-
538532
# Wake anything waiting in select_servers().
539533
self._condition.notify_all()
540534

@@ -557,6 +551,11 @@ async def on_change(
557551
# that didn't include this server.
558552
if self._opened and self._description.has_server(server_description.address):
559553
await self._process_change(server_description, reset_pool, interrupt_connections)
554+
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
555+
if reset_pool:
556+
server = self._servers.get(server_description.address)
557+
if server:
558+
await server.pool.reset(interrupt_connections=interrupt_connections)
560559

561560
async def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
562561
"""Process a new seedlist on an opened topology.

pymongo/synchronous/pool.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
from __future__ import annotations
1616

17+
import asyncio
1718
import collections
1819
import contextlib
1920
import logging
@@ -858,8 +859,14 @@ def _reset(
858859
# PoolClosedEvent but that reset() SHOULD close sockets *after*
859860
# publishing the PoolClearedEvent.
860861
if close:
861-
for conn in sockets:
862-
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
862+
if not _IS_SYNC:
863+
asyncio.gather(
864+
*[conn.close_conn(ConnectionClosedReason.POOL_CLOSED) for conn in sockets],
865+
return_exceptions=True,
866+
)
867+
else:
868+
for conn in sockets:
869+
conn.close_conn(ConnectionClosedReason.POOL_CLOSED)
863870
if self.enabled_for_cmap:
864871
assert listeners is not None
865872
listeners.publish_pool_closed(self.address)
@@ -889,8 +896,14 @@ def _reset(
889896
serverPort=self.address[1],
890897
serviceId=service_id,
891898
)
892-
for conn in sockets:
893-
conn.close_conn(ConnectionClosedReason.STALE)
899+
if not _IS_SYNC:
900+
asyncio.gather(
901+
*[conn.close_conn(ConnectionClosedReason.STALE) for conn in sockets],
902+
return_exceptions=True,
903+
)
904+
else:
905+
for conn in sockets:
906+
conn.close_conn(ConnectionClosedReason.STALE)
894907

895908
def update_is_writable(self, is_writable: Optional[bool]) -> None:
896909
"""Updates the is_writable attribute on all sockets currently in the
@@ -934,8 +947,14 @@ def remove_stale_sockets(self, reference_generation: int) -> None:
934947
and self.conns[-1].idle_time_seconds() > self.opts.max_idle_time_seconds
935948
):
936949
close_conns.append(self.conns.pop())
937-
for conn in close_conns:
938-
conn.close_conn(ConnectionClosedReason.IDLE)
950+
if not _IS_SYNC:
951+
asyncio.gather(
952+
*[conn.close_conn(ConnectionClosedReason.IDLE) for conn in close_conns],
953+
return_exceptions=True,
954+
)
955+
else:
956+
for conn in close_conns:
957+
conn.close_conn(ConnectionClosedReason.IDLE)
939958

940959
while True:
941960
with self.size_cond:

pymongo/synchronous/topology.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -529,12 +529,6 @@ def _process_change(
529529
if not _IS_SYNC:
530530
self._monitor_tasks.append(self._srv_monitor)
531531

532-
# Clear the pool from a failed heartbeat.
533-
if reset_pool:
534-
server = self._servers.get(server_description.address)
535-
if server:
536-
server.pool.reset(interrupt_connections=interrupt_connections)
537-
538532
# Wake anything waiting in select_servers().
539533
self._condition.notify_all()
540534

@@ -557,6 +551,11 @@ def on_change(
557551
# that didn't include this server.
558552
if self._opened and self._description.has_server(server_description.address):
559553
self._process_change(server_description, reset_pool, interrupt_connections)
554+
# Clear the pool from a failed heartbeat, done outside the lock to avoid blocking on connection close.
555+
if reset_pool:
556+
server = self._servers.get(server_description.address)
557+
if server:
558+
server.pool.reset(interrupt_connections=interrupt_connections)
560559

561560
def _process_srv_update(self, seedlist: list[tuple[str, Any]]) -> None:
562561
"""Process a new seedlist on an opened topology.

test/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -823,6 +823,14 @@ def require_sync(self, func):
823823
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
824824
)
825825

826+
def require_async(self, func):
827+
"""Run a test only if using the asynchronous API.""" # unasync: off
828+
return self._require(
829+
lambda: not _IS_SYNC,
830+
"This test only works with the asynchronous API", # unasync: off
831+
func=func,
832+
)
833+
826834
def mongos_seeds(self):
827835
return ",".join("{}:{}".format(*address) for address in self.mongoses)
828836

test/asynchronous/__init__.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -825,6 +825,14 @@ def require_sync(self, func):
825825
lambda: _IS_SYNC, "This test only works with the synchronous API", func=func
826826
)
827827

828+
def require_async(self, func):
829+
"""Run a test only if using the asynchronous API.""" # unasync: off
830+
return self._require(
831+
lambda: not _IS_SYNC,
832+
"This test only works with the asynchronous API", # unasync: off
833+
func=func,
834+
)
835+
828836
def mongos_seeds(self):
829837
return ",".join("{}:{}".format(*address) for address in self.mongoses)
830838

test/asynchronous/test_discovery_and_monitoring.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
import socketserver
2121
import sys
2222
import threading
23+
import time
2324
from asyncio import StreamReader, StreamWriter
2425
from pathlib import Path
2526
from test.asynchronous.helpers import ConcurrentRunner
2627

28+
from pymongo.asynchronous.pool import AsyncConnection
29+
from pymongo.operations import _Op
30+
from pymongo.server_selectors import writable_server_selector
31+
2732
sys.path[0:0] = [""]
2833

2934
from test.asynchronous import (
@@ -370,6 +375,74 @@ async def test_pool_unpause(self):
370375
await listener.async_wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
371376
await listener.async_wait_for_event(monitoring.PoolReadyEvent, 1)
372377

378+
@async_client_context.require_failCommand_appName
379+
@async_client_context.require_test_commands
380+
@async_client_context.require_async
381+
async def test_connection_close_does_not_block_other_operations(self):
382+
listener = CMAPHeartbeatListener()
383+
client = await self.async_single_client(
384+
appName="SDAMConnectionCloseTest",
385+
event_listeners=[listener],
386+
heartbeatFrequencyMS=500,
387+
minPoolSize=10,
388+
)
389+
server = await (await client._get_topology()).select_server(
390+
writable_server_selector, _Op.TEST
391+
)
392+
await async_wait_until(
393+
lambda: len(server._pool.conns) == 10,
394+
"pool initialized with 10 connections",
395+
)
396+
397+
await client.db.test.insert_one({"x": 1})
398+
close_delay = 0.1
399+
latencies = []
400+
should_exit = []
401+
402+
async def run_task():
403+
while True:
404+
start_time = time.monotonic()
405+
await client.db.test.find_one({})
406+
elapsed = time.monotonic() - start_time
407+
latencies.append(elapsed)
408+
if should_exit:
409+
break
410+
await asyncio.sleep(0.001)
411+
412+
task = ConcurrentRunner(target=run_task)
413+
await task.start()
414+
original_close = AsyncConnection.close_conn
415+
try:
416+
# Artificially delay the close operation to simulate a slow close
417+
async def mock_close(self, reason):
418+
await asyncio.sleep(close_delay)
419+
await original_close(self, reason)
420+
421+
AsyncConnection.close_conn = mock_close
422+
423+
fail_hello = {
424+
"mode": {"times": 4},
425+
"data": {
426+
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
427+
"errorCode": 91,
428+
"appName": "SDAMConnectionCloseTest",
429+
},
430+
}
431+
async with self.fail_point(fail_hello):
432+
# Wait for server heartbeat to fail
433+
await listener.async_wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
434+
# Wait until all idle connections are closed to simulate real-world conditions
435+
await listener.async_wait_for_event(monitoring.ConnectionClosedEvent, 10)
436+
# Wait for one more find to complete after the pool has been reset, then shutdown the task
437+
n = len(latencies)
438+
await async_wait_until(lambda: len(latencies) >= n + 1, "run one more find")
439+
should_exit.append(True)
440+
await task.join()
441+
# No operation latency should not significantly exceed close_delay
442+
self.assertLessEqual(max(latencies), close_delay * 5.0)
443+
finally:
444+
AsyncConnection.close_conn = original_close
445+
373446

374447
class TestServerMonitoringMode(AsyncIntegrationTest):
375448
@async_client_context.require_no_serverless

test/test_discovery_and_monitoring.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,10 +20,15 @@
2020
import socketserver
2121
import sys
2222
import threading
23+
import time
2324
from asyncio import StreamReader, StreamWriter
2425
from pathlib import Path
2526
from test.helpers import ConcurrentRunner
2627

28+
from pymongo.operations import _Op
29+
from pymongo.server_selectors import writable_server_selector
30+
from pymongo.synchronous.pool import Connection
31+
2732
sys.path[0:0] = [""]
2833

2934
from test import (
@@ -370,6 +375,72 @@ def test_pool_unpause(self):
370375
listener.wait_for_event(monitoring.ServerHeartbeatSucceededEvent, 1)
371376
listener.wait_for_event(monitoring.PoolReadyEvent, 1)
372377

378+
@client_context.require_failCommand_appName
379+
@client_context.require_test_commands
380+
@client_context.require_async
381+
def test_connection_close_does_not_block_other_operations(self):
382+
listener = CMAPHeartbeatListener()
383+
client = self.single_client(
384+
appName="SDAMConnectionCloseTest",
385+
event_listeners=[listener],
386+
heartbeatFrequencyMS=500,
387+
minPoolSize=10,
388+
)
389+
server = (client._get_topology()).select_server(writable_server_selector, _Op.TEST)
390+
wait_until(
391+
lambda: len(server._pool.conns) == 10,
392+
"pool initialized with 10 connections",
393+
)
394+
395+
client.db.test.insert_one({"x": 1})
396+
close_delay = 0.1
397+
latencies = []
398+
should_exit = []
399+
400+
def run_task():
401+
while True:
402+
start_time = time.monotonic()
403+
client.db.test.find_one({})
404+
elapsed = time.monotonic() - start_time
405+
latencies.append(elapsed)
406+
if should_exit:
407+
break
408+
time.sleep(0.001)
409+
410+
task = ConcurrentRunner(target=run_task)
411+
task.start()
412+
original_close = Connection.close_conn
413+
try:
414+
# Artificially delay the close operation to simulate a slow close
415+
def mock_close(self, reason):
416+
time.sleep(close_delay)
417+
original_close(self, reason)
418+
419+
Connection.close_conn = mock_close
420+
421+
fail_hello = {
422+
"mode": {"times": 4},
423+
"data": {
424+
"failCommands": [HelloCompat.LEGACY_CMD, "hello"],
425+
"errorCode": 91,
426+
"appName": "SDAMConnectionCloseTest",
427+
},
428+
}
429+
with self.fail_point(fail_hello):
430+
# Wait for server heartbeat to fail
431+
listener.wait_for_event(monitoring.ServerHeartbeatFailedEvent, 1)
432+
# Wait until all idle connections are closed to simulate real-world conditions
433+
listener.wait_for_event(monitoring.ConnectionClosedEvent, 10)
434+
# Wait for one more find to complete after the pool has been reset, then shutdown the task
435+
n = len(latencies)
436+
wait_until(lambda: len(latencies) >= n + 1, "run one more find")
437+
should_exit.append(True)
438+
task.join()
439+
# No operation latency should not significantly exceed close_delay
440+
self.assertLessEqual(max(latencies), close_delay * 5.0)
441+
finally:
442+
Connection.close_conn = original_close
443+
373444

374445
class TestServerMonitoringMode(IntegrationTest):
375446
@client_context.require_no_serverless

0 commit comments

Comments
 (0)
pFad - Phonifier reborn

Pfad - The Proxy pFad of © 2024 Garber Painting. All rights reserved.

Note: This service is not intended for secure transactions such as banking, social media, email, or purchasing. Use at your own risk. We assume no liability whatsoever for broken pages.


Alternative Proxies:

Alternative Proxy

pFad Proxy

pFad v3 Proxy

pFad v4 Proxy