mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-30 09:45:13 +02:00
Merge pull request #955 from jmolz/fix-socket-client-bare-excepts
fix: avoid socket client bare excepts
This commit is contained in:
parent
00dd7a4e14
commit
eb24d0c60e
2 changed files with 83 additions and 5 deletions
|
|
@ -8,6 +8,7 @@ import pytest
|
||||||
from unittest.mock import Mock, patch, MagicMock, call
|
from unittest.mock import Mock, patch, MagicMock, call
|
||||||
import json
|
import json
|
||||||
|
|
||||||
|
from trustgraph.api.socket_client import SocketClient
|
||||||
from trustgraph.api import (
|
from trustgraph.api import (
|
||||||
Api,
|
Api,
|
||||||
Triple,
|
Triple,
|
||||||
|
|
@ -222,6 +223,82 @@ class TestSocketClient:
|
||||||
for method in expected_methods:
|
for method in expected_methods:
|
||||||
assert hasattr(flow_instance, method), f"Missing method: {method}"
|
assert hasattr(flow_instance, method), f"Missing method: {method}"
|
||||||
|
|
||||||
|
def test_socket_client_close_does_not_swallow_base_exceptions(self):
|
||||||
|
"""Test close cleanup does not suppress process-level interrupts."""
|
||||||
|
|
||||||
|
class InterruptingLoop:
|
||||||
|
def is_closed(self):
|
||||||
|
return False
|
||||||
|
|
||||||
|
def run_until_complete(self, awaitable):
|
||||||
|
if hasattr(awaitable, "close"):
|
||||||
|
awaitable.close()
|
||||||
|
raise SystemExit("stop")
|
||||||
|
|
||||||
|
socket = SocketClient(url="http://test/", timeout=60, token=None)
|
||||||
|
socket._loop = InterruptingLoop()
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
socket.close()
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
("generator_method", "async_method"),
|
||||||
|
[
|
||||||
|
("_streaming_generator", "_send_request_async_streaming"),
|
||||||
|
("_streaming_generator_raw", "_send_request_async_streaming_raw"),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_socket_client_streaming_cleanup_does_not_swallow_base_exceptions(
|
||||||
|
self, generator_method, async_method
|
||||||
|
):
|
||||||
|
"""Test streaming cleanup does not suppress process-level interrupts."""
|
||||||
|
|
||||||
|
class FakeAsyncGenerator:
|
||||||
|
def __anext__(self):
|
||||||
|
return "next"
|
||||||
|
|
||||||
|
def aclose(self):
|
||||||
|
return "close"
|
||||||
|
|
||||||
|
class InterruptingLoop:
|
||||||
|
def run_until_complete(self, awaitable):
|
||||||
|
if awaitable == "next":
|
||||||
|
raise StopAsyncIteration
|
||||||
|
if awaitable == "close":
|
||||||
|
raise SystemExit("stop")
|
||||||
|
raise AssertionError(f"unexpected awaitable: {awaitable!r}")
|
||||||
|
|
||||||
|
socket = SocketClient(url="http://test/", timeout=60, token=None)
|
||||||
|
setattr(socket, async_method, lambda *args, **kwargs: FakeAsyncGenerator())
|
||||||
|
generator = getattr(socket, generator_method)(
|
||||||
|
"agent", "default", {}, InterruptingLoop()
|
||||||
|
)
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
next(generator)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_socket_client_reader_does_not_swallow_base_exceptions(self):
|
||||||
|
"""Test reader error fanout does not suppress process-level interrupts."""
|
||||||
|
|
||||||
|
class FailingSocket:
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
raise ValueError("reader failed")
|
||||||
|
|
||||||
|
class InterruptingQueue:
|
||||||
|
async def put(self, message):
|
||||||
|
raise SystemExit("stop")
|
||||||
|
|
||||||
|
socket = SocketClient(url="http://test/", timeout=60, token=None)
|
||||||
|
socket._socket = FailingSocket()
|
||||||
|
socket._pending = {"req-1": InterruptingQueue()}
|
||||||
|
|
||||||
|
with pytest.raises(SystemExit):
|
||||||
|
await socket._reader()
|
||||||
|
|
||||||
|
|
||||||
class TestBulkClient:
|
class TestBulkClient:
|
||||||
"""Test bulk operations client"""
|
"""Test bulk operations client"""
|
||||||
|
|
|
||||||
|
|
@ -11,6 +11,7 @@ multiplexes requests by ID.
|
||||||
import json
|
import json
|
||||||
import asyncio
|
import asyncio
|
||||||
import websockets
|
import websockets
|
||||||
|
from websockets.exceptions import ConnectionClosed
|
||||||
from typing import Optional, Dict, Any, Iterator, Union, List
|
from typing import Optional, Dict, Any, Iterator, Union, List
|
||||||
from threading import Lock
|
from threading import Lock
|
||||||
|
|
||||||
|
|
@ -191,13 +192,13 @@ class SocketClient:
|
||||||
if request_id and request_id in self._pending:
|
if request_id and request_id in self._pending:
|
||||||
await self._pending[request_id].put(response)
|
await self._pending[request_id].put(response)
|
||||||
|
|
||||||
except websockets.exceptions.ConnectionClosed:
|
except ConnectionClosed:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
for queue in self._pending.values():
|
for queue in self._pending.values():
|
||||||
try:
|
try:
|
||||||
await queue.put({"error": str(e)})
|
await queue.put({"error": str(e)})
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
finally:
|
finally:
|
||||||
self._connected = False
|
self._connected = False
|
||||||
|
|
@ -250,7 +251,7 @@ class SocketClient:
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(async_gen.aclose())
|
loop.run_until_complete(async_gen.aclose())
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _streaming_generator_raw(
|
def _streaming_generator_raw(
|
||||||
|
|
@ -273,7 +274,7 @@ class SocketClient:
|
||||||
finally:
|
finally:
|
||||||
try:
|
try:
|
||||||
loop.run_until_complete(async_gen.aclose())
|
loop.run_until_complete(async_gen.aclose())
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _send_request_async_streaming_raw(
|
async def _send_request_async_streaming_raw(
|
||||||
|
|
@ -542,7 +543,7 @@ class SocketClient:
|
||||||
if self._loop and not self._loop.is_closed():
|
if self._loop and not self._loop.is_closed():
|
||||||
try:
|
try:
|
||||||
self._loop.run_until_complete(self._close_async())
|
self._loop.run_until_complete(self._close_async())
|
||||||
except:
|
except Exception:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def _close_async(self):
|
async def _close_async(self):
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue