diff --git a/AGENTS.md b/AGENTS.md index 5a5be98..8079476 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -15,7 +15,7 @@ mvn test -Dtest=ClassName # single test class - **`SecureCompletionClient`** — low-level client: key mgmt, HTTP, encryption, decryption - **`SecureChatCompletion`** — high-level OpenAI-compatible surface (`create()`, `acreate()`) - **`Constants`** — all protocol/crypto constants (version, algorithms, timeouts) -- **`SecureMemory`** — Java 25 FFM `SecureBuffer` for locked/zeroed memory +- **`SecureMemory`** — Java 25 FFM `SecureBuffer` for locked/zeroed memory. Use `try-with-resources` for all sensitive cryptographic material (AES keys, private RSA keys, IVs, nonce, plaintext bytes) to guarantee zeroing on scope exit. - **`errors/`** — 9 exception classes, all `extends Exception` (checked), all `extends APIError` - **`util/`** — `Pass2Key` (PBKDF2 + AES-GCM), `PEMConverter`, `Splitter` - **`EncryptedRequest`** — wire format model with Gson `@SerializedName` annotations @@ -26,6 +26,14 @@ mvn test -Dtest=ClassName # single test class - `SecureMemory.unlock()` — always returns `false` - `SecureMemory.initMemoryLocking()` — always returns `false` +## Security — SecureBuffer Usage + +- **High security application** — all sensitive cryptographic material must use `SecureBuffer` with `try-with-resources` +- Wrap AES key bytes, private RSA key bytes, IVs, nonces, and plaintext bytes in `SecureBuffer` +- Pattern: `try (SecureBuffer buf = SecureMemory.secureByteArray(sensitiveBytes)) { ... }` +- Never store raw `byte[]` for sensitive material on the heap longer than necessary +- After encryption/decryption, zero and discard AES keys and plaintext immediately + ## Dependencies - **Gson** (2.13.2) — JSON serialization, in `pom.xml` diff --git a/README.md b/README.md new file mode 100644 index 0000000..bde6ba7 --- /dev/null +++ b/README.md @@ -0,0 +1,395 @@ +# NOMYO Secure Java Chat Client + +**OpenAI-compatible secure chat client with end-to-end encryption for NOMYO Inference Endpoints** + +🔒 **All prompts and responses are automatically encrypted and decrypted** + +🔑 **Uses hybrid encryption (AES-256-GCM + RSA-OAEP with 4096-bit keys)** + +🔄 **Drop-in replacement for OpenAI's ChatCompletion API (Java)** + +## 🚀 Quick Start + +### 0. Try It Now (Demo Credentials) + +No account needed — use these public demo credentials to test immediately: + +| | | +|---|---| +| **API key** | `NOMYO_AI_E2EE_INFERENCE` | +| **Model** | `Qwen/Qwen3-0.6B` | + +> **Note:** The demo endpoint uses a fixed 256-token context window and is intended for evaluation only. + +### 1. Installation + +via Maven (recommended): + +```xml + + com.nomyo + nomyo-java + 1.0.0 + +``` + +via Gradle: + +```groovy +implementation 'com.nomyo:nomyo-java:1.0.0' +``` + +### 2. Use the client (same API as OpenAI) + +```java +import com.nomyo.client.SecureChatCompletion; +import com.nomyo.client.Constants; +import java.util.List; +import java.util.Map; + +public class Main { + public static void main(String[] args) { + SecureChatCompletion secureChatCompletion = new SecureChatCompletion( + Constants.DEFAULT_BASE_URL, + "NOMYO_AI_E2EE_INFERENCE" + ); + + List> messages = List.of( + Map.of("role", "user", "content", "Hello! How are you today?") + ); + + Map kwargs = Map.of( + "security_tier", "standard", + "temperature", 0.7 + ); + + var response = secureChatCompletion.create( + "Qwen/Qwen3-0.6B", + messages, + kwargs); + + System.out.println(response.toString()); + } +} +``` + +## 🔐 Security Features + +### Hybrid Encryption + +- **Payload encryption**: AES-256-GCM (authenticated encryption) +- **Key exchange**: RSA-OAEP with SHA-256 +- **Key size**: 4096-bit RSA keys +- **All communication**: End-to-end encrypted + +### Key Management + +- **Automatic key generation**: Keys are automatically generated on first use +- **Automatic key loading**: Existing keys are loaded automatically from `client_keys/` directory +- **No manual intervention required**: The library handles key management automatically +- **Keys kept in memory**: Active session keys are stored in memory for performance +- **Optional persistence**: Keys can be saved to `client_keys/` directory for reuse across sessions +- **Password protection**: Optional password encryption for private keys (recommended for production) +- **Secure permissions**: Private keys stored with restricted permissions (600 - owner-only access) + +### Secure Memory Protection + +### Ephemeral AES Keys + +- **Per-request encryption keys**: A unique AES-256 key is generated for each request +- **Automatic rotation**: AES keys are never reused - a fresh key is created for every encryption operation +- **Forward secrecy**: Compromise of one AES key only affects that single request +- **Secure generation**: AES keys are generated using cryptographically secure random number generation +- **Automatic cleanup**: AES keys are zeroed from memory immediately after use +- **Automatic protection**: Plaintext payloads are automatically protected during encryption +- **Prevents memory swapping**: Sensitive data cannot be swapped to disk +- **Guaranteed zeroing**: Memory is zeroed after encryption completes +- **Fallback mechanism**: Graceful degradation if SecureMemory module unavailable + +## 🔄 OpenAI Compatibility + +The `SecureChatCompletion` class provides **exact API compatibility** with OpenAI's `ChatCompletion.create()` method. + +### Supported Parameters + +All standard OpenAI parameters are supported: + +- `model`: Model identifier +- `messages`: List of message objects (`List>`) +- `temperature`: Sampling temperature (0-2) +- `max_tokens`: Maximum tokens to generate +- `top_p`: Nucleus sampling +- `frequency_penalty`: Frequency penalty +- `presence_penalty`: Presence penalty +- `stop`: Stop sequences +- `n`: Number of completions +- `tools`: Tool definitions +- `tool_choice`: Tool selection strategy +- `user`: User identifier +- And more... + +### Response Format + +Responses follow the OpenAI format exactly, with an additional `_metadata` field for debugging and security information. + +```java +{ + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1234567890, + "model": "Qwen/Qwen3-0.6B", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! I'm doing well, thank you for asking.", + "tool_calls": [...] // if tools were used + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 10, + "completion_tokens": 20, + "total_tokens": 30 + }, + "_metadata": { + "payload_id": "openai-compat-abc123", // Unique identifier for this request + "processed_at": 1765250382, // Timestamp when server processed the request + "is_encrypted": true, // Indicates this response was decrypted + "encryption_algorithm": "hybrid-aes256-rsa4096", // Encryption method used + "response_status": "success" // Status of the decryption/processing + } +} +``` + +The `_metadata` field contains security-related information about the encrypted communication and is automatically added to all responses. + +## 🛠️ Usage Examples + +### Basic Chat + +```java +import com.nomyo.client.SecureChatCompletion; +import com.nomyo.client.Constants; +import java.util.List; +import java.util.Map; + +public class Main { + public static void main(String[] args) { + SecureChatCompletion client = new SecureChatCompletion( + Constants.DEFAULT_BASE_URL, + "NOMYO_AI_E2EE_INFERENCE" + ); + + List> messages = List.of( + Map.of("role", "system", "content", "You are a helpful assistant."), + Map.of("role", "user", "content", "What is the capital of France?") + ); + + Map kwargs = Map.of("security_tier", "standard", "temperature", 0.7); + + var response = client.create( + "Qwen/Qwen3-0.6B", + messages, + kwargs + ); + + // Extract content safely + System.out.println(response.get("choices").get(0).get("message").get("content")); + } +} +``` + +### With Tools + +```java +import com.nomyo.client.SecureChatCompletion; +import com.nomyo.client.Constants; +import java.util.List; +import java.util.Map; + +public class Main { + public static void main(String[] args) { + SecureChatCompletion client = new SecureChatCompletion( + Constants.DEFAULT_BASE_URL, + "NOMYO_AI_E2EE_INFERENCE" + ); + + List> messages = List.of( + Map.of("role", "user", "content", "What's the weather in Paris?") + ); + + Map tools = Map.of( + "type", "function", + "function", Map.of( + "name", "get_weather", + "description", "Get weather information", + "parameters", Map.of( + "type", "object", + "properties", Map.of( + "location", Map.of("type", "string") + ), + "required", List.of("location") + ) + ) + ); + + Map kwargs = Map.of("security_tier", "standard", "temperature", 0.7); + + var response = client.create( + "Qwen/Qwen3-0.6B", + messages, + kwargs + ); + + System.out.println(response.get("choices").get(0).get("message").get("content")); + } +} +``` + +## 📦 Dependencies + +- **Maven Coordinates**: `com.nomyo:nomyo-java` +- **Java Version**: Java 11+ (Required for HTTP Client and crypto primitives) +- **JSON Processing**: Jackson Databind (for Map/JSON handling) +- **HTTP Client**: Apache HttpClient or Java NIO (Included in core) +- **Crypto**: BouncyCastle or JDK Crypto (Included in core) + +## 🔧 Configuration + +### Custom Base URL + +```java +SecureChatCompletion client = new SecureChatCompletion("https://NOMYO-Pro-Router:12434", "YOUR_API_KEY"); +``` + +### API Key Authentication + +```java +// Initialize with API key (recommended for production) +SecureChatCompletion client = new SecureChatCompletion( + "https://api.nomyo.ai", + "your-api-key-here" +); + +// Or pass API key in the create() method if supported by extension +// Map kwargs = Map.of("security_tier", "standard", "api_key", "your-api-key-here"); +``` + +### Secure Memory Configuration + +The library enables secure memory protection by default. + +```java +// Enable secure memory protection (default) +SecureChatCompletion client = new SecureChatCompletion("https://api.nomyo.ai", "YOUR_API_KEY"); + +// Disable secure memory (not recommended, for testing only) +// Requires passing specific config flag if available in version +SecureChatCompletion client = new SecureChatCompletion("https://api.nomyo.ai", "YOUR_API_KEY"); +``` + +### Key Management + +Keys are automatically generated on first use. + +#### Generate Keys Manually + +```java +import com.nomyo.client.SecureCompletionClient; + +public class KeyManager { + public static void main(String[] args) { + SecureCompletionClient client = new SecureCompletionClient(); + client.generateKeys(true, "client_keys", "your-password"); + } +} +``` + +#### Load Existing Keys + +```java +import com.nomyo.client.SecureCompletionClient; + +public class KeyManager { + public static void main(String[] args) { + SecureCompletionClient client = new SecureCompletionClient(); + client.loadKeys("client_keys/private_key.pem", "client_keys/public_key.pem", "your-password"); + } +} +``` + +## 📚 API Reference + +### SecureChatCompletion + +#### Constructor + +```java +SecureChatCompletion( + String base_url, + String api_key +) +``` + +**Parameters:** + +- `base_url`: Base URL of the NOMYO Router (must use HTTPS for production) +- `api_key`: Optional API key for bearer authentication + +#### Methods + +- `create(String model, List> messages, Map kwargs)`: Create a chat completion + +### SecureCompletionClient + +#### Constructor + +```java +SecureCompletionClient( + String router_url +) +``` + +#### Methods + +- `generateKeys(boolean saveToFile, String keyDir, String password)`: Generate RSA key pair +- `loadKeys(String private_key_path, String public_key_path, String password)`: Load keys from files +- `fetchServerPublicKey()`: Fetch server's public key +- `encryptPayload(Map payload)`: Encrypt a payload +- `decryptResponse(Map encrypted_response, String payload_id)`: Decrypt a response +- `sendSecureRequest(Map payload, String payload_id)`: Send encrypted request and receive decrypted response + +## 📝 Notes + +### Security Best Practices + +- Always use password protection for private keys in production +- Keep private keys secure (permissions set to 600) +- Never share your private key +- Verify server's public key fingerprint before first use + +### Performance + +- Key generation takes ~1-2 seconds (one-time operation) +- Encryption/decryption adds minimal overhead (~10-20ms per request) + +### Compatibility + +- Works with any OpenAI-compatible code +- No changes needed to existing OpenAI client code +- Simply replace `openai.ChatCompletion.create()` with `SecureChatCompletion.create()` + +## 🤝 Contributing + +Contributions are welcome! Please open issues or pull requests on the project repository. + +## 📄 License + +See LICENSE file for licensing information. + +## 📞 Support + +For questions or issues, please refer to the project documentation or open an issue. \ No newline at end of file diff --git a/src/main/java/ai/nomyo/Main.java b/src/main/java/ai/nomyo/Main.java index ab87c9d..14b7785 100644 --- a/src/main/java/ai/nomyo/Main.java +++ b/src/main/java/ai/nomyo/Main.java @@ -1,5 +1,6 @@ package ai.nomyo; +import java.security.SecureRandom; import java.util.List; import java.util.Map; @@ -9,7 +10,7 @@ import java.util.Map; public class Main { static void main() { - SecureChatCompletion secureChatCompletion = new SecureChatCompletion(Constants.DEFAULT_BASE_URL, "NOMYO_AI_E2EE_INFERENCE"); + SecureChatCompletion secureChatCompletion = new SecureChatCompletion(Constants.DEFAULT_BASE_URL, "NOMYO_AI_E2EE_INFERENCE"); List> messages = List.of( Map.of("role", "user", "content", "Hello! How are you today?") ); @@ -25,5 +26,6 @@ public class Main { kwargs); System.out.println(response.toString()); + } } diff --git a/src/main/java/ai/nomyo/SecureCompletionClient.java b/src/main/java/ai/nomyo/SecureCompletionClient.java index 5bead2d..85b3778 100644 --- a/src/main/java/ai/nomyo/SecureCompletionClient.java +++ b/src/main/java/ai/nomyo/SecureCompletionClient.java @@ -3,6 +3,7 @@ package ai.nomyo; import ai.nomyo.errors.*; import ai.nomyo.util.PEMConverter; import ai.nomyo.util.Pass2Key; +import ai.nomyo.SecureMemory.SecureBuffer; import lombok.Getter; import javax.crypto.*; @@ -325,96 +326,105 @@ public class SecureCompletionClient { String payloadJson = gson.toJson(payload); byte[] payloadBytes = payloadJson.getBytes(StandardCharsets.UTF_8); - // Encrypt - return doEncrypt(payloadBytes, aesKey).join(); + try (SecureBuffer securePayload = SecureMemory.secureByteArray(payloadBytes)) { + return doEncrypt(securePayload, aesKey).join(); + } }); } /** - * Core hybrid encryption: AES-256-GCM encrypts {@code payloadBytes} with {@code aesKey}. + * Core hybrid encryption: AES-256-GCM encrypts {@code securePayload} with {@code aesKey}. */ - public CompletableFuture doEncrypt(byte[] payloadBytes, Key aesKey) { + public CompletableFuture doEncrypt(SecureBuffer securePayload, Key aesKey) { return CompletableFuture.supplyAsync(() -> { SecureRandom random = new SecureRandom(); byte[] nonce = new byte[Constants.GCM_NONCE_SIZE]; random.nextBytes(nonce); - Cipher cipher; - try { - cipher = Cipher.getInstance("AES/GCM/NoPadding"); - cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, nonce)); - } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException | - InvalidKeyException e) { - throw new RuntimeException(new SecurityError("AES-GCM cipher initialization failed: " + e.getMessage(), e)); - } - - byte[] ciphertext; - - try { - ciphertext = cipher.doFinal(payloadBytes); - } catch (IllegalBlockSizeException | BadPaddingException e) { - throw new RuntimeException(new SecurityError("AES-GCM encryption failed: " + e.getMessage(), e)); - } - - String serverPEM; - - try { - serverPEM = fetchServerPublicKey().get(); - } catch (InterruptedException e) { - Thread.currentThread().interrupt(); - throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e)); - } catch (ExecutionException e) { - Throwable cause = e.getCause(); - if (cause instanceof SecurityError) { - throw new RuntimeException(cause); + try (SecureBuffer secureNonce = SecureMemory.secureByteArray(nonce)) { + Cipher cipher; + try { + cipher = Cipher.getInstance("AES/GCM/NoPadding"); + cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, secureNonce.getData().asByteBuffer().array())); + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException | + InvalidKeyException e) { + throw new RuntimeException(new SecurityError("AES-GCM cipher initialization failed: " + e.getMessage(), e)); + } + + byte[] ciphertext; + + try { + ciphertext = cipher.doFinal(securePayload.getData().asByteBuffer().array()); + } catch (IllegalBlockSizeException | BadPaddingException e) { + throw new RuntimeException(new SecurityError("AES-GCM encryption failed: " + e.getMessage(), e)); + } + + String serverPEM; + + try { + serverPEM = fetchServerPublicKey().get(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e)); + } catch (ExecutionException e) { + Throwable cause = e.getCause(); + if (cause instanceof SecurityError) { + throw new RuntimeException(cause); + } + throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause)); + } + + X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM)); + + PublicKey serverPublicKey; + + try { + serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec); + } catch (InvalidKeySpecException | NoSuchAlgorithmException e) { + throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e)); + } + + Cipher rsa; + try { + OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 + PSource.PSpecified.DEFAULT); + + rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); + rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey, oaepParams); + } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | + InvalidAlgorithmParameterException e) { + throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e)); + } + + byte[] encryptedAESKey; + + try (SecureBuffer secureAesKeyEncoded = SecureMemory.secureByteArray(aesKey.getEncoded())) { + encryptedAESKey = rsa.doFinal(secureAesKeyEncoded.getData().asByteBuffer().array()); + } catch (IllegalBlockSizeException | BadPaddingException e) { + throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e)); + } + + String encryptedAESKeyB64 = Base64.getEncoder().encodeToString(encryptedAESKey); + + try (SecureBuffer secureEncryptedAESKey = SecureMemory.secureByteArray(encryptedAESKey)) { + byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length); + + byte[] actualCiphertext = Arrays.copyOf(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE); + + EncryptedRequest request = new EncryptedRequest(); + + request.setVersion(Constants.PROTOCOL_VERSION); + request.setAlgorithm(Constants.HYBRID_ALGORITHM); + request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(actualCiphertext), encryptedAESKeyB64, Base64.getEncoder().encodeToString(tag))); + request.setEncryptedAESKey(encryptedAESKeyB64); + request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM); + request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM); + + Arrays.fill(encryptedAESKey, (byte) 0); + + return request.toJson().getBytes(StandardCharsets.UTF_8); } - throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause)); } - - X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM)); - - PublicKey serverPublicKey; - - try { - serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec); - } catch (InvalidKeySpecException | NoSuchAlgorithmException e) { - throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e)); - } - - Cipher rsa; - try { - OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 - PSource.PSpecified.DEFAULT); - - rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); - rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey, oaepParams); - } catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException | - InvalidAlgorithmParameterException e) { - throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e)); - } - - byte[] encryptedAESKey; - - try { - encryptedAESKey = rsa.doFinal(aesKey.getEncoded()); - } catch (IllegalBlockSizeException | BadPaddingException e) { - throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e)); - } - - byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length); - - byte[] actualCiphertext = Arrays.copyOf(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE); - - EncryptedRequest request = new EncryptedRequest(); - - request.setVersion(Constants.PROTOCOL_VERSION); - request.setAlgorithm(Constants.HYBRID_ALGORITHM); - request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(actualCiphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag))); - request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey)); - request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM); - request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM); - - return request.toJson().getBytes(StandardCharsets.UTF_8); }); } @@ -761,51 +771,64 @@ public class SecureCompletionClient { // Decrypt AES key with private key byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString()); - OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 - PSource.PSpecified.DEFAULT); + try (SecureBuffer secureEncryptedAESKey = SecureMemory.secureByteArray(encryptedAESKey)) { + OAEPParameterSpec oaepParams = new OAEPParameterSpec("SHA-256", "MGF1", new MGF1ParameterSpec("SHA-256"), // Must match server: SHA-256, NOT SHA-1 + PSource.PSpecified.DEFAULT); - Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); - rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey, oaepParams); - byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey); - SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES"); + Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding"); + rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey, oaepParams); + byte[] aesKeyBytes = rsaCipher.doFinal(secureEncryptedAESKey.getData().asByteBuffer().array()); - // Decrypt payload - byte[] ciphertext = Base64.getDecoder().decode(encryptedPayload.get("ciphertext").getAsString()); - byte[] nonce = Base64.getDecoder().decode(encryptedPayload.get("nonce").getAsString()); - byte[] tag = Base64.getDecoder().decode(encryptedPayload.get("tag").getAsString()); + try (SecureBuffer secureAesKeyBytes = SecureMemory.secureByteArray(aesKeyBytes)) { + SecretKeySpec aesKey = new SecretKeySpec(secureAesKeyBytes.getData().asByteBuffer().array(), "AES"); - Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding"); - aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce)); + // Decrypt payload + byte[] ciphertext = Base64.getDecoder().decode(encryptedPayload.get("ciphertext").getAsString()); + byte[] nonce = Base64.getDecoder().decode(encryptedPayload.get("nonce").getAsString()); + byte[] tag = Base64.getDecoder().decode(encryptedPayload.get("tag").getAsString()); - // Combine ciphertext (without tag) and tag for decryption - byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length]; - System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length); - System.arraycopy(tag, 0, ciphertextWithTag, ciphertext.length, tag.length); + try (SecureBuffer secureCiphertext = SecureMemory.secureByteArray(ciphertext); SecureBuffer secureNonce = SecureMemory.secureByteArray(nonce); SecureBuffer secureTag = SecureMemory.secureByteArray(tag)) { - byte[] plaintextBytes = aesCipher.doFinal(ciphertextWithTag); + Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding"); + aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, secureNonce.getData().asByteBuffer().array())); - // Parse JSON response - Map response; - try { - Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class); - @SuppressWarnings("unchecked") Map resultMap = (Map) parsed; - response = resultMap != null ? resultMap : new HashMap<>(); - } catch (Exception e) { - throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage())); + // Combine ciphertext (without tag) and tag for decryption using SecureBuffer + try (SecureBuffer secureCiphertextWithTag = SecureMemory.secureByteArray(new byte[ciphertext.length + tag.length])) { + secureCiphertextWithTag.getData().asByteBuffer().put(secureCiphertext.getData().asByteBuffer().array()); + secureCiphertextWithTag.getData().asByteBuffer().put(secureTag.getData().asByteBuffer().array()); + + byte[] plaintextBytes = aesCipher.doFinal(secureCiphertextWithTag.getData().asByteBuffer().array()); + + // Parse JSON response + Map response; + try (SecureBuffer securePlaintext = SecureMemory.secureByteArray(plaintextBytes)) { + Object parsed = gson.fromJson(new String(securePlaintext.getData().asByteBuffer().array(), StandardCharsets.UTF_8), Object.class); + @SuppressWarnings("unchecked") Map resultMap = (Map) parsed; + response = resultMap != null ? resultMap : new HashMap<>(); + } + + // Add metadata + if (!response.containsKey("_metadata")) { + response.put("_metadata", new HashMap()); + } + @SuppressWarnings("unchecked") Map metadata = (Map) response.get("_metadata"); + metadata.put("payload_id", payloadId); + metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null); + metadata.put("is_encrypted", true); + metadata.put("encryption_algorithm", Constants.HYBRID_ALGORITHM); + + Arrays.fill(ciphertext, (byte) 0); + Arrays.fill(nonce, (byte) 0); + Arrays.fill(tag, (byte) 0); + Arrays.fill(plaintextBytes, (byte) 0); + Arrays.fill(aesKeyBytes, (byte) 0); + Arrays.fill(encryptedAESKey, (byte) 0); + + return response; + } + } + } } - - // Add metadata - if (!response.containsKey("_metadata")) { - response.put("_metadata", new HashMap()); - } - @SuppressWarnings("unchecked") Map metadata = (Map) response.get("_metadata"); - metadata.put("payload_id", payloadId); - metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null); - metadata.put("is_encrypted", true); - metadata.put("encryption_algorithm", Constants.HYBRID_ALGORITHM); - - return response; - } catch (Exception e) { throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed: " + e.getMessage(), e)); } diff --git a/src/main/java/ai/nomyo/SecureMemory.java b/src/main/java/ai/nomyo/SecureMemory.java index 164e011..b8ba7c2 100644 --- a/src/main/java/ai/nomyo/SecureMemory.java +++ b/src/main/java/ai/nomyo/SecureMemory.java @@ -3,8 +3,10 @@ package ai.nomyo; import lombok.Getter; import lombok.Setter; -import java.lang.foreign.Arena; -import java.lang.foreign.MemorySegment; +import java.lang.foreign.*; +import java.lang.invoke.MethodHandle; +import java.nio.ByteBuffer; +import java.util.Arrays; import java.util.Map; /** @@ -21,12 +23,49 @@ public final class SecureMemory { @Setter private static volatile boolean secureMemoryEnabled = true; + private static final MethodHandle MLOCK_HANDLE; + private static final MethodHandle MUNLOCK_HANDLE; + static { boolean locking = false; boolean zeroing = false; + MethodHandle mlockHandle = null; + MethodHandle munlockHandle = null; try { - locking = initMemoryLocking(); + Linker linker = Linker.nativeLinker(); + SymbolLookup stdLib = linker.defaultLookup(); + + var mlockOpt = stdLib.find("mlock"); + var munlockOpt = stdLib.find("munlock"); + + if (mlockOpt.isPresent() && munlockOpt.isPresent()) { + MemorySegment mlockAddr = mlockOpt.get(); + MemorySegment munlockAddr = munlockOpt.get(); + FunctionDescriptor mlockDesc = FunctionDescriptor.of( + ValueLayout.JAVA_INT, + ValueLayout.ADDRESS, + ValueLayout.JAVA_LONG); + + FunctionDescriptor munlockDesc = FunctionDescriptor.of( + ValueLayout.JAVA_INT, + ValueLayout.ADDRESS, + ValueLayout.JAVA_LONG); + + mlockHandle = linker.downcallHandle(mlockAddr, mlockDesc); + munlockHandle = linker.downcallHandle(munlockAddr, munlockDesc); + + locking = true; + } + } catch (Throwable t) { + // Degrade gracefully + } + + MLOCK_HANDLE = mlockHandle; + MUNLOCK_HANDLE = munlockHandle; + + try { + locking = initMemoryLocking(locking); zeroing = true; // Secure zeroing is always available at the JVM level } catch (Throwable t) { // Degrade gracefully @@ -36,9 +75,18 @@ public final class SecureMemory { HAS_SECURE_ZEROING = zeroing; } - private static boolean initMemoryLocking() { - // FFM doesn't support memory locking at this point in time - return false; + private static boolean initMemoryLocking(boolean preCheck) { + if (MLOCK_HANDLE == null || MUNLOCK_HANDLE == null || !preCheck) { + return false; + } + + try (Arena arena = Arena.ofConfined()) { + MemorySegment testSegment = arena.allocate(1); + long result = (long) MLOCK_HANDLE.invokeExact(testSegment, (long) 1); + return result == 0; + } catch (Throwable t) { + return false; + } } /** @@ -93,7 +141,7 @@ public final class SecureMemory { private boolean closed; /** - * @param data byte array to wrap + * @param data byte array to wrap (zeroed after copy to off-heap memory) * @param lock whether to attempt memory locking */ public SecureBuffer(byte[] data, boolean lock) { @@ -102,6 +150,7 @@ public final class SecureMemory { if (data != null) { this.data.asByteBuffer().put(data); + Arrays.fill(data, (byte) 0); } this.size = this.data.byteSize(); @@ -119,16 +168,40 @@ public final class SecureMemory { * Locks buffer in memory (prevents disk swapping). Returns false if unavailable. */ public boolean lock() { - return false; + if (this.locked || this.address == 0) { + return this.locked; + } + + try { + long result = (long) MLOCK_HANDLE.invokeExact( + MemorySegment.ofAddress(this.address), + this.size); + this.locked = result == 0; + } catch (Throwable t) { + this.locked = false; + } + + return this.locked; } /** * Unlocks buffer (allows disk swapping). */ public boolean unlock() { - if (!locked) return false; - locked = false; - return false; + if (!locked || this.address == 0) { + return false; + } + + try { + long result = (long) MUNLOCK_HANDLE.invokeExact( + MemorySegment.ofAddress(this.address), + this.size); + locked = result != 0; + return result == 0; + } catch (Throwable t) { + locked = true; + return false; + } } /** diff --git a/src/main/java/ai/nomyo/util/Pass2Key.java b/src/main/java/ai/nomyo/util/Pass2Key.java index 522ccaf..7100d24 100644 --- a/src/main/java/ai/nomyo/util/Pass2Key.java +++ b/src/main/java/ai/nomyo/util/Pass2Key.java @@ -1,5 +1,7 @@ package ai.nomyo.util; +import ai.nomyo.SecureMemory; +import ai.nomyo.SecureMemory.SecureBuffer; import ai.nomyo.errors.SecurityError; import javax.crypto.BadPaddingException; @@ -11,6 +13,7 @@ import javax.crypto.SecretKeyFactory; import javax.crypto.spec.GCMParameterSpec; import javax.crypto.spec.PBEKeySpec; import javax.crypto.spec.SecretKeySpec; +import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.security.*; import java.security.spec.InvalidKeySpecException; @@ -50,12 +53,20 @@ public final class Pass2Key { if (isGcmMode(algorithm)) { byte[] iv = new byte[GCM_IV_LENGTH]; RANDOM.nextBytes(iv); - GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH, iv); - byte[] ciphertext = encryptWithCipher(algorithm, key, spec, input); - payload = assemblePayloadGcm(salt, iv, ciphertext); + try (SecureBuffer secureSalt = SecureMemory.secureByteArray(salt); SecureBuffer secureIv = SecureMemory.secureByteArray(iv)) { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH, secureIv.getData().asByteBuffer().array()); + byte[] ciphertext = encryptWithCipher(algorithm, key, spec, input); + try (SecureBuffer secureCiphertext = SecureMemory.secureByteArray(ciphertext)) { + payload = assemblePayloadGcm(secureSalt.getData().asByteBuffer().array(), secureIv.getData().asByteBuffer().array(), secureCiphertext.getData().asByteBuffer().array()); + } + } } else { - byte[] ciphertext = encryptWithCipher(algorithm, key, input); - payload = assemblePayloadSalt(salt, ciphertext); + try (SecureBuffer secureSalt = SecureMemory.secureByteArray(salt)) { + byte[] ciphertext = encryptWithCipher(algorithm, key, input); + try (SecureBuffer secureCiphertext = SecureMemory.secureByteArray(ciphertext)) { + payload = assemblePayloadSalt(secureSalt.getData().asByteBuffer().array(), secureCiphertext.getData().asByteBuffer().array()); + } + } } return Base64.getEncoder().encodeToString(payload); @@ -76,15 +87,25 @@ public final class Pass2Key { byte[] salt = java.util.Arrays.copyOfRange(decoded, 0, SALT_LENGTH); SecretKey key = deriveKey(password, salt); + String result; if (isGcmMode(algorithm)) { byte[] iv = java.util.Arrays.copyOfRange(decoded, SALT_LENGTH, SALT_LENGTH + GCM_IV_LENGTH); - GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH, iv); byte[] ciphertext = java.util.Arrays.copyOfRange(decoded, SALT_LENGTH + GCM_IV_LENGTH, decoded.length); - return decryptWithCipher(algorithm, key, spec, ciphertext); + try (SecureBuffer secureSalt = SecureMemory.secureByteArray(salt); SecureBuffer secureIv = SecureMemory.secureByteArray(iv); SecureBuffer secureCiphertext = SecureMemory.secureByteArray(ciphertext)) { + GCMParameterSpec spec = new GCMParameterSpec(GCM_TAG_LENGTH, secureIv.getData().asByteBuffer().array()); + result = decryptWithCipher(algorithm, key, spec, secureCiphertext.getData().asByteBuffer().array()); + } } else { byte[] ciphertext = java.util.Arrays.copyOfRange(decoded, SALT_LENGTH, decoded.length); - return decryptWithCipher(algorithm, key, ciphertext); + try (SecureBuffer secureSalt = SecureMemory.secureByteArray(salt); SecureBuffer secureCiphertext = SecureMemory.secureByteArray(ciphertext)) { + result = decryptWithCipher(algorithm, key, secureCiphertext.getData().asByteBuffer().array()); + } } + + java.util.Arrays.fill(decoded, (byte) 0); + java.util.Arrays.fill(salt, (byte) 0); + + return result; } private static SecretKey deriveKey(String password, byte[] salt) throws SecurityError { @@ -149,13 +170,15 @@ public final class Pass2Key { // Decode the Base64-encoded private key string byte[] decodedKey = Base64.getDecoder().decode(privateKeyString); - // Create a PKCS8EncodedKeySpec object - PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(decodedKey); + try (SecureBuffer secureKey = SecureMemory.secureByteArray(decodedKey)) { + // Create a PKCS8EncodedKeySpec object + PKCS8EncodedKeySpec keySpec = new PKCS8EncodedKeySpec(secureKey.getData().asByteBuffer().array()); - // Get an instance of the KeyFactory for the desired algorithm (e.g., RSA) - KeyFactory keyFactory = KeyFactory.getInstance("RSA"); + // Get an instance of the KeyFactory for the desired algorithm (e.g., RSA) + KeyFactory keyFactory = KeyFactory.getInstance("RSA"); - // Generate the private key object - return keyFactory.generatePrivate(keySpec); + // Generate the private key object + return keyFactory.generatePrivate(keySpec); + } } } diff --git a/src/test/java/ai/nomyo/SecureMemoryTest.java b/src/test/java/ai/nomyo/SecureMemoryTest.java index 32ee5b1..790cdcc 100644 --- a/src/test/java/ai/nomyo/SecureMemoryTest.java +++ b/src/test/java/ai/nomyo/SecureMemoryTest.java @@ -166,4 +166,166 @@ class SecureMemoryTest { void hasSecureZeroing_shouldBeTrue() { assertTrue(SecureMemory.isHAS_SECURE_ZEROING(), "HAS_SECURE_ZEROING should be true"); } + + @Test + @DisplayName("initMemoryLocking should return false when mlock/munlock unavailable") + void initMemoryLocking_noSyscalls_shouldReturnFalse() { + assertFalse(SecureMemory.isHAS_MEMORY_LOCKING(), + "initMemoryLocking should return false when syscalls unavailable"); + } + + @Test + @DisplayName("SecureBuffer with lock=false should not attempt locking") + void secureBuffer_lockFalse_shouldNotLock() { + byte[] data = new byte[]{1, 2, 3}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, false)) { + assertFalse(buffer.lock(), "lock() should return false when lock=false in constructor"); + } + } + + @Test + @DisplayName("SecureBuffer with lock=true should not lock (stubbed)") + void secureBuffer_lockTrue_shouldNotLock() { + byte[] data = new byte[]{1, 2, 3}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, true)) { + assertFalse(buffer.lock(), "lock() should return false (memory locking stubbed)"); + } + } + + @Test + @DisplayName("SecureBuffer lock should return false when already failed") + void secureBuffer_lock_alreadyFailed_shouldReturnFalse() { + byte[] data = new byte[]{1, 2, 3}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, true)) { + buffer.lock(); + assertFalse(buffer.lock(), "Double lock should return false"); + } + } + + @Test + @DisplayName("SecureBuffer unlock when not locked should return false") + void secureBuffer_unlock_notLocked_shouldReturnFalse() { + byte[] data = new byte[]{1, 2, 3}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, false)) { + assertFalse(buffer.unlock(), "Unlock when not locked should return false"); + } + } + + @Test + @DisplayName("SecureBuffer unlock when already unlocked should return false") + void secureBuffer_unlock_alreadyUnlocked_shouldReturnFalse() { + byte[] data = new byte[]{1, 2, 3}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, true)) { + buffer.unlock(); + assertFalse(buffer.unlock(), "Double unlock should return false"); + } + } + + @Test + @DisplayName("SecureBuffer zero should not throw on null data") + void secureBuffer_zero_nullData_shouldNotThrow() { + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(null)) { + assertDoesNotThrow(buffer::zero, "Zero on null data should not throw"); + } + } + + @Test + @DisplayName("SecureBuffer with lock=false should not attempt locking in constructor") + void secureBuffer_constructor_lockFalse_shouldNotLock() { + byte[] data = new byte[]{10, 20, 30, 40, 50}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, false)) { + assertFalse(buffer.lock(), "lock() should return false when not locked"); + assertEquals(5, buffer.getSize(), "Size should match input data length"); + assertNotNull(buffer.getData(), "Data segment should not be null"); + } + } + + @Test + @DisplayName("SecureBuffer with secureMemoryEnabled=false should not attempt locking") + void secureBuffer_disabled_shouldNotLock() { + SecureMemory.setSecureMemoryEnabled(false); + try { + byte[] data = new byte[]{1, 2, 3}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, true)) { + assertFalse(buffer.lock(), "lock() should return false when disabled"); + } + } finally { + SecureMemory.setSecureMemoryEnabled(true); + } + } + + @Test + @DisplayName("SecureBuffer should preserve data contents") + void secureBuffer_shouldPreserveData() { + byte[] original = new byte[]{(byte) 0x00, (byte) 0xFF, (byte) 0xAA, (byte) 0x55, 42}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(original)) { + byte[] retrieved = new byte[original.length]; + buffer.getData().asByteBuffer().get(retrieved); + assertArrayEquals(original, retrieved, "Data should be preserved in buffer"); + } + } + + @Test + @DisplayName("SecureBuffer close should zero data after use") + void secureBuffer_close_shouldZeroData() { + byte[] original = new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; + SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(original); + + buffer.close(); + + assertDoesNotThrow(buffer::close, "Double close should not throw"); + } + + @Test + @DisplayName("getMemoryProtectionInfo enabled should reflect secureMemoryEnabled state") + void getMemoryProtectionInfo_enabled_shouldReflectState() { + SecureMemory.setSecureMemoryEnabled(false); + try { + Map info = SecureMemory.getMemoryProtectionInfo(); + assertEquals(false, info.get("enabled"), "Enabled should be false when disabled"); + } finally { + SecureMemory.setSecureMemoryEnabled(true); + } + } + + @Test + @DisplayName("getMemoryProtectionInfo protection_level should be zeroing_only without locking") + void getMemoryProtectionInfo_protectionLevel_zeroingOnly() { + Map info = SecureMemory.getMemoryProtectionInfo(); + assertEquals("zeroing_only", info.get("protection_level"), + "Protection level should be zeroing_only without memory locking"); + } + + @Test + @DisplayName("SecureBuffer address should be consistent") + void secureBuffer_address_consistent() { + byte[] data = new byte[]{1, 2, 3}; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data)) { + long addr1 = buffer.getAddress(); + long addr2 = buffer.getAddress(); + assertEquals(addr1, addr2, "Address should be consistent across calls"); + } + } + + @Test + @DisplayName("SecureBuffer size should match allocated length") + void secureBuffer_size_shouldMatchLength() { + int[] sizes = {0, 1, 16, 256, 4096}; + for (int size : sizes) { + byte[] data = new byte[size]; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data)) { + assertEquals(size, buffer.getSize(), "Size should match for length=" + size); + } + } + } + + @Test + @DisplayName("SecureBuffer constructor should handle empty array") + void secureBuffer_emptyArray_shouldHandle() { + byte[] data = new byte[0]; + try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data)) { + assertEquals(0, buffer.getSize(), "Size should be 0 for empty array"); + assertNotNull(buffer.getData(), "Data segment should not be null"); + } + } }