fix: switch to mutable types to allow zeroing where possible
This commit is contained in:
parent
4102f06bae
commit
5e8e4443a9
1 changed files with 44 additions and 23 deletions
|
|
@ -333,10 +333,11 @@ class SecureCompletionClient:
|
||||||
except Exception:
|
except Exception:
|
||||||
raise ValueError("Failed to fetch server's public key")
|
raise ValueError("Failed to fetch server's public key")
|
||||||
|
|
||||||
async def _do_encrypt(self, payload_bytes: bytes, aes_key: bytes) -> bytes:
|
async def _do_encrypt(self, payload_bytes: Union[bytes, bytearray], aes_key: Union[bytes, bytearray]) -> bytes:
|
||||||
"""
|
"""
|
||||||
Core AES-256-GCM + RSA-OAEP encryption. Caller is responsible for
|
Core AES-256-GCM + RSA-OAEP encryption. Caller is responsible for
|
||||||
memory protection of payload_bytes and aes_key before calling this.
|
memory protection of payload_bytes and aes_key before calling this.
|
||||||
|
Accepts bytearray to avoid creating an unzeroed immutable bytes copy.
|
||||||
"""
|
"""
|
||||||
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
|
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
|
||||||
cipher = Cipher(
|
cipher = Cipher(
|
||||||
|
|
@ -405,8 +406,8 @@ class SecureCompletionClient:
|
||||||
raise ValueError("Payload cannot be empty")
|
raise ValueError("Payload cannot be empty")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Serialize payload to JSON
|
# Serialize payload to JSON as bytearray so SecureBuffer can zero the original
|
||||||
payload_json = json.dumps(payload).encode('utf-8')
|
payload_json = bytearray(json.dumps(payload).encode('utf-8'))
|
||||||
|
|
||||||
# Validate payload size (prevent DoS)
|
# Validate payload size (prevent DoS)
|
||||||
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit
|
MAX_PAYLOAD_SIZE = 10 * 1024 * 1024 # 10MB limit
|
||||||
|
|
@ -415,14 +416,14 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
logger.debug("Payload size: %d bytes", len(payload_json))
|
logger.debug("Payload size: %d bytes", len(payload_json))
|
||||||
|
|
||||||
aes_key = secrets.token_bytes(32) # 256-bit key
|
aes_key = bytearray(secrets.token_bytes(32)) # 256-bit key as bytearray
|
||||||
try:
|
try:
|
||||||
if self._use_secure_memory:
|
if self._use_secure_memory:
|
||||||
with secure_bytearray(payload_json) as protected_payload:
|
with secure_bytearray(payload_json) as protected_payload:
|
||||||
with secure_bytearray(aes_key) as protected_aes_key:
|
with secure_bytearray(aes_key) as protected_aes_key:
|
||||||
return await self._do_encrypt(
|
return await self._do_encrypt(
|
||||||
bytes(protected_payload.data),
|
protected_payload.data,
|
||||||
bytes(protected_aes_key.data)
|
protected_aes_key.data
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
logger.warning("Secure memory not available, using standard encryption")
|
logger.warning("Secure memory not available, using standard encryption")
|
||||||
|
|
@ -476,6 +477,20 @@ class SecureCompletionClient:
|
||||||
if missing_fields:
|
if missing_fields:
|
||||||
raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}")
|
raise ValueError(f"Missing required fields in encrypted package: {', '.join(missing_fields)}")
|
||||||
|
|
||||||
|
# Validate version and algorithm to prevent downgrade attacks
|
||||||
|
SUPPORTED_VERSION = "1.0"
|
||||||
|
SUPPORTED_ALGORITHM = "hybrid-aes256-rsa4096"
|
||||||
|
if package["version"] != SUPPORTED_VERSION:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported protocol version: '{package['version']}'. "
|
||||||
|
f"Expected: '{SUPPORTED_VERSION}'"
|
||||||
|
)
|
||||||
|
if package["algorithm"] != SUPPORTED_ALGORITHM:
|
||||||
|
raise ValueError(
|
||||||
|
f"Unsupported encryption algorithm: '{package['algorithm']}'. "
|
||||||
|
f"Expected: '{SUPPORTED_ALGORITHM}'"
|
||||||
|
)
|
||||||
|
|
||||||
# Validate encrypted_payload structure
|
# Validate encrypted_payload structure
|
||||||
if not isinstance(package["encrypted_payload"], dict):
|
if not isinstance(package["encrypted_payload"], dict):
|
||||||
raise ValueError("Invalid encrypted_payload: must be a dictionary")
|
raise ValueError("Invalid encrypted_payload: must be a dictionary")
|
||||||
|
|
@ -485,9 +500,13 @@ class SecureCompletionClient:
|
||||||
if missing_payload_fields:
|
if missing_payload_fields:
|
||||||
raise ValueError(f"Missing fields in encrypted_payload: {', '.join(missing_payload_fields)}")
|
raise ValueError(f"Missing fields in encrypted_payload: {', '.join(missing_payload_fields)}")
|
||||||
|
|
||||||
|
# Guard: private key must be initialized before attempting decryption
|
||||||
|
if self.private_key is None:
|
||||||
|
raise SecurityError("Private key not initialized. Call generate_keys() or load_keys() first.")
|
||||||
|
|
||||||
# Decrypt with proper error handling — keep crypto errors opaque (timing attacks)
|
# Decrypt with proper error handling — keep crypto errors opaque (timing attacks)
|
||||||
plaintext_json: Optional[str] = None
|
|
||||||
plaintext_size: int = 0
|
plaintext_size: int = 0
|
||||||
|
response: Optional[Dict[str, Any]] = None
|
||||||
try:
|
try:
|
||||||
# Decrypt AES key with private key
|
# Decrypt AES key with private key
|
||||||
encrypted_aes_key = base64.b64decode(package["encrypted_aes_key"])
|
encrypted_aes_key = base64.b64decode(package["encrypted_aes_key"])
|
||||||
|
|
@ -508,7 +527,7 @@ class SecureCompletionClient:
|
||||||
tag = base64.b64decode(package["encrypted_payload"]["tag"])
|
tag = base64.b64decode(package["encrypted_payload"]["tag"])
|
||||||
|
|
||||||
cipher = Cipher(
|
cipher = Cipher(
|
||||||
algorithms.AES(bytes(protected_aes_key.data)),
|
algorithms.AES(protected_aes_key.data),
|
||||||
modes.GCM(nonce, tag),
|
modes.GCM(nonce, tag),
|
||||||
backend=default_backend()
|
backend=default_backend()
|
||||||
)
|
)
|
||||||
|
|
@ -517,12 +536,14 @@ class SecureCompletionClient:
|
||||||
plaintext_size = len(plaintext_bytes)
|
plaintext_size = len(plaintext_bytes)
|
||||||
|
|
||||||
with secure_bytearray(plaintext_bytes) as protected_plaintext:
|
with secure_bytearray(plaintext_bytes) as protected_plaintext:
|
||||||
# NOTE: plaintext_json is a Python str (immutable) and cannot be
|
# Parse directly from bytearray — json.loads accepts bytearray
|
||||||
# securely zeroed. The bytearray source is zeroed by the context
|
# (Python 3.6+), avoiding an immutable bytes/str copy that cannot
|
||||||
# manager, but the str object will persist until GC. This is a
|
# be zeroed. The bytearray is zeroed by the context manager on exit.
|
||||||
# known limitation of Python's memory model.
|
try:
|
||||||
plaintext_json = bytes(protected_plaintext.data).decode('utf-8')
|
response = json.loads(protected_plaintext.data)
|
||||||
del plaintext_bytes # drop immutable bytes ref; secure copy already zeroed
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||||
|
raise ValueError(f"Decrypted response is not valid JSON: {e}")
|
||||||
|
del plaintext_bytes
|
||||||
# AES key automatically zeroed here
|
# AES key automatically zeroed here
|
||||||
else:
|
else:
|
||||||
logger.warning("Secure memory not available, using standard decryption")
|
logger.warning("Secure memory not available, using standard decryption")
|
||||||
|
|
@ -538,19 +559,18 @@ class SecureCompletionClient:
|
||||||
decryptor = cipher.decryptor()
|
decryptor = cipher.decryptor()
|
||||||
plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize()
|
plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize()
|
||||||
plaintext_size = len(plaintext_bytes)
|
plaintext_size = len(plaintext_bytes)
|
||||||
plaintext_json = plaintext_bytes.decode('utf-8')
|
try:
|
||||||
|
response = json.loads(plaintext_bytes)
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||||
|
raise ValueError(f"Decrypted response is not valid JSON: {e}")
|
||||||
del plaintext_bytes
|
del plaintext_bytes
|
||||||
|
|
||||||
|
except ValueError:
|
||||||
|
raise # Re-raise JSON parse errors without masking as SecurityError
|
||||||
except Exception:
|
except Exception:
|
||||||
# Don't leak specific decryption errors (timing attacks)
|
# Don't leak specific decryption errors (timing attacks)
|
||||||
raise SecurityError("Decryption failed: integrity check or authentication failed")
|
raise SecurityError("Decryption failed: integrity check or authentication failed")
|
||||||
|
|
||||||
# Parse JSON outside the crypto exception handler so format errors aren't hidden
|
|
||||||
try:
|
|
||||||
response = json.loads(plaintext_json)
|
|
||||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
|
||||||
raise ValueError(f"Decrypted response is not valid JSON: {e}")
|
|
||||||
|
|
||||||
# Add metadata for debugging
|
# Add metadata for debugging
|
||||||
if "_metadata" not in response:
|
if "_metadata" not in response:
|
||||||
response["_metadata"] = {}
|
response["_metadata"] = {}
|
||||||
|
|
@ -744,8 +764,9 @@ class SecureCompletionClient:
|
||||||
raise APIConnectionError(f"Failed to connect to router: {e}")
|
raise APIConnectionError(f"Failed to connect to router: {e}")
|
||||||
except (SecurityError, APIError, AuthenticationError, InvalidRequestError, ForbiddenError, RateLimitError, ServerError, ServiceUnavailableError, APIConnectionError):
|
except (SecurityError, APIError, AuthenticationError, InvalidRequestError, ForbiddenError, RateLimitError, ServerError, ServiceUnavailableError, APIConnectionError):
|
||||||
raise # Re-raise known exceptions
|
raise # Re-raise known exceptions
|
||||||
except Exception as e:
|
except Exception:
|
||||||
raise Exception(f"Request failed: {e}")
|
logger.exception("Unexpected error in send_secure_request")
|
||||||
|
raise APIConnectionError("Request failed due to an unexpected error")
|
||||||
|
|
||||||
def _validate_rsa_key(self, key, key_type: str = "private") -> None:
|
def _validate_rsa_key(self, key, key_type: str = "private") -> None:
|
||||||
"""
|
"""
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue