fix: after code review
This commit is contained in:
parent
0d88de3bef
commit
ec3f3a64cc
4 changed files with 122 additions and 162 deletions
|
|
@ -93,7 +93,7 @@ class SecureCompletionClient:
|
|||
|
||||
# Validate HTTPS for security
|
||||
if not self.router_url.startswith("https://"):
|
||||
if not allow_http:
|
||||
if allow_http:
|
||||
warnings.warn(
|
||||
"⚠️ WARNING: Using HTTP instead of HTTPS. "
|
||||
"This is INSECURE and should only be used for local development. "
|
||||
|
|
@ -102,7 +102,10 @@ class SecureCompletionClient:
|
|||
stacklevel=2
|
||||
)
|
||||
else:
|
||||
logger.warning("HTTP mode enabled for local development (INSECURE)")
|
||||
logger.warning(
|
||||
"Non-HTTPS URL detected with allow_http=False. "
|
||||
"Requests will be rejected at runtime by fetch_server_public_key()."
|
||||
)
|
||||
|
||||
def _protect_private_key(self) -> None:
|
||||
"""
|
||||
|
|
@ -326,8 +329,53 @@ class SecureCompletionClient:
|
|||
raise # Re-raise security errors
|
||||
except ValueError:
|
||||
raise # Re-raise validation errors
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to fetch server's public key: {e}")
|
||||
except Exception:
|
||||
raise ValueError("Failed to fetch server's public key")
|
||||
|
||||
async def _do_encrypt(self, payload_bytes: bytes, aes_key: bytes) -> bytes:
|
||||
"""
|
||||
Core AES-256-GCM + RSA-OAEP encryption. Caller is responsible for
|
||||
memory protection of payload_bytes and aes_key before calling this.
|
||||
"""
|
||||
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
|
||||
cipher = Cipher(
|
||||
algorithms.AES(aes_key),
|
||||
modes.GCM(nonce),
|
||||
backend=default_backend()
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
ciphertext = encryptor.update(payload_bytes) + encryptor.finalize()
|
||||
tag = encryptor.tag
|
||||
|
||||
server_public_key_pem = await self.fetch_server_public_key()
|
||||
server_public_key = serialization.load_pem_public_key(
|
||||
server_public_key_pem.encode('utf-8'),
|
||||
backend=default_backend()
|
||||
)
|
||||
encrypted_aes_key = server_public_key.encrypt(
|
||||
aes_key,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
|
||||
encrypted_package = {
|
||||
"version": "1.0",
|
||||
"algorithm": "hybrid-aes256-rsa4096",
|
||||
"encrypted_payload": {
|
||||
"ciphertext": base64.b64encode(ciphertext).decode('utf-8'),
|
||||
"nonce": base64.b64encode(nonce).decode('utf-8'),
|
||||
"tag": base64.b64encode(tag).decode('utf-8')
|
||||
},
|
||||
"encrypted_aes_key": base64.b64encode(encrypted_aes_key).decode('utf-8'),
|
||||
"key_algorithm": "RSA-OAEP-SHA256",
|
||||
"payload_algorithm": "AES-256-GCM"
|
||||
}
|
||||
package_json = json.dumps(encrypted_package).encode('utf-8')
|
||||
logger.debug("Encrypted package size: %d bytes", len(package_json))
|
||||
return package_json
|
||||
|
||||
async def encrypt_payload(self, payload: Dict[str, Any]) -> bytes:
|
||||
"""
|
||||
|
|
@ -366,129 +414,26 @@ class SecureCompletionClient:
|
|||
|
||||
logger.debug("Payload size: %d bytes", len(payload_json))
|
||||
|
||||
# Use secure memory context to protect plaintext payload
|
||||
if _SECURE_MEMORY_AVAILABLE:
|
||||
with secure_bytearray(payload_json) as protected_payload:
|
||||
# Generate cryptographically secure random AES key
|
||||
aes_key = secrets.token_bytes(32) # 256-bit key
|
||||
|
||||
try:
|
||||
# Protect AES key in memory
|
||||
aes_key = secrets.token_bytes(32) # 256-bit key
|
||||
try:
|
||||
if _SECURE_MEMORY_AVAILABLE:
|
||||
with secure_bytearray(payload_json) as protected_payload:
|
||||
with secure_bytearray(aes_key) as protected_aes_key:
|
||||
# Encrypt payload with AES-GCM using Cipher API
|
||||
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
|
||||
cipher = Cipher(
|
||||
algorithms.AES(bytes(protected_aes_key.data)),
|
||||
modes.GCM(nonce),
|
||||
backend=default_backend()
|
||||
return await self._do_encrypt(
|
||||
bytes(protected_payload.data),
|
||||
bytes(protected_aes_key.data)
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
ciphertext = encryptor.update(bytes(protected_payload.data)) + encryptor.finalize()
|
||||
tag = encryptor.tag
|
||||
|
||||
# Fetch server's public key for encrypting the AES key
|
||||
server_public_key_pem = await self.fetch_server_public_key()
|
||||
|
||||
# Encrypt AES key with server's RSA-OAEP
|
||||
server_public_key = serialization.load_pem_public_key(
|
||||
server_public_key_pem.encode('utf-8'),
|
||||
backend=default_backend()
|
||||
)
|
||||
encrypted_aes_key = server_public_key.encrypt(
|
||||
bytes(protected_aes_key.data),
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
|
||||
# Create encrypted package
|
||||
encrypted_package = {
|
||||
"version": "1.0",
|
||||
"algorithm": "hybrid-aes256-rsa4096",
|
||||
"encrypted_payload": {
|
||||
"ciphertext": base64.b64encode(ciphertext).decode('utf-8'),
|
||||
"nonce": base64.b64encode(nonce).decode('utf-8'),
|
||||
"tag": base64.b64encode(tag).decode('utf-8')
|
||||
},
|
||||
"encrypted_aes_key": base64.b64encode(encrypted_aes_key).decode('utf-8'),
|
||||
"key_algorithm": "RSA-OAEP-SHA256",
|
||||
"payload_algorithm": "AES-256-GCM"
|
||||
}
|
||||
|
||||
# Serialize package to JSON and return as bytes
|
||||
package_json = json.dumps(encrypted_package).encode('utf-8')
|
||||
logger.debug("Encrypted package size: %d bytes", len(package_json))
|
||||
|
||||
return package_json
|
||||
finally:
|
||||
# Explicitly clear the AES key reference
|
||||
del aes_key
|
||||
else:
|
||||
# Fallback to standard encryption if secure memory not available
|
||||
logger.warning("Secure memory not available, using standard encryption")
|
||||
|
||||
# Generate cryptographically secure random AES key
|
||||
aes_key = secrets.token_bytes(32) # 256-bit key
|
||||
|
||||
try:
|
||||
# Encrypt payload with AES-GCM using Cipher API (matching server implementation)
|
||||
nonce = secrets.token_bytes(12) # 96-bit nonce for GCM
|
||||
cipher = Cipher(
|
||||
algorithms.AES(aes_key),
|
||||
modes.GCM(nonce),
|
||||
backend=default_backend()
|
||||
)
|
||||
encryptor = cipher.encryptor()
|
||||
ciphertext = encryptor.update(payload_json) + encryptor.finalize()
|
||||
tag = encryptor.tag
|
||||
|
||||
# Fetch server's public key for encrypting the AES key
|
||||
server_public_key_pem = await self.fetch_server_public_key()
|
||||
|
||||
# Encrypt AES key with server's RSA-OAEP
|
||||
server_public_key = serialization.load_pem_public_key(
|
||||
server_public_key_pem.encode('utf-8'),
|
||||
backend=default_backend()
|
||||
)
|
||||
encrypted_aes_key = server_public_key.encrypt(
|
||||
aes_key,
|
||||
padding.OAEP(
|
||||
mgf=padding.MGF1(algorithm=hashes.SHA256()),
|
||||
algorithm=hashes.SHA256(),
|
||||
label=None
|
||||
)
|
||||
)
|
||||
|
||||
# Create encrypted package
|
||||
encrypted_package = {
|
||||
"version": "1.0",
|
||||
"algorithm": "hybrid-aes256-rsa4096",
|
||||
"encrypted_payload": {
|
||||
"ciphertext": base64.b64encode(ciphertext).decode('utf-8'),
|
||||
"nonce": base64.b64encode(nonce).decode('utf-8'),
|
||||
"tag": base64.b64encode(tag).decode('utf-8')
|
||||
},
|
||||
"encrypted_aes_key": base64.b64encode(encrypted_aes_key).decode('utf-8'),
|
||||
"key_algorithm": "RSA-OAEP-SHA256",
|
||||
"payload_algorithm": "AES-256-GCM"
|
||||
}
|
||||
|
||||
# Serialize package to JSON and return as bytes
|
||||
package_json = json.dumps(encrypted_package).encode('utf-8')
|
||||
logger.debug("Encrypted package size: %d bytes", len(package_json))
|
||||
|
||||
return package_json
|
||||
finally:
|
||||
# Explicitly clear the AES key reference
|
||||
del aes_key
|
||||
else:
|
||||
logger.warning("Secure memory not available, using standard encryption")
|
||||
return await self._do_encrypt(payload_json, aes_key)
|
||||
finally:
|
||||
del aes_key
|
||||
|
||||
except ValueError:
|
||||
raise # Re-raise validation errors
|
||||
except SecurityError:
|
||||
raise # Re-raise security errors
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
# Don't leak internal details
|
||||
raise SecurityError("Encryption operation failed")
|
||||
|
||||
|
|
@ -571,6 +516,10 @@ class SecureCompletionClient:
|
|||
plaintext_size = len(plaintext_bytes)
|
||||
|
||||
with secure_bytearray(plaintext_bytes) as protected_plaintext:
|
||||
# NOTE: plaintext_json is a Python str (immutable) and cannot be
|
||||
# securely zeroed. The bytearray source is zeroed by the context
|
||||
# manager, but the str object will persist until GC. This is a
|
||||
# known limitation of Python's memory model.
|
||||
plaintext_json = bytes(protected_plaintext.data).decode('utf-8')
|
||||
del plaintext_bytes # drop immutable bytes ref; secure copy already zeroed
|
||||
# AES key automatically zeroed here
|
||||
|
|
|
|||
|
|
@ -24,6 +24,7 @@ Python's immutable bytes objects cannot be securely zeroed in place.
|
|||
import ctypes
|
||||
import logging
|
||||
import sys
|
||||
import threading
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum
|
||||
from typing import Optional, Union
|
||||
|
|
@ -230,7 +231,7 @@ class SecureMemory:
|
|||
def _init_windows(self):
|
||||
"""Initialize Windows-specific functions (VirtualLock + RtlZeroMemory)"""
|
||||
try:
|
||||
kernel32 = ctypes.windll.kernel32
|
||||
kernel32 = ctypes.WinDLL('kernel32', use_last_error=True)
|
||||
|
||||
# Get page size
|
||||
class SYSTEM_INFO(ctypes.Structure):
|
||||
|
|
@ -429,8 +430,6 @@ class SecureMemory:
|
|||
logger.debug(f"Memory lock failed: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _unlock_memory_at(self, addr: int, size: int) -> bool:
|
||||
"""
|
||||
Unlock memory at a specific address.
|
||||
|
|
@ -472,8 +471,6 @@ class SecureMemory:
|
|||
logger.debug(f"Memory unlock failed: {e}")
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _zero_memory_at(self, addr: int, size: int) -> None:
|
||||
"""
|
||||
Securely zero memory at a specific address.
|
||||
|
|
@ -655,13 +652,16 @@ class SecureMemory:
|
|||
|
||||
# Global secure memory instance
|
||||
_secure_memory: Optional[SecureMemory] = None
|
||||
_secure_memory_lock = threading.Lock()
|
||||
|
||||
|
||||
def _get_secure_memory() -> SecureMemory:
|
||||
"""Get or create the global SecureMemory instance."""
|
||||
global _secure_memory
|
||||
if _secure_memory is None:
|
||||
_secure_memory = SecureMemory()
|
||||
with _secure_memory_lock:
|
||||
if _secure_memory is None:
|
||||
_secure_memory = SecureMemory()
|
||||
return _secure_memory
|
||||
|
||||
|
||||
|
|
@ -763,7 +763,8 @@ def disable_secure_memory() -> None:
|
|||
This is useful for testing or when security is not required.
|
||||
"""
|
||||
global _secure_memory
|
||||
_secure_memory = SecureMemory(enable=False)
|
||||
with _secure_memory_lock:
|
||||
_secure_memory = SecureMemory(enable=False)
|
||||
logger.info("Secure memory operations disabled globally")
|
||||
|
||||
|
||||
|
|
@ -774,5 +775,6 @@ def enable_secure_memory() -> None:
|
|||
This reinitializes the secure memory handler with security enabled.
|
||||
"""
|
||||
global _secure_memory
|
||||
_secure_memory = SecureMemory(enable=True)
|
||||
with _secure_memory_lock:
|
||||
_secure_memory = SecureMemory(enable=True)
|
||||
logger.info("Secure memory operations re-enabled globally")
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ OpenAI-compatible secure chat client with end-to-end encryption.
|
|||
|
||||
from .nomyo import SecureChatCompletion
|
||||
from .SecureCompletionClient import (
|
||||
SecurityError,
|
||||
APIError,
|
||||
AuthenticationError,
|
||||
InvalidRequestError,
|
||||
|
|
@ -16,7 +17,20 @@ from .SecureCompletionClient import (
|
|||
ServiceUnavailableError
|
||||
)
|
||||
|
||||
# Import secure memory module if available
|
||||
__all__ = [
|
||||
'SecureChatCompletion',
|
||||
'SecurityError',
|
||||
'APIError',
|
||||
'AuthenticationError',
|
||||
'InvalidRequestError',
|
||||
'APIConnectionError',
|
||||
'ForbiddenError',
|
||||
'RateLimitError',
|
||||
'ServerError',
|
||||
'ServiceUnavailableError',
|
||||
]
|
||||
|
||||
# Import secure memory module if available; extend __all__ only for what was imported
|
||||
try:
|
||||
from .SecureMemory import (
|
||||
get_memory_protection_info,
|
||||
|
|
@ -26,27 +40,17 @@ try:
|
|||
secure_bytes, # Deprecated, use secure_bytearray instead
|
||||
SecureBuffer
|
||||
)
|
||||
__all__ += [
|
||||
'get_memory_protection_info',
|
||||
'disable_secure_memory',
|
||||
'enable_secure_memory',
|
||||
'secure_bytearray',
|
||||
'secure_bytes',
|
||||
'SecureBuffer',
|
||||
]
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
__all__ = [
|
||||
'SecureChatCompletion',
|
||||
'APIError',
|
||||
'AuthenticationError',
|
||||
'InvalidRequestError',
|
||||
'APIConnectionError',
|
||||
'ForbiddenError',
|
||||
'RateLimitError',
|
||||
'ServerError',
|
||||
'ServiceUnavailableError',
|
||||
'get_memory_protection_info',
|
||||
'disable_secure_memory',
|
||||
'enable_secure_memory',
|
||||
'secure_bytearray',
|
||||
'secure_bytes', # Deprecated, use secure_bytearray instead
|
||||
'SecureBuffer'
|
||||
]
|
||||
|
||||
__version__ = "0.1.1"
|
||||
__author__ = "NOMYO AI"
|
||||
__license__ = "Apache-2.0"
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from typing import Dict, Any, List, Optional
|
||||
|
|
@ -70,6 +71,7 @@ class SecureChatCompletion:
|
|||
"""
|
||||
self.client = SecureCompletionClient(router_url=base_url, allow_http=allow_http)
|
||||
self._keys_initialized = False
|
||||
self._keys_lock = asyncio.Lock()
|
||||
self.api_key = api_key
|
||||
self._key_dir = key_dir
|
||||
self._secure_memory_enabled = secure_memory
|
||||
|
|
@ -89,22 +91,25 @@ class SecureChatCompletion:
|
|||
stacklevel=2
|
||||
)
|
||||
|
||||
def _ensure_keys(self):
|
||||
"""Ensure keys are loaded or generated."""
|
||||
async def _ensure_keys(self):
|
||||
"""Ensure keys are loaded or generated (concurrency-safe)."""
|
||||
if self._keys_initialized:
|
||||
return
|
||||
if self._key_dir is not None:
|
||||
private_key_path = os.path.join(self._key_dir, "private_key.pem")
|
||||
public_key_path = os.path.join(self._key_dir, "public_key.pem")
|
||||
try:
|
||||
self.client.load_keys(private_key_path, public_key_path)
|
||||
self._keys_initialized = True
|
||||
async with self._keys_lock:
|
||||
if self._keys_initialized: # double-check after acquiring lock
|
||||
return
|
||||
except Exception:
|
||||
self.client.generate_keys(save_to_file=True, key_dir=self._key_dir)
|
||||
else:
|
||||
self.client.generate_keys()
|
||||
self._keys_initialized = True
|
||||
if self._key_dir is not None:
|
||||
private_key_path = os.path.join(self._key_dir, "private_key.pem")
|
||||
public_key_path = os.path.join(self._key_dir, "public_key.pem")
|
||||
try:
|
||||
self.client.load_keys(private_key_path, public_key_path)
|
||||
self._keys_initialized = True
|
||||
return
|
||||
except Exception:
|
||||
self.client.generate_keys(save_to_file=True, key_dir=self._key_dir)
|
||||
else:
|
||||
self.client.generate_keys()
|
||||
self._keys_initialized = True
|
||||
|
||||
async def create(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
||||
"""
|
||||
|
|
@ -215,8 +220,8 @@ class SecureChatCompletion:
|
|||
else:
|
||||
instance = self
|
||||
|
||||
# Ensure keys are available (synchronous)
|
||||
instance._ensure_keys()
|
||||
# Ensure keys are available
|
||||
await instance._ensure_keys()
|
||||
|
||||
# Build payload — api_key is intentionally excluded (sent as Bearer header)
|
||||
payload = {
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue