Add new tests

This commit is contained in:
Oracle 2026-04-26 19:06:38 +02:00
parent 89d5282b0f
commit 675418f411
Signed by: Oracle
SSH key fingerprint: SHA256:x4/RtnjUyuHkdvmwNDsWSfcfF1V5PNr3OpriZqOvCX8
9 changed files with 1057 additions and 13 deletions

View file

@ -1,11 +1,7 @@
package ai.nomyo;
import ai.nomyo.errors.APIConnectionError;
import ai.nomyo.errors.SecurityError;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
/**
* Entry point loads RSA keys and validates key length.

View file

@ -3,7 +3,6 @@ package ai.nomyo;
import ai.nomyo.errors.*;
import lombok.Getter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
@ -69,7 +68,7 @@ public class SecureChatCompletion {
* @throws ServiceUnavailableError HTTP 503
* @throws APIError other errors
*/
@SuppressWarnings("unchecked")
@SuppressWarnings({"JavadocDeclaration"})
public Map<String, Object> create(String model, List<Map<String, Object>> messages, Map<String, Object> kwargs) {
// Validate required parameters
if (model == null || model.isEmpty()) {

View file

@ -4,7 +4,6 @@ import ai.nomyo.errors.*;
import ai.nomyo.util.PEMConverter;
import ai.nomyo.util.Pass2Key;
import lombok.Getter;
import lombok.Setter;
import javax.crypto.*;
import javax.crypto.spec.GCMParameterSpec;
@ -23,7 +22,6 @@ import java.nio.file.attribute.PosixFilePermissions;
import java.security.*;
import java.security.spec.*;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
@ -172,8 +170,6 @@ public class SecureCompletionClient {
this.privateKey = pair.getPrivate();
this.publicPemKey = publicPem;
} catch (SecurityError e) {
throw e;
} catch (NoSuchAlgorithmException e) {
throw new SecurityError("RSA algorithm not available: " + e.getMessage(), e);
} catch (InvalidAlgorithmParameterException e) {
@ -329,7 +325,7 @@ public class SecureCompletionClient {
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
random.nextBytes(nonce);
Cipher cipher = null;
Cipher cipher;
try {
cipher = Cipher.getInstance("AES/GCM/NoPadding");
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, nonce));
@ -356,7 +352,7 @@ public class SecureCompletionClient {
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof SecurityError) {
throw new RuntimeException((SecurityError) cause);
throw new RuntimeException(cause);
}
throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause));
}
@ -423,6 +419,7 @@ public class SecureCompletionClient {
* @throws ServiceUnavailableError HTTP 503
* @throws APIError other errors
*/
@SuppressWarnings("JavadocDeclaration")
public CompletableFuture<Map<String, Object>> sendSecureRequest(Map<String, Object> payload, String payloadId, String apiKey, String securityTier) {
return CompletableFuture.supplyAsync(() -> {
// Validate security tier if provided
@ -447,7 +444,7 @@ public class SecureCompletionClient {
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof SecurityError) {
throw new CompletionException((SecurityError) cause);
throw new CompletionException(cause);
}
throw new CompletionException(new SecurityError("Encryption failed: " + cause.getMessage(), cause));
}

View file

@ -0,0 +1,65 @@
package ai.nomyo;
import ai.nomyo.errors.SecurityError;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@Execution(ExecutionMode.CONCURRENT)
class CloseTest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("close should clear private key from memory")
void close_shouldClearPrivateKey() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
assertNotNull(client.getPrivateKey(), "Private key should exist before close");
client.close();
assertNull(client.getPrivateKey(), "Private key should be null after close");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("close should clear public key from memory")
void close_shouldClearPublicPemKey() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
assertNotNull(client.getPublicPemKey(), "Public PEM should exist before close");
client.close();
assertNull(client.getPublicPemKey(), "Public PEM should be null after close");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("close should reset keysInitialized flag")
void close_shouldResetKeysInitialized() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
client.close();
assertDoesNotThrow(() -> client.ensureKeys(null),
"ensureKeys should work after close");
assertNotNull(client.getPrivateKey(), "Private key should be regenerated after close + ensureKeys");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("close should be safe to call multiple times")
void close_multipleCalls_shouldNotThrow() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
assertDoesNotThrow(() -> client.close(), "First close should not throw");
assertDoesNotThrow(() -> client.close(), "Second close should not throw");
assertDoesNotThrow(() -> client.close(), "Third close should not throw");
}
}

View file

@ -0,0 +1,349 @@
package ai.nomyo;
import ai.nomyo.errors.SecurityError;
import ai.nomyo.util.PEMConverter;
import com.google.gson.JsonObject;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import javax.crypto.Cipher;
import javax.crypto.KeyGenerator;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import java.nio.charset.StandardCharsets;
import java.security.*;
import java.util.Base64;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.ExecutionException;
import java.security.spec.X509EncodedKeySpec;
import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@Execution(ExecutionMode.CONCURRENT)
class DecryptResponseTest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should decrypt valid encrypted package")
void decryptResponse_validPackage_shouldReturnDecryptedMap() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
PrivateKey privateKey = client.getPrivateKey();
String plaintext = "{\"content\":\"Hello, world!\",\"role\":\"assistant\"}";
byte[] plaintextBytes = plaintext.getBytes(StandardCharsets.UTF_8);
KeyGenerator keyGen = KeyGenerator.getInstance("AES");
keyGen.init(Constants.AES_KEY_SIZE * 8);
SecretKey aesKey = keyGen.generateKey();
SecureRandom random = new SecureRandom();
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
random.nextBytes(nonce);
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
aesCipher.init(Cipher.ENCRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
byte[] ciphertextWithTag = aesCipher.doFinal(plaintextBytes);
byte[] ciphertext = java.util.Arrays.copyOfRange(ciphertextWithTag, 0, ciphertextWithTag.length - Constants.GCM_TAG_SIZE);
byte[] tag = java.util.Arrays.copyOfRange(ciphertextWithTag, ciphertextWithTag.length - Constants.GCM_TAG_SIZE, ciphertextWithTag.length);
String publicPem = client.getPublicPemKey();
byte[] pubKeyBytes = PEMConverter.fromPEM(publicPem);
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(pubKeyBytes);
PublicKey serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec);
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsaCipher.init(Cipher.ENCRYPT_MODE, serverPublicKey);
byte[] encryptedAESKey = rsaCipher.doFinal(aesKey.getEncoded());
JsonObject packageJson = new JsonObject();
packageJson.addProperty("version", Constants.PROTOCOL_VERSION);
packageJson.addProperty("algorithm", Constants.HYBRID_ALGORITHM);
packageJson.addProperty("processed_at", "2024-01-01T00:00:00Z");
JsonObject encryptedPayload = new JsonObject();
encryptedPayload.addProperty("ciphertext", Base64.getEncoder().encodeToString(ciphertext));
encryptedPayload.addProperty("nonce", Base64.getEncoder().encodeToString(nonce));
encryptedPayload.addProperty("tag", Base64.getEncoder().encodeToString(tag));
packageJson.add("encrypted_payload", encryptedPayload);
packageJson.addProperty("encrypted_aes_key", Base64.getEncoder().encodeToString(encryptedAESKey));
byte[] encryptedResponse = packageJson.toString().getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(encryptedResponse, "test-payload-id");
Map<String, Object> result = future.get();
assertNotNull(result, "Result should not be null");
assertEquals("Hello, world!", result.get("content"), "Content should match");
assertEquals("assistant", result.get("role"), "Role should match");
assertNotNull(result.get("_metadata"), "Metadata should be present");
@SuppressWarnings("unchecked")
Map<String, Object> metadata = (Map<String, Object>) result.get("_metadata");
assertEquals("test-payload-id", metadata.get("payload_id"));
assertEquals(true, metadata.get("is_encrypted"));
assertEquals(Constants.HYBRID_ALGORITHM, metadata.get("encryption_algorithm"));
assertEquals("2024-01-01T00:00:00Z", metadata.get("processed_at"));
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should handle package without processed_at")
void decryptResponse_missingProcessedAt_shouldSetNull() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
PrivateKey privateKey = client.getPrivateKey();
String plaintext = "{\"response\":\"ok\"}";
byte[] plaintextBytes = plaintext.getBytes(StandardCharsets.UTF_8);
KeyGenerator keyGen = KeyGenerator.getInstance("AES");
keyGen.init(Constants.AES_KEY_SIZE * 8);
SecretKey aesKey = keyGen.generateKey();
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
new SecureRandom().nextBytes(nonce);
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
aesCipher.init(Cipher.ENCRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
byte[] ciphertextWithTag = aesCipher.doFinal(plaintextBytes);
byte[] ciphertext = java.util.Arrays.copyOfRange(ciphertextWithTag, 0, ciphertextWithTag.length - Constants.GCM_TAG_SIZE);
byte[] tag = java.util.Arrays.copyOfRange(ciphertextWithTag, ciphertextWithTag.length - Constants.GCM_TAG_SIZE, ciphertextWithTag.length);
String publicPem = client.getPublicPemKey();
byte[] pubKeyBytes = PEMConverter.fromPEM(publicPem);
PublicKey serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(new X509EncodedKeySpec(pubKeyBytes));
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsaCipher.init(Cipher.ENCRYPT_MODE, serverPublicKey);
byte[] encryptedAESKey = rsaCipher.doFinal(aesKey.getEncoded());
JsonObject packageJson = new JsonObject();
packageJson.addProperty("version", Constants.PROTOCOL_VERSION);
packageJson.addProperty("algorithm", Constants.HYBRID_ALGORITHM);
JsonObject encryptedPayload = new JsonObject();
encryptedPayload.addProperty("ciphertext", Base64.getEncoder().encodeToString(ciphertext));
encryptedPayload.addProperty("nonce", Base64.getEncoder().encodeToString(nonce));
encryptedPayload.addProperty("tag", Base64.getEncoder().encodeToString(tag));
packageJson.add("encrypted_payload", encryptedPayload);
packageJson.addProperty("encrypted_aes_key", Base64.getEncoder().encodeToString(encryptedAESKey));
byte[] encryptedResponse = packageJson.toString().getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(encryptedResponse, "payload-2");
Map<String, Object> result = future.get();
assertNotNull(result);
@SuppressWarnings("unchecked")
Map<String, Object> metadata = (Map<String, Object>) result.get("_metadata");
assertNull(metadata.get("processed_at"), "processed_at should be null when not present");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw ValueError for empty response")
void decryptResponse_emptyResponse_shouldThrowValueError() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(new byte[0], "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecureCompletionClient.ValueError,
"Should throw ValueError for empty response");
assertTrue(error.getCause().getMessage().contains("Empty"),
"Error message should mention empty");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw ValueError for null response")
void decryptResponse_nullResponse_shouldThrowValueError() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(null, "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecureCompletionClient.ValueError,
"Should throw ValueError for null response");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw ValueError for malformed JSON")
void decryptResponse_malformedJson_shouldThrowValueError() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
byte[] invalidJson = "not valid json at all".getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(invalidJson, "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecureCompletionClient.ValueError,
"Should throw ValueError for malformed JSON");
assertTrue(error.getCause().getMessage().contains("malformed JSON") || error.getCause().getMessage().contains("JSON"),
"Error message should mention JSON");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw ValueError for missing required fields")
void decryptResponse_missingField_shouldThrowValueError() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
JsonObject packageJson = new JsonObject();
packageJson.addProperty("version", Constants.PROTOCOL_VERSION);
packageJson.addProperty("extra_field", "value");
byte[] encryptedResponse = packageJson.toString().getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(encryptedResponse, "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecureCompletionClient.ValueError,
"Should throw ValueError for missing fields");
assertTrue(error.getCause().getMessage().contains("Missing required fields"),
"Error message should mention missing fields");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw ValueError for wrong protocol version")
void decryptResponse_wrongVersion_shouldThrowValueError() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
JsonObject packageJson = new JsonObject();
packageJson.addProperty("version", "9.9");
packageJson.addProperty("algorithm", Constants.HYBRID_ALGORITHM);
packageJson.addProperty("encrypted_aes_key", "dGVzdA==");
JsonObject encryptedPayload = new JsonObject();
encryptedPayload.addProperty("ciphertext", "dGVzdA==");
encryptedPayload.addProperty("nonce", "dGVzdA==");
encryptedPayload.addProperty("tag", "dGVzdA==");
packageJson.add("encrypted_payload", encryptedPayload);
byte[] encryptedResponse = packageJson.toString().getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(encryptedResponse, "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecureCompletionClient.ValueError,
"Should throw ValueError for wrong version");
assertTrue(error.getCause().getMessage().contains("Unsupported protocol version"),
"Error message should mention unsupported version");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw ValueError for wrong algorithm")
void decryptResponse_wrongAlgorithm_shouldThrowValueError() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
JsonObject packageJson = new JsonObject();
packageJson.addProperty("version", Constants.PROTOCOL_VERSION);
packageJson.addProperty("algorithm", "wrong-algorithm");
packageJson.addProperty("encrypted_aes_key", "dGVzdA==");
JsonObject encryptedPayload = new JsonObject();
encryptedPayload.addProperty("ciphertext", "dGVzdA==");
encryptedPayload.addProperty("nonce", "dGVzdA==");
encryptedPayload.addProperty("tag", "dGVzdA==");
packageJson.add("encrypted_payload", encryptedPayload);
byte[] encryptedResponse = packageJson.toString().getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(encryptedResponse, "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecureCompletionClient.ValueError,
"Should throw ValueError for wrong algorithm");
assertTrue(error.getCause().getMessage().contains("Unsupported encryption algorithm"),
"Error message should mention unsupported algorithm");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw SecurityError when private key not initialized")
void decryptResponse_noPrivateKey_shouldThrowSecurityError() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
JsonObject packageJson = new JsonObject();
packageJson.addProperty("version", Constants.PROTOCOL_VERSION);
packageJson.addProperty("algorithm", Constants.HYBRID_ALGORITHM);
packageJson.addProperty("encrypted_aes_key", "dGVzdA==");
JsonObject encryptedPayload = new JsonObject();
encryptedPayload.addProperty("ciphertext", "dGVzdA==");
encryptedPayload.addProperty("nonce", "dGVzdA==");
encryptedPayload.addProperty("tag", "dGVzdA==");
packageJson.add("encrypted_payload", encryptedPayload);
byte[] encryptedResponse = packageJson.toString().getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client.decryptResponse(encryptedResponse, "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecurityError,
"Should throw SecurityError when no private key");
assertTrue(error.getCause().getMessage().contains("Private key not initialized"),
"Error message should mention private key not initialized");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("decryptResponse should throw SecurityError for wrong private key")
void decryptResponse_wrongPrivateKey_shouldThrowSecurityError() throws Exception {
SecureCompletionClient client1 = new SecureCompletionClient();
client1.generateKeys(false);
SecureCompletionClient client2 = new SecureCompletionClient();
client2.generateKeys(false);
String plaintext = "{\"data\":\"secret\"}";
byte[] plaintextBytes = plaintext.getBytes(StandardCharsets.UTF_8);
KeyGenerator keyGen = KeyGenerator.getInstance("AES");
keyGen.init(Constants.AES_KEY_SIZE * 8);
SecretKey aesKey = keyGen.generateKey();
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
new SecureRandom().nextBytes(nonce);
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
aesCipher.init(Cipher.ENCRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
byte[] ciphertextWithTag = aesCipher.doFinal(plaintextBytes);
byte[] ciphertext = java.util.Arrays.copyOfRange(ciphertextWithTag, 0, ciphertextWithTag.length - Constants.GCM_TAG_SIZE);
byte[] tag = java.util.Arrays.copyOfRange(ciphertextWithTag, ciphertextWithTag.length - Constants.GCM_TAG_SIZE, ciphertextWithTag.length);
String publicPem = client1.getPublicPemKey();
byte[] pubKeyBytes = PEMConverter.fromPEM(publicPem);
PublicKey serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(new X509EncodedKeySpec(pubKeyBytes));
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsaCipher.init(Cipher.ENCRYPT_MODE, serverPublicKey);
byte[] encryptedAESKey = rsaCipher.doFinal(aesKey.getEncoded());
JsonObject packageJson = new JsonObject();
packageJson.addProperty("version", Constants.PROTOCOL_VERSION);
packageJson.addProperty("algorithm", Constants.HYBRID_ALGORITHM);
JsonObject encryptedPayload = new JsonObject();
encryptedPayload.addProperty("ciphertext", Base64.getEncoder().encodeToString(ciphertext));
encryptedPayload.addProperty("nonce", Base64.getEncoder().encodeToString(nonce));
encryptedPayload.addProperty("tag", Base64.getEncoder().encodeToString(tag));
packageJson.add("encrypted_payload", encryptedPayload);
packageJson.addProperty("encrypted_aes_key", Base64.getEncoder().encodeToString(encryptedAESKey));
byte[] encryptedResponse = packageJson.toString().getBytes(StandardCharsets.UTF_8);
CompletableFuture<Map<String, Object>> future = client2.decryptResponse(encryptedResponse, "test-id");
ExecutionException error = assertThrows(ExecutionException.class, future::get);
assertTrue(error.getCause() instanceof SecurityError,
"Should throw SecurityError for wrong private key");
}
}

View file

@ -0,0 +1,99 @@
package ai.nomyo;
import ai.nomyo.errors.SecurityError;
import org.junit.jupiter.api.*;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import java.io.File;
import java.security.PrivateKey;
import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
@Execution(ExecutionMode.CONCURRENT)
class EnsureKeysTest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("ensureKeys should generate keys when not initialized")
void ensureKeys_notInitialized_shouldGenerateKeys() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
assertNull(client.getPrivateKey(), "Private key should be null initially");
client.ensureKeys(null);
assertNotNull(client.getPrivateKey(), "Private key should be set");
assertNotNull(client.getPublicPemKey(), "Public PEM key should be set");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("ensureKeys should not re-generate keys when already initialized")
void ensureKeys_alreadyInitialized_shouldNotRegenerate() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.ensureKeys(null);
PrivateKey firstKey = client.getPrivateKey();
client.ensureKeys(null);
assertNotNull(client.getPrivateKey(), "Private key should still be set");
assertArrayEquals(firstKey.getEncoded(), client.getPrivateKey().getEncoded(),
"Same key should be retained after second ensureKeys");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("ensureKeys with keyDir should save keys to default directory")
void ensureKeys_withKeyDir_shouldSaveToDefaultDir() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.ensureKeys("/tmp/nomyo_test_keys");
assertNotNull(client.getPrivateKey());
File keyDir = new File(Constants.DEFAULT_KEY_DIR);
File privateKeyFile = new File(keyDir, Constants.DEFAULT_PRIVATE_KEY_FILE);
File publicKeyFile = new File(keyDir, Constants.DEFAULT_PUBLIC_KEY_FILE);
assertTrue(privateKeyFile.exists(), "Private key file should be created in default dir");
assertTrue(publicKeyFile.exists(), "Public key file should be created in default dir");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("ensureKeys should be thread-safe with concurrent calls")
void ensureKeys_concurrentCalls_shouldBeThreadSafe() throws SecurityError, InterruptedException {
SecureCompletionClient client = new SecureCompletionClient();
Thread[] threads = new Thread[5];
PrivateKey[] results = new PrivateKey[5];
Exception[] errors = new Exception[5];
for (int i = 0; i < 5; i++) {
final int idx = i;
threads[i] = new Thread(() -> {
try {
client.ensureKeys(null);
results[idx] = client.getPrivateKey();
} catch (SecurityError e) {
errors[idx] = e;
}
});
threads[i].start();
}
for (Thread t : threads) {
t.join();
}
for (int i = 0; i < 5; i++) {
assertNull(errors[i], "No errors should occur in thread " + i);
assertNotNull(results[i], "Thread " + i + " should have a result");
}
for (int i = 1; i < 5; i++) {
assertArrayEquals(results[0].getEncoded(), results[i].getEncoded(),
"All threads should see the same key");
}
}
}

View file

@ -0,0 +1,285 @@
package ai.nomyo;
import ai.nomyo.errors.SecurityError;
import org.junit.jupiter.api.*;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import org.junit.jupiter.api.parallel.Execution;
import org.junit.jupiter.api.parallel.ExecutionMode;
import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
class SecureChatCompletionTest {
@Test
@DisplayName("create should throw IllegalArgumentException for null model")
void create_nullModel_shouldThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
IllegalArgumentException error = assertThrows(IllegalArgumentException.class,
() -> chat.create(null, List.of(Map.of("role", "user", "content", "hi"))));
assertTrue(error.getMessage().contains("model"),
"Error should mention model");
}
@Test
@DisplayName("create should throw IllegalArgumentException for empty model")
void create_emptyModel_shouldThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
IllegalArgumentException error = assertThrows(IllegalArgumentException.class,
() -> chat.create("", List.of(Map.of("role", "user", "content", "hi"))));
assertTrue(error.getMessage().contains("model"),
"Error should mention model");
}
@Test
@DisplayName("create should throw IllegalArgumentException for null messages")
void create_nullMessages_shouldThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
IllegalArgumentException error = assertThrows(IllegalArgumentException.class,
() -> chat.create("gpt-4", null));
assertTrue(error.getMessage().contains("messages"),
"Error should mention messages");
}
@Test
@DisplayName("create should throw IllegalArgumentException for empty messages")
void create_emptyMessages_shouldThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
IllegalArgumentException error = assertThrows(IllegalArgumentException.class,
() -> chat.create("gpt-4", List.of()));
assertTrue(error.getMessage().contains("messages"),
"Error should mention messages");
}
@Test
@DisplayName("create should throw IllegalArgumentException for stream=true")
void create_streamTrue_shouldThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
Map<String, Object> kwargs = Map.of("stream", true);
IllegalArgumentException error = assertThrows(IllegalArgumentException.class,
() -> chat.create("gpt-4", List.of(Map.of("role", "user", "content", "hi")), kwargs));
assertTrue(error.getMessage().contains("Streaming"),
"Error should mention streaming");
}
@Test
@DisplayName("create should throw IllegalArgumentException for stream=\"true\"")
void create_streamStringTrue_shouldThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
Map<String, Object> kwargs = Map.of("stream", "true");
IllegalArgumentException error = assertThrows(IllegalArgumentException.class,
() -> chat.create("gpt-4", List.of(Map.of("role", "user", "content", "hi")), kwargs));
assertTrue(error.getMessage().contains("Streaming"),
"Error should mention streaming");
}
@Test
@DisplayName("create should throw IllegalArgumentException for invalid security_tier")
void create_invalidSecurityTier_shouldThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
Map<String, Object> kwargs = Map.of("security_tier", "invalid-tier");
IllegalArgumentException error = assertThrows(IllegalArgumentException.class,
() -> chat.create("gpt-4", List.of(Map.of("role", "user", "content", "hi")), kwargs));
assertTrue(error.getMessage().contains("Invalid security_tier"),
"Error should mention invalid security_tier");
assertTrue(error.getMessage().contains("standard") || error.getMessage().contains("high") || error.getMessage().contains("maximum"),
"Error should list valid tiers");
}
@Test
@DisplayName("create should accept valid security_tier values")
void create_validSecurityTier_shouldPassValidation() {
SecureChatCompletion chat = new SecureChatCompletion();
for (String tier : new String[]{"standard", "high", "maximum"}) {
Map<String, Object> kwargs = Map.of("security_tier", tier);
assertThrows(ExecutionException.class, () -> {
try {
chat.create("gpt-4", List.of(Map.of("role", "user", "content", "hi")), kwargs);
} catch (RuntimeException e) {
throw e.getCause() instanceof ExecutionException
? (ExecutionException) e.getCause()
: new ExecutionException(e.getCause());
}
}, "Valid security_tier should not throw IllegalArgumentException");
}
}
@Test
@DisplayName("create should pass kwargs through to payload")
void create_shouldPassKwargsThrough() {
SecureChatCompletion chat = new SecureChatCompletion();
Map<String, Object> kwargs = Map.of("temperature", 0.7, "max_tokens", 100);
assertThrows(ExecutionException.class, () -> {
try {
chat.create("gpt-4", List.of(Map.of("role", "user", "content", "hi")), kwargs);
} catch (RuntimeException e) {
Throwable cause = e.getCause();
if (cause instanceof ExecutionException) {
throw (ExecutionException) cause;
}
throw new ExecutionException(cause);
}
}, "Valid kwargs should pass validation");
}
@Test
@DisplayName("create should use kwargs api_key override over instance key")
void create_kwargsApiKey_shouldOverrideInstanceKey() {
SecureChatCompletion chat = new SecureChatCompletion("https://api.nomyo.ai", "instance-key");
Map<String, Object> kwargs = Map.of("api_key", "override-key");
assertThrows(ExecutionException.class, () -> {
try {
chat.create("gpt-4", List.of(Map.of("role", "user", "content", "hi")), kwargs);
} catch (RuntimeException e) {
Throwable cause = e.getCause();
if (cause instanceof ExecutionException) {
throw (ExecutionException) cause;
}
throw new ExecutionException(cause);
}
}, "Override api_key should pass validation");
}
@Test
@DisplayName("convenience create should delegate to full create")
void create_convenience_shouldDelegate() {
SecureChatCompletion chat = new SecureChatCompletion();
assertThrows(ExecutionException.class, () -> {
try {
chat.create("gpt-4", List.of(Map.of("role", "user", "content", "hi")));
} catch (RuntimeException e) {
Throwable cause = e.getCause();
if (cause instanceof ExecutionException) {
throw (ExecutionException) cause;
}
throw new ExecutionException(cause);
}
}, "Convenience create should delegate to full create");
}
@Test
@DisplayName("acreate should delegate to create")
void acreate_shouldDelegateToCreate() {
SecureChatCompletion chat = new SecureChatCompletion();
assertThrows(IllegalArgumentException.class,
() -> chat.acreate(null, List.of(Map.of("role", "user", "content", "hi"))));
}
@Test
@DisplayName("acreate convenience should delegate to create")
void acreate_convenience_shouldDelegate() {
SecureChatCompletion chat = new SecureChatCompletion();
assertThrows(ExecutionException.class, () -> {
try {
chat.acreate("gpt-4", List.of(Map.of("role", "user", "content", "hi")));
} catch (RuntimeException e) {
Throwable cause = e.getCause();
if (cause instanceof ExecutionException) {
throw (ExecutionException) cause;
}
throw new ExecutionException(cause);
}
}, "Convenience acreate should delegate to create");
}
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("close should delegate to client close")
void close_shouldDelegateToClient() throws SecurityError {
SecureChatCompletion chat = new SecureChatCompletion();
chat.getClient().generateKeys(false);
assertNotNull(chat.getClient().getPrivateKey(), "Client should have private key before close");
chat.close();
assertNull(chat.getClient().getPrivateKey(), "Client private key should be null after close");
}
@Test
@DisplayName("close should be safe to call multiple times")
void chatCompletion_close_multipleCalls_shouldNotThrow() {
SecureChatCompletion chat = new SecureChatCompletion();
assertDoesNotThrow(() -> chat.close(), "First close should not throw");
assertDoesNotThrow(() -> chat.close(), "Second close should not throw");
}
@Test
@DisplayName("SecureChatCompletion default constructor should use default base URL")
void chatCompletion_defaultConstructor_shouldUseDefaultUrl() {
SecureChatCompletion chat = new SecureChatCompletion();
assertEquals(Constants.DEFAULT_BASE_URL, chat.getClient().getRouterUrl());
assertNull(chat.getApiKey());
assertNull(chat.getKeyDir());
assertTrue(chat.getClient().isUseSecureMemory());
assertEquals(Constants.DEFAULT_MAX_RETRIES, chat.getClient().getMaxRetries());
}
@Test
@DisplayName("SecureChatCompletion baseUrl constructor should set url")
void chatCompletion_baseUrlConstructor_shouldSetUrl() {
SecureChatCompletion chat = new SecureChatCompletion("https://custom.api.com");
assertEquals("https://custom.api.com", chat.getClient().getRouterUrl());
assertNull(chat.getApiKey());
}
@Test
@DisplayName("SecureChatCompletion baseUrl+apiKey constructor should set both")
void chatCompletion_baseUrlApiKeyConstructor_shouldSetBoth() {
SecureChatCompletion chat = new SecureChatCompletion("https://api.nomyo.ai", "my-api-key");
assertEquals("https://api.nomyo.ai", chat.getClient().getRouterUrl());
assertEquals("my-api-key", chat.getApiKey());
}
@Test
@DisplayName("SecureChatCompletion full constructor should set all params")
void chatCompletion_fullConstructor_shouldSetAll() {
SecureChatCompletion chat = new SecureChatCompletion(
"https://custom.api.com",
true,
"test-key",
false,
"/tmp/keys",
10
);
assertEquals("https://custom.api.com", chat.getClient().getRouterUrl());
assertTrue(chat.getClient().isAllowHttp());
assertEquals("test-key", chat.getApiKey());
assertFalse(chat.getClient().isUseSecureMemory());
assertEquals("/tmp/keys", chat.getKeyDir());
assertEquals(10, chat.getClient().getMaxRetries());
}
}

