feat: standardize LLM rate-limiting and exception handling (#835)

- HTTP 429 translates to TooManyRequests (retryable)
- HTTP 503 translates to LlmError
This commit is contained in:
Het Patel 2026-04-21 20:45:11 +05:30 committed by GitHub
parent e7efb673ef
commit 0ef49ab6ae
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 170 additions and 38 deletions

View file

@ -10,7 +10,7 @@ import pytest
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
from unittest import IsolatedAsyncioTestCase from unittest import IsolatedAsyncioTestCase
from trustgraph.exceptions import TooManyRequests from trustgraph.exceptions import TooManyRequests, LlmError
class TestAzureServerless429(IsolatedAsyncioTestCase): class TestAzureServerless429(IsolatedAsyncioTestCase):
@ -77,6 +77,24 @@ class TestOpenAIRateLimit(IsolatedAsyncioTestCase):
with pytest.raises(TooManyRequests): with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt") await proc.generate_content("sys", "prompt")
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_503_raises_llm_error(self, _llm, _async, mock_cls):
from openai import InternalServerError
from trustgraph.model.text_completion.openai.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_client.chat.completions.create.side_effect = InternalServerError(
"service unavailable", response=MagicMock(), body=None
)
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestClaudeRateLimit(IsolatedAsyncioTestCase): class TestClaudeRateLimit(IsolatedAsyncioTestCase):
"""Claude/Anthropic: anthropic.RateLimitError → TooManyRequests""" """Claude/Anthropic: anthropic.RateLimitError → TooManyRequests"""
@ -103,32 +121,120 @@ class TestClaudeRateLimit(IsolatedAsyncioTestCase):
await proc.generate_content("sys", "prompt") await proc.generate_content("sys", "prompt")
class TestMistralRateLimit(IsolatedAsyncioTestCase):
"""Mistral: models.SDKError (429/503) → TooManyRequests/LlmError"""
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.model.text_completion.mistral.llm.models')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_429_raises_too_many_requests(self, _llm, _async, mock_models, mock_cls):
from trustgraph.model.text_completion.mistral.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
# Define a mock exception class
mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 429})
mock_client.chat.complete.side_effect = mock_models.SDKError()
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
@patch('trustgraph.model.text_completion.mistral.llm.Mistral')
@patch('trustgraph.model.text_completion.mistral.llm.models')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_503_raises_llm_error(self, _llm, _async, mock_models, mock_cls):
from trustgraph.model.text_completion.mistral.llm import Processor
mock_client = MagicMock()
mock_cls.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
)
mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 503})
mock_client.chat.complete.side_effect = mock_models.SDKError()
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestCohereRateLimit(IsolatedAsyncioTestCase): class TestCohereRateLimit(IsolatedAsyncioTestCase):
"""Cohere: cohere.TooManyRequestsError → TooManyRequests""" """Cohere: cohere.errors (429/503) → TooManyRequests/LlmError"""
@patch('trustgraph.model.text_completion.cohere.llm.cohere') @patch('trustgraph.model.text_completion.cohere.llm.cohere')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere): async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere):
from trustgraph.model.text_completion.cohere.llm import Processor from trustgraph.model.text_completion.cohere.llm import Processor
import trustgraph.model.text_completion.cohere.llm as cohere_llm
mock_client = MagicMock() mock_client = MagicMock()
mock_cohere.Client.return_value = mock_client mock_cohere.Client.return_value = mock_client
proc = Processor( proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
) )
ErrorCls = type("TooManyRequestsError", (Exception,), {})
with patch.object(cohere_llm, 'TooManyRequestsError', ErrorCls):
mock_client.chat.side_effect = ErrorCls()
with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt")
mock_cohere.TooManyRequestsError = type( @patch('trustgraph.model.text_completion.cohere.llm.cohere')
"TooManyRequestsError", (Exception,), {} @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
) @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
mock_client.chat.side_effect = mock_cohere.TooManyRequestsError( async def test_503_raises_llm_error(self, _llm, _async, mock_cohere):
"rate limited" from trustgraph.model.text_completion.cohere.llm import Processor
import trustgraph.model.text_completion.cohere.llm as cohere_llm
mock_client = MagicMock()
mock_cohere.Client.return_value = mock_client
proc = Processor(
api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t",
) )
ErrorCls = type("ServiceUnavailableError", (Exception,), {})
with patch.object(cohere_llm, 'ServiceUnavailableError', ErrorCls):
mock_client.chat.side_effect = ErrorCls()
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestVllmRateLimit(IsolatedAsyncioTestCase):
"""vLLM: HTTP 429/503 → TooManyRequests/LlmError"""
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_429_raises_too_many_requests(self, _llm, _async, mock_session):
from trustgraph.model.text_completion.vllm.llm import Processor
proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t")
mock_resp = AsyncMock()
mock_resp.status = 429
mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp
with pytest.raises(TooManyRequests): with pytest.raises(TooManyRequests):
await proc.generate_content("sys", "prompt") await proc.generate_content("sys", "prompt")
@patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession')
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None)
@patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None)
async def test_503_raises_llm_error(self, _llm, _async, mock_session):
from trustgraph.model.text_completion.vllm.llm import Processor
proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t")
mock_resp = AsyncMock()
mock_resp.status = 503
mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp
with pytest.raises(LlmError):
await proc.generate_content("sys", "prompt")
class TestClientSideRateLimitTranslation: class TestClientSideRateLimitTranslation:
"""Client base class: error type 'too-many-requests' → TooManyRequests""" """Client base class: error type 'too-many-requests' → TooManyRequests"""

