fix: after code review

This commit is contained in:
Alpha Nerd 2026-04-01 17:32:52 +02:00
parent 0d88de3bef
commit ec3f3a64cc
Signed by: alpha-nerd
SSH key fingerprint: SHA256:QkkAgVoYi9TQ0UKPkiKSfnerZy2h4qhi3SVPXJmBN+M
4 changed files with 122 additions and 162 deletions

View file

@ -93,7 +93,7 @@ class SecureCompletionClient:
# Validate HTTPS for security
if not self.router_url.startswith("https://"):
if not allow_http:
if allow_http:
warnings.warn(
"⚠️ WARNING: Using HTTP instead of HTTPS. "
"This is INSECURE and should only be used for local development. "
@ -102,7 +102,10 @@ class SecureCompletionClient:
stacklevel=2
)
else:
logger.warning("HTTP mode enabled for local development (INSECURE)")
logger.warning(
"Non-HTTPS URL detected with allow_http=False. "
"Requests will be rejected at runtime by fetch_server_public_key()."
)
def _protect_private_key(self) -> None:
"""
@ -326,8 +329,53 @@ class SecureCompletionClient:
raise # Re-raise security errors
except ValueError:
raise # Re-raise validation errors
except Exception as e:
raise ValueError(f"Failed to fetch server's public key: {e}")
except Exception:
raise ValueError("Failed to fetch server's public key")
async def _do_encrypt(self, payload_bytes: bytes, aes_key: bytes) -> bytes:
"""
Core AES-256-GCM + RSA-OAEP encryption. Caller is responsible for
memory protection of payload_bytes and aes_key before calling this.
"""
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
cipher = Cipher(
algorithms.AES(aes_key),
modes.GCM(nonce),
backend=default_backend()
)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(payload_bytes) + encryptor.finalize()
tag = encryptor.tag
server_public_key_pem = await self.fetch_server_public_key()
server_public_key = serialization.load_pem_public_key(
server_public_key_pem.encode('utf-8'),
backend=default_backend()
)
encrypted_aes_key = server_public_key.encrypt(
aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
encrypted_package = {
"version": "1.0",
"algorithm": "hybrid-aes256-rsa4096",
"encrypted_payload": {
"ciphertext": base64.b64encode(ciphertext).decode('utf-8'),
"nonce": base64.b64encode(nonce).decode('utf-8'),
"tag": base64.b64encode(tag).decode('utf-8')
},
"encrypted_aes_key": base64.b64encode(encrypted_aes_key).decode('utf-8'),
"key_algorithm": "RSA-OAEP-SHA256",
"payload_algorithm": "AES-256-GCM"
}
package_json = json.dumps(encrypted_package).encode('utf-8')
logger.debug("Encrypted package size: %d bytes", len(package_json))
return package_json
async def encrypt_payload(self, payload: Dict[str, Any]) -> bytes:
"""
@ -366,129 +414,26 @@ class SecureCompletionClient:
logger.debug("Payload size: %d bytes", len(payload_json))
# Use secure memory context to protect plaintext payload
if _SECURE_MEMORY_AVAILABLE:
with secure_bytearray(payload_json) as protected_payload:
# Generate cryptographically secure random AES key
aes_key = secrets.token_bytes(32) # 256-bit key
try:
# Protect AES key in memory
aes_key = secrets.token_bytes(32) # 256-bit key
try:
if _SECURE_MEMORY_AVAILABLE:
with secure_bytearray(payload_json) as protected_payload:
with secure_bytearray(aes_key) as protected_aes_key:
# Encrypt payload with AES-GCM using Cipher API
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
cipher = Cipher(
algorithms.AES(bytes(protected_aes_key.data)),
modes.GCM(nonce),
backend=default_backend()
return await self._do_encrypt(
bytes(protected_payload.data),
bytes(protected_aes_key.data)
)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(bytes(protected_payload.data)) + encryptor.finalize()
tag = encryptor.tag
# Fetch server's public key for encrypting the AES key
server_public_key_pem = await self.fetch_server_public_key()
# Encrypt AES key with server's RSA-OAEP
server_public_key = serialization.load_pem_public_key(
server_public_key_pem.encode('utf-8'),
backend=default_backend()
)
encrypted_aes_key = server_public_key.encrypt(
bytes(protected_aes_key.data),
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
# Create encrypted package
encrypted_package = {
"version": "1.0",
"algorithm": "hybrid-aes256-rsa4096",
"encrypted_payload": {
"ciphertext": base64.b64encode(ciphertext).decode('utf-8'),
"nonce": base64.b64encode(nonce).decode('utf-8'),
"tag": base64.b64encode(tag).decode('utf-8')
},
"encrypted_aes_key": base64.b64encode(encrypted_aes_key).decode('utf-8'),
"key_algorithm": "RSA-OAEP-SHA256",
"payload_algorithm": "AES-256-GCM"
}
# Serialize package to JSON and return as bytes
package_json = json.dumps(encrypted_package).encode('utf-8')
logger.debug("Encrypted package size: %d bytes", len(package_json))
return package_json
finally:
# Explicitly clear the AES key reference
del aes_key
else:
# Fallback to standard encryption if secure memory not available
logger.warning("Secure memory not available, using standard encryption")
# Generate cryptographically secure random AES key
aes_key = secrets.token_bytes(32) # 256-bit key
try:
# Encrypt payload with AES-GCM using Cipher API (matching server implementation)
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
cipher = Cipher(
algorithms.AES(aes_key),
modes.GCM(nonce),
backend=default_backend()
)
encryptor = cipher.encryptor()
ciphertext = encryptor.update(payload_json) + encryptor.finalize()
tag = encryptor.tag
# Fetch server's public key for encrypting the AES key
server_public_key_pem = await self.fetch_server_public_key()
# Encrypt AES key with server's RSA-OAEP
server_public_key = serialization.load_pem_public_key(
server_public_key_pem.encode('utf-8'),
backend=default_backend()
)
encrypted_aes_key = server_public_key.encrypt(
aes_key,
padding.OAEP(
mgf=padding.MGF1(algorithm=hashes.SHA256()),
algorithm=hashes.SHA256(),
label=None
)
)
# Create encrypted package
encrypted_package = {
"version": "1.0",
"algorithm": "hybrid-aes256-rsa4096",
"encrypted_payload": {
"ciphertext": base64.b64encode(ciphertext).decode('utf-8'),
"nonce": base64.b64encode(nonce).decode('utf-8'),
"tag": base64.b64encode(tag).decode('utf-8')
},
"encrypted_aes_key": base64.b64encode(encrypted_aes_key).decode('utf-8'),
"key_algorithm": "RSA-OAEP-SHA256",
"payload_algorithm": "AES-256-GCM"
}
# Serialize package to JSON and return as bytes
package_json = json.dumps(encrypted_package).encode('utf-8')
logger.debug("Encrypted package size: %d bytes", len(package_json))
return package_json
finally:
# Explicitly clear the AES key reference
del aes_key
else:
logger.warning("Secure memory not available, using standard encryption")
return await self._do_encrypt(payload_json, aes_key)
finally:
del aes_key
except ValueError:
raise # Re-raise validation errors
except SecurityError:
raise # Re-raise security errors
except Exception as e:
except Exception:
# Don't leak internal details
raise SecurityError("Encryption operation failed")
@ -571,6 +516,10 @@ class SecureCompletionClient:
plaintext_size = len(plaintext_bytes)
with secure_bytearray(plaintext_bytes) as protected_plaintext:
# NOTE: plaintext_json is a Python str (immutable) and cannot be
# securely zeroed. The bytearray source is zeroed by the context
# manager, but the str object will persist until GC. This is a
# known limitation of Python's memory model.
plaintext_json = bytes(protected_plaintext.data).decode('utf-8')
del plaintext_bytes # drop immutable bytes ref; secure copy already zeroed
# AES key automatically zeroed here

View file

@ -24,6 +24,7 @@ Python's immutable bytes objects cannot be securely zeroed in place.
import ctypes
import logging
import sys
import threading
from contextlib import contextmanager
from enum import Enum
from typing import Optional, Union
@ -230,7 +231,7 @@ class SecureMemory:
def _init_windows(self):
"""Initialize Windows-specific functions (VirtualLock + RtlZeroMemory)"""
try:
kernel32 = ctypes.windll.kernel32
kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
# Get page size
class SYSTEM_INFO(ctypes.Structure):
@ -429,8 +430,6 @@ class SecureMemory:
logger.debug(f"Memory lock failed: {e}")
return False
return False
def _unlock_memory_at(self, addr: int, size: int) -> bool:
"""
Unlock memory at a specific address.
@ -472,8 +471,6 @@ class SecureMemory:
logger.debug(f"Memory unlock failed: {e}")
return False
return False
def _zero_memory_at(self, addr: int, size: int) -> None:
"""
Securely zero memory at a specific address.
@ -655,13 +652,16 @@ class SecureMemory:
# Global secure memory instance
_secure_memory: Optional[SecureMemory] = None
_secure_memory_lock = threading.Lock()
def _get_secure_memory() -> SecureMemory:
"""Get or create the global SecureMemory instance."""
global _secure_memory
if _secure_memory is None:
_secure_memory = SecureMemory()
with _secure_memory_lock:
if _secure_memory is None:
_secure_memory = SecureMemory()
return _secure_memory
@ -763,7 +763,8 @@ def disable_secure_memory() -> None:
This is useful for testing or when security is not required.
"""
global _secure_memory
_secure_memory = SecureMemory(enable=False)
with _secure_memory_lock:
_secure_memory = SecureMemory(enable=False)
logger.info("Secure memory operations disabled globally")
@ -774,5 +775,6 @@ def enable_secure_memory() -> None:
This reinitializes the secure memory handler with security enabled.
"""
global _secure_memory
_secure_memory = SecureMemory(enable=True)
with _secure_memory_lock:
_secure_memory = SecureMemory(enable=True)
logger.info("Secure memory operations re-enabled globally")

View file

@ -6,6 +6,7 @@ OpenAI-compatible secure chat client with end-to-end encryption.
from .nomyo import SecureChatCompletion
from .SecureCompletionClient import (
SecurityError,
APIError,
AuthenticationError,
InvalidRequestError,
@ -16,7 +17,20 @@ from .SecureCompletionClient import (
ServiceUnavailableError
)
# Import secure memory module if available
__all__ = [
'SecureChatCompletion',
'SecurityError',
'APIError',
'AuthenticationError',
'InvalidRequestError',
'APIConnectionError',
'ForbiddenError',
'RateLimitError',
'ServerError',
'ServiceUnavailableError',
]
# Import secure memory module if available; extend __all__ only for what was imported
try:
from .SecureMemory import (
get_memory_protection_info,
@ -26,27 +40,17 @@ try:
secure_bytes, # Deprecated, use secure_bytearray instead
SecureBuffer
)
__all__ += [
'get_memory_protection_info',
'disable_secure_memory',
'enable_secure_memory',
'secure_bytearray',
'secure_bytes',
'SecureBuffer',
]
except ImportError:
pass
__all__ = [
'SecureChatCompletion',
'APIError',
'AuthenticationError',
'InvalidRequestError',
'APIConnectionError',
'ForbiddenError',
'RateLimitError',
'ServerError',
'ServiceUnavailableError',
'get_memory_protection_info',
'disable_secure_memory',
'enable_secure_memory',
'secure_bytearray',
'secure_bytes', # Deprecated, use secure_bytearray instead
'SecureBuffer'
]
__version__ = "0.1.1"
__author__ = "NOMYO AI"
__license__ = "Apache-2.0"

View file

@ -1,3 +1,4 @@
import asyncio
import os
import uuid
from typing import Dict, Any, List, Optional
@ -70,6 +71,7 @@ class SecureChatCompletion:
"""
self.client = SecureCompletionClient(router_url=base_url, allow_http=allow_http)
self._keys_initialized = False
self._keys_lock = asyncio.Lock()
self.api_key = api_key
self._key_dir = key_dir
self._secure_memory_enabled = secure_memory
@ -89,22 +91,25 @@ class SecureChatCompletion:
stacklevel=2
)
def _ensure_keys(self):
"""Ensure keys are loaded or generated."""
async def _ensure_keys(self):
"""Ensure keys are loaded or generated (concurrency-safe)."""
if self._keys_initialized:
return
if self._key_dir is not None:
private_key_path = os.path.join(self._key_dir, "private_key.pem")
public_key_path = os.path.join(self._key_dir, "public_key.pem")
try:
self.client.load_keys(private_key_path, public_key_path)
self._keys_initialized = True
async with self._keys_lock:
if self._keys_initialized: # double-check after acquiring lock
return
except Exception:
self.client.generate_keys(save_to_file=True, key_dir=self._key_dir)
else:
self.client.generate_keys()
self._keys_initialized = True
if self._key_dir is not None:
private_key_path = os.path.join(self._key_dir, "private_key.pem")
public_key_path = os.path.join(self._key_dir, "public_key.pem")
try:
self.client.load_keys(private_key_path, public_key_path)
self._keys_initialized = True
return
except Exception:
self.client.generate_keys(save_to_file=True, key_dir=self._key_dir)
else:
self.client.generate_keys()
self._keys_initialized = True
async def create(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
"""
@ -215,8 +220,8 @@ class SecureChatCompletion:
else:
instance = self
# Ensure keys are available (synchronous)
instance._ensure_keys()
# Ensure keys are available
await instance._ensure_keys()
# Build payload — api_key is intentionally excluded (sent as Bearer header)
payload = {