fix:
- default api url - removed pickle for serialization - plaintext del'd asap - payload not tainted with kwargs anymore
This commit is contained in:
parent
f33c6e3434
commit
14f841a0bf
3 changed files with 95 additions and 94 deletions
|
|
@ -77,7 +77,7 @@ class SecureCompletionClient:
|
||||||
- Response parsing
|
- Response parsing
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, router_url: str = "https://api.nomyo.ai:12434", allow_http: bool = False):
|
def __init__(self, router_url: str = "https://api.nomyo.ai:12435", allow_http: bool = False):
|
||||||
"""
|
"""
|
||||||
Initialize the secure completion client.
|
Initialize the secure completion client.
|
||||||
|
|
||||||
|
|
@ -106,37 +106,37 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
def _protect_private_key(self) -> None:
|
def _protect_private_key(self) -> None:
|
||||||
"""
|
"""
|
||||||
Attempt to lock private key in memory (best effort).
|
Best-effort attempt to prevent key pages from being swapped to disk.
|
||||||
|
|
||||||
Note: Due to Python's memory management and the cryptography library's
|
Note: The cryptography library uses OpenSSL's own memory allocator for
|
||||||
internal handling of key material, this provides limited protection.
|
the actual key material, which cannot be directly locked from Python.
|
||||||
The main benefit is defense-in-depth and signaling security intent.
|
This method exports the key to a DER bytearray, locks that page, then
|
||||||
|
immediately zeros and discards the copy. It does not protect OpenSSL's
|
||||||
|
internal representation, but serves as a defense-in-depth measure.
|
||||||
|
|
||||||
For maximum security:
|
For maximum security:
|
||||||
- Use password-protected key files
|
- Use password-protected key files
|
||||||
- Rotate keys regularly
|
- Rotate keys regularly
|
||||||
- Store keys outside the project directory in production
|
- Store keys outside the project directory in production
|
||||||
"""
|
"""
|
||||||
if not _SECURE_MEMORY_AVAILABLE or not self.private_key:
|
if not _SECURE_MEMORY_AVAILABLE or not self.private_key:
|
||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Attempt to lock the key object in memory
|
key_der = bytearray(self.private_key.private_bytes(
|
||||||
# Note: This is best-effort as the cryptography library
|
encoding=serialization.Encoding.DER,
|
||||||
# maintains its own internal key material
|
format=serialization.PrivateFormat.PKCS8,
|
||||||
import pickle
|
encryption_algorithm=serialization.NoEncryption()
|
||||||
key_data = bytearray(pickle.dumps(self.private_key))
|
))
|
||||||
secure_memory = _get_secure_memory()
|
secure_memory = _get_secure_memory()
|
||||||
locked = secure_memory.lock_memory(key_data)
|
locked = secure_memory.lock_memory(key_der)
|
||||||
if locked:
|
logger.debug("Private key page lock: %s", "success" if locked else "unavailable")
|
||||||
logger.debug("Private key locked in memory (best effort)")
|
secure_memory.zero_memory(key_der)
|
||||||
else:
|
|
||||||
logger.debug("Could not lock private key in memory")
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.debug(f"Private key protection unavailable: {e}")
|
logger.debug("Private key protection unavailable: %s", e)
|
||||||
|
|
||||||
|
|
||||||
async def generate_keys(self, save_to_file: bool = False, key_dir: str = "client_keys", password: Optional[str] = None) -> None:
|
def generate_keys(self, save_to_file: bool = False, key_dir: str = "client_keys", password: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Generate RSA key pair for secure communication.
|
Generate RSA key pair for secure communication.
|
||||||
|
|
||||||
|
|
@ -211,7 +211,7 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
logger.debug("Keys saved to %s/", key_dir)
|
logger.debug("Keys saved to %s/", key_dir)
|
||||||
|
|
||||||
async def load_keys(self, private_key_path: str, public_key_path: Optional[str] = None, password: Optional[str] = None) -> None:
|
def load_keys(self, private_key_path: str, public_key_path: Optional[str] = None, password: Optional[str] = None) -> None:
|
||||||
"""
|
"""
|
||||||
Load RSA keys from files.
|
Load RSA keys from files.
|
||||||
|
|
||||||
|
|
@ -226,27 +226,16 @@ class SecureCompletionClient:
|
||||||
with open(private_key_path, "rb") as f:
|
with open(private_key_path, "rb") as f:
|
||||||
private_pem = f.read()
|
private_pem = f.read()
|
||||||
|
|
||||||
# Try different password options
|
password_bytes = password.encode('utf-8') if password else None
|
||||||
password_options = []
|
try:
|
||||||
if password:
|
self.private_key = serialization.load_pem_private_key(
|
||||||
password_options.append(password.encode('utf-8'))
|
private_pem,
|
||||||
password_options.append(None) # Try without password
|
password=password_bytes,
|
||||||
|
backend=default_backend()
|
||||||
last_error = None
|
)
|
||||||
for pwd in password_options:
|
logger.debug("Private key loaded %s", 'with password' if password_bytes else 'without password')
|
||||||
try:
|
except Exception as e:
|
||||||
self.private_key = serialization.load_pem_private_key(
|
raise ValueError(f"Failed to load private key: {e}")
|
||||||
private_pem,
|
|
||||||
password=pwd,
|
|
||||||
backend=default_backend()
|
|
||||||
)
|
|
||||||
logger.debug("Private key loaded %s", 'with password' if pwd else 'without password')
|
|
||||||
break
|
|
||||||
except Exception as e:
|
|
||||||
last_error = e
|
|
||||||
continue
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Failed to load private key. Tried all password options. Error: {last_error}")
|
|
||||||
|
|
||||||
# Get public key
|
# Get public key
|
||||||
public_key = self.private_key.public_key()
|
public_key = self.private_key.public_key()
|
||||||
|
|
@ -550,7 +539,9 @@ class SecureCompletionClient:
|
||||||
if missing_payload_fields:
|
if missing_payload_fields:
|
||||||
raise ValueError(f"Missing fields in encrypted_payload: {', '.join(missing_payload_fields)}")
|
raise ValueError(f"Missing fields in encrypted_payload: {', '.join(missing_payload_fields)}")
|
||||||
|
|
||||||
# Decrypt with proper error handling
|
# Decrypt with proper error handling — keep crypto errors opaque (timing attacks)
|
||||||
|
plaintext_json: Optional[str] = None
|
||||||
|
plaintext_size: int = 0
|
||||||
try:
|
try:
|
||||||
# Decrypt AES key with private key
|
# Decrypt AES key with private key
|
||||||
encrypted_aes_key = base64.b64decode(package["encrypted_aes_key"])
|
encrypted_aes_key = base64.b64decode(package["encrypted_aes_key"])
|
||||||
|
|
@ -565,9 +556,7 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
# Use secure memory to protect AES key and decrypted plaintext
|
# Use secure memory to protect AES key and decrypted plaintext
|
||||||
if _SECURE_MEMORY_AVAILABLE:
|
if _SECURE_MEMORY_AVAILABLE:
|
||||||
# Protect AES key in memory
|
|
||||||
with secure_bytearray(aes_key) as protected_aes_key:
|
with secure_bytearray(aes_key) as protected_aes_key:
|
||||||
# Decrypt payload with AES-GCM using Cipher API
|
|
||||||
ciphertext = base64.b64decode(package["encrypted_payload"]["ciphertext"])
|
ciphertext = base64.b64decode(package["encrypted_payload"]["ciphertext"])
|
||||||
nonce = base64.b64decode(package["encrypted_payload"]["nonce"])
|
nonce = base64.b64decode(package["encrypted_payload"]["nonce"])
|
||||||
tag = base64.b64decode(package["encrypted_payload"]["tag"])
|
tag = base64.b64decode(package["encrypted_payload"]["tag"])
|
||||||
|
|
@ -578,20 +567,15 @@ class SecureCompletionClient:
|
||||||
backend=default_backend()
|
backend=default_backend()
|
||||||
)
|
)
|
||||||
decryptor = cipher.decryptor()
|
decryptor = cipher.decryptor()
|
||||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize()
|
||||||
|
plaintext_size = len(plaintext_bytes)
|
||||||
# Protect decrypted plaintext in memory
|
|
||||||
with secure_bytearray(plaintext) as protected_plaintext:
|
|
||||||
# Parse decrypted response
|
|
||||||
response = json.loads(bytes(protected_plaintext.data).decode('utf-8'))
|
|
||||||
# Plaintext automatically zeroed here
|
|
||||||
|
|
||||||
|
with secure_bytearray(plaintext_bytes) as protected_plaintext:
|
||||||
|
plaintext_json = bytes(protected_plaintext.data).decode('utf-8')
|
||||||
|
del plaintext_bytes # drop immutable bytes ref; secure copy already zeroed
|
||||||
# AES key automatically zeroed here
|
# AES key automatically zeroed here
|
||||||
else:
|
else:
|
||||||
# Fallback if secure memory not available
|
|
||||||
logger.warning("Secure memory not available, using standard decryption")
|
logger.warning("Secure memory not available, using standard decryption")
|
||||||
|
|
||||||
# Decrypt payload with AES-GCM using Cipher API
|
|
||||||
ciphertext = base64.b64decode(package["encrypted_payload"]["ciphertext"])
|
ciphertext = base64.b64decode(package["encrypted_payload"]["ciphertext"])
|
||||||
nonce = base64.b64decode(package["encrypted_payload"]["nonce"])
|
nonce = base64.b64decode(package["encrypted_payload"]["nonce"])
|
||||||
tag = base64.b64decode(package["encrypted_payload"]["tag"])
|
tag = base64.b64decode(package["encrypted_payload"]["tag"])
|
||||||
|
|
@ -602,15 +586,21 @@ class SecureCompletionClient:
|
||||||
backend=default_backend()
|
backend=default_backend()
|
||||||
)
|
)
|
||||||
decryptor = cipher.decryptor()
|
decryptor = cipher.decryptor()
|
||||||
plaintext = decryptor.update(ciphertext) + decryptor.finalize()
|
plaintext_bytes = decryptor.update(ciphertext) + decryptor.finalize()
|
||||||
|
plaintext_size = len(plaintext_bytes)
|
||||||
# Parse decrypted response
|
plaintext_json = plaintext_bytes.decode('utf-8')
|
||||||
response = json.loads(plaintext.decode('utf-8'))
|
del plaintext_bytes
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
# Don't leak specific decryption errors (timing attacks)
|
# Don't leak specific decryption errors (timing attacks)
|
||||||
raise SecurityError("Decryption failed: integrity check or authentication failed")
|
raise SecurityError("Decryption failed: integrity check or authentication failed")
|
||||||
|
|
||||||
|
# Parse JSON outside the crypto exception handler so format errors aren't hidden
|
||||||
|
try:
|
||||||
|
response = json.loads(plaintext_json)
|
||||||
|
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||||
|
raise ValueError(f"Decrypted response is not valid JSON: {e}")
|
||||||
|
|
||||||
# Add metadata for debugging
|
# Add metadata for debugging
|
||||||
if "_metadata" not in response:
|
if "_metadata" not in response:
|
||||||
response["_metadata"] = {}
|
response["_metadata"] = {}
|
||||||
|
|
@ -622,7 +612,7 @@ class SecureCompletionClient:
|
||||||
})
|
})
|
||||||
|
|
||||||
logger.debug("Response decrypted successfully")
|
logger.debug("Response decrypted successfully")
|
||||||
logger.debug("Response size: %d bytes", len(plaintext))
|
logger.debug("Response size: %d bytes", plaintext_size)
|
||||||
|
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
|
@ -788,13 +778,15 @@ class SecureCompletionClient:
|
||||||
|
|
||||||
else:
|
else:
|
||||||
# Unexpected status code
|
# Unexpected status code
|
||||||
unexp_detail = response.json()
|
try:
|
||||||
if not isinstance(unexp_detail, dict):
|
unexp_detail = response.json()
|
||||||
unexp_detail = {"detail": "unknown"}
|
if not isinstance(unexp_detail, dict):
|
||||||
if isinstance(unexp_detail, dict) and "detail" not in unexp_detail.keys():
|
unexp_detail = {"detail": "unknown"}
|
||||||
unexp_detail["detail"] = "unknown"
|
detail_msg = unexp_detail.get("detail", "unknown")
|
||||||
|
except (json.JSONDecodeError, ValueError):
|
||||||
|
detail_msg = "unknown"
|
||||||
raise APIError(
|
raise APIError(
|
||||||
f"Unexpected status code: {response.status_code} {unexp_detail['detail']}",
|
f"Unexpected status code: {response.status_code} {detail_msg}",
|
||||||
status_code=response.status_code
|
status_code=response.status_code
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -47,6 +47,6 @@ __all__ = [
|
||||||
'SecureBuffer'
|
'SecureBuffer'
|
||||||
]
|
]
|
||||||
|
|
||||||
__version__ = "0.1.0"
|
__version__ = "0.1.1"
|
||||||
__author__ = "NOMYO AI"
|
__author__ = "NOMYO AI"
|
||||||
__license__ = "Apache-2.0"
|
__license__ = "Apache-2.0"
|
||||||
|
|
|
||||||
|
|
@ -1,10 +1,11 @@
|
||||||
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Dict, Any, List, Optional
|
from typing import Dict, Any, List, Optional
|
||||||
from .SecureCompletionClient import SecureCompletionClient, APIError, AuthenticationError, InvalidRequestError, APIConnectionError, RateLimitError, ServerError
|
from .SecureCompletionClient import SecureCompletionClient
|
||||||
|
|
||||||
# Import secure memory module for configuration
|
# Import secure memory module for configuration
|
||||||
try:
|
try:
|
||||||
from .SecureMemory import get_memory_protection_info, disable_secure_memory, enable_secure_memory
|
from .SecureMemory import disable_secure_memory, enable_secure_memory
|
||||||
_SECURE_MEMORY_AVAILABLE = True
|
_SECURE_MEMORY_AVAILABLE = True
|
||||||
except ImportError:
|
except ImportError:
|
||||||
_SECURE_MEMORY_AVAILABLE = False
|
_SECURE_MEMORY_AVAILABLE = False
|
||||||
|
|
@ -27,7 +28,7 @@ class SecureChatCompletion:
|
||||||
Usage:
|
Usage:
|
||||||
```python
|
```python
|
||||||
# Create a client instance
|
# Create a client instance
|
||||||
client = SecureChatCompletion(base_url="http://api.nomyo.ai:12434")
|
client = SecureChatCompletion(base_url="https://api.nomyo.ai:12435")
|
||||||
|
|
||||||
# Simple chat completion
|
# Simple chat completion
|
||||||
response = await client.create(
|
response = await client.create(
|
||||||
|
|
@ -50,7 +51,7 @@ class SecureChatCompletion:
|
||||||
```
|
```
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, base_url: str = "https://api.nomyo.ai", allow_http: bool = False, api_key: Optional[str] = None, secure_memory: bool = True):
|
def __init__(self, base_url: str = "https://api.nomyo.ai:12435", allow_http: bool = False, api_key: Optional[str] = None, secure_memory: bool = True, key_dir: Optional[str] = None):
|
||||||
"""
|
"""
|
||||||
Initialize the secure chat completion client.
|
Initialize the secure chat completion client.
|
||||||
|
|
||||||
|
|
@ -64,11 +65,14 @@ class SecureChatCompletion:
|
||||||
When enabled, prevents plaintext payloads from being swapped to disk
|
When enabled, prevents plaintext payloads from being swapped to disk
|
||||||
and guarantees memory is zeroed after encryption.
|
and guarantees memory is zeroed after encryption.
|
||||||
Set to False for testing or when security is not required.
|
Set to False for testing or when security is not required.
|
||||||
|
key_dir: Directory to load/save RSA keys. If None, ephemeral keys are
|
||||||
|
generated in memory for this session only.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
self.client = SecureCompletionClient(router_url=base_url, allow_http=allow_http)
|
self.client = SecureCompletionClient(router_url=base_url, allow_http=allow_http)
|
||||||
self._keys_initialized = False
|
self._keys_initialized = False
|
||||||
self.api_key = api_key
|
self.api_key = api_key
|
||||||
|
self._key_dir = key_dir
|
||||||
|
self._secure_memory_enabled = secure_memory
|
||||||
|
|
||||||
# Configure secure memory if available
|
# Configure secure memory if available
|
||||||
if _SECURE_MEMORY_AVAILABLE:
|
if _SECURE_MEMORY_AVAILABLE:
|
||||||
|
|
@ -85,17 +89,22 @@ class SecureChatCompletion:
|
||||||
stacklevel=2
|
stacklevel=2
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _ensure_keys(self):
|
def _ensure_keys(self):
|
||||||
"""Ensure keys are loaded or generated."""
|
"""Ensure keys are loaded or generated."""
|
||||||
if not self._keys_initialized:
|
if self._keys_initialized:
|
||||||
# Try to load existing keys
|
return
|
||||||
|
if self._key_dir is not None:
|
||||||
|
private_key_path = os.path.join(self._key_dir, "private_key.pem")
|
||||||
|
public_key_path = os.path.join(self._key_dir, "public_key.pem")
|
||||||
try:
|
try:
|
||||||
await self.client.load_keys("client_keys/private_key.pem", "client_keys/public_key.pem")
|
self.client.load_keys(private_key_path, public_key_path)
|
||||||
self._keys_initialized = True
|
self._keys_initialized = True
|
||||||
|
return
|
||||||
except Exception:
|
except Exception:
|
||||||
# Generate new keys if loading fails
|
self.client.generate_keys(save_to_file=True, key_dir=self._key_dir)
|
||||||
await self.client.generate_keys()
|
else:
|
||||||
self._keys_initialized = True
|
self.client.generate_keys()
|
||||||
|
self._keys_initialized = True
|
||||||
|
|
||||||
async def create(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
async def create(self, model: str, messages: List[Dict[str, Any]], **kwargs) -> Dict[str, Any]:
|
||||||
"""
|
"""
|
||||||
|
|
@ -157,36 +166,36 @@ class SecureChatCompletion:
|
||||||
ConnectionError: If the connection to the router fails.
|
ConnectionError: If the connection to the router fails.
|
||||||
Exception: For other errors during the request.
|
Exception: For other errors during the request.
|
||||||
"""
|
"""
|
||||||
# Extract base_url if provided (OpenAI compatibility)
|
# Extract non-payload kwargs before building the payload dict
|
||||||
base_url = kwargs.pop("base_url", None)
|
base_url = kwargs.pop("base_url", None)
|
||||||
|
|
||||||
# Extract security_tier if provided
|
|
||||||
security_tier = kwargs.pop("security_tier", None)
|
security_tier = kwargs.pop("security_tier", None)
|
||||||
|
api_key_override = kwargs.pop("api_key", None)
|
||||||
|
|
||||||
# Use the instance's client unless base_url is explicitly overridden
|
# Use the instance's client unless base_url is explicitly overridden
|
||||||
if base_url is not None:
|
if base_url is not None:
|
||||||
# Create a temporary client with overridden base_url
|
temp_client = type(self)(
|
||||||
temp_client = type(self)(base_url=base_url)
|
base_url=base_url,
|
||||||
|
allow_http=self.client.allow_http,
|
||||||
|
api_key=self.api_key,
|
||||||
|
secure_memory=self._secure_memory_enabled,
|
||||||
|
key_dir=self._key_dir,
|
||||||
|
)
|
||||||
instance = temp_client
|
instance = temp_client
|
||||||
else:
|
else:
|
||||||
# Use the instance's existing client
|
|
||||||
instance = self
|
instance = self
|
||||||
|
|
||||||
# Ensure keys are available
|
# Ensure keys are available (synchronous)
|
||||||
await instance._ensure_keys()
|
instance._ensure_keys()
|
||||||
|
|
||||||
# Prepare payload in OpenAI format
|
# Build payload — api_key is intentionally excluded (sent as Bearer header)
|
||||||
payload = {
|
payload = {
|
||||||
"model": model,
|
"model": model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
**kwargs
|
**kwargs
|
||||||
}
|
}
|
||||||
|
|
||||||
# Generate a unique payload ID
|
payload_id = str(uuid.uuid4())
|
||||||
payload_id = f"{uuid.uuid4()}"
|
request_api_key = api_key_override if api_key_override is not None else instance.api_key
|
||||||
|
|
||||||
# Use instance's api_key if not overridden in kwargs
|
|
||||||
request_api_key = kwargs.pop("api_key", instance.api_key)
|
|
||||||
|
|
||||||
# Send secure request with security tier
|
# Send secure request with security tier
|
||||||
response = await instance.client.send_secure_request(payload, payload_id, request_api_key, security_tier)
|
response = await instance.client.send_secure_request(payload, payload_id, request_api_key, security_tier)
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue