Merge pull request #955 from jmolz/fix-socket-client-bare-excepts

fix: avoid socket client bare excepts
This commit is contained in:
Jacob Molz 2026-05-27 08:16:33 -04:00 committed by GitHub
parent 00dd7a4e14
commit eb24d0c60e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 83 additions and 5 deletions

View file

@ -8,6 +8,7 @@ import pytest
from unittest.mock import Mock, patch, MagicMock, call from unittest.mock import Mock, patch, MagicMock, call
import json import json
from trustgraph.api.socket_client import SocketClient
from trustgraph.api import ( from trustgraph.api import (
Api, Api,
Triple, Triple,
@ -222,6 +223,82 @@ class TestSocketClient:
for method in expected_methods: for method in expected_methods:
assert hasattr(flow_instance, method), f"Missing method: {method}" assert hasattr(flow_instance, method), f"Missing method: {method}"
def test_socket_client_close_does_not_swallow_base_exceptions(self):
"""Test close cleanup does not suppress process-level interrupts."""
class InterruptingLoop:
def is_closed(self):
return False
def run_until_complete(self, awaitable):
if hasattr(awaitable, "close"):
awaitable.close()
raise SystemExit("stop")
socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._loop = InterruptingLoop()
with pytest.raises(SystemExit):
socket.close()
@pytest.mark.parametrize(
("generator_method", "async_method"),
[
("_streaming_generator", "_send_request_async_streaming"),
("_streaming_generator_raw", "_send_request_async_streaming_raw"),
],
)
def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions(
self, generator_method, async_method
):
"""Test streaming cleanup does not suppress process-level interrupts."""
class FakeAsyncGenerator:
def __anext__(self):
return "next"
def aclose(self):
return "close"
class InterruptingLoop:
def run_until_complete(self, awaitable):
if awaitable == "next":
raise StopAsyncIteration
if awaitable == "close":
raise SystemExit("stop")
raise AssertionError(f"unexpected awaitable: {awaitable!r}")
socket = SocketClient(url="http://test/", timeout=60, token=None)
setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator())
generator = getattr(socket, generator_method)(
"agent", "default", {}, InterruptingLoop()
)
with pytest.raises(SystemExit):
next(generator)
@pytest.mark.asyncio
async def test_socket_client_reader_does_not_swallow_base_exceptions(self):
"""Test reader error fanout does not suppress process-level interrupts."""
class FailingSocket:
def __aiter__(self):
return self
async def __anext__(self):
raise ValueError("reader failed")
class InterruptingQueue:
async def put(self, message):
raise SystemExit("stop")
socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._socket = FailingSocket()
socket._pending = {"req-1": InterruptingQueue()}
with pytest.raises(SystemExit):
await socket._reader()
class TestBulkClient: class TestBulkClient:
"""Test bulk operations client""" """Test bulk operations client"""

View file

@ -11,6 +11,7 @@ multiplexes requests by ID.
import json import json
import asyncio import asyncio
import websockets import websockets
from websockets.exceptions import ConnectionClosed
from typing import Optional, Dict, Any, Iterator, Union, List from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock from threading import Lock
@ -191,13 +192,13 @@ class SocketClient:
if request_id and request_id in self._pending: if request_id and request_id in self._pending:
await self._pending[request_id].put(response) await self._pending[request_id].put(response)
except websockets.exceptions.ConnectionClosed: except ConnectionClosed:
pass pass
except Exception as e: except Exception as e:
for queue in self._pending.values(): for queue in self._pending.values():
try: try:
await queue.put({"error": str(e)}) await queue.put({"error": str(e)})
except: except Exception:
pass pass
finally: finally:
self._connected = False self._connected = False
@ -250,7 +251,7 @@ class SocketClient:
finally: finally:
try: try:
loop.run_until_complete(async_gen.aclose()) loop.run_until_complete(async_gen.aclose())
except: except Exception:
pass pass
def _streaming_generator_raw( def _streaming_generator_raw(
@ -273,7 +274,7 @@ class SocketClient:
finally: finally:
try: try:
loop.run_until_complete(async_gen.aclose()) loop.run_until_complete(async_gen.aclose())
except: except Exception:
pass pass
async def _send_request_async_streaming_raw( async def _send_request_async_streaming_raw(
@ -542,7 +543,7 @@ class SocketClient:
if self._loop and not self._loop.is_closed(): if self._loop and not self._loop.is_closed():
try: try:
self._loop.run_until_complete(self._close_async()) self._loop.run_until_complete(self._close_async())
except: except Exception:
pass pass
async def _close_async(self): async def _close_async(self):