nomyo/nomyo/SecureCompletionClient.py

778 lines
33 KiB
Python
Raw Normal View History

import json, base64, urllib.parse, httpx, os, secrets, warnings, logging
2025-12-17 16:03:20 +01:00
from typing import Dict, Any, Optional
from cryptography.hazmat.primitives import serialization, hashes
from cryptography.hazmat.primitives.asymmetric import rsa, padding
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
# Setup module logger
logger = logging.getLogger(__name__)
# Import secure memory module
try:
from .SecureMemory import secure_bytearray, _get_secure_memory
_SECURE_MEMORY_AVAILABLE = True
except ImportError:
_SECURE_MEMORY_AVAILABLE = False
logger.warning("SecureMemory module not available, falling back to standard memory handling")
2025-12-17 16:24:28 +01:00
class SecurityError(Exception):
"""Raised when a security violation is detected."""
pass
class APIError(Exception):
"""Base class for all API-related errors."""
def __init__(self, message: str, status_code: Optional[int] = None, error_details: Optional[Dict[str, Any]] = None):
self.message = message
self.status_code = status_code
self.error_details = error_details
super().__init__(message)
def __str__(self):
return self.message
class AuthenticationError(APIError):
"""Raised when authentication fails (e.g., invalid API key)."""
def __init__(self, message: str, status_code: int = 401, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class InvalidRequestError(APIError):
"""Raised when the request is invalid (HTTP 400)."""
def __init__(self, message: str, status_code: int = 400, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class APIConnectionError(Exception):
"""Raised when there's a connection error."""
pass
class RateLimitError(APIError):
"""Raised when rate limit is exceeded (HTTP 429)."""
def __init__(self, message: str, status_code: int = 429, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class ForbiddenError(APIError):
"""Raised when access is forbidden (HTTP 403), e.g. model not allowed for the requested security tier."""
def __init__(self, message: str, status_code: int = 403, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
class ServerError(APIError):
"""Raised when the server returns an error (HTTP 500)."""
def __init__(self, message: str, status_code: int = 500, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
2025-12-17 16:24:28 +01:00
class ServiceUnavailableError(APIError):
"""Raised when the inference backend is unavailable (HTTP 503)."""
def __init__(self, message: str, status_code: int = 503, error_details: Optional[Dict[str, Any]] = None):
super().__init__(message, status_code, error_details)
2025-12-17 16:03:20 +01:00
class SecureCompletionClient:
"""
Client for the /v1/chat/secure_completion endpoint.
Handles:
- Key generation and management
- Hybrid encryption/decryption
- API communication
- Response parsing
"""
def __init__(self, router_url: str = "https://api.nomyo.ai", allow_http: bool = False, secure_memory: bool = True):
2025-12-17 16:03:20 +01:00
"""
Initialize the secure completion client.
Args:
2025-12-17 16:24:28 +01:00
router_url: Base URL of the NOMYO Router (must use HTTPS for production)
allow_http: Allow HTTP connections (ONLY for local development, never in production)
secure_memory: Whether to use secure memory operations for this instance.
2025-12-17 16:03:20 +01:00
"""
self.router_url = router_url.rstrip('/')
self.private_key = None
self.public_key_pem = None
self.key_size = 4096 # RSA key size
2025-12-17 16:24:28 +01:00
self.allow_http = allow_http # Store for use in fetch_server_public_key
self._use_secure_memory = _SECURE_MEMORY_AVAILABLE and secure_memory
2025-12-17 16:24:28 +01:00
# Validate HTTPS for security
if not self.router_url.startswith("https://"):
2026-04-01 17:32:52 +02:00
if allow_http:
2025-12-17 16:24:28 +01:00
warnings.warn(
"⚠️ WARNING: Using HTTP instead of HTTPS. "
"This is INSECURE and should only be used for local development. "
"Man-in-the-middle attacks are possible!",
UserWarning,
stacklevel=2
)
else:
2026-04-01 17:32:52 +02:00
logger.warning(
"Non-HTTPS URL detected with allow_http=False. "
"Requests will be rejected at runtime by fetch_server_public_key()."
)
2025-12-17 16:03:20 +01:00
def _protect_private_key(self) -> None:
"""
Best-effort attempt to prevent key pages from being swapped to disk.
Note: The cryptography library uses OpenSSL's own memory allocator for
the actual key material, which cannot be directly locked from Python.
This method exports the key to a DER bytearray, locks that page, then
immediately zeros and discards the copy. It does not protect OpenSSL's
internal representation, but serves as a defense-in-depth measure.
For maximum security:
- Use password-protected key files
- Rotate keys regularly
- Store keys outside the project directory in production
"""
if not self._use_secure_memory or not self.private_key:
return
try:
key_der = bytearray(self.private_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
))
secure_memory = _get_secure_memory()
locked = secure_memory.lock_memory(key_der)
logger.debug("Private key page lock: %s", "success" if locked else "unavailable")
secure_memory.zero_memory(key_der)
except Exception as e:
logger.debug("Private key protection unavailable: %s", e)
def generate_keys(self, save_to_file: bool = False, key_dir: str = "client_keys", password: Optional[str] = None) -> None:
2025-12-17 16:03:20 +01:00
"""
Generate RSA key pair for secure communication.
Args:
save_to_file: Whether to save keys to files
key_dir: Directory to save keys (if save_to_file is True)
password: Optional password to encrypt private key (recommended for production)
"""
logger.info("Generating RSA key pair...")
2025-12-17 16:03:20 +01:00
# Generate private key
self.private_key = rsa.generate_private_key(
public_exponent=65537,
key_size=self.key_size,
backend=default_backend()
)
# Get public key
public_key = self.private_key.public_key()
# Serialize public key to PEM format
self.public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode('utf-8')
logger.debug("Generated %d-bit RSA key pair", self.key_size)
2025-12-17 16:03:20 +01:00
# Attempt to protect private key in memory (best effort)
self._protect_private_key()
2025-12-17 16:03:20 +01:00
if save_to_file:
os.makedirs(key_dir, exist_ok=True)
# Save private key
if password:
# Encrypt private key with user-provided password
private_pem = self.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.BestAvailableEncryption(password.encode('utf-8'))
)
logger.debug("Private key encrypted with password")
2025-12-17 16:03:20 +01:00
else:
# Save unencrypted for convenience (not recommended for production)
private_pem = self.private_key.private_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption()
)
logger.warning("Private key saved UNENCRYPTED (not recommended for production)")
2025-12-17 16:03:20 +01:00
# Write private key with restricted permissions (readable only by owner)
private_key_path = os.path.join(key_dir, "private_key.pem")
with open(private_key_path, "wb") as f:
f.write(private_pem)
try:
os.chmod(private_key_path, 0o600) # Only owner can read/write
logger.debug("Private key permissions set to 600 (owner-only access)")
2025-12-17 16:03:20 +01:00
except Exception as e:
logger.warning("Could not set private key permissions: %s", e)
2025-12-17 16:03:20 +01:00
# Save public key (always unencrypted, but with restricted permissions)
public_key_path = os.path.join(key_dir, "public_key.pem")
with open(public_key_path, "w") as f:
f.write(self.public_key_pem)
try:
os.chmod(public_key_path, 0o644) # Owner read/write, group/others read
logger.debug("Public key permissions set to 644")
2025-12-17 16:03:20 +01:00
except Exception as e:
logger.warning("Could not set public key permissions: %s", e)
2025-12-17 16:03:20 +01:00
logger.debug("Keys saved to %s/", key_dir)
2025-12-17 16:03:20 +01:00
def load_keys(self, private_key_path: str, public_key_path: Optional[str] = None, password: Optional[str] = None) -> None:
2025-12-17 16:03:20 +01:00
"""
Load RSA keys from files.
Args:
private_key_path: Path to private key file
public_key_path: Path to public key file (optional, derived from private key if not provided)
password: Optional password for encrypted private key
"""
logger.info("Loading keys from files...")
2025-12-17 16:03:20 +01:00
# Load private key
with open(private_key_path, "rb") as f:
private_pem = f.read()
password_bytes = password.encode('utf-8') if password else None
try:
self.private_key = serialization.load_pem_private_key(
private_pem,
password=password_bytes,
backend=default_backend()
)
logger.debug("Private key loaded %s", 'with password' if password_bytes else 'without password')
except Exception as e:
raise ValueError(f"Failed to load private key: {e}")
2025-12-17 16:03:20 +01:00
# Get public key
public_key = self.private_key.public_key()
# Load public key from file if provided, otherwise derive from private key
if public_key_path:
with open(public_key_path, "r") as f:
self.public_key_pem = f.read().strip()
else:
self.public_key_pem = public_key.public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo
).decode('utf-8')
2025-12-17 16:24:28 +01:00
# Validate loaded key
self._validate_rsa_key(self.private_key, "private")
# Attempt to protect private key in memory (best effort)
self._protect_private_key()
logger.debug("Keys loaded successfully")
2025-12-17 16:03:20 +01:00
async def fetch_server_public_key(self) -> str:
"""
Fetch the server's public key from the /pki/public_key endpoint.
2025-12-17 16:24:28 +01:00
Uses HTTPS with certificate verification to prevent MITM attacks.
HTTP is only allowed if explicitly enabled via allow_http parameter.
2025-12-17 16:03:20 +01:00
Returns:
Server's public key as PEM string
2025-12-17 16:24:28 +01:00
Raises:
SecurityError: If HTTPS is not used and HTTP is not explicitly allowed
ConnectionError: If connection fails
ValueError: If response is invalid
2025-12-17 16:03:20 +01:00
"""
logger.info("Fetching server's public key...")
2025-12-17 16:03:20 +01:00
2025-12-17 16:24:28 +01:00
# Security check: Ensure HTTPS is used unless HTTP explicitly allowed
if not self.router_url.startswith("https://"):
if not self.allow_http:
raise SecurityError(
"Server public key must be fetched over HTTPS to prevent MITM attacks. "
"For local development, initialize with allow_http=True: "
"SecureChatCompletion(base_url='http://localhost:12434', allow_http=True)"
)
else:
logger.warning("Fetching key over HTTP (local development mode)")
2025-12-17 16:24:28 +01:00
2025-12-17 16:03:20 +01:00
url = f"{self.router_url}/pki/public_key"
2025-12-17 16:03:20 +01:00
try:
2025-12-17 16:24:28 +01:00
# Use HTTPS verification only for HTTPS URLs
verify_ssl = self.router_url.startswith("https://")
2025-12-17 16:24:28 +01:00
async with httpx.AsyncClient(
timeout=60.0,
verify=verify_ssl, # Verify SSL/TLS certificates for HTTPS
) as client:
2025-12-17 16:03:20 +01:00
response = await client.get(url)
if response.status_code == 200:
server_public_key = response.text
2025-12-17 16:24:28 +01:00
# Validate it's a valid PEM key
try:
serialization.load_pem_public_key(
server_public_key.encode('utf-8'),
backend=default_backend()
)
except Exception:
raise ValueError("Server returned invalid public key format")
2025-12-17 16:24:28 +01:00
if verify_ssl:
logger.debug("Server's public key fetched securely over HTTPS")
2025-12-17 16:24:28 +01:00
else:
logger.warning("Server's public key fetched over HTTP (INSECURE)")
2025-12-17 16:03:20 +01:00
return server_public_key
else:
raise ValueError(f"Failed to fetch server's public key: HTTP {response.status_code}")
2025-12-17 16:24:28 +01:00
except httpx.ConnectError as e:
raise ConnectionError(f"Failed to connect to server: {e}")
except httpx.TimeoutException:
raise ConnectionError("Connection to server timed out")
except SecurityError:
raise # Re-raise security errors
except ValueError:
raise # Re-raise validation errors
2026-04-01 17:32:52 +02:00
except Exception:
raise ValueError("Failed to fetch server's public key")
async def _do_encrypt(self, payload_bytes: bytes, aes_key: bytes) -> bytes:
"""
Core AES-256-GCM + RSA-OAEP encryption. Caller is responsible for
memory protection of payload_bytes and aes_key before calling this.
"""
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
cipher = Cipher(
algorithms.AES(aes_key),
modes.GCM(nonce),
backend=default_backend()
)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(payload_bytes) + encryptor.finalize()
tag = encryptor.tag
server_public_key_pem = await self.fetch_server_public_key()
server_public_key = serialization.load_pem_public_key(
server_public_key_pem.encode('utf-8'),
backend=default_backend()
)
encrypted_aes_key = server_public_key.encrypt(
aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
encrypted_package = {
"version": "1.0",
"algorithm": "hybrid-aes256-rsa4096",
"encrypted_payload": {
"ciphertext": base64.b64encode(ciphertext).decode('utf-8'),
"nonce": base64.b64encode(nonce).decode('utf-8'),
"tag": base64.b64encode(tag).decode('utf-8')
},
"encrypted_aes_key": base64.b64encode(encrypted_aes_key).decode('utf-8'),
"key_algorithm": "RSA-OAEP-SHA256",
"payload_algorithm": "AES-256-GCM"
}
package_json = json.dumps(encrypted_package).encode('utf-8')
logger.debug("Encrypted package size: %d bytes", len(package_json))
return package_json
2025-12-17 16:03:20 +01:00
async def encrypt_payload(self, payload: Dict[str, Any]) -> bytes:
"""
Encrypt a payload using hybrid encryption (AES-256-GCM + RSA-OAEP).
This method uses secure memory operations to protect the plaintext payload
from being swapped to disk or lingering in memory after encryption.
2025-12-17 16:03:20 +01:00
Args:
payload: Dictionary containing the chat completion request
Returns:
Encrypted payload as bytes
Raises:
2025-12-17 16:24:28 +01:00
ValueError: If payload is invalid or too large
SecurityError: If encryption fails
2025-12-17 16:03:20 +01:00
"""
logger.info("Encrypting payload...")
2025-12-17 16:03:20 +01:00
2025-12-17 16:24:28 +01:00
# Validate payload
if not isinstance(payload, dict):
raise ValueError("Payload must be a dictionary")
2025-12-17 16:24:28 +01:00
if not payload:
raise ValueError("Payload cannot be empty")
2025-12-17 16:03:20 +01:00
try:
# Serialize payload to JSON
payload_json = json.dumps(payload).encode('utf-8')
2025-12-17 16:24:28 +01:00
# Validate payload size (prevent DoS)
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit
if len(payload_json) > MAX_PAYLOAD_SIZE:
raise ValueError(f"Payload too large: {len(payload_json)} bytes (max: {MAX_PAYLOAD_SIZE})")
logger.debug("Payload size: %d bytes", len(payload_json))
2025-12-17 16:03:20 +01:00
2026-04-01 17:32:52 +02:00
aes_key = secrets.token_bytes(32) # 256-bit key
try:
if self._use_secure_memory:
2026-04-01 17:32:52 +02:00
with secure_bytearray(payload_json) as protected_payload:
with secure_bytearray(aes_key) as protected_aes_key:
2026-04-01 17:32:52 +02:00
return await self._do_encrypt(
bytes(protected_payload.data),
bytes(protected_aes_key.data)
)
2026-04-01 17:32:52 +02:00
else:
logger.warning("Secure memory not available, using standard encryption")
return await self._do_encrypt(payload_json, aes_key)
finally:
del aes_key
2025-12-17 16:03:20 +01:00
2025-12-17 16:24:28 +01:00
except ValueError:
raise # Re-raise validation errors
except SecurityError:
raise # Re-raise security errors
2026-04-01 17:32:52 +02:00
except Exception:
2025-12-17 16:24:28 +01:00
# Don't leak internal details
raise SecurityError("Encryption operation failed")
2025-12-17 16:03:20 +01:00
async def decrypt_response(self, encrypted_response: bytes, payload_id: str) -> Dict[str, Any]:
"""
Decrypt a response from the secure endpoint.
Args:
encrypted_response: Encrypted response bytes
payload_id: Payload ID for metadata verification
Returns:
Decrypted response dictionary
2025-12-17 16:24:28 +01:00
Raises:
ValueError: If response format is invalid
SecurityError: If decryption fails or integrity check fails
2025-12-17 16:03:20 +01:00
"""
logger.info("Decrypting response...")
2025-12-17 16:24:28 +01:00
# Validate input
if not encrypted_response:
raise ValueError("Empty encrypted response")
2025-12-17 16:24:28 +01:00
if not isinstance(encrypted_response, bytes):
raise ValueError("Encrypted response must be bytes")
2025-12-17 16:03:20 +01:00
# Parse encrypted package
try:
package = json.loads(encrypted_response.decode('utf-8'))
2025-12-17 16:24:28 +01:00
except json.JSONDecodeError:
raise ValueError("Invalid encrypted package format: malformed JSON")
except UnicodeDecodeError:
raise ValueError("Invalid encrypted package format: not valid UTF-8")
2025-12-17 16:03:20 +01:00
# Validate package structure
required_fields = ["version", "algorithm", "encrypted_payload", "encrypted_aes_key"]
2025-12-17 16:24:28 +01:00
missing_fields = [f for f in required_fields if f not in package]
if missing_fields:
raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}")
2025-12-17 16:24:28 +01:00
# Validate encrypted_payload structure
if not isinstance(package["encrypted_payload"], dict):
raise ValueError("Invalid encrypted_payload: must be a dictionary")
2025-12-17 16:24:28 +01:00
payload_required = ["ciphertext", "nonce", "tag"]
missing_payload_fields = [f for f in payload_required if f not in package["encrypted_payload"]]
if missing_payload_fields:
raise ValueError(f"Missing fields in encrypted_payload: {', '.join(missing_payload_fields)}")
# Decrypt with proper error handling — keep crypto errors opaque (timing attacks)
plaintext_json: Optional[str] = None
plaintext_size: int = 0
2025-12-17 16:24:28 +01:00
try:
# Decrypt AES key with private key
encrypted_aes_key = base64.b64decode(package["encrypted_aes_key"])
aes_key = self.private_key.decrypt(
encrypted_aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
2025-12-17 16:03:20 +01:00
)
# Use secure memory to protect AES key and decrypted plaintext
if self._use_secure_memory:
with secure_bytearray(aes_key) as protected_aes_key:
ciphertext = base64.b64decode(package["encrypted_payload"]["ciphertext"])
nonce = base64.b64decode(package["encrypted_payload"]["nonce"])
tag = base64.b64decode(package["encrypted_payload"]["tag"])
2025-12-17 16:03:20 +01:00
cipher = Cipher(
algorithms.AES(bytes(protected_aes_key.data)),
modes.GCM(nonce, tag),
backend=default_backend()
)
decryptor = cipher.decryptor()
plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize()
plaintext_size = len(plaintext_bytes)
with secure_bytearray(plaintext_bytes) as protected_plaintext:
2026-04-01 17:32:52 +02:00
# NOTE: plaintext_json is a Python str (immutable) and cannot be
# securely zeroed. The bytearray source is zeroed by the context
# manager, but the str object will persist until GC. This is a
# known limitation of Python's memory model.
plaintext_json = bytes(protected_plaintext.data).decode('utf-8')
del plaintext_bytes # drop immutable bytes ref; secure copy already zeroed
# AES key automatically zeroed here
else:
logger.warning("Secure memory not available, using standard decryption")
ciphertext = base64.b64decode(package["encrypted_payload"]["ciphertext"])
nonce = base64.b64decode(package["encrypted_payload"]["nonce"])
tag = base64.b64decode(package["encrypted_payload"]["tag"])
cipher = Cipher(
algorithms.AES(aes_key),
modes.GCM(nonce, tag),
backend=default_backend()
)
decryptor = cipher.decryptor()
plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize()
plaintext_size = len(plaintext_bytes)
plaintext_json = plaintext_bytes.decode('utf-8')
del plaintext_bytes
2025-12-17 16:03:20 +01:00
2025-12-17 16:24:28 +01:00
except Exception:
# Don't leak specific decryption errors (timing attacks)
raise SecurityError("Decryption failed: integrity check or authentication failed")
2025-12-17 16:03:20 +01:00
# Parse JSON outside the crypto exception handler so format errors aren't hidden
try:
response = json.loads(plaintext_json)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise ValueError(f"Decrypted response is not valid JSON: {e}")
2025-12-17 16:03:20 +01:00
# Add metadata for debugging
if "_metadata" not in response:
response["_metadata"] = {}
response["_metadata"].update({
"payload_id": payload_id,
"processed_at": package.get("processed_at"),
"is_encrypted": True,
"encryption_algorithm": package["algorithm"]
})
logger.debug("Response decrypted successfully")
logger.debug("Response size: %d bytes", plaintext_size)
2025-12-17 16:03:20 +01:00
return response
async def send_secure_request(self, payload: Dict[str, Any], payload_id: str, api_key: Optional[str] = None, security_tier: Optional[str] = None) -> Dict[str, Any]:
2025-12-17 16:03:20 +01:00
"""
Send a secure chat completion request to the router.
Args:
payload: Chat completion request payload
payload_id: Unique identifier for this request
api_key: Optional API key for bearer authentication
security_tier: Optional security tier for routing ("standard", "high", or "maximum").
Controls hardware preference:
- "standard": general secure inference
- "high": sensitive business data
- "maximum": maximum isolation (PHI, classified data)
If not specified, server uses default based on model name mapping.
2025-12-17 16:03:20 +01:00
Returns:
Decrypted response from the LLM
Raises:
AuthenticationError: If API key is invalid or missing (HTTP 401)
InvalidRequestError: If the request is invalid (HTTP 400)
APIError: For other HTTP errors
APIConnectionError: If connection fails
SecurityError: If encryption/decryption fails
ValueError: If security_tier is invalid
2025-12-17 16:03:20 +01:00
"""
logger.info("Sending secure chat completion request...")
2025-12-17 16:03:20 +01:00
# Validate security tier if provided
if security_tier is not None:
valid_tiers = ["standard", "high", "maximum"]
if security_tier not in valid_tiers:
raise ValueError(
f"Invalid security_tier: '{security_tier}'. "
f"Must be one of: {', '.join(valid_tiers)}"
)
2025-12-17 16:03:20 +01:00
# Step 1: Encrypt the payload
encrypted_payload = await self.encrypt_payload(payload)
# Step 2: Prepare headers
headers = {
"X-Payload-ID": payload_id,
"X-Public-Key": urllib.parse.quote(self.public_key_pem),
"Content-Type": "application/octet-stream"
}
# Add Authorization header if api_key is provided
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
# Add Security-Tier header if security_tier is provided
if security_tier:
headers["X-Security-Tier"] = security_tier
2025-12-17 16:03:20 +01:00
# Step 3: Send request to router
url = f"{self.router_url}/v1/chat/secure_completion"
logger.debug("Target URL: %s", url)
2025-12-17 16:03:20 +01:00
try:
async with httpx.AsyncClient(timeout=60.0) as client:
response = await client.post(
url,
headers=headers,
content=encrypted_payload
)
logger.debug("HTTP Status: %d", response.status_code)
2025-12-17 16:03:20 +01:00
if response.status_code == 200:
# Step 4: Decrypt the response
encrypted_response = response.content
decrypted_response = await self.decrypt_response(encrypted_response, payload_id)
return decrypted_response
elif response.status_code == 400:
# Bad request
try:
error = response.json()
raise InvalidRequestError(
f"Bad request: {error.get('detail', 'Unknown error')}",
status_code=400,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise InvalidRequestError("Bad request: Invalid response format")
elif response.status_code == 401:
# Unauthorized - authentication failed
try:
error = response.json()
error_message = error.get('detail', 'Invalid API key or authentication failed')
raise AuthenticationError(
error_message,
status_code=401,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise AuthenticationError("Invalid API key or authentication failed")
2025-12-17 16:03:20 +01:00
elif response.status_code == 403:
# Forbidden - model not allowed for security tier
try:
error = response.json()
raise ForbiddenError(
f"Forbidden: {error.get('detail', 'Model not allowed for the requested security tier')}",
status_code=403,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise ForbiddenError("Forbidden: Model not allowed for the requested security tier")
2025-12-17 16:03:20 +01:00
elif response.status_code == 404:
# Endpoint not found
try:
error = response.json()
raise APIError(
f"Endpoint not found: {error.get('detail', 'Secure inference not enabled')}",
status_code=404,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise APIError("Endpoint not found: Secure inference not enabled")
elif response.status_code == 429:
# Rate limit exceeded
try:
error = response.json()
raise RateLimitError(
f"Rate limit exceeded: {error.get('detail', 'Too many requests')}",
status_code=429,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise RateLimitError("Rate limit exceeded: Too many requests")
2025-12-17 16:03:20 +01:00
elif response.status_code == 500:
# Server error
try:
error = response.json()
raise ServerError(
f"Server error: {error.get('detail', 'Internal server error')}",
status_code=500,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise ServerError("Server error: Internal server error")
2025-12-17 16:03:20 +01:00
elif response.status_code == 503:
# Service unavailable - inference backend is down
try:
error = response.json()
raise ServiceUnavailableError(
f"Service unavailable: {error.get('detail', 'Inference backend is unavailable')}",
status_code=503,
error_details=error
)
except (json.JSONDecodeError, ValueError):
raise ServiceUnavailableError("Service unavailable: Inference backend is unavailable")
2025-12-17 16:03:20 +01:00
else:
# Unexpected status code
try:
unexp_detail = response.json()
if not isinstance(unexp_detail, dict):
unexp_detail = {"detail": "unknown"}
detail_msg = unexp_detail.get("detail", "unknown")
except (json.JSONDecodeError, ValueError):
detail_msg = "unknown"
raise APIError(
f"Unexpected status code: {response.status_code} {detail_msg}",
status_code=response.status_code
)
2025-12-17 16:03:20 +01:00
except httpx.NetworkError as e:
raise APIConnectionError(f"Failed to connect to router: {e}")
except (SecurityError, APIError, AuthenticationError, InvalidRequestError, ForbiddenError, RateLimitError, ServerError, ServiceUnavailableError, APIConnectionError):
2025-12-17 16:24:28 +01:00
raise # Re-raise known exceptions
2025-12-17 16:03:20 +01:00
except Exception as e:
raise Exception(f"Request failed: {e}")
2025-12-17 16:24:28 +01:00
def _validate_rsa_key(self, key, key_type: str = "private") -> None:
"""
Validate that a key is a valid RSA key with appropriate size.
2025-12-17 16:24:28 +01:00
Args:
key: The key to validate
key_type: "private" or "public"
2025-12-17 16:24:28 +01:00
Raises:
ValueError: If key is invalid
"""
if key_type == "private":
if not isinstance(key, rsa.RSAPrivateKey):
raise ValueError("Invalid private key: not an RSA private key")
key_size = key.key_size
else:
if not isinstance(key, rsa.RSAPublicKey):
raise ValueError("Invalid public key: not an RSA public key")
key_size = key.key_size
2025-12-17 16:24:28 +01:00
MIN_KEY_SIZE = 2048
if key_size < MIN_KEY_SIZE:
raise ValueError(
f"Key size {key_size} is too small. "
f"Minimum recommended size is {MIN_KEY_SIZE} bits."
)
logger.debug("Valid %d-bit RSA %s key", key_size, key_type)