From eb24d0c60e63aee14eebaafbe82a7c66eb100f64 Mon Sep 17 00:00:00 2001 From: Jacob Molz Date: Wed, 27 May 2026 08:16:33 -0400 Subject: [PATCH] Merge pull request #955 from jmolz/fix-socket-client-bare-excepts fix: avoid socket client bare excepts --- tests/unit/test_python_api_client.py | 77 +++++++++++++++++++ .../trustgraph/api/socket_client.py | 11 +-- 2 files changed, 83 insertions(+), 5 deletions(-) diff --git a/tests/unit/test_python_api_client.py b/tests/unit/test_python_api_client.py index 0b6709fb..1fea0ee6 100644 --- a/tests/unit/test_python_api_client.py +++ b/tests/unit/test_python_api_client.py @@ -8,6 +8,7 @@ import pytest from unittest.mock import Mock, patch, MagicMock, call import json +from trustgraph.api.socket_client import SocketClient from trustgraph.api import ( Api, Triple, @@ -222,6 +223,82 @@ class TestSocketClient: for method in expected_methods: assert hasattr(flow_instance, method), f"Missing method: {method}" + def test_socket_client_close_does_not_swallow_base_exceptions(self): + """Test close cleanup does not suppress process-level interrupts.""" + + class InterruptingLoop: + def is_closed(self): + return False + + def run_until_complete(self, awaitable): + if hasattr(awaitable, "close"): + awaitable.close() + raise SystemExit("stop") + + socket = SocketClient(url="http://test/", timeout=60, token=None) + socket._loop = InterruptingLoop() + + with pytest.raises(SystemExit): + socket.close() + + @pytest.mark.parametrize( + ("generator_method", "async_method"), + [ + ("_streaming_generator", "_send_request_async_streaming"), + ("_streaming_generator_raw", "_send_request_async_streaming_raw"), + ], + ) + def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions( + self, generator_method, async_method + ): + """Test streaming cleanup does not suppress process-level interrupts.""" + + class FakeAsyncGenerator: + def __anext__(self): + return "next" + + def aclose(self): + return "close" + + class InterruptingLoop: + def run_until_complete(self, awaitable): + if awaitable == "next": + raise StopAsyncIteration + if awaitable == "close": + raise SystemExit("stop") + raise AssertionError(f"unexpected awaitable: {awaitable!r}") + + socket = SocketClient(url="http://test/", timeout=60, token=None) + setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator()) + generator = getattr(socket, generator_method)( + "agent", "default", {}, InterruptingLoop() + ) + + with pytest.raises(SystemExit): + next(generator) + + @pytest.mark.asyncio + async def test_socket_client_reader_does_not_swallow_base_exceptions(self): + """Test reader error fanout does not suppress process-level interrupts.""" + + class FailingSocket: + def __aiter__(self): + return self + + async def __anext__(self): + raise ValueError("reader failed") + + class InterruptingQueue: + async def put(self, message): + raise SystemExit("stop") + + socket = SocketClient(url="http://test/", timeout=60, token=None) + socket._socket = FailingSocket() + socket._pending = {"req-1": InterruptingQueue()} + + with pytest.raises(SystemExit): + await socket._reader() + class TestBulkClient: """Test bulk operations client""" diff --git a/trustgraph-base/trustgraph/api/socket_client.py b/trustgraph-base/trustgraph/api/socket_client.py index 9874c8af..6eeb95ff 100644 --- a/trustgraph-base/trustgraph/api/socket_client.py +++ b/trustgraph-base/trustgraph/api/socket_client.py @@ -11,6 +11,7 @@ multiplexes requests by ID. import json import asyncio import websockets +from websockets.exceptions import ConnectionClosed from typing import Optional, Dict, Any, Iterator, Union, List from threading import Lock @@ -191,13 +192,13 @@ class SocketClient: if request_id and request_id in self._pending: await self._pending[request_id].put(response) - except websockets.exceptions.ConnectionClosed: + except ConnectionClosed: pass except Exception as e: for queue in self._pending.values(): try: await queue.put({"error": str(e)}) - except: + except Exception: pass finally: self._connected = False @@ -250,7 +251,7 @@ class SocketClient: finally: try: loop.run_until_complete(async_gen.aclose()) - except: + except Exception: pass def _streaming_generator_raw( @@ -273,7 +274,7 @@ class SocketClient: finally: try: loop.run_until_complete(async_gen.aclose()) - except: + except Exception: pass async def _send_request_async_streaming_raw( @@ -542,7 +543,7 @@ class SocketClient: if self._loop and not self._loop.is_closed(): try: self._loop.run_until_complete(self._close_async()) - except: + except Exception: pass async def _close_async(self):