diff --git a/nomyo/SecureCompletionClient.py b/nomyo/SecureCompletionClient.py index 3d1e896..274fbc5 100644 --- a/nomyo/SecureCompletionClient.py +++ b/nomyo/SecureCompletionClient.py @@ -9,11 +9,44 @@ from cryptography.hazmat.primitives.kdf.hkdf import HKDF # Setup module logger logger = logging.getLogger(__name__) - class SecurityError(Exception): """Raised when a security violation is detected.""" pass +class APIError(Exception): + """Base class for all API-related errors.""" + def __init__(self, message: str, status_code: Optional[int] = None, error_details: Optional[Dict[str, Any]] = None): + self.message = message + self.status_code = status_code + self.error_details = error_details + super().__init__(message) + + def __str__(self): + return self.message + +class AuthenticationError(APIError): + """Raised when authentication fails (e.g., invalid API key).""" + def __init__(self, message: str, status_code: int = 401, error_details: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code, error_details) + +class InvalidRequestError(APIError): + """Raised when the request is invalid (HTTP 400).""" + def __init__(self, message: str, status_code: int = 400, error_details: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code, error_details) + +class APIConnectionError(Exception): + """Raised when there's a connection error.""" + pass + +class RateLimitError(APIError): + """Raised when rate limit is exceeded (HTTP 429).""" + def __init__(self, message: str, status_code: int = 429, error_details: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code, error_details) + +class ServerError(APIError): + """Raised when the server returns an error (HTTP 500).""" + def __init__(self, message: str, status_code: int = 500, error_details: Optional[Dict[str, Any]] = None): + super().__init__(message, status_code, error_details) class SecureCompletionClient: """ @@ -39,7 +72,7 @@ class SecureCompletionClient: self.public_key_pem = None self.key_size = 4096 # RSA key size self.allow_http = allow_http # Store for use in fetch_server_public_key - + # Validate HTTPS for security if not self.router_url.startswith("https://"): if not allow_http: @@ -145,7 +178,7 @@ class SecureCompletionClient: if password: password_options.append(password.encode('utf-8')) password_options.append(None) # Try without password - + last_error = None for pwd in password_options: try: @@ -177,19 +210,19 @@ class SecureCompletionClient: # Validate loaded key self._validate_rsa_key(self.private_key, "private") - + logger.debug("Keys loaded successfully") async def fetch_server_public_key(self) -> str: """ Fetch the server's public key from the /pki/public_key endpoint. - + Uses HTTPS with certificate verification to prevent MITM attacks. HTTP is only allowed if explicitly enabled via allow_http parameter. Returns: Server's public key as PEM string - + Raises: SecurityError: If HTTPS is not used and HTTP is not explicitly allowed ConnectionError: If connection fails @@ -209,11 +242,11 @@ class SecureCompletionClient: logger.warning("Fetching key over HTTP (local development mode)") url = f"{self.router_url}/pki/public_key" - + try: # Use HTTPS verification only for HTTPS URLs verify_ssl = self.router_url.startswith("https://") - + async with httpx.AsyncClient( timeout=60.0, verify=verify_ssl, # Verify SSL/TLS certificates for HTTPS @@ -222,7 +255,7 @@ class SecureCompletionClient: if response.status_code == 200: server_public_key = response.text - + # Validate it's a valid PEM key try: serialization.load_pem_public_key( @@ -231,7 +264,7 @@ class SecureCompletionClient: ) except Exception: raise ValueError("Server returned invalid public key format") - + if verify_ssl: logger.debug("Server's public key fetched securely over HTTPS") else: @@ -239,7 +272,7 @@ class SecureCompletionClient: return server_public_key else: raise ValueError(f"Failed to fetch server's public key: HTTP {response.status_code}") - + except httpx.ConnectError as e: raise ConnectionError(f"Failed to connect to server: {e}") except httpx.TimeoutException: @@ -270,19 +303,19 @@ class SecureCompletionClient: # Validate payload if not isinstance(payload, dict): raise ValueError("Payload must be a dictionary") - + if not payload: raise ValueError("Payload cannot be empty") try: # Serialize payload to JSON payload_json = json.dumps(payload).encode('utf-8') - + # Validate payload size (prevent DoS) MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit if len(payload_json) > MAX_PAYLOAD_SIZE: raise ValueError(f"Payload too large: {len(payload_json)} bytes (max: {MAX_PAYLOAD_SIZE})") - + logger.debug("Payload size: %d bytes", len(payload_json)) # Generate cryptographically secure random AES key @@ -354,17 +387,17 @@ class SecureCompletionClient: Returns: Decrypted response dictionary - + Raises: ValueError: If response format is invalid SecurityError: If decryption fails or integrity check fails """ logger.info("Decrypting response...") - + # Validate input if not encrypted_response: raise ValueError("Empty encrypted response") - + if not isinstance(encrypted_response, bytes): raise ValueError("Encrypted response must be bytes") @@ -381,11 +414,11 @@ class SecureCompletionClient: missing_fields = [f for f in required_fields if f not in package] if missing_fields: raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}") - + # Validate encrypted_payload structure if not isinstance(package["encrypted_payload"], dict): raise ValueError("Invalid encrypted_payload: must be a dictionary") - + payload_required = ["ciphertext", "nonce", "tag"] missing_payload_fields = [f for f in payload_required if f not in package["encrypted_payload"]] if missing_payload_fields: @@ -449,6 +482,13 @@ class SecureCompletionClient: Returns: Decrypted response from the LLM + + Raises: + AuthenticationError: If API key is invalid or missing (HTTP 401) + InvalidRequestError: If the request is invalid (HTTP 400) + APIError: For other HTTP errors + APIConnectionError: If connection fails + SecurityError: If encryption/decryption fails """ logger.info("Sending secure chat completion request...") @@ -487,35 +527,88 @@ class SecureCompletionClient: return decrypted_response elif response.status_code == 400: - error = response.json() - raise ValueError(f"Bad request: {error.get('detail', 'Unknown error')}") + # Bad request + try: + error = response.json() + raise InvalidRequestError( + f"Bad request: {error.get('detail', 'Unknown error')}", + status_code=400, + error_details=error + ) + except (json.JSONDecodeError, ValueError): + raise InvalidRequestError("Bad request: Invalid response format") + + elif response.status_code == 401: + # Unauthorized - authentication failed + try: + error = response.json() + error_message = error.get('detail', 'Invalid API key or authentication failed') + raise AuthenticationError( + error_message, + status_code=401, + error_details=error + ) + except (json.JSONDecodeError, ValueError): + raise AuthenticationError("Invalid API key or authentication failed") elif response.status_code == 404: - error = response.json() - raise ValueError(f"Endpoint not found: {error.get('detail', 'Secure inference not enabled')}") + # Endpoint not found + try: + error = response.json() + raise APIError( + f"Endpoint not found: {error.get('detail', 'Secure inference not enabled')}", + status_code=404, + error_details=error + ) + except (json.JSONDecodeError, ValueError): + raise APIError("Endpoint not found: Secure inference not enabled") + + elif response.status_code == 429: + # Rate limit exceeded + try: + error = response.json() + raise RateLimitError( + f"Rate limit exceeded: {error.get('detail', 'Too many requests')}", + status_code=429, + error_details=error + ) + except (json.JSONDecodeError, ValueError): + raise RateLimitError("Rate limit exceeded: Too many requests") elif response.status_code == 500: - error = response.json() - raise ValueError(f"Server error: {error.get('detail', 'Internal server error')}") + # Server error + try: + error = response.json() + raise ServerError( + f"Server error: {error.get('detail', 'Internal server error')}", + status_code=500, + error_details=error + ) + except (json.JSONDecodeError, ValueError): + raise ServerError("Server error: Internal server error") else: - raise ValueError(f"Unexpected status code: {response.status_code}") + # Unexpected status code + raise APIError( + f"Unexpected status code: {response.status_code}", + status_code=response.status_code + ) except httpx.NetworkError as e: - raise ConnectionError(f"Failed to connect to router: {e}") - except (ValueError, SecurityError, ConnectionError): + raise APIConnectionError(f"Failed to connect to router: {e}") + except (SecurityError, APIError, AuthenticationError, InvalidRequestError, RateLimitError, ServerError, APIConnectionError): raise # Re-raise known exceptions except Exception as e: raise Exception(f"Request failed: {e}") - + def _validate_rsa_key(self, key, key_type: str = "private") -> None: """ Validate that a key is a valid RSA key with appropriate size. - + Args: key: The key to validate key_type: "private" or "public" - + Raises: ValueError: If key is invalid """ @@ -527,12 +620,12 @@ class SecureCompletionClient: if not isinstance(key, rsa.RSAPublicKey): raise ValueError("Invalid public key: not an RSA public key") key_size = key.key_size - + MIN_KEY_SIZE = 2048 if key_size < MIN_KEY_SIZE: raise ValueError( f"Key size {key_size} is too small. " f"Minimum recommended size is {MIN_KEY_SIZE} bits." ) - + logger.debug("Valid %d-bit RSA %s key", key_size, key_type) diff --git a/nomyo/__init__.py b/nomyo/__init__.py index 190ca2e..968d1b0 100644 --- a/nomyo/__init__.py +++ b/nomyo/__init__.py @@ -5,6 +5,24 @@ OpenAI-compatible secure chat client with end-to-end encryption. """ from .nomyo import SecureChatCompletion +from .SecureCompletionClient import ( + APIError, + AuthenticationError, + InvalidRequestError, + APIConnectionError, + RateLimitError, + ServerError +) + +__all__ = [ + 'SecureChatCompletion', + 'APIError', + 'AuthenticationError', + 'InvalidRequestError', + 'APIConnectionError', + 'RateLimitError', + 'ServerError' +] __version__ = "0.1.0" __author__ = "NOMYO AI" diff --git a/nomyo/nomyo.py b/nomyo/nomyo.py index 1bc7dec..4cb6ebb 100644 --- a/nomyo/nomyo.py +++ b/nomyo/nomyo.py @@ -1,6 +1,6 @@ import uuid from typing import Dict, Any, List, Optional -from .SecureCompletionClient import SecureCompletionClient +from .SecureCompletionClient import SecureCompletionClient, APIError, AuthenticationError, InvalidRequestError, APIConnectionError, RateLimitError, ServerError class SecureChatCompletion: """