View file

@ -5,6 +5,7 @@ Input is prompt, output is response.
""" """
import cohere import cohere
from cohere.errors import TooManyRequestsError, ServiceUnavailableError
from prometheus_client import Histogram from prometheus_client import Histogram
import os import os
import logging import logging
@ -12,7 +13,7 @@ import logging
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion" default_ident = "text-completion"
@ -84,13 +85,14 @@ class Processor(LlmService):
return resp return resp
# FIXME: Wrong exception, don't know what this LLM throws except TooManyRequestsError:
# for a rate limit
except cohere.TooManyRequestsError:
# Leave rate limit retries to the base handler # Leave rate limit retries to the base handler
raise TooManyRequests() raise TooManyRequests()
except ServiceUnavailableError:
# Treat 503 as a retryable LlmError
raise LlmError()
except Exception as e: except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable # Apart from rate limits, treat all exceptions as unrecoverable
@ -152,10 +154,14 @@ class Processor(LlmService):
logger.debug("Streaming complete") logger.debug("Streaming complete")
except cohere.TooManyRequestsError: except TooManyRequestsError:
logger.warning("Rate limit exceeded during streaming") logger.warning("Rate limit exceeded during streaming")
raise TooManyRequests() raise TooManyRequests()
except ServiceUnavailableError:
logger.warning("Service unavailable during streaming")
raise LlmError()
except Exception as e: except Exception as e:
logger.error(f"Cohere streaming exception ({type(e).__name__}): {e}", exc_info=True) logger.error(f"Cohere streaming exception ({type(e).__name__}): {e}", exc_info=True)
raise e raise e

View file

@ -4,14 +4,14 @@ Simple LLM service, performs text prompt completion using Mistral.
Input is prompt, output is response. Input is prompt, output is response.
""" """
from mistralai import Mistral from mistralai import Mistral, models
import os import os
import logging import logging
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion" default_ident = "text-completion"
@ -100,18 +100,14 @@ class Processor(LlmService):
return resp return resp
# FIXME: Wrong exception. The MistralAI library has retry logic except models.SDKError as e:
# so retry-able errors are retried transparently. It means we if e.status_code == 429:
# don't get rate limit events. # Leave rate limit retries to the base handler
raise TooManyRequests()
# We could choose to turn off retry and handle all that here elif e.status_code == 503:
# or subclass BackoffStrategy to keep the retry logic, but # Treat 503 as a retryable LlmError
# get the events out. raise LlmError()
raise e
# except Mistral.RateLimitError:
# # Leave rate limit retries to the base handler
# raise TooManyRequests()
except Exception as e: except Exception as e:
@ -185,8 +181,13 @@ class Processor(LlmService):
logger.debug("Streaming complete") logger.debug("Streaming complete")
except Exception as e: except models.SDKError as e:
logger.error(f"Mistral streaming exception ({type(e).__name__}): {e}", exc_info=True) if e.status_code == 429:
logger.warning("Hit rate limit during streaming")
raise TooManyRequests()
elif e.status_code == 503:
logger.warning("Hit internal server error during streaming")
raise LlmError()
raise e raise e
@staticmethod @staticmethod

View file

@ -4,11 +4,11 @@ Simple LLM service, performs text prompt completion using OpenAI.
Input is prompt, output is response. Input is prompt, output is response.
""" """
from openai import OpenAI, RateLimitError from openai import OpenAI, RateLimitError, InternalServerError
import os import os
import logging import logging
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk from .... base import LlmService, LlmResult, LlmChunk
# Module logger # Module logger
@ -104,13 +104,14 @@ class Processor(LlmService):
return resp return resp
# FIXME: Wrong exception, don't know what this LLM throws
# for a rate limit
except RateLimitError: except RateLimitError:
# Leave rate limit retries to the base handler # Leave rate limit retries to the base handler
raise TooManyRequests() raise TooManyRequests()
except InternalServerError:
# Treat 503 as a retryable LlmError
raise LlmError()
except Exception as e: except Exception as e:
# Apart from rate limits, treat all exceptions as unrecoverable # Apart from rate limits, treat all exceptions as unrecoverable
@ -191,6 +192,10 @@ class Processor(LlmService):
logger.warning("Hit rate limit during streaming") logger.warning("Hit rate limit during streaming")
raise TooManyRequests() raise TooManyRequests()
except InternalServerError:
logger.warning("Hit internal server error during streaming")
raise LlmError()
except Exception as e: except Exception as e:
logger.error(f"OpenAI streaming exception ({type(e).__name__}): {e}", exc_info=True) logger.error(f"OpenAI streaming exception ({type(e).__name__}): {e}", exc_info=True)
raise e raise e

View file

@ -11,7 +11,7 @@ import logging
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests, LlmError
from .... base import LlmService, LlmResult, LlmChunk from .... base import LlmService, LlmResult, LlmChunk
default_ident = "text-completion" default_ident = "text-completion"
@ -83,6 +83,10 @@ class Processor(LlmService):
json=request, json=request,
) as response: ) as response:
if response.status == 429:
raise TooManyRequests()
if response.status == 503:
raise LlmError()
if response.status != 200: if response.status != 200:
raise RuntimeError("Bad status: " + str(response.status)) raise RuntimeError("Bad status: " + str(response.status))
@ -104,7 +108,13 @@ class Processor(LlmService):
return resp return resp
# FIXME: Assuming vLLM won't produce rate limits? except TooManyRequests:
# Leave rate limit retries to the base handler
raise TooManyRequests()
except LlmError:
# Treat 503 as a retryable LlmError
raise LlmError()
except Exception as e: except Exception as e:
@ -150,6 +160,10 @@ class Processor(LlmService):
json=request, json=request,
) as response: ) as response:
if response.status == 429:
raise TooManyRequests()
if response.status == 503:
raise LlmError()
if response.status != 200: if response.status != 200:
raise RuntimeError("Bad status: " + str(response.status)) raise RuntimeError("Bad status: " + str(response.status))