feat: add comprehensive API error handling classes

Added new exception classes for API error handling including APIError, AuthenticationError, InvalidRequestError, APIConnectionError, RateLimitError, and ServerError to improve error handling and debugging capabilities in the SecureCompletionClient. These changes enhance the robustness of the API client by providing specific error types for different failure scenarios.
This commit is contained in:
Alpha Nerd 2025-12-18 10:22:03 +01:00
parent 165f023513
commit 39c03fb975
3 changed files with 145 additions and 34 deletions

View file

@ -9,11 +9,44 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
# Setup module logger
logger = logging.getLogger(__name__)
class SecurityError(Exception):
"""Raised when a security violation is detected."""
pass
class APIError(Exception):
"""Base class for all API-related errors."""
def __init__(self, message: str, status_code: Optional[int] = None, error_details: Optional[Dict[str, Any]] = None):
self.message = message
self.status_code = status_code
self.error_details = error_details
super().__init__(message)
def __str__(self):
return self.message
class AuthenticationError(APIError):
"""Raised when authentication fails (e.g., invalid API key)."""
def __init__(self, message: str, status_code: int = 401, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class InvalidRequestError(APIError):
"""Raised when the request is invalid (HTTP 400)."""
def __init__(self, message: str, status_code: int = 400, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class APIConnectionError(Exception):
"""Raised when there's a connection error."""
pass
class RateLimitError(APIError):
"""Raised when rate limit is exceeded (HTTP 429)."""
def __init__(self, message: str, status_code: int = 429, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class ServerError(APIError):
"""Raised when the server returns an error (HTTP 500)."""
def __init__(self, message: str, status_code: int = 500, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class SecureCompletionClient:
"""
@ -39,7 +72,7 @@ class SecureCompletionClient:
self.public_key_pem = None
self.key_size = 4096 # RSA key size
self.allow_http = allow_http # Store for use in fetch_server_public_key
# Validate HTTPS for security
if not self.router_url.startswith("https://"):
if not allow_http:
@ -145,7 +178,7 @@ class SecureCompletionClient:
if password:
password_options.append(password.encode('utf-8'))
password_options.append(None) # Try without password
last_error = None
for pwd in password_options:
try:
@ -177,19 +210,19 @@ class SecureCompletionClient:
# Validate loaded key
self._validate_rsa_key(self.private_key, "private")
logger.debug("Keys loaded successfully")
async def fetch_server_public_key(self) -> str:
"""
Fetch the server's public key from the /pki/public_key endpoint.
Uses HTTPS with certificate verification to prevent MITM attacks.
HTTP is only allowed if explicitly enabled via allow_http parameter.
Returns:
Server's public key as PEM string
Raises:
SecurityError: If HTTPS is not used and HTTP is not explicitly allowed
ConnectionError: If connection fails
@ -209,11 +242,11 @@ class SecureCompletionClient:
logger.warning("Fetching key over HTTP (local development mode)")
url = f"{self.router_url}/pki/public_key"
try:
# Use HTTPS verification only for HTTPS URLs
verify_ssl = self.router_url.startswith("https://")
async with httpx.AsyncClient(
timeout=60.0,
verify=verify_ssl, # Verify SSL/TLS certificates for HTTPS
@ -222,7 +255,7 @@ class SecureCompletionClient:
if response.status_code == 200:
server_public_key = response.text
# Validate it's a valid PEM key
try:
serialization.load_pem_public_key(
@ -231,7 +264,7 @@ class SecureCompletionClient:
)
except Exception:
raise ValueError("Server returned invalid public key format")
if verify_ssl:
logger.debug("Server's public key fetched securely over HTTPS")
else:
@ -239,7 +272,7 @@ class SecureCompletionClient:
return server_public_key
else:
raise ValueError(f"Failed to fetch server's public key: HTTP {response.status_code}")
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to server: {e}")
except httpx.TimeoutException:
@ -270,19 +303,19 @@ class SecureCompletionClient:
# Validate payload
if not isinstance(payload, dict):
raise ValueError("Payload must be a dictionary")
if not payload:
raise ValueError("Payload cannot be empty")
try:
# Serialize payload to JSON
payload_json = json.dumps(payload).encode('utf-8')
# Validate payload size (prevent DoS)
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit
if len(payload_json) > MAX_PAYLOAD_SIZE:
raise ValueError(f"Payload too large: {len(payload_json)} bytes (max: {MAX_PAYLOAD_SIZE})")
logger.debug("Payload size: %d bytes", len(payload_json))
# Generate cryptographically secure random AES key
@ -354,17 +387,17 @@ class SecureCompletionClient:
Returns:
Decrypted response dictionary
Raises:
ValueError: If response format is invalid
SecurityError: If decryption fails or integrity check fails
"""
logger.info("Decrypting response...")
# Validate input
if not encrypted_response:
raise ValueError("Empty encrypted response")
if not isinstance(encrypted_response, bytes):
raise ValueError("Encrypted response must be bytes")
@ -381,11 +414,11 @@ class SecureCompletionClient:
missing_fields = [f for f in required_fields if f not in package]
if missing_fields:
raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}")
# Validate encrypted_payload structure
if not isinstance(package["encrypted_payload"], dict):
raise ValueError("Invalid encrypted_payload: must be a dictionary")
payload_required = ["ciphertext", "nonce", "tag"]
missing_payload_fields = [f for f in payload_required if f not in package["encrypted_payload"]]
if missing_payload_fields:
@ -449,6 +482,13 @@ class SecureCompletionClient:
Returns:
Decrypted response from the LLM
Raises:
AuthenticationError: If API key is invalid or missing (HTTP 401)
InvalidRequestError: If the request is invalid (HTTP 400)
APIError: For other HTTP errors
APIConnectionError: If connection fails
SecurityError: If encryption/decryption fails
"""
logger.info("Sending secure chat completion request...")
@ -487,35 +527,88 @@ class SecureCompletionClient:
return decrypted_response
elif response.status_code == 400:
error = response.json()
raise ValueError(f"Bad request: {error.get('detail', 'Unknown error')}")
# Bad request
try:
error = response.json()
raise InvalidRequestError(
f"Bad request: {error.get('detail', 'Unknown error')}",
status_code=400,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise InvalidRequestError("Bad request: Invalid response format")
elif response.status_code == 401:
# Unauthorized - authentication failed
try:
error = response.json()
error_message = error.get('detail', 'Invalid API key or authentication failed')
raise AuthenticationError(
error_message,
status_code=401,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise AuthenticationError("Invalid API key or authentication failed")
elif response.status_code == 404:
error = response.json()
raise ValueError(f"Endpoint not found: {error.get('detail', 'Secure inference not enabled')}")
# Endpoint not found
try:
error = response.json()
raise APIError(
f"Endpoint not found: {error.get('detail', 'Secure inference not enabled')}",
status_code=404,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise APIError("Endpoint not found: Secure inference not enabled")
elif response.status_code == 429:
# Rate limit exceeded
try:
error = response.json()
raise RateLimitError(
f"Rate limit exceeded: {error.get('detail', 'Too many requests')}",
status_code=429,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise RateLimitError("Rate limit exceeded: Too many requests")
elif response.status_code == 500:
error = response.json()
raise ValueError(f"Server error: {error.get('detail', 'Internal server error')}")
# Server error
try:
error = response.json()
raise ServerError(
f"Server error: {error.get('detail', 'Internal server error')}",
status_code=500,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise ServerError("Server error: Internal server error")
else:
raise ValueError(f"Unexpected status code: {response.status_code}")
# Unexpected status code
raise APIError(
f"Unexpected status code: {response.status_code}",
status_code=response.status_code
)
except httpx.NetworkError as e:
raise ConnectionError(f"Failed to connect to router: {e}")
except (ValueError, SecurityError, ConnectionError):
raise APIConnectionError(f"Failed to connect to router: {e}")
except (SecurityError, APIError, AuthenticationError, InvalidRequestError, RateLimitError, ServerError, APIConnectionError):
raise # Re-raise known exceptions
except Exception as e:
raise Exception(f"Request failed: {e}")
def _validate_rsa_key(self, key, key_type: str = "private") -> None:
"""
Validate that a key is a valid RSA key with appropriate size.
Args:
key: The key to validate
key_type: "private" or "public"
Raises:
ValueError: If key is invalid
"""
@ -527,12 +620,12 @@ class SecureCompletionClient:
if not isinstance(key, rsa.RSAPublicKey):
raise ValueError("Invalid public key: not an RSA public key")
key_size = key.key_size
MIN_KEY_SIZE = 2048
if key_size < MIN_KEY_SIZE:
raise ValueError(
f"Key size {key_size} is too small. "
f"Minimum recommended size is {MIN_KEY_SIZE} bits."
)
logger.debug("Valid %d-bit RSA %s key", key_size, key_type)

View file

@ -5,6 +5,24 @@ OpenAI-compatible secure chat client with end-to-end encryption.
"""
from .nomyo import SecureChatCompletion
from .SecureCompletionClient import (
APIError,
AuthenticationError,
InvalidRequestError,
APIConnectionError,
RateLimitError,
ServerError
)
__all__ = [
'SecureChatCompletion',
'APIError',
'AuthenticationError',
'InvalidRequestError',
'APIConnectionError',
'RateLimitError',
'ServerError'
]
__version__ = "0.1.0"
__author__ = "NOMYO AI"

View file

@ -1,6 +1,6 @@
import uuid
from typing import Dict, Any, List, Optional
from .SecureCompletionClient import SecureCompletionClient
from .SecureCompletionClient import SecureCompletionClient, APIError, AuthenticationError, InvalidRequestError, APIConnectionError, RateLimitError, ServerError
class SecureChatCompletion:
"""