feat: add comprehensive API error handling classes
Added new exception classes for API error handling including APIError, AuthenticationError, InvalidRequestError, APIConnectionError, RateLimitError, and ServerError to improve error handling and debugging capabilities in the SecureCompletionClient. These changes enhance the robustness of the API client by providing specific error types for different failure scenarios.
This commit is contained in:
parent
165f023513
commit
39c03fb975
3 changed files with 145 additions and 34 deletions
|
|
@ -9,11 +9,44 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF
|
||||||
# Setup module logger
|
# Setup module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class SecurityError(Exception):
|
class SecurityError(Exception):
|
||||||
"""Raised when a security violation is detected."""
|
"""Raised when a security violation is detected."""
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
class APIError(Exception):
|
||||||
|
"""Base class for all API-related errors."""
|
||||||
|
def __init__(self, message: str, status_code: Optional[int] = None, error_details: Optional[Dict[str, Any]] = None):
|
||||||
|
self.message = message
|
||||||
|
self.status_code = status_code
|
||||||
|
self.error_details = error_details
|
||||||
|
super().__init__(message)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.message
|
||||||
|
|
||||||
|
class AuthenticationError(APIError):
|
||||||
|
"""Raised when authentication fails (e.g., invalid API key)."""
|
||||||
|
def __init__(self, message: str, status_code: int = 401, error_details: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__(message, status_code, error_details)
|
||||||
|
|
||||||
|
class InvalidRequestError(APIError):
|
||||||
|
"""Raised when the request is invalid (HTTP 400)."""
|
||||||
|
def __init__(self, message: str, status_code: int = 400, error_details: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__(message, status_code, error_details)
|
||||||
|
|
||||||
|
class APIConnectionError(Exception):
|
||||||
|
"""Raised when there's a connection error."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
class RateLimitError(APIError):
|
||||||
|
"""Raised when rate limit is exceeded (HTTP 429)."""
|
||||||
|
def __init__(self, message: str, status_code: int = 429, error_details: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__(message, status_code, error_details)
|
||||||
|
|
||||||
|
class ServerError(APIError):
|
||||||
|
"""Raised when the server returns an error (HTTP 500)."""
|
||||||
|
def __init__(self, message: str, status_code: int = 500, error_details: Optional[Dict[str, Any]] = None):
|
||||||
|
super().__init__(message, status_code, error_details)
|
||||||
|
|
||||||
class SecureCompletionClient:
|
class SecureCompletionClient:
|
||||||
"""
|
"""
|
||||||
|
|
@ -39,7 +72,7 @@ class SecureCompletionClient:
|
||||||
self.public_key_pem = None
|
self.public_key_pem = None
|
||||||
self.key_size = 4096 # RSA key size
|
self.key_size = 4096 # RSA key size
|
||||||
self.allow_http = allow_http # Store for use in fetch_server_public_key
|
self.allow_http = allow_http # Store for use in fetch_server_public_key
|
||||||
|
|
||||||
# Validate HTTPS for security
|
# Validate HTTPS for security
|
||||||
if not self.router_url.startswith("https://"):
|
if not self.router_url.startswith("https://"):
|
||||||
if not allow_http:
|
if not allow_http:
|
||||||
|
|
@ -145,7 +178,7 @@ class SecureCompletionClient:
|
||||||
if password:
|
if password:
|
||||||
password_options.append(password.encode('utf-8'))
|
password_options.append(password.encode('utf-8'))
|
||||||
password_options.append(None) # Try without password
|
password_options.append(None) # Try without password
|
||||||
|
|
||||||
last_error = None
|
last_error = None
|
||||||
for pwd in password_options:
|
for pwd in password_options:
|
||||||
try:
|
try:
|
||||||
|
|
@ -177,19 +210,19 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
# Validate loaded key
|
# Validate loaded key
|
||||||
self._validate_rsa_key(self.private_key, "private")
|
self._validate_rsa_key(self.private_key, "private")
|
||||||
|
|
||||||
logger.debug("Keys loaded successfully")
|
logger.debug("Keys loaded successfully")
|
||||||
|
|
||||||
async def fetch_server_public_key(self) -> str:
|
async def fetch_server_public_key(self) -> str:
|
||||||
"""
|
"""
|
||||||
Fetch the server's public key from the /pki/public_key endpoint.
|
Fetch the server's public key from the /pki/public_key endpoint.
|
||||||
|
|
||||||
Uses HTTPS with certificate verification to prevent MITM attacks.
|
Uses HTTPS with certificate verification to prevent MITM attacks.
|
||||||
HTTP is only allowed if explicitly enabled via allow_http parameter.
|
HTTP is only allowed if explicitly enabled via allow_http parameter.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Server's public key as PEM string
|
Server's public key as PEM string
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
SecurityError: If HTTPS is not used and HTTP is not explicitly allowed
|
SecurityError: If HTTPS is not used and HTTP is not explicitly allowed
|
||||||
ConnectionError: If connection fails
|
ConnectionError: If connection fails
|
||||||
|
|
@ -209,11 +242,11 @@ class SecureCompletionClient:
|
||||||
logger.warning("Fetching key over HTTP (local development mode)")
|
logger.warning("Fetching key over HTTP (local development mode)")
|
||||||
|
|
||||||
url = f"{self.router_url}/pki/public_key"
|
url = f"{self.router_url}/pki/public_key"
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Use HTTPS verification only for HTTPS URLs
|
# Use HTTPS verification only for HTTPS URLs
|
||||||
verify_ssl = self.router_url.startswith("https://")
|
verify_ssl = self.router_url.startswith("https://")
|
||||||
|
|
||||||
async with httpx.AsyncClient(
|
async with httpx.AsyncClient(
|
||||||
timeout=60.0,
|
timeout=60.0,
|
||||||
verify=verify_ssl, # Verify SSL/TLS certificates for HTTPS
|
verify=verify_ssl, # Verify SSL/TLS certificates for HTTPS
|
||||||
|
|
@ -222,7 +255,7 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
if response.status_code == 200:
|
if response.status_code == 200:
|
||||||
server_public_key = response.text
|
server_public_key = response.text
|
||||||
|
|
||||||
# Validate it's a valid PEM key
|
# Validate it's a valid PEM key
|
||||||
try:
|
try:
|
||||||
serialization.load_pem_public_key(
|
serialization.load_pem_public_key(
|
||||||
|
|
@ -231,7 +264,7 @@ class SecureCompletionClient:
|
||||||
)
|
)
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Server returned invalid public key format")
|
raise ValueError("Server returned invalid public key format")
|
||||||
|
|
||||||
if verify_ssl:
|
if verify_ssl:
|
||||||
logger.debug("Server's public key fetched securely over HTTPS")
|
logger.debug("Server's public key fetched securely over HTTPS")
|
||||||
else:
|
else:
|
||||||
|
|
@ -239,7 +272,7 @@ class SecureCompletionClient:
|
||||||
return server_public_key
|
return server_public_key
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Failed to fetch server's public key: HTTP {response.status_code}")
|
raise ValueError(f"Failed to fetch server's public key: HTTP {response.status_code}")
|
||||||
|
|
||||||
except httpx.ConnectError as e:
|
except httpx.ConnectError as e:
|
||||||
raise ConnectionError(f"Failed to connect to server: {e}")
|
raise ConnectionError(f"Failed to connect to server: {e}")
|
||||||
except httpx.TimeoutException:
|
except httpx.TimeoutException:
|
||||||
|
|
@ -270,19 +303,19 @@ class SecureCompletionClient:
|
||||||
# Validate payload
|
# Validate payload
|
||||||
if not isinstance(payload, dict):
|
if not isinstance(payload, dict):
|
||||||
raise ValueError("Payload must be a dictionary")
|
raise ValueError("Payload must be a dictionary")
|
||||||
|
|
||||||
if not payload:
|
if not payload:
|
||||||
raise ValueError("Payload cannot be empty")
|
raise ValueError("Payload cannot be empty")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Serialize payload to JSON
|
# Serialize payload to JSON
|
||||||
payload_json = json.dumps(payload).encode('utf-8')
|
payload_json = json.dumps(payload).encode('utf-8')
|
||||||
|
|
||||||
# Validate payload size (prevent DoS)
|
# Validate payload size (prevent DoS)
|
||||||
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit
|
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit
|
||||||
if len(payload_json) > MAX_PAYLOAD_SIZE:
|
if len(payload_json) > MAX_PAYLOAD_SIZE:
|
||||||
raise ValueError(f"Payload too large: {len(payload_json)} bytes (max: {MAX_PAYLOAD_SIZE})")
|
raise ValueError(f"Payload too large: {len(payload_json)} bytes (max: {MAX_PAYLOAD_SIZE})")
|
||||||
|
|
||||||
logger.debug("Payload size: %d bytes", len(payload_json))
|
logger.debug("Payload size: %d bytes", len(payload_json))
|
||||||
|
|
||||||
# Generate cryptographically secure random AES key
|
# Generate cryptographically secure random AES key
|
||||||
|
|
@ -354,17 +387,17 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decrypted response dictionary
|
Decrypted response dictionary
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If response format is invalid
|
ValueError: If response format is invalid
|
||||||
SecurityError: If decryption fails or integrity check fails
|
SecurityError: If decryption fails or integrity check fails
|
||||||
"""
|
"""
|
||||||
logger.info("Decrypting response...")
|
logger.info("Decrypting response...")
|
||||||
|
|
||||||
# Validate input
|
# Validate input
|
||||||
if not encrypted_response:
|
if not encrypted_response:
|
||||||
raise ValueError("Empty encrypted response")
|
raise ValueError("Empty encrypted response")
|
||||||
|
|
||||||
if not isinstance(encrypted_response, bytes):
|
if not isinstance(encrypted_response, bytes):
|
||||||
raise ValueError("Encrypted response must be bytes")
|
raise ValueError("Encrypted response must be bytes")
|
||||||
|
|
||||||
|
|
@ -381,11 +414,11 @@ class SecureCompletionClient:
|
||||||
missing_fields = [f for f in required_fields if f not in package]
|
missing_fields = [f for f in required_fields if f not in package]
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}")
|
raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}")
|
||||||
|
|
||||||
# Validate encrypted_payload structure
|
# Validate encrypted_payload structure
|
||||||
if not isinstance(package["encrypted_payload"], dict):
|
if not isinstance(package["encrypted_payload"], dict):
|
||||||
raise ValueError("Invalid encrypted_payload: must be a dictionary")
|
raise ValueError("Invalid encrypted_payload: must be a dictionary")
|
||||||
|
|
||||||
payload_required = ["ciphertext", "nonce", "tag"]
|
payload_required = ["ciphertext", "nonce", "tag"]
|
||||||
missing_payload_fields = [f for f in payload_required if f not in package["encrypted_payload"]]
|
missing_payload_fields = [f for f in payload_required if f not in package["encrypted_payload"]]
|
||||||
if missing_payload_fields:
|
if missing_payload_fields:
|
||||||
|
|
@ -449,6 +482,13 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Decrypted response from the LLM
|
Decrypted response from the LLM
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AuthenticationError: If API key is invalid or missing (HTTP 401)
|
||||||
|
InvalidRequestError: If the request is invalid (HTTP 400)
|
||||||
|
APIError: For other HTTP errors
|
||||||
|
APIConnectionError: If connection fails
|
||||||
|
SecurityError: If encryption/decryption fails
|
||||||
"""
|
"""
|
||||||
logger.info("Sending secure chat completion request...")
|
logger.info("Sending secure chat completion request...")
|
||||||
|
|
||||||
|
|
@ -487,35 +527,88 @@ class SecureCompletionClient:
|
||||||
return decrypted_response
|
return decrypted_response
|
||||||
|
|
||||||
elif response.status_code == 400:
|
elif response.status_code == 400:
|
||||||
error = response.json()
|
# Bad request
|
||||||
raise ValueError(f"Bad request: {error.get('detail', 'Unknown error')}")
|
try:
|
||||||
|
error = response.json()
|
||||||
|
raise InvalidRequestError(
|
||||||
|
f"Bad request: {error.get('detail', 'Unknown error')}",
|
||||||
|
status_code=400,
|
||||||
|
error_details=error
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
raise InvalidRequestError("Bad request: Invalid response format")
|
||||||
|
|
||||||
|
elif response.status_code == 401:
|
||||||
|
# Unauthorized - authentication failed
|
||||||
|
try:
|
||||||
|
error = response.json()
|
||||||
|
error_message = error.get('detail', 'Invalid API key or authentication failed')
|
||||||
|
raise AuthenticationError(
|
||||||
|
error_message,
|
||||||
|
status_code=401,
|
||||||
|
error_details=error
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
raise AuthenticationError("Invalid API key or authentication failed")
|
||||||
|
|
||||||
elif response.status_code == 404:
|
elif response.status_code == 404:
|
||||||
error = response.json()
|
# Endpoint not found
|
||||||
raise ValueError(f"Endpoint not found: {error.get('detail', 'Secure inference not enabled')}")
|
try:
|
||||||
|
error = response.json()
|
||||||
|
raise APIError(
|
||||||
|
f"Endpoint not found: {error.get('detail', 'Secure inference not enabled')}",
|
||||||
|
status_code=404,
|
||||||
|
error_details=error
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
raise APIError("Endpoint not found: Secure inference not enabled")
|
||||||
|
|
||||||
|
elif response.status_code == 429:
|
||||||
|
# Rate limit exceeded
|
||||||
|
try:
|
||||||
|
error = response.json()
|
||||||
|
raise RateLimitError(
|
||||||
|
f"Rate limit exceeded: {error.get('detail', 'Too many requests')}",
|
||||||
|
status_code=429,
|
||||||
|
error_details=error
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
raise RateLimitError("Rate limit exceeded: Too many requests")
|
||||||
|
|
||||||
elif response.status_code == 500:
|
elif response.status_code == 500:
|
||||||
error = response.json()
|
# Server error
|
||||||
raise ValueError(f"Server error: {error.get('detail', 'Internal server error')}")
|
try:
|
||||||
|
error = response.json()
|
||||||
|
raise ServerError(
|
||||||
|
f"Server error: {error.get('detail', 'Internal server error')}",
|
||||||
|
status_code=500,
|
||||||
|
error_details=error
|
||||||
|
)
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
raise ServerError("Server error: Internal server error")
|
||||||
|
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unexpected status code: {response.status_code}")
|
# Unexpected status code
|
||||||
|
raise APIError(
|
||||||
|
f"Unexpected status code: {response.status_code}",
|
||||||
|
status_code=response.status_code
|
||||||
|
)
|
||||||
|
|
||||||
except httpx.NetworkError as e:
|
except httpx.NetworkError as e:
|
||||||
raise ConnectionError(f"Failed to connect to router: {e}")
|
raise APIConnectionError(f"Failed to connect to router: {e}")
|
||||||
except (ValueError, SecurityError, ConnectionError):
|
except (SecurityError, APIError, AuthenticationError, InvalidRequestError, RateLimitError, ServerError, APIConnectionError):
|
||||||
raise # Re-raise known exceptions
|
raise # Re-raise known exceptions
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(f"Request failed: {e}")
|
raise Exception(f"Request failed: {e}")
|
||||||
|
|
||||||
def _validate_rsa_key(self, key, key_type: str = "private") -> None:
|
def _validate_rsa_key(self, key, key_type: str = "private") -> None:
|
||||||
"""
|
"""
|
||||||
Validate that a key is a valid RSA key with appropriate size.
|
Validate that a key is a valid RSA key with appropriate size.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
key: The key to validate
|
key: The key to validate
|
||||||
key_type: "private" or "public"
|
key_type: "private" or "public"
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: If key is invalid
|
ValueError: If key is invalid
|
||||||
"""
|
"""
|
||||||
|
|
@ -527,12 +620,12 @@ class SecureCompletionClient:
|
||||||
if not isinstance(key, rsa.RSAPublicKey):
|
if not isinstance(key, rsa.RSAPublicKey):
|
||||||
raise ValueError("Invalid public key: not an RSA public key")
|
raise ValueError("Invalid public key: not an RSA public key")
|
||||||
key_size = key.key_size
|
key_size = key.key_size
|
||||||
|
|
||||||
MIN_KEY_SIZE = 2048
|
MIN_KEY_SIZE = 2048
|
||||||
if key_size < MIN_KEY_SIZE:
|
if key_size < MIN_KEY_SIZE:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Key size {key_size} is too small. "
|
f"Key size {key_size} is too small. "
|
||||||
f"Minimum recommended size is {MIN_KEY_SIZE} bits."
|
f"Minimum recommended size is {MIN_KEY_SIZE} bits."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug("Valid %d-bit RSA %s key", key_size, key_type)
|
logger.debug("Valid %d-bit RSA %s key", key_size, key_type)
|
||||||
|
|
|
||||||
|
|
@ -5,6 +5,24 @@ OpenAI-compatible secure chat client with end-to-end encryption.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from .nomyo import SecureChatCompletion
|
from .nomyo import SecureChatCompletion
|
||||||
|
from .SecureCompletionClient import (
|
||||||
|
APIError,
|
||||||
|
AuthenticationError,
|
||||||
|
InvalidRequestError,
|
||||||
|
APIConnectionError,
|
||||||
|
RateLimitError,
|
||||||
|
ServerError
|
||||||
|
)
|
||||||
|
|
||||||
|
__all__ = [
|
||||||
|
'SecureChatCompletion',
|
||||||
|
'APIError',
|
||||||
|
'AuthenticationError',
|
||||||
|
'InvalidRequestError',
|
||||||
|
'APIConnectionError',
|
||||||
|
'RateLimitError',
|
||||||
|
'ServerError'
|
||||||
|
]
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.0"
|
||||||
__author__ = "NOMYO AI"
|
__author__ = "NOMYO AI"
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,6 @@
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from .SecureCompletionClient import SecureCompletionClient
|
from .SecureCompletionClient import SecureCompletionClient, APIError, AuthenticationError, InvalidRequestError, APIConnectionError, RateLimitError, ServerError
|
||||||
|
|
||||||
class SecureChatCompletion:
|
class SecureChatCompletion:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue