mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
feat: standardize LLM rate-limiting and exception handling (#835)
- HTTP 429 translates to TooManyRequests (retryable) - HTTP 503 translates to LlmError
This commit is contained in:
parent
e7efb673ef
commit
0ef49ab6ae
5 changed files with 170 additions and 38 deletions
|
|
@ -10,7 +10,7 @@ import pytest
|
|||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
from unittest import IsolatedAsyncioTestCase
|
||||
|
||||
from trustgraph.exceptions import TooManyRequests
|
||||
from trustgraph.exceptions import TooManyRequests, LlmError
|
||||
|
||||
|
||||
class TestAzureServerless429(IsolatedAsyncioTestCase):
|
||||
|
|
@ -77,6 +77,24 @@ class TestOpenAIRateLimit(IsolatedAsyncioTestCase):
|
|||
with pytest.raises(TooManyRequests):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
|
||||
async def test_503_raises_llm_error(self, _llm, _async, mock_cls):
|
||||
from openai import InternalServerError
|
||||
from trustgraph.model.text_completion.openai.llm import Processor
|
||||
mock_client = MagicMock()
|
||||
mock_cls.return_value = mock_client
|
||||
proc = Processor(
|
||||
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
|
||||
)
|
||||
mock_client.chat.completions.create.side_effect = InternalServerError(
|
||||
"service unavailable", response=MagicMock(), body=None
|
||||
)
|
||||
|
||||
with pytest.raises(LlmError):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
|
||||
class TestClaudeRateLimit(IsolatedAsyncioTestCase):
|
||||
"""Claude/Anthropic: anthropic.RateLimitError → TooManyRequests"""
|
||||
|
|
@ -103,32 +121,120 @@ class TestClaudeRateLimit(IsolatedAsyncioTestCase):
|
|||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
|
||||
class TestMistralRateLimit(IsolatedAsyncioTestCase):
|
||||
"""Mistral: models.SDKError (429/503) → TooManyRequests/LlmError"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.models')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
|
||||
async def test_429_raises_too_many_requests(self, _llm, _async, mock_models, mock_cls):
|
||||
from trustgraph.model.text_completion.mistral.llm import Processor
|
||||
mock_client = MagicMock()
|
||||
mock_cls.return_value = mock_client
|
||||
proc = Processor(
|
||||
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
|
||||
)
|
||||
|
||||
# Define a mock exception class
|
||||
mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 429})
|
||||
mock_client.chat.complete.side_effect = mock_models.SDKError()
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
|
||||
@patch('trustgraph.model.text_completion.mistral.llm.models')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
|
||||
async def test_503_raises_llm_error(self, _llm, _async, mock_models, mock_cls):
|
||||
from trustgraph.model.text_completion.mistral.llm import Processor
|
||||
mock_client = MagicMock()
|
||||
mock_cls.return_value = mock_client
|
||||
proc = Processor(
|
||||
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
|
||||
)
|
||||
|
||||
mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 503})
|
||||
mock_client.chat.complete.side_effect = mock_models.SDKError()
|
||||
|
||||
with pytest.raises(LlmError):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
|
||||
class TestCohereRateLimit(IsolatedAsyncioTestCase):
|
||||
"""Cohere: cohere.TooManyRequestsError → TooManyRequests"""
|
||||
"""Cohere: cohere.errors (429/503) → TooManyRequests/LlmError"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
|
||||
async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere):
|
||||
from trustgraph.model.text_completion.cohere.llm import Processor
|
||||
|
||||
import trustgraph.model.text_completion.cohere.llm as cohere_llm
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_cohere.Client.return_value = mock_client
|
||||
|
||||
proc = Processor(
|
||||
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
|
||||
)
|
||||
|
||||
ErrorCls = type("TooManyRequestsError", (Exception,), {})
|
||||
with patch.object(cohere_llm, 'TooManyRequestsError', ErrorCls):
|
||||
mock_client.chat.side_effect = ErrorCls()
|
||||
with pytest.raises(TooManyRequests):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
mock_cohere.TooManyRequestsError = type(
|
||||
"TooManyRequestsError", (Exception,), {}
|
||||
)
|
||||
mock_client.chat.side_effect = mock_cohere.TooManyRequestsError(
|
||||
"rate limited"
|
||||
@patch('trustgraph.model.text_completion.cohere.llm.cohere')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
|
||||
async def test_503_raises_llm_error(self, _llm, _async, mock_cohere):
|
||||
from trustgraph.model.text_completion.cohere.llm import Processor
|
||||
import trustgraph.model.text_completion.cohere.llm as cohere_llm
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_cohere.Client.return_value = mock_client
|
||||
proc = Processor(
|
||||
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
|
||||
)
|
||||
|
||||
ErrorCls = type("ServiceUnavailableError", (Exception,), {})
|
||||
with patch.object(cohere_llm, 'ServiceUnavailableError', ErrorCls):
|
||||
mock_client.chat.side_effect = ErrorCls()
|
||||
with pytest.raises(LlmError):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
|
||||
class TestVllmRateLimit(IsolatedAsyncioTestCase):
|
||||
"""vLLM: HTTP 429/503 → TooManyRequests/LlmError"""
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
|
||||
async def test_429_raises_too_many_requests(self, _llm, _async, mock_session):
|
||||
from trustgraph.model.text_completion.vllm.llm import Processor
|
||||
proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t")
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 429
|
||||
mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp
|
||||
|
||||
with pytest.raises(TooManyRequests):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
|
||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
|
||||
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
|
||||
async def test_503_raises_llm_error(self, _llm, _async, mock_session):
|
||||
from trustgraph.model.text_completion.vllm.llm import Processor
|
||||
proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t")
|
||||
|
||||
mock_resp = AsyncMock()
|
||||
mock_resp.status = 503
|
||||
mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp
|
||||
|
||||
with pytest.raises(LlmError):
|
||||
await proc.generate_content("sys", "prompt")
|
||||
|
||||
|
||||
class TestClientSideRateLimitTranslation:
|
||||
"""Client base class: error type 'too-many-requests' → TooManyRequests"""
|
||||
|
|
|
|||
|
|
@ -5,6 +5,7 @@ Input is prompt, output is response.
|
|||
"""
|
||||
|
||||
import cohere
|
||||
from cohere.errors import TooManyRequestsError, ServiceUnavailableError
|
||||
from prometheus_client import Histogram
|
||||
import os
|
||||
import logging
|
||||
|
|
@ -12,7 +13,7 @@ import logging
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... exceptions import TooManyRequests, LlmError
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
default_ident = "text-completion"
|
||||
|
|
@ -84,13 +85,14 @@ class Processor(LlmService):
|
|||
|
||||
return resp
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except cohere.TooManyRequestsError:
|
||||
|
||||
except TooManyRequestsError:
|
||||
# Leave rate limit retries to the base handler
|
||||
raise TooManyRequests()
|
||||
|
||||
except ServiceUnavailableError:
|
||||
# Treat 503 as a retryable LlmError
|
||||
raise LlmError()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||
|
|
@ -152,10 +154,14 @@ class Processor(LlmService):
|
|||
|
||||
logger.debug("Streaming complete")
|
||||
|
||||
except cohere.TooManyRequestsError:
|
||||
except TooManyRequestsError:
|
||||
logger.warning("Rate limit exceeded during streaming")
|
||||
raise TooManyRequests()
|
||||
|
||||
except ServiceUnavailableError:
|
||||
logger.warning("Service unavailable during streaming")
|
||||
raise LlmError()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Cohere streaming exception ({type(e).__name__}): {e}", exc_info=True)
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -4,14 +4,14 @@ Simple LLM service, performs text prompt completion using Mistral.
|
|||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from mistralai import Mistral
|
||||
from mistralai import Mistral, models
|
||||
import os
|
||||
import logging
|
||||
|
||||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... exceptions import TooManyRequests, LlmError
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
default_ident = "text-completion"
|
||||
|
|
@ -100,18 +100,14 @@ class Processor(LlmService):
|
|||
|
||||
return resp
|
||||
|
||||
# FIXME: Wrong exception. The MistralAI library has retry logic
|
||||
# so retry-able errors are retried transparently. It means we
|
||||
# don't get rate limit events.
|
||||
|
||||
# We could choose to turn off retry and handle all that here
|
||||
# or subclass BackoffStrategy to keep the retry logic, but
|
||||
# get the events out.
|
||||
|
||||
# except Mistral.RateLimitError:
|
||||
|
||||
# # Leave rate limit retries to the base handler
|
||||
# raise TooManyRequests()
|
||||
except models.SDKError as e:
|
||||
if e.status_code == 429:
|
||||
# Leave rate limit retries to the base handler
|
||||
raise TooManyRequests()
|
||||
elif e.status_code == 503:
|
||||
# Treat 503 as a retryable LlmError
|
||||
raise LlmError()
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
@ -185,8 +181,13 @@ class Processor(LlmService):
|
|||
|
||||
logger.debug("Streaming complete")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Mistral streaming exception ({type(e).__name__}): {e}", exc_info=True)
|
||||
except models.SDKError as e:
|
||||
if e.status_code == 429:
|
||||
logger.warning("Hit rate limit during streaming")
|
||||
raise TooManyRequests()
|
||||
elif e.status_code == 503:
|
||||
logger.warning("Hit internal server error during streaming")
|
||||
raise LlmError()
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
|
|
|
|||
|
|
@ -4,11 +4,11 @@ Simple LLM service, performs text prompt completion using OpenAI.
|
|||
Input is prompt, output is response.
|
||||
"""
|
||||
|
||||
from openai import OpenAI, RateLimitError
|
||||
from openai import OpenAI, RateLimitError, InternalServerError
|
||||
import os
|
||||
import logging
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... exceptions import TooManyRequests, LlmError
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
# Module logger
|
||||
|
|
@ -104,13 +104,14 @@ class Processor(LlmService):
|
|||
|
||||
return resp
|
||||
|
||||
# FIXME: Wrong exception, don't know what this LLM throws
|
||||
# for a rate limit
|
||||
except RateLimitError:
|
||||
|
||||
# Leave rate limit retries to the base handler
|
||||
raise TooManyRequests()
|
||||
|
||||
except InternalServerError:
|
||||
# Treat 503 as a retryable LlmError
|
||||
raise LlmError()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
# Apart from rate limits, treat all exceptions as unrecoverable
|
||||
|
|
@ -191,6 +192,10 @@ class Processor(LlmService):
|
|||
logger.warning("Hit rate limit during streaming")
|
||||
raise TooManyRequests()
|
||||
|
||||
except InternalServerError:
|
||||
logger.warning("Hit internal server error during streaming")
|
||||
raise LlmError()
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"OpenAI streaming exception ({type(e).__name__}): {e}", exc_info=True)
|
||||
raise e
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ import logging
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... exceptions import TooManyRequests, LlmError
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
default_ident = "text-completion"
|
||||
|
|
@ -83,6 +83,10 @@ class Processor(LlmService):
|
|||
json=request,
|
||||
) as response:
|
||||
|
||||
if response.status == 429:
|
||||
raise TooManyRequests()
|
||||
if response.status == 503:
|
||||
raise LlmError()
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Bad status: " + str(response.status))
|
||||
|
||||
|
|
@ -104,7 +108,13 @@ class Processor(LlmService):
|
|||
|
||||
return resp
|
||||
|
||||
# FIXME: Assuming vLLM won't produce rate limits?
|
||||
except TooManyRequests:
|
||||
# Leave rate limit retries to the base handler
|
||||
raise TooManyRequests()
|
||||
|
||||
except LlmError:
|
||||
# Treat 503 as a retryable LlmError
|
||||
raise LlmError()
|
||||
|
||||
except Exception as e:
|
||||
|
||||
|
|
@ -150,6 +160,10 @@ class Processor(LlmService):
|
|||
json=request,
|
||||
) as response:
|
||||
|
||||
if response.status == 429:
|
||||
raise TooManyRequests()
|
||||
if response.status == 503:
|
||||
raise LlmError()
|
||||
if response.status != 200:
|
||||
raise RuntimeError("Bad status: " + str(response.status))
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue