Finish functionality
This commit is contained in:
parent
b6af1c9792
commit
89d5282b0f
9 changed files with 583 additions and 133 deletions
|
|
@ -1,6 +1,7 @@
|
||||||
package ai.nomyo;
|
package ai.nomyo;
|
||||||
|
|
||||||
import com.google.gson.Gson;
|
import com.google.gson.Gson;
|
||||||
|
import com.google.gson.GsonBuilder;
|
||||||
import com.google.gson.annotations.SerializedName;
|
import com.google.gson.annotations.SerializedName;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
import lombok.Setter;
|
import lombok.Setter;
|
||||||
|
|
@ -12,6 +13,8 @@ import lombok.Setter;
|
||||||
@Getter
|
@Getter
|
||||||
public class EncryptedRequest {
|
public class EncryptedRequest {
|
||||||
|
|
||||||
|
private static final Gson GSON = new GsonBuilder().create();
|
||||||
|
|
||||||
// Getters and Setters
|
// Getters and Setters
|
||||||
@SerializedName("version")
|
@SerializedName("version")
|
||||||
private String version;
|
private String version;
|
||||||
|
|
@ -55,7 +58,7 @@ public class EncryptedRequest {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public String toJson() {
|
public String toJson() {
|
||||||
return new Gson().toJson(this);
|
return GSON.toJson(this);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,10 @@
|
||||||
package ai.nomyo;
|
package ai.nomyo;
|
||||||
|
|
||||||
|
import ai.nomyo.errors.APIConnectionError;
|
||||||
|
import ai.nomyo.errors.SecurityError;
|
||||||
|
|
||||||
|
import java.util.List;
|
||||||
|
import java.util.Map;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -8,17 +13,22 @@ import java.util.concurrent.ExecutionException;
|
||||||
public class Main {
|
public class Main {
|
||||||
|
|
||||||
static void main() {
|
static void main() {
|
||||||
SecureCompletionClient secureCompletionClient = new SecureCompletionClient();
|
SecureChatCompletion secureChatCompletion = new SecureChatCompletion( Constants.DEFAULT_BASE_URL, "NOMYO_AI_E2EE_INFERENCE");
|
||||||
//secureCompletionClient.generateKeys(true, "client_keys", "pokemon");
|
List<Map<String, Object>> messages = List.of(
|
||||||
//secureCompletionClient.loadKeys("client_keys/private_key.pem", "pokemon");
|
Map.of("role", "user", "content", "Hello! How are you today?")
|
||||||
|
);
|
||||||
|
|
||||||
|
Map<String, Object> kwargs = Map.of(
|
||||||
|
"security_tier", "standard",
|
||||||
|
"temperature", 0.7
|
||||||
|
);
|
||||||
|
|
||||||
try {
|
var response = secureChatCompletion.create(
|
||||||
System.out.println(secureCompletionClient.fetchServerPublicKey().get());
|
"Qwen/Qwen3-0.6B",
|
||||||
} catch (InterruptedException | ExecutionException e) {
|
messages,
|
||||||
throw new RuntimeException(e);
|
kwargs);
|
||||||
}
|
|
||||||
|
|
||||||
|
System.out.println(response.toString());
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -3,8 +3,12 @@ package ai.nomyo;
|
||||||
import ai.nomyo.errors.*;
|
import ai.nomyo.errors.*;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
|
||||||
|
import java.util.ArrayList;
|
||||||
|
import java.util.HashMap;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
|
import java.util.UUID;
|
||||||
|
import java.util.concurrent.ExecutionException;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* High-level OpenAI-compatible entrypoint with automatic hybrid encryption (AES-256-GCM + RSA-4096).
|
* High-level OpenAI-compatible entrypoint with automatic hybrid encryption (AES-256-GCM + RSA-4096).
|
||||||
|
|
@ -23,6 +27,14 @@ public class SecureChatCompletion {
|
||||||
this(Constants.DEFAULT_BASE_URL, false, null, true, null, Constants.DEFAULT_MAX_RETRIES);
|
this(Constants.DEFAULT_BASE_URL, false, null, true, null, Constants.DEFAULT_MAX_RETRIES);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public SecureChatCompletion(String baseUrl, String apiKey) {
|
||||||
|
this(baseUrl, false, apiKey, true, null, Constants.DEFAULT_MAX_RETRIES);
|
||||||
|
}
|
||||||
|
|
||||||
|
public SecureChatCompletion(String baseUrl) {
|
||||||
|
this(baseUrl, false, null, true, null, Constants.DEFAULT_MAX_RETRIES);
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* @param baseUrl NOMYO Router base URL (HTTPS enforced unless {@code allowHttp})
|
* @param baseUrl NOMYO Router base URL (HTTPS enforced unless {@code allowHttp})
|
||||||
* @param allowHttp permit {@code http://} URLs (development only)
|
* @param allowHttp permit {@code http://} URLs (development only)
|
||||||
|
|
@ -57,15 +69,85 @@ public class SecureChatCompletion {
|
||||||
* @throws ServiceUnavailableError HTTP 503
|
* @throws ServiceUnavailableError HTTP 503
|
||||||
* @throws APIError other errors
|
* @throws APIError other errors
|
||||||
*/
|
*/
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
public Map<String, Object> create(String model, List<Map<String, Object>> messages, Map<String, Object> kwargs) {
|
public Map<String, Object> create(String model, List<Map<String, Object>> messages, Map<String, Object> kwargs) {
|
||||||
// Build payload from model, messages, and kwargs
|
// Validate required parameters
|
||||||
// Validate stream is false
|
if (model == null || model.isEmpty()) {
|
||||||
// Validate securityTier if provided
|
throw new IllegalArgumentException("model is required");
|
||||||
// Use per-call api_key override if provided, else instance apiKey
|
}
|
||||||
// Create temp client if baseUrl override provided
|
if (messages == null || messages.isEmpty()) {
|
||||||
|
throw new IllegalArgumentException("messages is required and cannot be empty");
|
||||||
|
}
|
||||||
|
|
||||||
|
// Build payload
|
||||||
|
Map<String, Object> payload = new HashMap<>();
|
||||||
|
payload.put("model", model);
|
||||||
|
payload.put("messages", messages);
|
||||||
|
|
||||||
|
// Add kwargs
|
||||||
|
if (kwargs != null) {
|
||||||
|
// Check for stream parameter
|
||||||
|
if (kwargs.containsKey("stream")) {
|
||||||
|
Object streamValue = kwargs.get("stream");
|
||||||
|
boolean stream = streamValue instanceof Boolean ? (Boolean) streamValue : Boolean.parseBoolean(streamValue.toString());
|
||||||
|
if (stream) {
|
||||||
|
throw new IllegalArgumentException("Streaming is not supported");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check for security_tier
|
||||||
|
if (kwargs.containsKey("security_tier")) {
|
||||||
|
Object tier = kwargs.get("security_tier");
|
||||||
|
if (tier != null && !Constants.VALID_SECURITY_TIERS.contains(tier.toString())) {
|
||||||
|
throw new IllegalArgumentException(
|
||||||
|
"Invalid security_tier: '" + tier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
payload.putAll(kwargs);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine API key (per-call override or instance key)
|
||||||
|
String apiKey = this.apiKey;
|
||||||
|
if (kwargs != null && kwargs.containsKey("api_key")) {
|
||||||
|
Object key = kwargs.get("api_key");
|
||||||
|
if (key != null) {
|
||||||
|
apiKey = key.toString();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Determine security tier
|
||||||
|
String securityTier = null;
|
||||||
|
if (kwargs != null && kwargs.containsKey("security_tier")) {
|
||||||
|
securityTier = kwargs.get("security_tier").toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate payload ID
|
||||||
|
String payloadId = UUID.randomUUID().toString();
|
||||||
|
|
||||||
// Send secure request
|
// Send secure request
|
||||||
// Return decrypted response map
|
try {
|
||||||
return null;
|
return client.sendSecureRequest(payload, payloadId, apiKey, securityTier).get();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw new RuntimeException("Request interrupted", e);
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
Throwable cause = e.getCause();
|
||||||
|
switch (cause) {
|
||||||
|
case SecurityError securityError -> throw new RuntimeException(cause);
|
||||||
|
case InvalidRequestError invalidRequestError -> throw new RuntimeException(cause);
|
||||||
|
case AuthenticationError authenticationError -> throw new RuntimeException(cause);
|
||||||
|
case ForbiddenError forbiddenError -> throw new RuntimeException(cause);
|
||||||
|
case RateLimitError rateLimitError -> throw new RuntimeException(cause);
|
||||||
|
case ServerError serverError -> throw new RuntimeException(cause);
|
||||||
|
case ServiceUnavailableError serviceUnavailableError -> throw new RuntimeException(cause);
|
||||||
|
case APIError apiError -> throw new RuntimeException(cause);
|
||||||
|
case APIConnectionError apiConnectionError -> throw new RuntimeException(cause);
|
||||||
|
case SecureCompletionClient.ValueError valueError -> throw new IllegalArgumentException(cause);
|
||||||
|
default ->
|
||||||
|
throw new RuntimeException("Request failed: " + cause.getMessage(), cause);
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
|
||||||
|
|
@ -4,6 +4,7 @@ import ai.nomyo.errors.*;
|
||||||
import ai.nomyo.util.PEMConverter;
|
import ai.nomyo.util.PEMConverter;
|
||||||
import ai.nomyo.util.Pass2Key;
|
import ai.nomyo.util.Pass2Key;
|
||||||
import lombok.Getter;
|
import lombok.Getter;
|
||||||
|
import lombok.Setter;
|
||||||
|
|
||||||
import javax.crypto.*;
|
import javax.crypto.*;
|
||||||
import javax.crypto.spec.GCMParameterSpec;
|
import javax.crypto.spec.GCMParameterSpec;
|
||||||
|
|
@ -27,6 +28,10 @@ import java.util.*;
|
||||||
import java.util.concurrent.CompletableFuture;
|
import java.util.concurrent.CompletableFuture;
|
||||||
import java.util.concurrent.CompletionException;
|
import java.util.concurrent.CompletionException;
|
||||||
import java.util.concurrent.ExecutionException;
|
import java.util.concurrent.ExecutionException;
|
||||||
|
import com.google.gson.Gson;
|
||||||
|
import com.google.gson.JsonObject;
|
||||||
|
import com.google.gson.JsonParser;
|
||||||
|
|
||||||
import java.util.concurrent.locks.ReentrantLock;
|
import java.util.concurrent.locks.ReentrantLock;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -113,7 +118,7 @@ public class SecureCompletionClient {
|
||||||
/**
|
/**
|
||||||
* Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}.
|
* Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}.
|
||||||
*/
|
*/
|
||||||
public void generateKeys(boolean saveToFile, String keyDir, String password) {
|
public void generateKeys(boolean saveToFile, String keyDir, String password) throws SecurityError {
|
||||||
try {
|
try {
|
||||||
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA");
|
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA");
|
||||||
generator.initialize(new RSAKeyGenParameterSpec(Constants.RSA_KEY_SIZE, BigInteger.valueOf(Constants.RSA_PUBLIC_EXPONENT)));
|
generator.initialize(new RSAKeyGenParameterSpec(Constants.RSA_KEY_SIZE, BigInteger.valueOf(Constants.RSA_PUBLIC_EXPONENT)));
|
||||||
|
|
@ -145,8 +150,8 @@ public class SecureCompletionClient {
|
||||||
try {
|
try {
|
||||||
privatePem = Pass2Key.encrypt("AES/GCM/NoPadding", privatePem, password);
|
privatePem = Pass2Key.encrypt("AES/GCM/NoPadding", privatePem, password);
|
||||||
} catch (NoSuchPaddingException | IllegalBlockSizeException | BadPaddingException |
|
} catch (NoSuchPaddingException | IllegalBlockSizeException | BadPaddingException |
|
||||||
InvalidKeyException e) {
|
InvalidKeyException | SecurityError e) {
|
||||||
throw new RuntimeException(e);
|
throw new SecurityError("Failed to encrypt private key with password: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
writer.write(privatePem);
|
writer.write(privatePem);
|
||||||
|
|
@ -167,10 +172,12 @@ public class SecureCompletionClient {
|
||||||
this.privateKey = pair.getPrivate();
|
this.privateKey = pair.getPrivate();
|
||||||
this.publicPemKey = publicPem;
|
this.publicPemKey = publicPem;
|
||||||
|
|
||||||
|
} catch (SecurityError e) {
|
||||||
|
throw e;
|
||||||
} catch (NoSuchAlgorithmException e) {
|
} catch (NoSuchAlgorithmException e) {
|
||||||
throw new RuntimeException("RSA not available: " + e.getMessage(), e);
|
throw new SecurityError("RSA algorithm not available: " + e.getMessage(), e);
|
||||||
} catch (InvalidAlgorithmParameterException e) {
|
} catch (InvalidAlgorithmParameterException e) {
|
||||||
throw new RuntimeException(e);
|
throw new SecurityError("Invalid RSA key generation parameters: " + e.getMessage(), e);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to save keys: " + e.getMessage(), e);
|
throw new RuntimeException("Failed to save keys: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
|
|
@ -179,7 +186,7 @@ public class SecureCompletionClient {
|
||||||
/**
|
/**
|
||||||
* Generates a 4096-bit RSA key pair and saves to the default directory.
|
* Generates a 4096-bit RSA key pair and saves to the default directory.
|
||||||
*/
|
*/
|
||||||
public void generateKeys(boolean saveToFile) {
|
public void generateKeys(boolean saveToFile) throws SecurityError {
|
||||||
generateKeys(saveToFile, Constants.DEFAULT_KEY_DIR, null);
|
generateKeys(saveToFile, Constants.DEFAULT_KEY_DIR, null);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -190,11 +197,12 @@ public class SecureCompletionClient {
|
||||||
* @param privateKeyPath private key PEM path
|
* @param privateKeyPath private key PEM path
|
||||||
* @param publicPemKeyPath optional public key PEM path
|
* @param publicPemKeyPath optional public key PEM path
|
||||||
* @param password optional password for encrypted private key
|
* @param password optional password for encrypted private key
|
||||||
|
* @throws SecurityError if key file not found, unreadable, or decryption fails
|
||||||
*/
|
*/
|
||||||
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) {
|
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) throws SecurityError {
|
||||||
Path keyPath = Path.of(privateKeyPath);
|
Path keyPath = Path.of(privateKeyPath);
|
||||||
if (!Files.exists(keyPath)) {
|
if (!Files.exists(keyPath)) {
|
||||||
throw new RuntimeException("Private key file not found: " + privateKeyPath);
|
throw new SecurityError("Private key file not found: " + privateKeyPath);
|
||||||
}
|
}
|
||||||
|
|
||||||
String keyContent;
|
String keyContent;
|
||||||
|
|
@ -202,35 +210,36 @@ public class SecureCompletionClient {
|
||||||
try {
|
try {
|
||||||
keyContent = readFileContent(privateKeyPath);
|
keyContent = readFileContent(privateKeyPath);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
|
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password);
|
keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password);
|
||||||
} catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException |
|
} catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException |
|
||||||
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException e) {
|
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | SecurityError e) {
|
||||||
System.out.println("Wrong password!");
|
throw new SecurityError("Failed to decrypt private key with provided password: " + e.getMessage(), e);
|
||||||
return;
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
try {
|
try {
|
||||||
keyContent = readFileContent(privateKeyPath);
|
keyContent = readFileContent(privateKeyPath);
|
||||||
} catch (IOException e) {
|
} catch (IOException e) {
|
||||||
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
|
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
try {
|
try {
|
||||||
this.privateKey = Pass2Key.convertStringToPrivateKey(keyContent);
|
this.privateKey = Pass2Key.convertStringToPrivateKey(keyContent);
|
||||||
} catch (Exception e) {
|
} catch (Exception e) {
|
||||||
throw new RuntimeException("Failed to load private key: " + e.getMessage(), e);
|
throw new SecurityError("Failed to load private key: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Loads RSA private key from disk, deriving public key.
|
* Loads RSA private key from disk, deriving public key.
|
||||||
|
*
|
||||||
|
* @throws SecurityError if key file not found, unreadable, or decryption fails
|
||||||
*/
|
*/
|
||||||
public void loadKeys(String privateKeyPath, String password) {
|
public void loadKeys(String privateKeyPath, String password) throws SecurityError {
|
||||||
loadKeys(privateKeyPath, null, password);
|
loadKeys(privateKeyPath, null, password);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -250,12 +259,12 @@ public class SecureCompletionClient {
|
||||||
URI url;
|
URI url;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
url = new URI(this.routerUrl + "/pki/public_key");
|
url = new URI(this.routerUrl + Constants.PKI_PUBLIC_KEY_PATH);
|
||||||
} catch (URISyntaxException e) {
|
} catch (URISyntaxException e) {
|
||||||
return CompletableFuture.failedFuture(new CompletionException("Invalid URI: " + e.getMessage(), e));
|
return CompletableFuture.failedFuture(new CompletionException(new APIConnectionError("Invalid URI: " + e.getMessage(), e)));
|
||||||
}
|
}
|
||||||
|
|
||||||
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.of(60, ChronoUnit.SECONDS)).GET().build();
|
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).GET().build();
|
||||||
|
|
||||||
return this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> {
|
return this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> {
|
||||||
if (response.statusCode() != 200) {
|
if (response.statusCode() != 200) {
|
||||||
|
|
@ -264,7 +273,7 @@ public class SecureCompletionClient {
|
||||||
return response.body();
|
return response.body();
|
||||||
}).thenApply(body -> {
|
}).thenApply(body -> {
|
||||||
if (!PEMConverter.validatePEM(body)) {
|
if (!PEMConverter.validatePEM(body)) {
|
||||||
throw new CompletionException(new InvalidKeyException("PEM key had invalid format"));
|
throw new CompletionException(new SecurityError("Server returned invalid PEM key format (possible MITM attack)"));
|
||||||
}
|
}
|
||||||
return body;
|
return body;
|
||||||
});
|
});
|
||||||
|
|
@ -278,7 +287,37 @@ public class SecureCompletionClient {
|
||||||
* @throws SecurityError if encryption fails or keys not loaded
|
* @throws SecurityError if encryption fails or keys not loaded
|
||||||
*/
|
*/
|
||||||
public CompletableFuture<byte[]> encryptPayload(Map<String, Object> payload) {
|
public CompletableFuture<byte[]> encryptPayload(Map<String, Object> payload) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
return CompletableFuture.supplyAsync(() -> {
|
||||||
|
try {
|
||||||
|
ensureKeys(null);
|
||||||
|
} catch (SecurityError e) {
|
||||||
|
throw new CompletionException(new SecurityError("Failed to ensure keys are initialized: " + e.getMessage(), e));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (this.privateKey == null) {
|
||||||
|
throw new CompletionException(new SecurityError("Private key not available for encryption"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Generate AES key
|
||||||
|
KeyGenerator keyGen;
|
||||||
|
|
||||||
|
try {
|
||||||
|
keyGen = KeyGenerator.getInstance("AES");
|
||||||
|
keyGen.init(Constants.AES_KEY_SIZE * 8);
|
||||||
|
} catch (NoSuchAlgorithmException e) {
|
||||||
|
throw new CompletionException(new SecurityError("AES key generation not available: " + e.getMessage(), e));
|
||||||
|
}
|
||||||
|
|
||||||
|
Key aesKey = keyGen.generateKey();
|
||||||
|
|
||||||
|
// Serialize payload to JSON
|
||||||
|
Gson gson = new Gson();
|
||||||
|
String payloadJson = gson.toJson(payload);
|
||||||
|
byte[] payloadBytes = payloadJson.getBytes(StandardCharsets.UTF_8);
|
||||||
|
|
||||||
|
// Encrypt
|
||||||
|
return doEncrypt(payloadBytes, aesKey).join();
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -287,16 +326,16 @@ public class SecureCompletionClient {
|
||||||
public CompletableFuture<byte[]> doEncrypt(byte[] payloadBytes, Key aesKey) {
|
public CompletableFuture<byte[]> doEncrypt(byte[] payloadBytes, Key aesKey) {
|
||||||
return CompletableFuture.supplyAsync(() -> {
|
return CompletableFuture.supplyAsync(() -> {
|
||||||
SecureRandom random = new SecureRandom();
|
SecureRandom random = new SecureRandom();
|
||||||
byte[] nonce = new byte[12];
|
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
|
||||||
random.nextBytes(nonce);
|
random.nextBytes(nonce);
|
||||||
|
|
||||||
Cipher cipher = null;
|
Cipher cipher = null;
|
||||||
try {
|
try {
|
||||||
cipher = Cipher.getInstance("AES/GCM/NoPadding");
|
cipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||||
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(128, nonce));
|
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, nonce));
|
||||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException |
|
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException |
|
||||||
InvalidKeyException e) {
|
InvalidKeyException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(new SecurityError("AES-GCM cipher initialization failed: " + e.getMessage(), e));
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] ciphertext;
|
byte[] ciphertext;
|
||||||
|
|
@ -304,66 +343,63 @@ public class SecureCompletionClient {
|
||||||
try {
|
try {
|
||||||
ciphertext = cipher.doFinal(payloadBytes);
|
ciphertext = cipher.doFinal(payloadBytes);
|
||||||
} catch (IllegalBlockSizeException | BadPaddingException e) {
|
} catch (IllegalBlockSizeException | BadPaddingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(new SecurityError("AES-GCM encryption failed: " + e.getMessage(), e));
|
||||||
}
|
}
|
||||||
|
|
||||||
String serverPEM;
|
String serverPEM;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
serverPEM = fetchServerPublicKey().get();
|
serverPEM = fetchServerPublicKey().get();
|
||||||
} catch (InterruptedException | ExecutionException e) {
|
} catch (InterruptedException e) {
|
||||||
throw new RuntimeException(e);
|
Thread.currentThread().interrupt();
|
||||||
|
throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e));
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
Throwable cause = e.getCause();
|
||||||
|
if (cause instanceof SecurityError) {
|
||||||
|
throw new RuntimeException((SecurityError) cause);
|
||||||
|
}
|
||||||
|
throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause));
|
||||||
}
|
}
|
||||||
|
|
||||||
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM).getBytes());
|
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM));
|
||||||
|
|
||||||
PublicKey serverPublicKey;
|
PublicKey serverPublicKey;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec);
|
serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec);
|
||||||
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
|
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e));
|
||||||
}
|
}
|
||||||
|
|
||||||
Cipher rsa;
|
Cipher rsa;
|
||||||
|
|
||||||
byte[] enryptedAESKey = aesKey.getEncoded();
|
|
||||||
|
|
||||||
try {
|
try {
|
||||||
rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
|
rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
|
||||||
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey);
|
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey);
|
||||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
|
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e));
|
||||||
}
|
}
|
||||||
|
byte[] encryptedAESKey;
|
||||||
try {
|
try {
|
||||||
rsa.doFinal(enryptedAESKey);
|
encryptedAESKey = rsa.doFinal(aesKey.getEncoded());
|
||||||
} catch (IllegalBlockSizeException | BadPaddingException e) {
|
} catch (IllegalBlockSizeException | BadPaddingException e) {
|
||||||
throw new RuntimeException(e);
|
throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e));
|
||||||
}
|
}
|
||||||
|
|
||||||
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - (128 / Byte.SIZE), ciphertext.length);
|
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length);
|
||||||
|
|
||||||
EncryptedRequest request = new EncryptedRequest();
|
EncryptedRequest request = new EncryptedRequest();
|
||||||
|
|
||||||
request.setVersion("1.0");
|
request.setVersion(Constants.PROTOCOL_VERSION);
|
||||||
request.setAlgorithm("hybrid-aes256-rsa4096");
|
request.setAlgorithm(Constants.HYBRID_ALGORITHM);
|
||||||
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
|
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
|
||||||
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(enryptedAESKey));
|
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey));
|
||||||
request.setKeyAlgorithm("RSA-OAEP-SHA256");
|
request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM);
|
||||||
request.setPayloadAlgorithm("AES-256-GCM");
|
request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM);
|
||||||
|
|
||||||
return request.toJson().getBytes(StandardCharsets.UTF_8);
|
return request.toJson().getBytes(StandardCharsets.UTF_8);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* Decrypts server response.
|
|
||||||
*/
|
|
||||||
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
|
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* encrypt → POST {routerUrl}/v1/chat/secure_completion → retry → decrypt → return.
|
* encrypt → POST {routerUrl}/v1/chat/secure_completion → retry → decrypt → return.
|
||||||
* <p>Headers: Content-Type=octet-stream, X-Payload-ID, X-Public-Key, Authorization (Bearer), X-Security-Tier.
|
* <p>Headers: Content-Type=octet-stream, X-Payload-ID, X-Public-Key, Authorization (Bearer), X-Security-Tier.
|
||||||
|
|
@ -388,7 +424,210 @@ public class SecureCompletionClient {
|
||||||
* @throws APIError other errors
|
* @throws APIError other errors
|
||||||
*/
|
*/
|
||||||
public CompletableFuture<Map<String, Object>> sendSecureRequest(Map<String, Object> payload, String payloadId, String apiKey, String securityTier) {
|
public CompletableFuture<Map<String, Object>> sendSecureRequest(Map<String, Object> payload, String payloadId, String apiKey, String securityTier) {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
return CompletableFuture.supplyAsync(() -> {
|
||||||
|
// Validate security tier if provided
|
||||||
|
if (securityTier != null && !Constants.VALID_SECURITY_TIERS.contains(securityTier)) {
|
||||||
|
throw new CompletionException(new ValueError(
|
||||||
|
"Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
ensureKeys(null);
|
||||||
|
} catch (SecurityError e) {
|
||||||
|
throw new CompletionException(new SecurityError("Failed to ensure keys: " + e.getMessage(), e));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 1: Encrypt payload
|
||||||
|
byte[] encryptedPayload;
|
||||||
|
try {
|
||||||
|
encryptedPayload = encryptPayload(payload).get();
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw new CompletionException(new APIConnectionError("Encryption interrupted"));
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
Throwable cause = e.getCause();
|
||||||
|
if (cause instanceof SecurityError) {
|
||||||
|
throw new CompletionException((SecurityError) cause);
|
||||||
|
}
|
||||||
|
throw new CompletionException(new SecurityError("Encryption failed: " + cause.getMessage(), cause));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Step 2: Prepare headers
|
||||||
|
URI url;
|
||||||
|
try {
|
||||||
|
url = new URI(this.routerUrl + Constants.SECURE_COMPLETION_PATH);
|
||||||
|
} catch (URISyntaxException e) {
|
||||||
|
throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e));
|
||||||
|
}
|
||||||
|
|
||||||
|
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url)
|
||||||
|
.timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS))
|
||||||
|
.header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM)
|
||||||
|
.header(Constants.HEADER_PAYLOAD_ID, payloadId)
|
||||||
|
.header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey))
|
||||||
|
.POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
|
||||||
|
|
||||||
|
if (apiKey != null && !apiKey.isEmpty()) {
|
||||||
|
requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey);
|
||||||
|
}
|
||||||
|
if (securityTier != null) {
|
||||||
|
requestBuilder.header(Constants.HEADER_SECURITY_TIER, securityTier);
|
||||||
|
}
|
||||||
|
|
||||||
|
HttpRequest request = requestBuilder.build();
|
||||||
|
|
||||||
|
// Step 3: Send request with retry
|
||||||
|
Exception lastExc = new APIConnectionError("Request failed");
|
||||||
|
int retryableCodes = 0;
|
||||||
|
for (int attempt = 0; attempt <= this.maxRetries; attempt++) {
|
||||||
|
if (attempt > 0) {
|
||||||
|
long delay = (long) Math.pow(2, attempt - 1);
|
||||||
|
try {
|
||||||
|
Thread.sleep(delay * 1000);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw new CompletionException(new APIConnectionError("Retry interrupted"));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
HttpResponse<byte[]> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray());
|
||||||
|
int statusCode = response.statusCode();
|
||||||
|
|
||||||
|
if (statusCode == 200) {
|
||||||
|
// Step 4: Decrypt response
|
||||||
|
return decryptResponse(response.body(), payloadId).get();
|
||||||
|
} else if (statusCode == 400) {
|
||||||
|
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||||
|
Map<String, Object> errorDetails = Map.of();
|
||||||
|
try {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||||
|
if (parsed != null) errorDetails = parsed;
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Unknown error";
|
||||||
|
throw new CompletionException(new InvalidRequestError("Bad request: " + detail, 400, errorDetails));
|
||||||
|
} else if (statusCode == 401) {
|
||||||
|
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||||
|
Map<String, Object> errorDetails = Map.of();
|
||||||
|
try {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||||
|
if (parsed != null) errorDetails = parsed;
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Invalid API key or authentication failed";
|
||||||
|
throw new CompletionException(new AuthenticationError(detail, 401, errorDetails));
|
||||||
|
} else if (statusCode == 403) {
|
||||||
|
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||||
|
Map<String, Object> errorDetails = Map.of();
|
||||||
|
try {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||||
|
if (parsed != null) errorDetails = parsed;
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Model not allowed for the requested security tier";
|
||||||
|
throw new CompletionException(new ForbiddenError("Forbidden: " + detail, 403, errorDetails));
|
||||||
|
} else if (statusCode == 404) {
|
||||||
|
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||||
|
Map<String, Object> errorDetails = Map.of();
|
||||||
|
try {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||||
|
if (parsed != null) errorDetails = parsed;
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Secure inference not enabled";
|
||||||
|
throw new CompletionException(new APIError("Endpoint not found: " + detail, 404, errorDetails));
|
||||||
|
} else if (Constants.RETRYABLE_STATUS_CODES.contains(statusCode)) {
|
||||||
|
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||||
|
Map<String, Object> error = Map.of();
|
||||||
|
String detailMsg = "unknown";
|
||||||
|
try {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||||
|
if (parsed != null) {
|
||||||
|
error = parsed;
|
||||||
|
if (error.containsKey("detail")) detailMsg = error.get("detail").toString();
|
||||||
|
}
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
|
||||||
|
if (statusCode == 429) {
|
||||||
|
lastExc = new RateLimitError("Rate limit exceeded: " + detailMsg, 429, error);
|
||||||
|
} else if (statusCode == 500) {
|
||||||
|
lastExc = new ServerError("Server error: " + detailMsg, 500, error);
|
||||||
|
} else if (statusCode == 503) {
|
||||||
|
lastExc = new ServiceUnavailableError("Service unavailable: " + detailMsg, 503, error);
|
||||||
|
} else {
|
||||||
|
lastExc = new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, error);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (attempt < this.maxRetries) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new CompletionException(lastExc);
|
||||||
|
} else {
|
||||||
|
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||||
|
Map<String, Object> errorDetails = Map.of();
|
||||||
|
String detailMsg = "unknown";
|
||||||
|
try {
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||||
|
if (parsed != null) {
|
||||||
|
errorDetails = parsed;
|
||||||
|
if (errorDetails.containsKey("detail")) detailMsg = errorDetails.get("detail").toString();
|
||||||
|
}
|
||||||
|
} catch (Exception ignored) {
|
||||||
|
}
|
||||||
|
throw new CompletionException(new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, errorDetails));
|
||||||
|
}
|
||||||
|
} catch (CompletionException e) {
|
||||||
|
Throwable cause = e.getCause();
|
||||||
|
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError ||
|
||||||
|
cause instanceof ForbiddenError || cause instanceof SecurityError) {
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
if (cause instanceof RateLimitError || cause instanceof ServerError || cause instanceof ServiceUnavailableError) {
|
||||||
|
lastExc = (Exception) cause;
|
||||||
|
if (attempt < this.maxRetries) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new CompletionException(lastExc);
|
||||||
|
}
|
||||||
|
if (cause instanceof APIError apiError) {
|
||||||
|
if (Constants.RETRYABLE_STATUS_CODES.contains(apiError.getStatusCode())) {
|
||||||
|
lastExc = apiError;
|
||||||
|
if (attempt < this.maxRetries) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new CompletionException(lastExc);
|
||||||
|
}
|
||||||
|
throw e;
|
||||||
|
}
|
||||||
|
lastExc = new APIConnectionError("Failed to connect to router: " + e.getMessage());
|
||||||
|
if (attempt < this.maxRetries) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new CompletionException(lastExc);
|
||||||
|
} catch (InterruptedException e) {
|
||||||
|
Thread.currentThread().interrupt();
|
||||||
|
throw new CompletionException(new APIConnectionError("Request interrupted"));
|
||||||
|
} catch (ExecutionException e) {
|
||||||
|
throw new CompletionException(new APIConnectionError("Request failed: " + e.getMessage(), e));
|
||||||
|
} catch (IOException e) {
|
||||||
|
lastExc = new APIConnectionError("IO error: " + e.getMessage(), e);
|
||||||
|
if (attempt < this.maxRetries) {
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
throw new CompletionException(lastExc);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
throw new CompletionException(lastExc);
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -410,7 +649,7 @@ public class SecureCompletionClient {
|
||||||
*
|
*
|
||||||
* @param keyDir key directory or {@code null} for ephemeral
|
* @param keyDir key directory or {@code null} for ephemeral
|
||||||
*/
|
*/
|
||||||
public void ensureKeys(String keyDir) {
|
public void ensureKeys(String keyDir) throws SecurityError {
|
||||||
if (keysInitialized) return;
|
if (keysInitialized) return;
|
||||||
keyInitLock.lock();
|
keyInitLock.lock();
|
||||||
try {
|
try {
|
||||||
|
|
@ -458,21 +697,19 @@ public class SecureCompletionClient {
|
||||||
* Maps HTTP status code to exception (200→null).
|
* Maps HTTP status code to exception (200→null).
|
||||||
*/
|
*/
|
||||||
public Exception mapHttpStatus(int statusCode, String responseBody) {
|
public Exception mapHttpStatus(int statusCode, String responseBody) {
|
||||||
|
String message = responseBody != null ? responseBody : "no body";
|
||||||
|
Map<String, Object> errorDetails = responseBody != null ? Map.of("response", responseBody) : Map.of();
|
||||||
return switch (statusCode) {
|
return switch (statusCode) {
|
||||||
case 200 -> null;
|
case 200 -> null;
|
||||||
case 400 ->
|
case 400 -> new InvalidRequestError("Invalid request: " + message, statusCode, errorDetails);
|
||||||
new InvalidRequestError("Invalid request: " + (responseBody != null ? responseBody : "no body"));
|
case 401 -> new AuthenticationError("Authentication failed: " + message, statusCode, errorDetails);
|
||||||
case 401 ->
|
case 403 -> new ForbiddenError("Access forbidden: " + message, statusCode, errorDetails);
|
||||||
new AuthenticationError("Authentication failed: " + (responseBody != null ? responseBody : "no body"));
|
case 404 -> new APIError("Not found: " + message, statusCode, errorDetails);
|
||||||
case 403 -> new ForbiddenError("Access forbidden: " + (responseBody != null ? responseBody : "no body"));
|
case 429 -> new RateLimitError("Rate limit exceeded: " + message, statusCode, errorDetails);
|
||||||
case 404 -> new APIError("Not found: " + (responseBody != null ? responseBody : "no body"));
|
case 500 -> new ServerError("Internal server error: " + message, statusCode, errorDetails);
|
||||||
case 429 -> new RateLimitError("Rate limit exceeded: " + (responseBody != null ? responseBody : "no body"));
|
case 503 -> new ServiceUnavailableError("Service unavailable: " + message, statusCode, errorDetails);
|
||||||
case 500 -> new ServerError("Internal server error: " + (responseBody != null ? responseBody : "no body"));
|
case 502, 504 -> new APIError("Gateway error: " + message, statusCode, errorDetails);
|
||||||
case 503 ->
|
default -> new APIError("Unexpected status " + statusCode + ": " + message, statusCode, errorDetails);
|
||||||
new ServiceUnavailableError("Service unavailable: " + (responseBody != null ? responseBody : "no body"));
|
|
||||||
case 502, 504 -> new APIError("Gateway error: " + (responseBody != null ? responseBody : "no body"));
|
|
||||||
default ->
|
|
||||||
new APIError("Unexpected status " + statusCode + ": " + (responseBody != null ? responseBody : "no body"));
|
|
||||||
};
|
};
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -484,9 +721,138 @@ public class SecureCompletionClient {
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Delegates to resource cleanup (stub).
|
* Closes the HTTP client and clears keys from memory.
|
||||||
*/
|
*/
|
||||||
public void close() {
|
public void close() {
|
||||||
throw new UnsupportedOperationException("Not yet implemented");
|
this.httpClient.close();
|
||||||
|
this.privateKey = null;
|
||||||
|
this.publicPemKey = null;
|
||||||
|
this.keysInitialized = false;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Decrypts server response.
|
||||||
|
*/
|
||||||
|
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
|
||||||
|
return CompletableFuture.supplyAsync(() -> {
|
||||||
|
if (encryptedResponse == null || encryptedResponse.length == 0) {
|
||||||
|
throw new CompletionException(new ValueError("Empty encrypted response"));
|
||||||
|
}
|
||||||
|
|
||||||
|
String jsonResponse;
|
||||||
|
try {
|
||||||
|
jsonResponse = new String(encryptedResponse, StandardCharsets.UTF_8);
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new CompletionException(new ValueError("Invalid encrypted package format: not valid UTF-8"));
|
||||||
|
}
|
||||||
|
|
||||||
|
Gson gson = new Gson();
|
||||||
|
JsonObject packageJson;
|
||||||
|
try {
|
||||||
|
packageJson = JsonParser.parseString(jsonResponse).getAsJsonObject();
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new CompletionException(new ValueError("Invalid encrypted package format: malformed JSON"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate required fields
|
||||||
|
String[] requiredFields = {"version", "algorithm", "encrypted_payload", "encrypted_aes_key"};
|
||||||
|
for (String field : requiredFields) {
|
||||||
|
if (!packageJson.has(field)) {
|
||||||
|
throw new CompletionException(new ValueError("Missing required fields in encrypted package: " + field));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate version and algorithm
|
||||||
|
String version = packageJson.get("version").getAsString();
|
||||||
|
String algorithm = packageJson.get("algorithm").getAsString();
|
||||||
|
|
||||||
|
if (!Constants.PROTOCOL_VERSION.equals(version)) {
|
||||||
|
throw new CompletionException(new ValueError(
|
||||||
|
"Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
|
||||||
|
}
|
||||||
|
if (!Constants.HYBRID_ALGORITHM.equals(algorithm)) {
|
||||||
|
throw new CompletionException(new ValueError(
|
||||||
|
"Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Validate encrypted_payload structure
|
||||||
|
JsonObject encryptedPayload = packageJson.get("encrypted_payload").getAsJsonObject();
|
||||||
|
String[] payloadRequired = {"ciphertext", "nonce", "tag"};
|
||||||
|
for (String field : payloadRequired) {
|
||||||
|
if (!encryptedPayload.has(field)) {
|
||||||
|
throw new CompletionException(new ValueError("Missing fields in encrypted_payload: " + field));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Guard: private key must be initialized
|
||||||
|
if (this.privateKey == null) {
|
||||||
|
throw new CompletionException(new SecurityError("Private key not initialized. Call generateKeys() or loadKeys() first."));
|
||||||
|
}
|
||||||
|
|
||||||
|
try {
|
||||||
|
// Decrypt AES key with private key
|
||||||
|
byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString());
|
||||||
|
|
||||||
|
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
|
||||||
|
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey);
|
||||||
|
byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey);
|
||||||
|
SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES");
|
||||||
|
|
||||||
|
// Decrypt payload
|
||||||
|
byte[] ciphertext = Base64.getDecoder().decode(encryptedPayload.get("ciphertext").getAsString());
|
||||||
|
byte[] nonce = Base64.getDecoder().decode(encryptedPayload.get("nonce").getAsString());
|
||||||
|
byte[] tag = Base64.getDecoder().decode(encryptedPayload.get("tag").getAsString());
|
||||||
|
|
||||||
|
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||||
|
aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
|
||||||
|
|
||||||
|
// Combine ciphertext (without tag) and tag for decryption
|
||||||
|
byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length];
|
||||||
|
System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length);
|
||||||
|
System.arraycopy(tag, 0, ciphertextWithTag, ciphertext.length, tag.length);
|
||||||
|
|
||||||
|
byte[] plaintextBytes = aesCipher.doFinal(ciphertextWithTag);
|
||||||
|
|
||||||
|
// Parse JSON response
|
||||||
|
Map<String, Object> response;
|
||||||
|
try {
|
||||||
|
Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class);
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> resultMap = (Map<String, Object>) parsed;
|
||||||
|
response = resultMap != null ? resultMap : new HashMap<>();
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage()));
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add metadata
|
||||||
|
if (!response.containsKey("_metadata")) {
|
||||||
|
response.put("_metadata", new HashMap<String, Object>());
|
||||||
|
}
|
||||||
|
@SuppressWarnings("unchecked")
|
||||||
|
Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
|
||||||
|
metadata.put("payload_id", payloadId);
|
||||||
|
metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null);
|
||||||
|
metadata.put("is_encrypted", true);
|
||||||
|
metadata.put("encryption_algorithm", Constants.HYBRID_ALGORITHM);
|
||||||
|
|
||||||
|
return response;
|
||||||
|
|
||||||
|
} catch (Exception e) {
|
||||||
|
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed"));
|
||||||
|
}
|
||||||
|
});
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Error class for invalid argument/value errors (maps to Python ValueError).
|
||||||
|
*/
|
||||||
|
public static class ValueError extends Exception {
|
||||||
|
public ValueError(String message) {
|
||||||
|
super(message);
|
||||||
|
}
|
||||||
|
|
||||||
|
public ValueError(String message, Throwable cause) {
|
||||||
|
super(message, cause);
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -54,22 +54,6 @@ public final class SecureMemory {
|
||||||
return secureByteArray(data, true);
|
return secureByteArray(data, true);
|
||||||
}
|
}
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link #secureByteArray(byte[])} instead.
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public static SecureBuffer secureBytes(byte[] data, boolean lock) {
|
|
||||||
return new SecureBuffer(data, lock);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* @deprecated Use {@link #secureByteArray(byte[])} instead.
|
|
||||||
*/
|
|
||||||
@Deprecated
|
|
||||||
public static SecureBuffer secureBytes(byte[] data) {
|
|
||||||
return secureBytes(data, true);
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Returns protection capabilities: enabled, protection_level, has_memory_locking, has_secure_zeroing, supports_full_protection, page_size.
|
* Returns protection capabilities: enabled, protection_level, has_memory_locking, has_secure_zeroing, supports_full_protection, page_size.
|
||||||
*/
|
*/
|
||||||
|
|
|
||||||
|
|
@ -1,6 +1,5 @@
|
||||||
package ai.nomyo.util;
|
package ai.nomyo.util;
|
||||||
|
|
||||||
import java.util.Arrays;
|
|
||||||
import java.util.Base64;
|
import java.util.Base64;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
|
|
@ -25,10 +24,14 @@ public class PEMConverter {
|
||||||
return publicKeyFormatted.toString();
|
return publicKeyFormatted.toString();
|
||||||
}
|
}
|
||||||
|
|
||||||
public static String fromPEM(String pem) {
|
public static byte[] fromPEM(String pem) {
|
||||||
pem = pem.replaceAll("^-----BEGIN\\s+PRIVATE\\s+KEY-----|^------END\\s+PUBLIC\\s+KEY-----\n", "");
|
pem = pem.replace("-----BEGIN PRIVATE KEY-----", "")
|
||||||
|
.replace("-----BEGIN PUBLIC KEY-----", "")
|
||||||
|
.replace("-----END PRIVATE KEY-----", "")
|
||||||
|
.replace("-----END PUBLIC KEY-----", "")
|
||||||
|
.replaceAll("\\s+", "");
|
||||||
|
|
||||||
return Arrays.toString(Base64.getDecoder().decode(pem));
|
return Base64.getDecoder().decode(pem);
|
||||||
}
|
}
|
||||||
|
|
||||||
public static boolean validatePEM(String keyIn) {
|
public static boolean validatePEM(String keyIn) {
|
||||||
|
|
|
||||||
|
|
@ -1,5 +1,6 @@
|
||||||
package ai.nomyo.util;
|
package ai.nomyo.util;
|
||||||
|
|
||||||
|
import ai.nomyo.errors.SecurityError;
|
||||||
import javax.crypto.BadPaddingException;
|
import javax.crypto.BadPaddingException;
|
||||||
import javax.crypto.Cipher;
|
import javax.crypto.Cipher;
|
||||||
import javax.crypto.IllegalBlockSizeException;
|
import javax.crypto.IllegalBlockSizeException;
|
||||||
|
|
@ -39,7 +40,7 @@ public final class Pass2Key {
|
||||||
* @param password the password used to derive the encryption key
|
* @param password the password used to derive the encryption key
|
||||||
* @return base64-encoded ciphertext including salt and IV
|
* @return base64-encoded ciphertext including salt and IV
|
||||||
*/
|
*/
|
||||||
public static String encrypt(String algorithm, String input, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException {
|
public static String encrypt(String algorithm, String input, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, SecurityError {
|
||||||
|
|
||||||
byte[] salt = generateRandomBytes(SALT_LENGTH);
|
byte[] salt = generateRandomBytes(SALT_LENGTH);
|
||||||
SecretKey key = deriveKey(password, salt);
|
SecretKey key = deriveKey(password, salt);
|
||||||
|
|
@ -66,7 +67,7 @@ public final class Pass2Key {
|
||||||
* @param password the password used to derive the decryption key
|
* @param password the password used to derive the decryption key
|
||||||
* @return the decrypted plaintext
|
* @return the decrypted plaintext
|
||||||
*/
|
*/
|
||||||
public static String decrypt(String algorithm, String cipherText, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException {
|
public static String decrypt(String algorithm, String cipherText, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, SecurityError {
|
||||||
|
|
||||||
byte[] decoded = Base64.getDecoder().decode(cipherText);
|
byte[] decoded = Base64.getDecoder().decode(cipherText);
|
||||||
|
|
||||||
|
|
@ -86,13 +87,13 @@ public final class Pass2Key {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private static SecretKey deriveKey(String password, byte[] salt) {
|
private static SecretKey deriveKey(String password, byte[] salt) throws SecurityError {
|
||||||
try {
|
try {
|
||||||
SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
|
SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
|
||||||
KeySpec spec = new PBEKeySpec(password.toCharArray(), salt, ITERATION_COUNT, 256);
|
KeySpec spec = new PBEKeySpec(password.toCharArray(), salt, ITERATION_COUNT, 256);
|
||||||
return new SecretKeySpec(factory.generateSecret(spec).getEncoded(), "AES");
|
return new SecretKeySpec(factory.generateSecret(spec).getEncoded(), "AES");
|
||||||
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
|
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
|
||||||
throw new RuntimeException("Could not derive key: " + e.getMessage());
|
throw new SecurityError("Could not derive key: " + e.getMessage(), e);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,7 @@ class SecureCompletionClientE2ETest {
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("E2E: Generate keys, save to disk, load in new client, validate")
|
@DisplayName("E2E: Generate keys, save to disk, load in new client, validate")
|
||||||
void e2e_fullLifecycle_generateSaveLoadValidate() {
|
void e2e_fullLifecycle_generateSaveLoadValidate() throws Exception {
|
||||||
// Step 1: Generate keys and save to disk
|
// Step 1: Generate keys and save to disk
|
||||||
SecureCompletionClient generateClient = new SecureCompletionClient(BASE_URL, false, true, 2);
|
SecureCompletionClient generateClient = new SecureCompletionClient(BASE_URL, false, true, 2);
|
||||||
generateClient.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
generateClient.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||||
|
|
@ -77,7 +77,7 @@ class SecureCompletionClientE2ETest {
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("E2E: Generate plaintext keys, load, and validate")
|
@DisplayName("E2E: Generate plaintext keys, load, and validate")
|
||||||
void e2e_plaintextKeys_generateLoadValidate() {
|
void e2e_plaintextKeys_generateLoadValidate() throws Exception {
|
||||||
// Generate plaintext keys (no password)
|
// Generate plaintext keys (no password)
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(true, keyDir.getAbsolutePath(), null);
|
client.generateKeys(true, keyDir.getAbsolutePath(), null);
|
||||||
|
|
@ -296,7 +296,7 @@ class SecureCompletionClientE2ETest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("E2E: Encrypted key file is unreadable without password")
|
@DisplayName("E2E: Encrypted key file throws SecurityError without correct password")
|
||||||
void e2e_encryptedKey_unreadableWithoutPassword() throws Exception {
|
void e2e_encryptedKey_unreadableWithoutPassword() throws Exception {
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||||
|
|
@ -308,17 +308,19 @@ class SecureCompletionClientE2ETest {
|
||||||
assertFalse(encryptedContent.contains("BEGIN PRIVATE KEY"),
|
assertFalse(encryptedContent.contains("BEGIN PRIVATE KEY"),
|
||||||
"Encrypted file should not contain PEM header");
|
"Encrypted file should not contain PEM header");
|
||||||
|
|
||||||
// Try loading with wrong password - should not throw
|
// Try loading with wrong password - should throw SecurityError
|
||||||
SecureCompletionClient loadClient = new SecureCompletionClient();
|
SecureCompletionClient loadClient = new SecureCompletionClient();
|
||||||
assertDoesNotThrow(() ->
|
SecurityError error = assertThrows(SecurityError.class, () ->
|
||||||
loadClient.loadKeys(privateKeyFile.getAbsolutePath(), null, "wrong-password"),
|
loadClient.loadKeys(privateKeyFile.getAbsolutePath(), null, "wrong-password"),
|
||||||
"Wrong password should be handled gracefully"
|
"Wrong password should throw SecurityError"
|
||||||
);
|
);
|
||||||
|
assertTrue(error.getMessage().contains("decrypt") || error.getMessage().contains("password"),
|
||||||
|
"Error message should mention decryption or password");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("E2E: Generate keys without saving produces in-memory keys")
|
@DisplayName("E2E: Generate keys without saving produces in-memory keys")
|
||||||
void e2e_generateKeys_noSave_producesInMemoryKeys() {
|
void e2e_generateKeys_noSave_producesInMemoryKeys() throws SecurityError {
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(false);
|
client.generateKeys(false);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -28,7 +28,7 @@ class SecureCompletionClientTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("generateKeys should create 4096-bit RSA key pair")
|
@DisplayName("generateKeys should create 4096-bit RSA key pair")
|
||||||
void generateKeys_shouldCreateValidKeyPair() {
|
void generateKeys_shouldCreateValidKeyPair() throws SecurityError {
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(false);
|
client.generateKeys(false);
|
||||||
|
|
||||||
|
|
@ -58,7 +58,7 @@ class SecureCompletionClientTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("generateKeys should create unique keys on each call")
|
@DisplayName("generateKeys should create unique keys on each call")
|
||||||
void generateKeys_shouldProduceUniqueKeys() {
|
void generateKeys_shouldProduceUniqueKeys() throws SecurityError {
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(false);
|
client.generateKeys(false);
|
||||||
PrivateKey firstKey = client.getPrivateKey();
|
PrivateKey firstKey = client.getPrivateKey();
|
||||||
|
|
@ -76,7 +76,7 @@ class SecureCompletionClientTest {
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("generateKeys with saveToFile=true should create key files")
|
@DisplayName("generateKeys with saveToFile=true should create key files")
|
||||||
void generateKeys_withSaveToFile_shouldCreateKeyFiles(@TempDir Path tempDir) {
|
void generateKeys_withSaveToFile_shouldCreateKeyFiles(@TempDir Path tempDir) throws SecurityError {
|
||||||
File keyDir = tempDir.toFile();
|
File keyDir = tempDir.toFile();
|
||||||
|
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
|
|
@ -112,7 +112,7 @@ class SecureCompletionClientTest {
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("generateKeys should not overwrite existing key files")
|
@DisplayName("generateKeys should not overwrite existing key files")
|
||||||
void generateKeys_shouldNotOverwriteExistingKeys(@TempDir Path tempDir) {
|
void generateKeys_shouldNotOverwriteExistingKeys(@TempDir Path tempDir) throws SecurityError {
|
||||||
File keyDir = tempDir.toFile();
|
File keyDir = tempDir.toFile();
|
||||||
|
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
|
|
@ -129,10 +129,10 @@ class SecureCompletionClientTest {
|
||||||
|
|
||||||
// ── Key Loading Tests ─────────────────────────────────────────────
|
// ── Key Loading Tests ─────────────────────────────────────────────
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("loadKeys should load plaintext private key from file")
|
@DisplayName("loadKeys should load plaintext private key from file")
|
||||||
void loadKeys_plaintext_shouldLoadPrivateKey(@TempDir Path tempDir) {
|
void loadKeys_plaintext_shouldLoadPrivateKey(@TempDir Path tempDir) throws Exception {
|
||||||
File keyDir = tempDir.toFile();
|
File keyDir = tempDir.toFile();
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(true, keyDir.getAbsolutePath(), null);
|
client.generateKeys(true, keyDir.getAbsolutePath(), null);
|
||||||
|
|
@ -151,10 +151,10 @@ class SecureCompletionClientTest {
|
||||||
"Loaded key should have same size as original");
|
"Loaded key should have same size as original");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("loadKeys should load encrypted private key with correct password")
|
@DisplayName("loadKeys should load encrypted private key with correct password")
|
||||||
void loadKeys_encrypted_correctPassword_shouldLoadPrivateKey(@TempDir Path tempDir) {
|
void loadKeys_encrypted_correctPassword_shouldLoadPrivateKey(@TempDir Path tempDir) throws Exception {
|
||||||
File keyDir = tempDir.toFile();
|
File keyDir = tempDir.toFile();
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||||
|
|
@ -176,30 +176,33 @@ class SecureCompletionClientTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("loadKeys should handle wrong password gracefully")
|
@DisplayName("loadKeys should throw SecurityError for wrong password")
|
||||||
void loadKeys_encrypted_wrongPassword_shouldHandleGracefully(@TempDir Path tempDir) {
|
void loadKeys_encrypted_wrongPassword_shouldThrowSecurityError(@TempDir Path tempDir) throws SecurityError {
|
||||||
File keyDir = tempDir.toFile();
|
File keyDir = tempDir.toFile();
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||||
|
|
||||||
SecureCompletionClient loadClient = new SecureCompletionClient();
|
SecureCompletionClient loadClient = new SecureCompletionClient();
|
||||||
|
|
||||||
assertDoesNotThrow(() ->
|
SecurityError error = assertThrows(SecurityError.class, () ->
|
||||||
loadClient.loadKeys(
|
loadClient.loadKeys(
|
||||||
new File(keyDir, Constants.DEFAULT_PRIVATE_KEY_FILE).getAbsolutePath(),
|
new File(keyDir, Constants.DEFAULT_PRIVATE_KEY_FILE).getAbsolutePath(),
|
||||||
null,
|
null,
|
||||||
"wrong-password"
|
"wrong-password"
|
||||||
),
|
),
|
||||||
"Wrong password should not throw exception"
|
"Wrong password should throw SecurityError"
|
||||||
);
|
);
|
||||||
|
|
||||||
|
assertTrue(error.getMessage().contains("decrypt") || error.getMessage().contains("password"),
|
||||||
|
"Error message should mention decryption or password");
|
||||||
}
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("loadKeys should throw exception for non-existent file")
|
@DisplayName("loadKeys should throw exception for non-existent file")
|
||||||
void loadKeys_nonExistentFile_shouldThrowException() {
|
void loadKeys_nonExistentFile_shouldThrowException() {
|
||||||
SecureCompletionClient loadClient = new SecureCompletionClient();
|
SecureCompletionClient loadClient = new SecureCompletionClient();
|
||||||
|
|
||||||
RuntimeException error = assertThrows(RuntimeException.class, () ->
|
SecurityError error = assertThrows(SecurityError.class, () ->
|
||||||
loadClient.loadKeys("/non/existent/path/private_key.pem", null, null));
|
loadClient.loadKeys("/non/existent/path/private_key.pem", null, null));
|
||||||
|
|
||||||
assertTrue(error.getMessage().contains("not found"),
|
assertTrue(error.getMessage().contains("not found"),
|
||||||
|
|
@ -210,7 +213,7 @@ class SecureCompletionClientTest {
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
@DisplayName("validateRsaKey should accept valid 4096-bit key")
|
@DisplayName("validateRsaKey should accept valid 4096-bit key")
|
||||||
void validateRsaKey_validKey_shouldPass() {
|
void validateRsaKey_validKey_shouldPass() throws SecurityError {
|
||||||
SecureCompletionClient client = new SecureCompletionClient();
|
SecureCompletionClient client = new SecureCompletionClient();
|
||||||
client.generateKeys(false);
|
client.generateKeys(false);
|
||||||
PrivateKey key = client.getPrivateKey();
|
PrivateKey key = client.getPrivateKey();
|
||||||
|
|
@ -264,7 +267,7 @@ class SecureCompletionClientTest {
|
||||||
@Test
|
@Test
|
||||||
@Execution(ExecutionMode.SAME_THREAD)
|
@Execution(ExecutionMode.SAME_THREAD)
|
||||||
@DisplayName("Full roundtrip: generate, save, load should produce same key")
|
@DisplayName("Full roundtrip: generate, save, load should produce same key")
|
||||||
void roundtrip_generateSaveLoad_shouldProduceSameKey(@TempDir Path tempDir) {
|
void roundtrip_generateSaveLoad_shouldProduceSameKey(@TempDir Path tempDir) throws Exception {
|
||||||
File keyDir = tempDir.toFile();
|
File keyDir = tempDir.toFile();
|
||||||
|
|
||||||
// Generate and save
|
// Generate and save
|
||||||
|
|
@ -407,10 +410,6 @@ class SecureCompletionClientTest {
|
||||||
tempClient.generateKeys(false);
|
tempClient.generateKeys(false);
|
||||||
PrivateKey originalKey = tempClient.getPrivateKey();
|
PrivateKey originalKey = tempClient.getPrivateKey();
|
||||||
|
|
||||||
String pem = "-----BEGIN PRIVATE KEY-----\n " +
|
|
||||||
originalKey.getEncoded().length + "lines\n" +
|
|
||||||
"-----END PRIVATE KEY-----";
|
|
||||||
|
|
||||||
String formattedPem = ai.nomyo.util.PEMConverter.toPEM(originalKey.getEncoded(), true);
|
String formattedPem = ai.nomyo.util.PEMConverter.toPEM(originalKey.getEncoded(), true);
|
||||||
String pemWithWhitespace = formattedPem.replace("\n", "\n ");
|
String pemWithWhitespace = formattedPem.replace("\n", "\n ");
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue