From 0ef49ab6ae4df8aed5c50aca172f208b7361a0f2 Mon Sep 17 00:00:00 2001 From: Het Patel <102606191+CuriousHet@users.noreply.github.com> Date: Tue, 21 Apr 2026 20:45:11 +0530 Subject: [PATCH] feat: standardize LLM rate-limiting and exception handling (#835) - HTTP 429 translates to TooManyRequests (retryable) - HTTP 503 translates to LlmError --- .../test_rate_limit_contract.py | 124 ++++++++++++++++-- .../model/text_completion/cohere/llm.py | 18 ++- .../model/text_completion/mistral/llm.py | 33 ++--- .../model/text_completion/openai/llm.py | 15 ++- .../model/text_completion/vllm/llm.py | 18 ++- 5 files changed, 170 insertions(+), 38 deletions(-) diff --git a/tests/unit/test_text_completion/test_rate_limit_contract.py b/tests/unit/test_text_completion/test_rate_limit_contract.py index c9df217b..9cf00b7c 100644 --- a/tests/unit/test_text_completion/test_rate_limit_contract.py +++ b/tests/unit/test_text_completion/test_rate_limit_contract.py @@ -10,7 +10,7 @@ import pytest from unittest.mock import AsyncMock, MagicMock, patch from unittest import IsolatedAsyncioTestCase -from trustgraph.exceptions import TooManyRequests +from trustgraph.exceptions import TooManyRequests, LlmError class TestAzureServerless429(IsolatedAsyncioTestCase): @@ -77,6 +77,24 @@ class TestOpenAIRateLimit(IsolatedAsyncioTestCase): with pytest.raises(TooManyRequests): await proc.generate_content("sys", "prompt") + @patch('trustgraph.model.text_completion.openai.llm.OpenAI') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_cls): + from openai import InternalServerError + from trustgraph.model.text_completion.openai.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + mock_client.chat.completions.create.side_effect = InternalServerError( + "service unavailable", response=MagicMock(), body=None + ) + + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + class TestClaudeRateLimit(IsolatedAsyncioTestCase): """Claude/Anthropic: anthropic.RateLimitError → TooManyRequests""" @@ -103,32 +121,120 @@ class TestClaudeRateLimit(IsolatedAsyncioTestCase): await proc.generate_content("sys", "prompt") +class TestMistralRateLimit(IsolatedAsyncioTestCase): + """Mistral: models.SDKError (429/503) → TooManyRequests/LlmError""" + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.model.text_completion.mistral.llm.models') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_429_raises_too_many_requests(self, _llm, _async, mock_models, mock_cls): + from trustgraph.model.text_completion.mistral.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + # Define a mock exception class + mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 429}) + mock_client.chat.complete.side_effect = mock_models.SDKError() + + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") + + @patch('trustgraph.model.text_completion.mistral.llm.Mistral') + @patch('trustgraph.model.text_completion.mistral.llm.models') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_models, mock_cls): + from trustgraph.model.text_completion.mistral.llm import Processor + mock_client = MagicMock() + mock_cls.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", + ) + + mock_models.SDKError = type("SDKError", (Exception,), {"status_code": 503}) + mock_client.chat.complete.side_effect = mock_models.SDKError() + + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + + class TestCohereRateLimit(IsolatedAsyncioTestCase): - """Cohere: cohere.TooManyRequestsError → TooManyRequests""" + """Cohere: cohere.errors (429/503) → TooManyRequests/LlmError""" @patch('trustgraph.model.text_completion.cohere.llm.cohere') @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) async def test_rate_limit_error_raises_too_many_requests(self, _llm, _async, mock_cohere): from trustgraph.model.text_completion.cohere.llm import Processor - + import trustgraph.model.text_completion.cohere.llm as cohere_llm + mock_client = MagicMock() mock_cohere.Client.return_value = mock_client - proc = Processor( api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", ) + + ErrorCls = type("TooManyRequestsError", (Exception,), {}) + with patch.object(cohere_llm, 'TooManyRequestsError', ErrorCls): + mock_client.chat.side_effect = ErrorCls() + with pytest.raises(TooManyRequests): + await proc.generate_content("sys", "prompt") - mock_cohere.TooManyRequestsError = type( - "TooManyRequestsError", (Exception,), {} - ) - mock_client.chat.side_effect = mock_cohere.TooManyRequestsError( - "rate limited" + @patch('trustgraph.model.text_completion.cohere.llm.cohere') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_cohere): + from trustgraph.model.text_completion.cohere.llm import Processor + import trustgraph.model.text_completion.cohere.llm as cohere_llm + + mock_client = MagicMock() + mock_cohere.Client.return_value = mock_client + proc = Processor( + api_key="k", concurrency=1, taskgroup=AsyncMock(), id="t", ) + + ErrorCls = type("ServiceUnavailableError", (Exception,), {}) + with patch.object(cohere_llm, 'ServiceUnavailableError', ErrorCls): + mock_client.chat.side_effect = ErrorCls() + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + + +class TestVllmRateLimit(IsolatedAsyncioTestCase): + """vLLM: HTTP 429/503 → TooManyRequests/LlmError""" + + @patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_429_raises_too_many_requests(self, _llm, _async, mock_session): + from trustgraph.model.text_completion.vllm.llm import Processor + proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t") + + mock_resp = AsyncMock() + mock_resp.status = 429 + mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp with pytest.raises(TooManyRequests): await proc.generate_content("sys", "prompt") + @patch('trustgraph.model.text_completion.vllm.llm.aiohttp.ClientSession') + @patch('trustgraph.base.async_processor.AsyncProcessor.__init__', return_value=None) + @patch('trustgraph.base.llm_service.LlmService.__init__', return_value=None) + async def test_503_raises_llm_error(self, _llm, _async, mock_session): + from trustgraph.model.text_completion.vllm.llm import Processor + proc = Processor(concurrency=1, taskgroup=AsyncMock(), id="t") + + mock_resp = AsyncMock() + mock_resp.status = 503 + mock_session.return_value.post.return_value.__aenter__.return_value = mock_resp + + with pytest.raises(LlmError): + await proc.generate_content("sys", "prompt") + class TestClientSideRateLimitTranslation: """Client base class: error type 'too-many-requests' → TooManyRequests""" diff --git a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py index 5093e556..4190cb98 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/cohere/llm.py @@ -5,6 +5,7 @@ Input is prompt, output is response. """ import cohere +from cohere.errors import TooManyRequestsError, ServiceUnavailableError from prometheus_client import Histogram import os import logging @@ -12,7 +13,7 @@ import logging # Module logger logger = logging.getLogger(__name__) -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" @@ -84,13 +85,14 @@ class Processor(LlmService): return resp - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit - except cohere.TooManyRequestsError: - + except TooManyRequestsError: # Leave rate limit retries to the base handler raise TooManyRequests() + except ServiceUnavailableError: + # Treat 503 as a retryable LlmError + raise LlmError() + except Exception as e: # Apart from rate limits, treat all exceptions as unrecoverable @@ -152,10 +154,14 @@ class Processor(LlmService): logger.debug("Streaming complete") - except cohere.TooManyRequestsError: + except TooManyRequestsError: logger.warning("Rate limit exceeded during streaming") raise TooManyRequests() + except ServiceUnavailableError: + logger.warning("Service unavailable during streaming") + raise LlmError() + except Exception as e: logger.error(f"Cohere streaming exception ({type(e).__name__}): {e}", exc_info=True) raise e diff --git a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py index fab41ecd..e53f6f6e 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/mistral/llm.py @@ -4,14 +4,14 @@ Simple LLM service, performs text prompt completion using Mistral. Input is prompt, output is response. """ -from mistralai import Mistral +from mistralai import Mistral, models import os import logging # Module logger logger = logging.getLogger(__name__) -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" @@ -100,18 +100,14 @@ class Processor(LlmService): return resp - # FIXME: Wrong exception. The MistralAI library has retry logic - # so retry-able errors are retried transparently. It means we - # don't get rate limit events. - - # We could choose to turn off retry and handle all that here - # or subclass BackoffStrategy to keep the retry logic, but - # get the events out. - -# except Mistral.RateLimitError: - -# # Leave rate limit retries to the base handler -# raise TooManyRequests() + except models.SDKError as e: + if e.status_code == 429: + # Leave rate limit retries to the base handler + raise TooManyRequests() + elif e.status_code == 503: + # Treat 503 as a retryable LlmError + raise LlmError() + raise e except Exception as e: @@ -185,8 +181,13 @@ class Processor(LlmService): logger.debug("Streaming complete") - except Exception as e: - logger.error(f"Mistral streaming exception ({type(e).__name__}): {e}", exc_info=True) + except models.SDKError as e: + if e.status_code == 429: + logger.warning("Hit rate limit during streaming") + raise TooManyRequests() + elif e.status_code == 503: + logger.warning("Hit internal server error during streaming") + raise LlmError() raise e @staticmethod diff --git a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py index cdc8602a..0ee61521 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/openai/llm.py @@ -4,11 +4,11 @@ Simple LLM service, performs text prompt completion using OpenAI. Input is prompt, output is response. """ -from openai import OpenAI, RateLimitError +from openai import OpenAI, RateLimitError, InternalServerError import os import logging -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk # Module logger @@ -104,13 +104,14 @@ class Processor(LlmService): return resp - # FIXME: Wrong exception, don't know what this LLM throws - # for a rate limit except RateLimitError: - # Leave rate limit retries to the base handler raise TooManyRequests() + except InternalServerError: + # Treat 503 as a retryable LlmError + raise LlmError() + except Exception as e: # Apart from rate limits, treat all exceptions as unrecoverable @@ -191,6 +192,10 @@ class Processor(LlmService): logger.warning("Hit rate limit during streaming") raise TooManyRequests() + except InternalServerError: + logger.warning("Hit internal server error during streaming") + raise LlmError() + except Exception as e: logger.error(f"OpenAI streaming exception ({type(e).__name__}): {e}", exc_info=True) raise e diff --git a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py index 2dd4576e..7570fa40 100755 --- a/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py +++ b/trustgraph-flow/trustgraph/model/text_completion/vllm/llm.py @@ -11,7 +11,7 @@ import logging # Module logger logger = logging.getLogger(__name__) -from .... exceptions import TooManyRequests +from .... exceptions import TooManyRequests, LlmError from .... base import LlmService, LlmResult, LlmChunk default_ident = "text-completion" @@ -83,6 +83,10 @@ class Processor(LlmService): json=request, ) as response: + if response.status == 429: + raise TooManyRequests() + if response.status == 503: + raise LlmError() if response.status != 200: raise RuntimeError("Bad status: " + str(response.status)) @@ -104,7 +108,13 @@ class Processor(LlmService): return resp - # FIXME: Assuming vLLM won't produce rate limits? + except TooManyRequests: + # Leave rate limit retries to the base handler + raise TooManyRequests() + + except LlmError: + # Treat 503 as a retryable LlmError + raise LlmError() except Exception as e: @@ -150,6 +160,10 @@ class Processor(LlmService): json=request, ) as response: + if response.status == 429: + raise TooManyRequests() + if response.status == 503: + raise LlmError() if response.status != 200: raise RuntimeError("Bad status: " + str(response.status))