Merge pull request #955 from jmolz/fix-socket-client-bare-excepts

fix: avoid socket client bare excepts
This commit is contained in:
Jacob Molz 2026-05-27 08:16:33 -04:00 committed by GitHub
parent 00dd7a4e14
commit eb24d0c60e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
2 changed files with 83 additions and 5 deletions

View file

@ -8,6 +8,7 @@ import pytest
from unittest.mock import Mock, patch, MagicMock, call
import json
from trustgraph.api.socket_client import SocketClient
from trustgraph.api import (
Api,
Triple,
@ -222,6 +223,82 @@ class TestSocketClient:
for method in expected_methods:
assert hasattr(flow_instance, method), f"Missing method: {method}"
def test_socket_client_close_does_not_swallow_base_exceptions(self):
"""Test close cleanup does not suppress process-level interrupts."""
class InterruptingLoop:
def is_closed(self):
return False
def run_until_complete(self, awaitable):
if hasattr(awaitable, "close"):
awaitable.close()
raise SystemExit("stop")
socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._loop = InterruptingLoop()
with pytest.raises(SystemExit):
socket.close()
@pytest.mark.parametrize(
("generator_method", "async_method"),
[
("_streaming_generator", "_send_request_async_streaming"),
("_streaming_generator_raw", "_send_request_async_streaming_raw"),
],
)
def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions(
self, generator_method, async_method
):
"""Test streaming cleanup does not suppress process-level interrupts."""
class FakeAsyncGenerator:
def __anext__(self):
return "next"
def aclose(self):
return "close"
class InterruptingLoop:
def run_until_complete(self, awaitable):
if awaitable == "next":
raise StopAsyncIteration
if awaitable == "close":
raise SystemExit("stop")
raise AssertionError(f"unexpected awaitable: {awaitable!r}")
socket = SocketClient(url="http://test/", timeout=60, token=None)
setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator())
generator = getattr(socket, generator_method)(
"agent", "default", {}, InterruptingLoop()
)
with pytest.raises(SystemExit):
next(generator)
@pytest.mark.asyncio
async def test_socket_client_reader_does_not_swallow_base_exceptions(self):
"""Test reader error fanout does not suppress process-level interrupts."""
class FailingSocket:
def __aiter__(self):
return self
async def __anext__(self):
raise ValueError("reader failed")
class InterruptingQueue:
async def put(self, message):
raise SystemExit("stop")
socket = SocketClient(url="http://test/", timeout=60, token=None)
socket._socket = FailingSocket()
socket._pending = {"req-1": InterruptingQueue()}
with pytest.raises(SystemExit):
await socket._reader()
class TestBulkClient:
"""Test bulk operations client"""

View file

@ -11,6 +11,7 @@ multiplexes requests by ID.
import json
import asyncio
import websockets
from websockets.exceptions import ConnectionClosed
from typing import Optional, Dict, Any, Iterator, Union, List
from threading import Lock
@ -191,13 +192,13 @@ class SocketClient:
if request_id and request_id in self._pending:
await self._pending[request_id].put(response)
except websockets.exceptions.ConnectionClosed:
except ConnectionClosed:
pass
except Exception as e:
for queue in self._pending.values():
try:
await queue.put({"error": str(e)})
except:
except Exception:
pass
finally:
self._connected = False
@ -250,7 +251,7 @@ class SocketClient:
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
except Exception:
pass
def _streaming_generator_raw(
@ -273,7 +274,7 @@ class SocketClient:
finally:
try:
loop.run_until_complete(async_gen.aclose())
except:
except Exception:
pass
async def _send_request_async_streaming_raw(
@ -542,7 +543,7 @@ class SocketClient:
if self._loop and not self._loop.is_closed():
try:
self._loop.run_until_complete(self._close_async())
except:
except Exception:
pass
async def _close_async(self):