feat: standardize LLM rate-limiting and exception handling (#835)

- HTTP 429 translates to TooManyRequests (retryable)
- HTTP 503 translates to LlmError
This commit is contained in:
Het Patel 2026-04-21 20:45:11 +05:30 committed by GitHub
parent e7efb673ef
commit 0ef49ab6ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 170 additions and 38 deletions

View file

@ -10,7 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase
from trustgraph.exceptions import TooManyRequests
from trustgraph.exceptions import TooManyRequests, LlmError
class TestAzureServerless429(IsolatedAsyncioTestCase):
@ -77,6 +77,24 @@ class TestOpenAIRateLimit(IsolatedAsyncioTestCase):
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_503_raises_llm_error(self, _llm, _async, mock_cls):
from openai import InternalServerError
from trustgraph.model.text_completion.openai.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_client.chat.completions.create.side_effect = InternalServerError(
"service unavailable", response=MagicMock(), body=None
)
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestClaudeRateLimit(IsolatedAsyncioTestCase):
"""Claude/Anthropic: anthropic.RateLimitError → TooManyRequests"""
@ -103,32 +121,120 @@ class TestClaudeRateLimit(IsolatedAsyncioTestCase):
await proc.generate_content("sys", "prompt")
class TestMistralRateLimit(IsolatedAsyncioTestCase):
"""Mistral: models.SDKError (429/503) → TooManyRequests/LlmError"""
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.model.text_completion.mistral.llm.models')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_429_raises_too_many_requests(self, _llm, _async, mock_models, mock_cls):
from trustgraph.model.text_completion.mistral.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
# Define a mock exception class
mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 429})
mock_client.chat.complete.side_effect = mock_models.SDKError()
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.model.text_completion.mistral.llm.models')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_503_raises_llm_error(self, _llm, _async, mock_models, mock_cls):
from trustgraph.model.text_completion.mistral.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 503})
mock_client.chat.complete.side_effect = mock_models.SDKError()
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestCohereRateLimit(IsolatedAsyncioTestCase):
"""Cohere: cohere.TooManyRequestsError → TooManyRequests"""
"""Cohere: cohere.errors (429/503) → TooManyRequests/LlmError"""
@patch('trustgraph.model.text_completion.cohere.llm.cohere')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere):
from trustgraph.model.text_completion.cohere.llm import Processor
import trustgraph.model.text_completion.cohere.llm as cohere_llm
mock_client = MagicMock()
mock_cohere.Client.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
ErrorCls = type("TooManyRequestsError", (Exception,), {})
with patch.object(cohere_llm, 'TooManyRequestsError', ErrorCls):
mock_client.chat.side_effect = ErrorCls()
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
mock_cohere.TooManyRequestsError = type(
"TooManyRequestsError", (Exception,), {}
)
mock_client.chat.side_effect = mock_cohere.TooManyRequestsError(
"rate limited"
@patch('trustgraph.model.text_completion.cohere.llm.cohere')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_503_raises_llm_error(self, _llm, _async, mock_cohere):
from trustgraph.model.text_completion.cohere.llm import Processor
import trustgraph.model.text_completion.cohere.llm as cohere_llm
mock_client = MagicMock()
mock_cohere.Client.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
ErrorCls = type("ServiceUnavailableError", (Exception,), {})
with patch.object(cohere_llm, 'ServiceUnavailableError', ErrorCls):
mock_client.chat.side_effect = ErrorCls()
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestVllmRateLimit(IsolatedAsyncioTestCase):
"""vLLM: HTTP 429/503 → TooManyRequests/LlmError"""
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_429_raises_too_many_requests(self, _llm, _async, mock_session):
from trustgraph.model.text_completion.vllm.llm import Processor
proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t")
mock_resp = AsyncMock()
mock_resp.status = 429
mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_503_raises_llm_error(self, _llm, _async, mock_session):
from trustgraph.model.text_completion.vllm.llm import Processor
proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t")
mock_resp = AsyncMock()
mock_resp.status = 503
mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestClientSideRateLimitTranslation:
"""Client base class: error type 'too-many-requests' → TooManyRequests"""

View file

@ -5,6 +5,7 @@ Input is prompt, output is response.
"""
import cohere
from cohere.errors import TooManyRequestsError, ServiceUnavailableError
from prometheus_client import Histogram
import os
import logging
@ -12,7 +13,7 @@ import logging
# Module logger
logger = logging.getLogger(__name__)
from .... exceptions import TooManyRequests
from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion"
@ -84,13 +85,14 @@ class Processor(LlmService):
return resp
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except cohere.TooManyRequestsError:
except TooManyRequestsError:
# Leave rate limit retries to the base handler
raise TooManyRequests()
except ServiceUnavailableError:
# Treat 503 as a retryable LlmError
raise LlmError()
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
@ -152,10 +154,14 @@ class Processor(LlmService):
logger.debug("Streaming complete")
except cohere.TooManyRequestsError:
except TooManyRequestsError:
logger.warning("Rate limit exceeded during streaming")
raise TooManyRequests()
except ServiceUnavailableError:
logger.warning("Service unavailable during streaming")
raise LlmError()
except Exception as e:
logger.error(f"Cohere streaming exception ({type(e).__name__}): {e}", exc_info=True)
raise e

View file

@ -4,14 +4,14 @@ Simple LLM service, performs text prompt completion using Mistral.
Input is prompt, output is response.
"""
from mistralai import Mistral
from mistralai import Mistral, models
import os
import logging
# Module logger
logger = logging.getLogger(__name__)
from .... exceptions import TooManyRequests
from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion"
@ -100,18 +100,14 @@ class Processor(LlmService):
return resp
# FIXME: Wrong exception. The MistralAI library has retry logic
# so retry-able errors are retried transparently. It means we
# don't get rate limit events.
# We could choose to turn off retry and handle all that here
# or subclass BackoffStrategy to keep the retry logic, but
# get the events out.
# except Mistral.RateLimitError:
# # Leave rate limit retries to the base handler
# raise TooManyRequests()
except models.SDKError as e:
if e.status_code == 429:
# Leave rate limit retries to the base handler
raise TooManyRequests()
elif e.status_code == 503:
# Treat 503 as a retryable LlmError
raise LlmError()
raise e
except Exception as e:
@ -185,8 +181,13 @@ class Processor(LlmService):
logger.debug("Streaming complete")
except Exception as e:
logger.error(f"Mistral streaming exception ({type(e).__name__}): {e}", exc_info=True)
except models.SDKError as e:
if e.status_code == 429:
logger.warning("Hit rate limit during streaming")
raise TooManyRequests()
elif e.status_code == 503:
logger.warning("Hit internal server error during streaming")
raise LlmError()
raise e
@staticmethod

View file

@ -4,11 +4,11 @@ Simple LLM service, performs text prompt completion using OpenAI.
Input is prompt, output is response.
"""
from openai import OpenAI, RateLimitError
from openai import OpenAI, RateLimitError, InternalServerError
import os
import logging
from .... exceptions import TooManyRequests
from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk
# Module logger
@ -104,13 +104,14 @@ class Processor(LlmService):
return resp
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except RateLimitError:
# Leave rate limit retries to the base handler
raise TooManyRequests()
except InternalServerError:
# Treat 503 as a retryable LlmError
raise LlmError()
except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable
@ -191,6 +192,10 @@ class Processor(LlmService):
logger.warning("Hit rate limit during streaming")
raise TooManyRequests()
except InternalServerError:
logger.warning("Hit internal server error during streaming")
raise LlmError()
except Exception as e:
logger.error(f"OpenAI streaming exception ({type(e).__name__}): {e}", exc_info=True)
raise e

View file

@ -11,7 +11,7 @@ import logging
# Module logger
logger = logging.getLogger(__name__)
from .... exceptions import TooManyRequests
from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion"
@ -83,6 +83,10 @@ class Processor(LlmService):
json=request,
) as response:
if response.status == 429:
raise TooManyRequests()
if response.status == 503:
raise LlmError()
if response.status != 200:
raise RuntimeError("Bad status: " + str(response.status))
@ -104,7 +108,13 @@ class Processor(LlmService):
return resp
# FIXME: Assuming vLLM won't produce rate limits?
except TooManyRequests:
# Leave rate limit retries to the base handler
raise TooManyRequests()
except LlmError:
# Treat 503 as a retryable LlmError
raise LlmError()
except Exception as e:
@ -150,6 +160,10 @@ class Processor(LlmService):
json=request,
) as response:
if response.status == 429:
raise TooManyRequests()
if response.status == 503:
raise LlmError()
if response.status != 200:
raise RuntimeError("Bad status: " + str(response.status))