View file

@ -0,0 +1,165 @@
package ai.nomyo;
import org.junit.jupiter.api.*;
import java.util.Map;
import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
class SecureMemoryTest {
@BeforeEach
void resetSecureMemoryState() {
SecureMemory.setSecureMemoryEnabled(true);
}
@Test
@DisplayName("getMemoryProtectionInfo should return enabled=true by default")
void getMemoryProtectionInfo_default_shouldBeEnabled() {
Map<String, Object> info = SecureMemory.getMemoryProtectionInfo();
assertEquals(true, info.get("enabled"), "Memory protection should be enabled by default");
}
@Test
@DisplayName("getMemoryProtectionInfo should report has_secure_zeroing=true")
void getMemoryProtectionInfo_shouldReportZeroing() {
Map<String, Object> info = SecureMemory.getMemoryProtectionInfo();
assertEquals(true, info.get("has_secure_zeroing"), "Should report secure zeroing available");
}
@Test
@DisplayName("getMemoryProtectionInfo should report has_memory_locking=false")
void getMemoryProtectionInfo_shouldReportNoLocking() {
Map<String, Object> info = SecureMemory.getMemoryProtectionInfo();
assertEquals(false, info.get("has_memory_locking"), "Should report memory locking unavailable");
}
@Test
@DisplayName("getMemoryProtectionInfo should report protection_level=zeroing_only")
void getMemoryProtectionInfo_protectionLevel_shouldBeZeroingOnly() {
Map<String, Object> info = SecureMemory.getMemoryProtectionInfo();
assertEquals("zeroing_only", info.get("protection_level"),
"Protection level should be zeroing_only (no memory locking)");
}
@Test
@DisplayName("getMemoryProtectionInfo should return page_size from Constants")
void getMemoryProtectionInfo_pageSize_shouldMatchConstants() {
Map<String, Object> info = SecureMemory.getMemoryProtectionInfo();
assertEquals(Constants.PAGE_SIZE, info.get("page_size"),
"Page size should match Constants.PAGE_SIZE");
}
@Test
@DisplayName("getMemoryProtectionInfo should report supports_full_protection=false without locking")
void getMemoryProtectionInfo_fullProtection_shouldBeFalse() {
Map<String, Object> info = SecureMemory.getMemoryProtectionInfo();
assertEquals(false, info.get("supports_full_protection"),
"Full protection should be false without memory locking");
}
@Test
@DisplayName("getMemoryProtectionInfo should report supports_full_protection=false when disabled")
void getMemoryProtectionInfo_disabled_shouldNotSupportFull() {
SecureMemory.setSecureMemoryEnabled(false);
try {
Map<String, Object> info = SecureMemory.getMemoryProtectionInfo();
assertEquals(false, info.get("supports_full_protection"),
"Full protection should be false when disabled");
assertEquals(false, info.get("enabled"), "Enabled should be false");
} finally {
SecureMemory.setSecureMemoryEnabled(true);
}
}
@Test
@DisplayName("secureByteArray should create a SecureBuffer with data")
void secureByteArray_shouldCreateBuffer() {
byte[] data = new byte[]{1, 2, 3, 4, 5};
try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data)) {
assertNotNull(buffer.getData(), "Data segment should not be null");
assertEquals(5, buffer.getSize(), "Size should match input");
assertNotEquals(0, buffer.getAddress(), "Address should not be zero");
}
}
@Test
@DisplayName("secureByteArray should handle null data gracefully")
void secureByteArray_nullData_shouldHandleGracefully() {
try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(null)) {
assertNotNull(buffer, "Buffer should not be null even with null data");
assertEquals(0, buffer.getSize(), "Size should be 0 for null data");
}
}
@Test
@DisplayName("SecureBuffer zero should clear all bytes")
void secureBuffer_zero_shouldClearBytes() {
byte[] data = new byte[]{(byte)0xFF, (byte)0xFF, (byte)0xFF, (byte)0xFF};
SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data);
buffer.zero();
assertDoesNotThrow(() -> buffer.zero(), "Zeroing should not throw");
}
@Test
@DisplayName("SecureBuffer close should zero and unlock")
void secureBuffer_close_shouldZeroAndUnlock() {
byte[] data = new byte[]{1, 2, 3};
assertDoesNotThrow(() -> {
try (SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data)) {
assertNotNull(buffer.getData());
}
}, "Close via try-with-resources should not throw");
}
@Test
@DisplayName("SecureBuffer close should be idempotent")
void secureBuffer_close_idempotent() {
byte[] data = new byte[]{1, 2, 3};
SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data);
assertDoesNotThrow(() -> buffer.close(), "First close should not throw");
assertDoesNotThrow(() -> buffer.close(), "Second close should not throw");
}
@Test
@DisplayName("SecureBuffer lock should return false (not supported)")
void secureBuffer_lock_shouldReturnFalse() {
byte[] data = new byte[]{1, 2, 3};
SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, true);
assertFalse(buffer.lock(), "Lock should return false (not supported)");
}
@Test
@DisplayName("SecureBuffer unlock should return false")
void secureBuffer_unlock_shouldReturnFalse() {
byte[] data = new byte[]{1, 2, 3};
SecureMemory.SecureBuffer buffer = SecureMemory.secureByteArray(data, true);
assertFalse(buffer.unlock(), "Unlock should return false");
}
@Test
@DisplayName("HAS_MEMORY_LOCKING should be false")
void hasMemoryLocking_shouldBeFalse() {
assertFalse(SecureMemory.isHAS_MEMORY_LOCKING(), "HAS_MEMORY_LOCKING should be false");
}
@Test
@DisplayName("HAS_SECURE_ZEROING should be true")
void hasSecureZeroing_shouldBeTrue() {
assertTrue(SecureMemory.isHAS_SECURE_ZEROING(), "HAS_SECURE_ZEROING should be true");
}
}

View file

@ -0,0 +1,89 @@
package ai.nomyo;
import ai.nomyo.util.Splitter;
import org.junit.jupiter.api.*;
import java.util.List;
import static org.junit.jupiter.api.Assertions.*;
@TestMethodOrder(MethodOrderer.OrderAnnotation.class)
class SplitterTest {
@Test
@DisplayName("fixedLengthString should split string into equal parts")
void fixedLengthString_equalParts_shouldSplitCorrectly() {
List<String> result = Splitter.fixedLengthString(5, "1234567890");
assertEquals(2, result.size(), "Should have 2 parts");
assertEquals("12345", result.get(0), "First part should be '12345'");
assertEquals("67890", result.get(1), "Second part should be '67890'");
}
@Test
@DisplayName("fixedLengthString should handle last part being shorter")
void fixedLengthString_lastPartShorter_shouldWork() {
List<String> result = Splitter.fixedLengthString(5, "123456789");
assertEquals(2, result.size(), "Should have 2 parts");
assertEquals("12345", result.get(0), "First part should be '12345'");
assertEquals("6789", result.get(1), "Second part should be '6789'");
}
@Test
@DisplayName("fixedLengthString with length equal to string should return single part")
void fixedLengthString_lengthEqualsString_shouldReturnSingle() {
List<String> result = Splitter.fixedLengthString(10, "1234567890");
assertEquals(1, result.size(), "Should have 1 part");
assertEquals("1234567890", result.get(0), "Single part should be the full string");
}
@Test
@DisplayName("fixedLengthString with length greater than string should return single part")
void fixedLengthString_lengthGreaterThanString_shouldReturnSingle() {
List<String> result = Splitter.fixedLengthString(100, "hello");
assertEquals(1, result.size(), "Should have 1 part");
assertEquals("hello", result.get(0), "Single part should be 'hello'");
}
@Test
@DisplayName("fixedLengthString with length 1 should split into single characters")
void fixedLengthString_lengthOne_shouldSplitChars() {
List<String> result = Splitter.fixedLengthString(1, "abc");
assertEquals(3, result.size(), "Should have 3 parts");
assertEquals("a", result.get(0));
assertEquals("b", result.get(1));
assertEquals("c", result.get(2));
}
@Test
@DisplayName("fixedLengthString with empty string should return empty list")
void fixedLengthString_emptyString_shouldReturnEmpty() {
List<String> result = Splitter.fixedLengthString(5, "");
assertEquals(0, result.size(), "Should return empty list");
}
@Test
@DisplayName("fixedLengthString should handle string with special characters")
void fixedLengthString_specialChars_shouldSplitCorrectly() {
List<String> result = Splitter.fixedLengthString(3, "a!@b#c$");
assertEquals(3, result.size(), "Should have 3 parts");
assertEquals("a!@", result.get(0));
assertEquals("b#c", result.get(1));
assertEquals("$", result.get(2));
}
@Test
@DisplayName("fixedLengthString should handle unicode characters")
void fixedLengthString_unicode_shouldSplitCorrectly() {
List<String> result = Splitter.fixedLengthString(2, "ab\u00e9\u00fc");
assertEquals(2, result.size(), "Should have 2 parts");
assertEquals("ab", result.get(0));
}
}