diff --git a/nomyo/SecureCompletionClient.py b/nomyo/SecureCompletionClient.py index 7bbe0fe..f4c14eb 100644 --- a/nomyo/SecureCompletionClient.py +++ b/nomyo/SecureCompletionClient.py @@ -333,10 +333,11 @@ class SecureCompletionClient: except Exception: raise ValueError("Failed to fetch server's public key") - async def _do_encrypt(self, payload_bytes: bytes, aes_key: bytes) -> bytes: + async def _do_encrypt(self, payload_bytes: Union[bytes, bytearray], aes_key: Union[bytes, bytearray]) -> bytes: """ Core AES-256-GCM + RSA-OAEP encryption. Caller is responsible for memory protection of payload_bytes and aes_key before calling this. + Accepts bytearray to avoid creating an unzeroed immutable bytes copy. """ nonce = secrets.token_bytes(12) # 96-bit nonce for GCM cipher = Cipher( @@ -405,8 +406,8 @@ class SecureCompletionClient: raise ValueError("Payload cannot be empty") try: - # Serialize payload to JSON - payload_json = json.dumps(payload).encode('utf-8') + # Serialize payload to JSON as bytearray so SecureBuffer can zero the original + payload_json = bytearray(json.dumps(payload).encode('utf-8')) # Validate payload size (prevent DoS) MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit @@ -415,14 +416,14 @@ class SecureCompletionClient: logger.debug("Payload size: %d bytes", len(payload_json)) - aes_key = secrets.token_bytes(32) # 256-bit key + aes_key = bytearray(secrets.token_bytes(32)) # 256-bit key as bytearray try: if self._use_secure_memory: with secure_bytearray(payload_json) as protected_payload: with secure_bytearray(aes_key) as protected_aes_key: return await self._do_encrypt( - bytes(protected_payload.data), - bytes(protected_aes_key.data) + protected_payload.data, + protected_aes_key.data ) else: logger.warning("Secure memory not available, using standard encryption") @@ -476,6 +477,20 @@ class SecureCompletionClient: if missing_fields: raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}") + # Validate version and algorithm to prevent downgrade attacks + SUPPORTED_VERSION = "1.0" + SUPPORTED_ALGORITHM = "hybrid-aes256-rsa4096" + if package["version"] != SUPPORTED_VERSION: + raise ValueError( + f"Unsupported protocol version: '{package['version']}'. " + f"Expected: '{SUPPORTED_VERSION}'" + ) + if package["algorithm"] != SUPPORTED_ALGORITHM: + raise ValueError( + f"Unsupported encryption algorithm: '{package['algorithm']}'. " + f"Expected: '{SUPPORTED_ALGORITHM}'" + ) + # Validate encrypted_payload structure if not isinstance(package["encrypted_payload"], dict): raise ValueError("Invalid encrypted_payload: must be a dictionary") @@ -485,9 +500,13 @@ class SecureCompletionClient: if missing_payload_fields: raise ValueError(f"Missing fields in encrypted_payload: {', '.join(missing_payload_fields)}") + # Guard: private key must be initialized before attempting decryption + if self.private_key is None: + raise SecurityError("Private key not initialized. Call generate_keys() or load_keys() first.") + # Decrypt with proper error handling — keep crypto errors opaque (timing attacks) - plaintext_json: Optional[str] = None plaintext_size: int = 0 + response: Optional[Dict[str, Any]] = None try: # Decrypt AES key with private key encrypted_aes_key = base64.b64decode(package["encrypted_aes_key"]) @@ -508,7 +527,7 @@ class SecureCompletionClient: tag = base64.b64decode(package["encrypted_payload"]["tag"]) cipher = Cipher( - algorithms.AES(bytes(protected_aes_key.data)), + algorithms.AES(protected_aes_key.data), modes.GCM(nonce, tag), backend=default_backend() ) @@ -517,12 +536,14 @@ class SecureCompletionClient: plaintext_size = len(plaintext_bytes) with secure_bytearray(plaintext_bytes) as protected_plaintext: - # NOTE: plaintext_json is a Python str (immutable) and cannot be - # securely zeroed. The bytearray source is zeroed by the context - # manager, but the str object will persist until GC. This is a - # known limitation of Python's memory model. - plaintext_json = bytes(protected_plaintext.data).decode('utf-8') - del plaintext_bytes # drop immutable bytes ref; secure copy already zeroed + # Parse directly from bytearray — json.loads accepts bytearray + # (Python 3.6+), avoiding an immutable bytes/str copy that cannot + # be zeroed. The bytearray is zeroed by the context manager on exit. + try: + response = json.loads(protected_plaintext.data) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Decrypted response is not valid JSON: {e}") + del plaintext_bytes # AES key automatically zeroed here else: logger.warning("Secure memory not available, using standard decryption") @@ -538,19 +559,18 @@ class SecureCompletionClient: decryptor = cipher.decryptor() plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize() plaintext_size = len(plaintext_bytes) - plaintext_json = plaintext_bytes.decode('utf-8') + try: + response = json.loads(plaintext_bytes) + except (json.JSONDecodeError, UnicodeDecodeError) as e: + raise ValueError(f"Decrypted response is not valid JSON: {e}") del plaintext_bytes + except ValueError: + raise # Re-raise JSON parse errors without masking as SecurityError except Exception: # Don't leak specific decryption errors (timing attacks) raise SecurityError("Decryption failed: integrity check or authentication failed") - # Parse JSON outside the crypto exception handler so format errors aren't hidden - try: - response = json.loads(plaintext_json) - except (json.JSONDecodeError, UnicodeDecodeError) as e: - raise ValueError(f"Decrypted response is not valid JSON: {e}") - # Add metadata for debugging if "_metadata" not in response: response["_metadata"] = {} @@ -744,8 +764,9 @@ class SecureCompletionClient: raise APIConnectionError(f"Failed to connect to router: {e}") except (SecurityError, APIError, AuthenticationError, InvalidRequestError, ForbiddenError, RateLimitError, ServerError, ServiceUnavailableError, APIConnectionError): raise # Re-raise known exceptions - except Exception as e: - raise Exception(f"Request failed: {e}") + except Exception: + logger.exception("Unexpected error in send_secure_request") + raise APIConnectionError("Request failed due to an unexpected error") def _validate_rsa_key(self, key, key_type: str = "private") -> None: """