fix: switch to mutable types to allow zeroing where possible

This commit is contained in:
Alpha Nerd 2026-04-12 17:47:44 +02:00
parent 4102f06bae
commit 5e8e4443a9
Signed by: alpha-nerd
SSH key fingerprint: SHA256:QkkAgVoYi9TQ0UKPkiKSfnerZy2h4qhi3SVPXJmBN+M

View file

@ -333,10 +333,11 @@ class SecureCompletionClient:
except Exception:
raise ValueError("Failed to fetch server's public key")
async def _do_encrypt(self, payload_bytes: bytes, aes_key: bytes) -> bytes:
async def _do_encrypt(self, payload_bytes: Union[bytes, bytearray], aes_key: Union[bytes, bytearray]) -> bytes:
"""
Core AES-256-GCM + RSA-OAEP encryption. Caller is responsible for
memory protection of payload_bytes and aes_key before calling this.
Accepts bytearray to avoid creating an unzeroed immutable bytes copy.
"""
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
cipher = Cipher(
@ -405,8 +406,8 @@ class SecureCompletionClient:
raise ValueError("Payload cannot be empty")
try:
# Serialize payload to JSON
payload_json = json.dumps(payload).encode('utf-8')
# Serialize payload to JSON as bytearray so SecureBuffer can zero the original
payload_json = bytearray(json.dumps(payload).encode('utf-8'))
# Validate payload size (prevent DoS)
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit
@ -415,14 +416,14 @@ class SecureCompletionClient:
logger.debug("Payload size: %d bytes", len(payload_json))
aes_key = secrets.token_bytes(32) # 256-bit key
aes_key = bytearray(secrets.token_bytes(32)) # 256-bit key as bytearray
try:
if self._use_secure_memory:
with secure_bytearray(payload_json) as protected_payload:
with secure_bytearray(aes_key) as protected_aes_key:
return await self._do_encrypt(
bytes(protected_payload.data),
bytes(protected_aes_key.data)
protected_payload.data,
protected_aes_key.data
)
else:
logger.warning("Secure memory not available, using standard encryption")
@ -476,6 +477,20 @@ class SecureCompletionClient:
if missing_fields:
raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}")
# Validate version and algorithm to prevent downgrade attacks
SUPPORTED_VERSION = "1.0"
SUPPORTED_ALGORITHM = "hybrid-aes256-rsa4096"
if package["version"] != SUPPORTED_VERSION:
raise ValueError(
f"Unsupported protocol version: '{package['version']}'. "
f"Expected: '{SUPPORTED_VERSION}'"
)
if package["algorithm"] != SUPPORTED_ALGORITHM:
raise ValueError(
f"Unsupported encryption algorithm: '{package['algorithm']}'. "
f"Expected: '{SUPPORTED_ALGORITHM}'"
)
# Validate encrypted_payload structure
if not isinstance(package["encrypted_payload"], dict):
raise ValueError("Invalid encrypted_payload: must be a dictionary")
@ -485,9 +500,13 @@ class SecureCompletionClient:
if missing_payload_fields:
raise ValueError(f"Missing fields in encrypted_payload: {', '.join(missing_payload_fields)}")
# Guard: private key must be initialized before attempting decryption
if self.private_key is None:
raise SecurityError("Private key not initialized. Call generate_keys() or load_keys() first.")
# Decrypt with proper error handling — keep crypto errors opaque (timing attacks)
plaintext_json: Optional[str] = None
plaintext_size: int = 0
response: Optional[Dict[str, Any]] = None
try:
# Decrypt AES key with private key
encrypted_aes_key = base64.b64decode(package["encrypted_aes_key"])
@ -508,7 +527,7 @@ class SecureCompletionClient:
tag = base64.b64decode(package["encrypted_payload"]["tag"])
cipher = Cipher(
algorithms.AES(bytes(protected_aes_key.data)),
algorithms.AES(protected_aes_key.data),
modes.GCM(nonce, tag),
backend=default_backend()
)
@ -517,12 +536,14 @@ class SecureCompletionClient:
plaintext_size = len(plaintext_bytes)
with secure_bytearray(plaintext_bytes) as protected_plaintext:
# NOTE: plaintext_json is a Python str (immutable) and cannot be
# securely zeroed. The bytearray source is zeroed by the context
# manager, but the str object will persist until GC. This is a
# known limitation of Python's memory model.
plaintext_json = bytes(protected_plaintext.data).decode('utf-8')
del plaintext_bytes # drop immutable bytes ref; secure copy already zeroed
# Parse directly from bytearray — json.loads accepts bytearray
# (Python 3.6+), avoiding an immutable bytes/str copy that cannot
# be zeroed. The bytearray is zeroed by the context manager on exit.
try:
response = json.loads(protected_plaintext.data)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise ValueError(f"Decrypted response is not valid JSON: {e}")
del plaintext_bytes
# AES key automatically zeroed here
else:
logger.warning("Secure memory not available, using standard decryption")
@ -538,19 +559,18 @@ class SecureCompletionClient:
decryptor = cipher.decryptor()
plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize()
plaintext_size = len(plaintext_bytes)
plaintext_json = plaintext_bytes.decode('utf-8')
try:
response = json.loads(plaintext_bytes)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise ValueError(f"Decrypted response is not valid JSON: {e}")
del plaintext_bytes
except ValueError:
raise # Re-raise JSON parse errors without masking as SecurityError
except Exception:
# Don't leak specific decryption errors (timing attacks)
raise SecurityError("Decryption failed: integrity check or authentication failed")
# Parse JSON outside the crypto exception handler so format errors aren't hidden
try:
response = json.loads(plaintext_json)
except (json.JSONDecodeError, UnicodeDecodeError) as e:
raise ValueError(f"Decrypted response is not valid JSON: {e}")
# Add metadata for debugging
if "_metadata" not in response:
response["_metadata"] = {}
@ -744,8 +764,9 @@ class SecureCompletionClient:
raise APIConnectionError(f"Failed to connect to router: {e}")
except (SecurityError, APIError, AuthenticationError, InvalidRequestError, ForbiddenError, RateLimitError, ServerError, ServiceUnavailableError, APIConnectionError):
raise # Re-raise known exceptions
except Exception as e:
raise Exception(f"Request failed: {e}")
except Exception:
logger.exception("Unexpected error in send_secure_request")
raise APIConnectionError("Request failed due to an unexpected error")
def _validate_rsa_key(self, key, key_type: str = "private") -> None:
"""