Finish functionality
This commit is contained in:
parent
b6af1c9792
commit
89d5282b0f
9 changed files with 583 additions and 133 deletions
|
|
@ -1,6 +1,7 @@
|
|||
package ai.nomyo;
|
||||
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.GsonBuilder;
|
||||
import com.google.gson.annotations.SerializedName;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
|
@ -12,6 +13,8 @@ import lombok.Setter;
|
|||
@Getter
|
||||
public class EncryptedRequest {
|
||||
|
||||
private static final Gson GSON = new GsonBuilder().create();
|
||||
|
||||
// Getters and Setters
|
||||
@SerializedName("version")
|
||||
private String version;
|
||||
|
|
@ -55,7 +58,7 @@ public class EncryptedRequest {
|
|||
}
|
||||
}
|
||||
|
||||
public String toJson() {
|
||||
return new Gson().toJson(this);
|
||||
public String toJson() {
|
||||
return GSON.toJson(this);
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,5 +1,10 @@
|
|||
package ai.nomyo;
|
||||
|
||||
import ai.nomyo.errors.APIConnectionError;
|
||||
import ai.nomyo.errors.SecurityError;
|
||||
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
|
|
@ -8,17 +13,22 @@ import java.util.concurrent.ExecutionException;
|
|||
public class Main {
|
||||
|
||||
static void main() {
|
||||
SecureCompletionClient secureCompletionClient = new SecureCompletionClient();
|
||||
//secureCompletionClient.generateKeys(true, "client_keys", "pokemon");
|
||||
//secureCompletionClient.loadKeys("client_keys/private_key.pem", "pokemon");
|
||||
SecureChatCompletion secureChatCompletion = new SecureChatCompletion( Constants.DEFAULT_BASE_URL, "NOMYO_AI_E2EE_INFERENCE");
|
||||
List<Map<String, Object>> messages = List.of(
|
||||
Map.of("role", "user", "content", "Hello! How are you today?")
|
||||
);
|
||||
|
||||
Map<String, Object> kwargs = Map.of(
|
||||
"security_tier", "standard",
|
||||
"temperature", 0.7
|
||||
);
|
||||
|
||||
try {
|
||||
System.out.println(secureCompletionClient.fetchServerPublicKey().get());
|
||||
} catch (InterruptedException | ExecutionException e) {
|
||||
throw new RuntimeException(e);
|
||||
}
|
||||
var response = secureChatCompletion.create(
|
||||
"Qwen/Qwen3-0.6B",
|
||||
messages,
|
||||
kwargs);
|
||||
|
||||
System.out.println(response.toString());
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -3,8 +3,12 @@ package ai.nomyo;
|
|||
import ai.nomyo.errors.*;
|
||||
import lombok.Getter;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.UUID;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
|
||||
/**
|
||||
* High-level OpenAI-compatible entrypoint with automatic hybrid encryption (AES-256-GCM + RSA-4096).
|
||||
|
|
@ -23,6 +27,14 @@ public class SecureChatCompletion {
|
|||
this(Constants.DEFAULT_BASE_URL, false, null, true, null, Constants.DEFAULT_MAX_RETRIES);
|
||||
}
|
||||
|
||||
public SecureChatCompletion(String baseUrl, String apiKey) {
|
||||
this(baseUrl, false, apiKey, true, null, Constants.DEFAULT_MAX_RETRIES);
|
||||
}
|
||||
|
||||
public SecureChatCompletion(String baseUrl) {
|
||||
this(baseUrl, false, null, true, null, Constants.DEFAULT_MAX_RETRIES);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param baseUrl NOMYO Router base URL (HTTPS enforced unless {@code allowHttp})
|
||||
* @param allowHttp permit {@code http://} URLs (development only)
|
||||
|
|
@ -57,15 +69,85 @@ public class SecureChatCompletion {
|
|||
* @throws ServiceUnavailableError HTTP 503
|
||||
* @throws APIError other errors
|
||||
*/
|
||||
@SuppressWarnings("unchecked")
|
||||
public Map<String, Object> create(String model, List<Map<String, Object>> messages, Map<String, Object> kwargs) {
|
||||
// Build payload from model, messages, and kwargs
|
||||
// Validate stream is false
|
||||
// Validate securityTier if provided
|
||||
// Use per-call api_key override if provided, else instance apiKey
|
||||
// Create temp client if baseUrl override provided
|
||||
// Validate required parameters
|
||||
if (model == null || model.isEmpty()) {
|
||||
throw new IllegalArgumentException("model is required");
|
||||
}
|
||||
if (messages == null || messages.isEmpty()) {
|
||||
throw new IllegalArgumentException("messages is required and cannot be empty");
|
||||
}
|
||||
|
||||
// Build payload
|
||||
Map<String, Object> payload = new HashMap<>();
|
||||
payload.put("model", model);
|
||||
payload.put("messages", messages);
|
||||
|
||||
// Add kwargs
|
||||
if (kwargs != null) {
|
||||
// Check for stream parameter
|
||||
if (kwargs.containsKey("stream")) {
|
||||
Object streamValue = kwargs.get("stream");
|
||||
boolean stream = streamValue instanceof Boolean ? (Boolean) streamValue : Boolean.parseBoolean(streamValue.toString());
|
||||
if (stream) {
|
||||
throw new IllegalArgumentException("Streaming is not supported");
|
||||
}
|
||||
}
|
||||
|
||||
// Check for security_tier
|
||||
if (kwargs.containsKey("security_tier")) {
|
||||
Object tier = kwargs.get("security_tier");
|
||||
if (tier != null && !Constants.VALID_SECURITY_TIERS.contains(tier.toString())) {
|
||||
throw new IllegalArgumentException(
|
||||
"Invalid security_tier: '" + tier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS);
|
||||
}
|
||||
}
|
||||
|
||||
payload.putAll(kwargs);
|
||||
}
|
||||
|
||||
// Determine API key (per-call override or instance key)
|
||||
String apiKey = this.apiKey;
|
||||
if (kwargs != null && kwargs.containsKey("api_key")) {
|
||||
Object key = kwargs.get("api_key");
|
||||
if (key != null) {
|
||||
apiKey = key.toString();
|
||||
}
|
||||
}
|
||||
|
||||
// Determine security tier
|
||||
String securityTier = null;
|
||||
if (kwargs != null && kwargs.containsKey("security_tier")) {
|
||||
securityTier = kwargs.get("security_tier").toString();
|
||||
}
|
||||
|
||||
// Generate payload ID
|
||||
String payloadId = UUID.randomUUID().toString();
|
||||
|
||||
// Send secure request
|
||||
// Return decrypted response map
|
||||
return null;
|
||||
try {
|
||||
return client.sendSecureRequest(payload, payloadId, apiKey, securityTier).get();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException("Request interrupted", e);
|
||||
} catch (ExecutionException e) {
|
||||
Throwable cause = e.getCause();
|
||||
switch (cause) {
|
||||
case SecurityError securityError -> throw new RuntimeException(cause);
|
||||
case InvalidRequestError invalidRequestError -> throw new RuntimeException(cause);
|
||||
case AuthenticationError authenticationError -> throw new RuntimeException(cause);
|
||||
case ForbiddenError forbiddenError -> throw new RuntimeException(cause);
|
||||
case RateLimitError rateLimitError -> throw new RuntimeException(cause);
|
||||
case ServerError serverError -> throw new RuntimeException(cause);
|
||||
case ServiceUnavailableError serviceUnavailableError -> throw new RuntimeException(cause);
|
||||
case APIError apiError -> throw new RuntimeException(cause);
|
||||
case APIConnectionError apiConnectionError -> throw new RuntimeException(cause);
|
||||
case SecureCompletionClient.ValueError valueError -> throw new IllegalArgumentException(cause);
|
||||
default ->
|
||||
throw new RuntimeException("Request failed: " + cause.getMessage(), cause);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -4,6 +4,7 @@ import ai.nomyo.errors.*;
|
|||
import ai.nomyo.util.PEMConverter;
|
||||
import ai.nomyo.util.Pass2Key;
|
||||
import lombok.Getter;
|
||||
import lombok.Setter;
|
||||
|
||||
import javax.crypto.*;
|
||||
import javax.crypto.spec.GCMParameterSpec;
|
||||
|
|
@ -27,6 +28,10 @@ import java.util.*;
|
|||
import java.util.concurrent.CompletableFuture;
|
||||
import java.util.concurrent.CompletionException;
|
||||
import java.util.concurrent.ExecutionException;
|
||||
import com.google.gson.Gson;
|
||||
import com.google.gson.JsonObject;
|
||||
import com.google.gson.JsonParser;
|
||||
|
||||
import java.util.concurrent.locks.ReentrantLock;
|
||||
|
||||
/**
|
||||
|
|
@ -113,7 +118,7 @@ public class SecureCompletionClient {
|
|||
/**
|
||||
* Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}.
|
||||
*/
|
||||
public void generateKeys(boolean saveToFile, String keyDir, String password) {
|
||||
public void generateKeys(boolean saveToFile, String keyDir, String password) throws SecurityError {
|
||||
try {
|
||||
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA");
|
||||
generator.initialize(new RSAKeyGenParameterSpec(Constants.RSA_KEY_SIZE, BigInteger.valueOf(Constants.RSA_PUBLIC_EXPONENT)));
|
||||
|
|
@ -145,8 +150,8 @@ public class SecureCompletionClient {
|
|||
try {
|
||||
privatePem = Pass2Key.encrypt("AES/GCM/NoPadding", privatePem, password);
|
||||
} catch (NoSuchPaddingException | IllegalBlockSizeException | BadPaddingException |
|
||||
InvalidKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
InvalidKeyException | SecurityError e) {
|
||||
throw new SecurityError("Failed to encrypt private key with password: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
writer.write(privatePem);
|
||||
|
|
@ -167,10 +172,12 @@ public class SecureCompletionClient {
|
|||
this.privateKey = pair.getPrivate();
|
||||
this.publicPemKey = publicPem;
|
||||
|
||||
} catch (SecurityError e) {
|
||||
throw e;
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new RuntimeException("RSA not available: " + e.getMessage(), e);
|
||||
throw new SecurityError("RSA algorithm not available: " + e.getMessage(), e);
|
||||
} catch (InvalidAlgorithmParameterException e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new SecurityError("Invalid RSA key generation parameters: " + e.getMessage(), e);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to save keys: " + e.getMessage(), e);
|
||||
}
|
||||
|
|
@ -179,7 +186,7 @@ public class SecureCompletionClient {
|
|||
/**
|
||||
* Generates a 4096-bit RSA key pair and saves to the default directory.
|
||||
*/
|
||||
public void generateKeys(boolean saveToFile) {
|
||||
public void generateKeys(boolean saveToFile) throws SecurityError {
|
||||
generateKeys(saveToFile, Constants.DEFAULT_KEY_DIR, null);
|
||||
}
|
||||
|
||||
|
|
@ -190,11 +197,12 @@ public class SecureCompletionClient {
|
|||
* @param privateKeyPath private key PEM path
|
||||
* @param publicPemKeyPath optional public key PEM path
|
||||
* @param password optional password for encrypted private key
|
||||
* @throws SecurityError if key file not found, unreadable, or decryption fails
|
||||
*/
|
||||
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) {
|
||||
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) throws SecurityError {
|
||||
Path keyPath = Path.of(privateKeyPath);
|
||||
if (!Files.exists(keyPath)) {
|
||||
throw new RuntimeException("Private key file not found: " + privateKeyPath);
|
||||
throw new SecurityError("Private key file not found: " + privateKeyPath);
|
||||
}
|
||||
|
||||
String keyContent;
|
||||
|
|
@ -202,35 +210,36 @@ public class SecureCompletionClient {
|
|||
try {
|
||||
keyContent = readFileContent(privateKeyPath);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
|
||||
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
|
||||
}
|
||||
|
||||
try {
|
||||
keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password);
|
||||
} catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException |
|
||||
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException e) {
|
||||
System.out.println("Wrong password!");
|
||||
return;
|
||||
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | SecurityError e) {
|
||||
throw new SecurityError("Failed to decrypt private key with provided password: " + e.getMessage(), e);
|
||||
}
|
||||
} else {
|
||||
try {
|
||||
keyContent = readFileContent(privateKeyPath);
|
||||
} catch (IOException e) {
|
||||
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
|
||||
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
this.privateKey = Pass2Key.convertStringToPrivateKey(keyContent);
|
||||
} catch (Exception e) {
|
||||
throw new RuntimeException("Failed to load private key: " + e.getMessage(), e);
|
||||
throw new SecurityError("Failed to load private key: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Loads RSA private key from disk, deriving public key.
|
||||
*
|
||||
* @throws SecurityError if key file not found, unreadable, or decryption fails
|
||||
*/
|
||||
public void loadKeys(String privateKeyPath, String password) {
|
||||
public void loadKeys(String privateKeyPath, String password) throws SecurityError {
|
||||
loadKeys(privateKeyPath, null, password);
|
||||
}
|
||||
|
||||
|
|
@ -250,12 +259,12 @@ public class SecureCompletionClient {
|
|||
URI url;
|
||||
|
||||
try {
|
||||
url = new URI(this.routerUrl + "/pki/public_key");
|
||||
url = new URI(this.routerUrl + Constants.PKI_PUBLIC_KEY_PATH);
|
||||
} catch (URISyntaxException e) {
|
||||
return CompletableFuture.failedFuture(new CompletionException("Invalid URI: " + e.getMessage(), e));
|
||||
return CompletableFuture.failedFuture(new CompletionException(new APIConnectionError("Invalid URI: " + e.getMessage(), e)));
|
||||
}
|
||||
|
||||
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.of(60, ChronoUnit.SECONDS)).GET().build();
|
||||
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).GET().build();
|
||||
|
||||
return this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> {
|
||||
if (response.statusCode() != 200) {
|
||||
|
|
@ -264,7 +273,7 @@ public class SecureCompletionClient {
|
|||
return response.body();
|
||||
}).thenApply(body -> {
|
||||
if (!PEMConverter.validatePEM(body)) {
|
||||
throw new CompletionException(new InvalidKeyException("PEM key had invalid format"));
|
||||
throw new CompletionException(new SecurityError("Server returned invalid PEM key format (possible MITM attack)"));
|
||||
}
|
||||
return body;
|
||||
});
|
||||
|
|
@ -278,7 +287,37 @@ public class SecureCompletionClient {
|
|||
* @throws SecurityError if encryption fails or keys not loaded
|
||||
*/
|
||||
public CompletableFuture<byte[]> encryptPayload(Map<String, Object> payload) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
try {
|
||||
ensureKeys(null);
|
||||
} catch (SecurityError e) {
|
||||
throw new CompletionException(new SecurityError("Failed to ensure keys are initialized: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
if (this.privateKey == null) {
|
||||
throw new CompletionException(new SecurityError("Private key not available for encryption"));
|
||||
}
|
||||
|
||||
// Generate AES key
|
||||
KeyGenerator keyGen;
|
||||
|
||||
try {
|
||||
keyGen = KeyGenerator.getInstance("AES");
|
||||
keyGen.init(Constants.AES_KEY_SIZE * 8);
|
||||
} catch (NoSuchAlgorithmException e) {
|
||||
throw new CompletionException(new SecurityError("AES key generation not available: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
Key aesKey = keyGen.generateKey();
|
||||
|
||||
// Serialize payload to JSON
|
||||
Gson gson = new Gson();
|
||||
String payloadJson = gson.toJson(payload);
|
||||
byte[] payloadBytes = payloadJson.getBytes(StandardCharsets.UTF_8);
|
||||
|
||||
// Encrypt
|
||||
return doEncrypt(payloadBytes, aesKey).join();
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -287,16 +326,16 @@ public class SecureCompletionClient {
|
|||
public CompletableFuture<byte[]> doEncrypt(byte[] payloadBytes, Key aesKey) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
SecureRandom random = new SecureRandom();
|
||||
byte[] nonce = new byte[12];
|
||||
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
|
||||
random.nextBytes(nonce);
|
||||
|
||||
Cipher cipher = null;
|
||||
try {
|
||||
cipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(128, nonce));
|
||||
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, nonce));
|
||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException |
|
||||
InvalidKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException(new SecurityError("AES-GCM cipher initialization failed: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
byte[] ciphertext;
|
||||
|
|
@ -304,66 +343,63 @@ public class SecureCompletionClient {
|
|||
try {
|
||||
ciphertext = cipher.doFinal(payloadBytes);
|
||||
} catch (IllegalBlockSizeException | BadPaddingException e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException(new SecurityError("AES-GCM encryption failed: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
String serverPEM;
|
||||
|
||||
try {
|
||||
serverPEM = fetchServerPublicKey().get();
|
||||
} catch (InterruptedException | ExecutionException e) {
|
||||
throw new RuntimeException(e);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e));
|
||||
} catch (ExecutionException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof SecurityError) {
|
||||
throw new RuntimeException((SecurityError) cause);
|
||||
}
|
||||
throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause));
|
||||
}
|
||||
|
||||
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM).getBytes());
|
||||
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM));
|
||||
|
||||
PublicKey serverPublicKey;
|
||||
|
||||
try {
|
||||
serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec);
|
||||
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
Cipher rsa;
|
||||
|
||||
byte[] enryptedAESKey = aesKey.getEncoded();
|
||||
|
||||
try {
|
||||
rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
|
||||
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey);
|
||||
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
byte[] encryptedAESKey;
|
||||
try {
|
||||
rsa.doFinal(enryptedAESKey);
|
||||
encryptedAESKey = rsa.doFinal(aesKey.getEncoded());
|
||||
} catch (IllegalBlockSizeException | BadPaddingException e) {
|
||||
throw new RuntimeException(e);
|
||||
throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - (128 / Byte.SIZE), ciphertext.length);
|
||||
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length);
|
||||
|
||||
EncryptedRequest request = new EncryptedRequest();
|
||||
|
||||
request.setVersion("1.0");
|
||||
request.setAlgorithm("hybrid-aes256-rsa4096");
|
||||
request.setVersion(Constants.PROTOCOL_VERSION);
|
||||
request.setAlgorithm(Constants.HYBRID_ALGORITHM);
|
||||
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
|
||||
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(enryptedAESKey));
|
||||
request.setKeyAlgorithm("RSA-OAEP-SHA256");
|
||||
request.setPayloadAlgorithm("AES-256-GCM");
|
||||
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey));
|
||||
request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM);
|
||||
request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM);
|
||||
|
||||
return request.toJson().getBytes(StandardCharsets.UTF_8);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypts server response.
|
||||
*/
|
||||
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
}
|
||||
|
||||
/**
|
||||
* encrypt → POST {routerUrl}/v1/chat/secure_completion → retry → decrypt → return.
|
||||
* <p>Headers: Content-Type=octet-stream, X-Payload-ID, X-Public-Key, Authorization (Bearer), X-Security-Tier.
|
||||
|
|
@ -388,7 +424,210 @@ public class SecureCompletionClient {
|
|||
* @throws APIError other errors
|
||||
*/
|
||||
public CompletableFuture<Map<String, Object>> sendSecureRequest(Map<String, Object> payload, String payloadId, String apiKey, String securityTier) {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
// Validate security tier if provided
|
||||
if (securityTier != null && !Constants.VALID_SECURITY_TIERS.contains(securityTier)) {
|
||||
throw new CompletionException(new ValueError(
|
||||
"Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
|
||||
}
|
||||
|
||||
try {
|
||||
ensureKeys(null);
|
||||
} catch (SecurityError e) {
|
||||
throw new CompletionException(new SecurityError("Failed to ensure keys: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
// Step 1: Encrypt payload
|
||||
byte[] encryptedPayload;
|
||||
try {
|
||||
encryptedPayload = encryptPayload(payload).get();
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new CompletionException(new APIConnectionError("Encryption interrupted"));
|
||||
} catch (ExecutionException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof SecurityError) {
|
||||
throw new CompletionException((SecurityError) cause);
|
||||
}
|
||||
throw new CompletionException(new SecurityError("Encryption failed: " + cause.getMessage(), cause));
|
||||
}
|
||||
|
||||
// Step 2: Prepare headers
|
||||
URI url;
|
||||
try {
|
||||
url = new URI(this.routerUrl + Constants.SECURE_COMPLETION_PATH);
|
||||
} catch (URISyntaxException e) {
|
||||
throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e));
|
||||
}
|
||||
|
||||
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url)
|
||||
.timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS))
|
||||
.header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM)
|
||||
.header(Constants.HEADER_PAYLOAD_ID, payloadId)
|
||||
.header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey))
|
||||
.POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
|
||||
|
||||
if (apiKey != null && !apiKey.isEmpty()) {
|
||||
requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey);
|
||||
}
|
||||
if (securityTier != null) {
|
||||
requestBuilder.header(Constants.HEADER_SECURITY_TIER, securityTier);
|
||||
}
|
||||
|
||||
HttpRequest request = requestBuilder.build();
|
||||
|
||||
// Step 3: Send request with retry
|
||||
Exception lastExc = new APIConnectionError("Request failed");
|
||||
int retryableCodes = 0;
|
||||
for (int attempt = 0; attempt <= this.maxRetries; attempt++) {
|
||||
if (attempt > 0) {
|
||||
long delay = (long) Math.pow(2, attempt - 1);
|
||||
try {
|
||||
Thread.sleep(delay * 1000);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new CompletionException(new APIConnectionError("Retry interrupted"));
|
||||
}
|
||||
}
|
||||
|
||||
try {
|
||||
HttpResponse<byte[]> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray());
|
||||
int statusCode = response.statusCode();
|
||||
|
||||
if (statusCode == 200) {
|
||||
// Step 4: Decrypt response
|
||||
return decryptResponse(response.body(), payloadId).get();
|
||||
} else if (statusCode == 400) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Unknown error";
|
||||
throw new CompletionException(new InvalidRequestError("Bad request: " + detail, 400, errorDetails));
|
||||
} else if (statusCode == 401) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Invalid API key or authentication failed";
|
||||
throw new CompletionException(new AuthenticationError(detail, 401, errorDetails));
|
||||
} else if (statusCode == 403) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Model not allowed for the requested security tier";
|
||||
throw new CompletionException(new ForbiddenError("Forbidden: " + detail, 403, errorDetails));
|
||||
} else if (statusCode == 404) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) errorDetails = parsed;
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Secure inference not enabled";
|
||||
throw new CompletionException(new APIError("Endpoint not found: " + detail, 404, errorDetails));
|
||||
} else if (Constants.RETRYABLE_STATUS_CODES.contains(statusCode)) {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> error = Map.of();
|
||||
String detailMsg = "unknown";
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) {
|
||||
error = parsed;
|
||||
if (error.containsKey("detail")) detailMsg = error.get("detail").toString();
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
|
||||
if (statusCode == 429) {
|
||||
lastExc = new RateLimitError("Rate limit exceeded: " + detailMsg, 429, error);
|
||||
} else if (statusCode == 500) {
|
||||
lastExc = new ServerError("Server error: " + detailMsg, 500, error);
|
||||
} else if (statusCode == 503) {
|
||||
lastExc = new ServiceUnavailableError("Service unavailable: " + detailMsg, 503, error);
|
||||
} else {
|
||||
lastExc = new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, error);
|
||||
}
|
||||
|
||||
if (attempt < this.maxRetries) {
|
||||
continue;
|
||||
}
|
||||
throw new CompletionException(lastExc);
|
||||
} else {
|
||||
String body = new String(response.body(), StandardCharsets.UTF_8);
|
||||
Map<String, Object> errorDetails = Map.of();
|
||||
String detailMsg = "unknown";
|
||||
try {
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
|
||||
if (parsed != null) {
|
||||
errorDetails = parsed;
|
||||
if (errorDetails.containsKey("detail")) detailMsg = errorDetails.get("detail").toString();
|
||||
}
|
||||
} catch (Exception ignored) {
|
||||
}
|
||||
throw new CompletionException(new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, errorDetails));
|
||||
}
|
||||
} catch (CompletionException e) {
|
||||
Throwable cause = e.getCause();
|
||||
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError ||
|
||||
cause instanceof ForbiddenError || cause instanceof SecurityError) {
|
||||
throw e;
|
||||
}
|
||||
if (cause instanceof RateLimitError || cause instanceof ServerError || cause instanceof ServiceUnavailableError) {
|
||||
lastExc = (Exception) cause;
|
||||
if (attempt < this.maxRetries) {
|
||||
continue;
|
||||
}
|
||||
throw new CompletionException(lastExc);
|
||||
}
|
||||
if (cause instanceof APIError apiError) {
|
||||
if (Constants.RETRYABLE_STATUS_CODES.contains(apiError.getStatusCode())) {
|
||||
lastExc = apiError;
|
||||
if (attempt < this.maxRetries) {
|
||||
continue;
|
||||
}
|
||||
throw new CompletionException(lastExc);
|
||||
}
|
||||
throw e;
|
||||
}
|
||||
lastExc = new APIConnectionError("Failed to connect to router: " + e.getMessage());
|
||||
if (attempt < this.maxRetries) {
|
||||
continue;
|
||||
}
|
||||
throw new CompletionException(lastExc);
|
||||
} catch (InterruptedException e) {
|
||||
Thread.currentThread().interrupt();
|
||||
throw new CompletionException(new APIConnectionError("Request interrupted"));
|
||||
} catch (ExecutionException e) {
|
||||
throw new CompletionException(new APIConnectionError("Request failed: " + e.getMessage(), e));
|
||||
} catch (IOException e) {
|
||||
lastExc = new APIConnectionError("IO error: " + e.getMessage(), e);
|
||||
if (attempt < this.maxRetries) {
|
||||
continue;
|
||||
}
|
||||
throw new CompletionException(lastExc);
|
||||
}
|
||||
}
|
||||
|
||||
throw new CompletionException(lastExc);
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -410,7 +649,7 @@ public class SecureCompletionClient {
|
|||
*
|
||||
* @param keyDir key directory or {@code null} for ephemeral
|
||||
*/
|
||||
public void ensureKeys(String keyDir) {
|
||||
public void ensureKeys(String keyDir) throws SecurityError {
|
||||
if (keysInitialized) return;
|
||||
keyInitLock.lock();
|
||||
try {
|
||||
|
|
@ -458,21 +697,19 @@ public class SecureCompletionClient {
|
|||
* Maps HTTP status code to exception (200→null).
|
||||
*/
|
||||
public Exception mapHttpStatus(int statusCode, String responseBody) {
|
||||
String message = responseBody != null ? responseBody : "no body";
|
||||
Map<String, Object> errorDetails = responseBody != null ? Map.of("response", responseBody) : Map.of();
|
||||
return switch (statusCode) {
|
||||
case 200 -> null;
|
||||
case 400 ->
|
||||
new InvalidRequestError("Invalid request: " + (responseBody != null ? responseBody : "no body"));
|
||||
case 401 ->
|
||||
new AuthenticationError("Authentication failed: " + (responseBody != null ? responseBody : "no body"));
|
||||
case 403 -> new ForbiddenError("Access forbidden: " + (responseBody != null ? responseBody : "no body"));
|
||||
case 404 -> new APIError("Not found: " + (responseBody != null ? responseBody : "no body"));
|
||||
case 429 -> new RateLimitError("Rate limit exceeded: " + (responseBody != null ? responseBody : "no body"));
|
||||
case 500 -> new ServerError("Internal server error: " + (responseBody != null ? responseBody : "no body"));
|
||||
case 503 ->
|
||||
new ServiceUnavailableError("Service unavailable: " + (responseBody != null ? responseBody : "no body"));
|
||||
case 502, 504 -> new APIError("Gateway error: " + (responseBody != null ? responseBody : "no body"));
|
||||
default ->
|
||||
new APIError("Unexpected status " + statusCode + ": " + (responseBody != null ? responseBody : "no body"));
|
||||
case 400 -> new InvalidRequestError("Invalid request: " + message, statusCode, errorDetails);
|
||||
case 401 -> new AuthenticationError("Authentication failed: " + message, statusCode, errorDetails);
|
||||
case 403 -> new ForbiddenError("Access forbidden: " + message, statusCode, errorDetails);
|
||||
case 404 -> new APIError("Not found: " + message, statusCode, errorDetails);
|
||||
case 429 -> new RateLimitError("Rate limit exceeded: " + message, statusCode, errorDetails);
|
||||
case 500 -> new ServerError("Internal server error: " + message, statusCode, errorDetails);
|
||||
case 503 -> new ServiceUnavailableError("Service unavailable: " + message, statusCode, errorDetails);
|
||||
case 502, 504 -> new APIError("Gateway error: " + message, statusCode, errorDetails);
|
||||
default -> new APIError("Unexpected status " + statusCode + ": " + message, statusCode, errorDetails);
|
||||
};
|
||||
}
|
||||
|
||||
|
|
@ -484,9 +721,138 @@ public class SecureCompletionClient {
|
|||
}
|
||||
|
||||
/**
|
||||
* Delegates to resource cleanup (stub).
|
||||
* Closes the HTTP client and clears keys from memory.
|
||||
*/
|
||||
public void close() {
|
||||
throw new UnsupportedOperationException("Not yet implemented");
|
||||
this.httpClient.close();
|
||||
this.privateKey = null;
|
||||
this.publicPemKey = null;
|
||||
this.keysInitialized = false;
|
||||
}
|
||||
|
||||
/**
|
||||
* Decrypts server response.
|
||||
*/
|
||||
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
|
||||
return CompletableFuture.supplyAsync(() -> {
|
||||
if (encryptedResponse == null || encryptedResponse.length == 0) {
|
||||
throw new CompletionException(new ValueError("Empty encrypted response"));
|
||||
}
|
||||
|
||||
String jsonResponse;
|
||||
try {
|
||||
jsonResponse = new String(encryptedResponse, StandardCharsets.UTF_8);
|
||||
} catch (Exception e) {
|
||||
throw new CompletionException(new ValueError("Invalid encrypted package format: not valid UTF-8"));
|
||||
}
|
||||
|
||||
Gson gson = new Gson();
|
||||
JsonObject packageJson;
|
||||
try {
|
||||
packageJson = JsonParser.parseString(jsonResponse).getAsJsonObject();
|
||||
} catch (Exception e) {
|
||||
throw new CompletionException(new ValueError("Invalid encrypted package format: malformed JSON"));
|
||||
}
|
||||
|
||||
// Validate required fields
|
||||
String[] requiredFields = {"version", "algorithm", "encrypted_payload", "encrypted_aes_key"};
|
||||
for (String field : requiredFields) {
|
||||
if (!packageJson.has(field)) {
|
||||
throw new CompletionException(new ValueError("Missing required fields in encrypted package: " + field));
|
||||
}
|
||||
}
|
||||
|
||||
// Validate version and algorithm
|
||||
String version = packageJson.get("version").getAsString();
|
||||
String algorithm = packageJson.get("algorithm").getAsString();
|
||||
|
||||
if (!Constants.PROTOCOL_VERSION.equals(version)) {
|
||||
throw new CompletionException(new ValueError(
|
||||
"Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
|
||||
}
|
||||
if (!Constants.HYBRID_ALGORITHM.equals(algorithm)) {
|
||||
throw new CompletionException(new ValueError(
|
||||
"Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
|
||||
}
|
||||
|
||||
// Validate encrypted_payload structure
|
||||
JsonObject encryptedPayload = packageJson.get("encrypted_payload").getAsJsonObject();
|
||||
String[] payloadRequired = {"ciphertext", "nonce", "tag"};
|
||||
for (String field : payloadRequired) {
|
||||
if (!encryptedPayload.has(field)) {
|
||||
throw new CompletionException(new ValueError("Missing fields in encrypted_payload: " + field));
|
||||
}
|
||||
}
|
||||
|
||||
// Guard: private key must be initialized
|
||||
if (this.privateKey == null) {
|
||||
throw new CompletionException(new SecurityError("Private key not initialized. Call generateKeys() or loadKeys() first."));
|
||||
}
|
||||
|
||||
try {
|
||||
// Decrypt AES key with private key
|
||||
byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString());
|
||||
|
||||
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
|
||||
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey);
|
||||
byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey);
|
||||
SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES");
|
||||
|
||||
// Decrypt payload
|
||||
byte[] ciphertext = Base64.getDecoder().decode(encryptedPayload.get("ciphertext").getAsString());
|
||||
byte[] nonce = Base64.getDecoder().decode(encryptedPayload.get("nonce").getAsString());
|
||||
byte[] tag = Base64.getDecoder().decode(encryptedPayload.get("tag").getAsString());
|
||||
|
||||
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
|
||||
aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
|
||||
|
||||
// Combine ciphertext (without tag) and tag for decryption
|
||||
byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length];
|
||||
System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length);
|
||||
System.arraycopy(tag, 0, ciphertextWithTag, ciphertext.length, tag.length);
|
||||
|
||||
byte[] plaintextBytes = aesCipher.doFinal(ciphertextWithTag);
|
||||
|
||||
// Parse JSON response
|
||||
Map<String, Object> response;
|
||||
try {
|
||||
Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class);
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> resultMap = (Map<String, Object>) parsed;
|
||||
response = resultMap != null ? resultMap : new HashMap<>();
|
||||
} catch (Exception e) {
|
||||
throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage()));
|
||||
}
|
||||
|
||||
// Add metadata
|
||||
if (!response.containsKey("_metadata")) {
|
||||
response.put("_metadata", new HashMap<String, Object>());
|
||||
}
|
||||
@SuppressWarnings("unchecked")
|
||||
Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
|
||||
metadata.put("payload_id", payloadId);
|
||||
metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null);
|
||||
metadata.put("is_encrypted", true);
|
||||
metadata.put("encryption_algorithm", Constants.HYBRID_ALGORITHM);
|
||||
|
||||
return response;
|
||||
|
||||
} catch (Exception e) {
|
||||
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed"));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
/**
|
||||
* Error class for invalid argument/value errors (maps to Python ValueError).
|
||||
*/
|
||||
public static class ValueError extends Exception {
|
||||
public ValueError(String message) {
|
||||
super(message);
|
||||
}
|
||||
|
||||
public ValueError(String message, Throwable cause) {
|
||||
super(message, cause);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -54,22 +54,6 @@ public final class SecureMemory {
|
|||
return secureByteArray(data, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #secureByteArray(byte[])} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public static SecureBuffer secureBytes(byte[] data, boolean lock) {
|
||||
return new SecureBuffer(data, lock);
|
||||
}
|
||||
|
||||
/**
|
||||
* @deprecated Use {@link #secureByteArray(byte[])} instead.
|
||||
*/
|
||||
@Deprecated
|
||||
public static SecureBuffer secureBytes(byte[] data) {
|
||||
return secureBytes(data, true);
|
||||
}
|
||||
|
||||
/**
|
||||
* Returns protection capabilities: enabled, protection_level, has_memory_locking, has_secure_zeroing, supports_full_protection, page_size.
|
||||
*/
|
||||
|
|
|
|||
|
|
@ -1,6 +1,5 @@
|
|||
package ai.nomyo.util;
|
||||
|
||||
import java.util.Arrays;
|
||||
import java.util.Base64;
|
||||
|
||||
/**
|
||||
|
|
@ -25,10 +24,14 @@ public class PEMConverter {
|
|||
return publicKeyFormatted.toString();
|
||||
}
|
||||
|
||||
public static String fromPEM(String pem) {
|
||||
pem = pem.replaceAll("^-----BEGIN\\s+PRIVATE\\s+KEY-----|^------END\\s+PUBLIC\\s+KEY-----\n", "");
|
||||
public static byte[] fromPEM(String pem) {
|
||||
pem = pem.replace("-----BEGIN PRIVATE KEY-----", "")
|
||||
.replace("-----BEGIN PUBLIC KEY-----", "")
|
||||
.replace("-----END PRIVATE KEY-----", "")
|
||||
.replace("-----END PUBLIC KEY-----", "")
|
||||
.replaceAll("\\s+", "");
|
||||
|
||||
return Arrays.toString(Base64.getDecoder().decode(pem));
|
||||
return Base64.getDecoder().decode(pem);
|
||||
}
|
||||
|
||||
public static boolean validatePEM(String keyIn) {
|
||||
|
|
|
|||
|
|
@ -1,5 +1,6 @@
|
|||
package ai.nomyo.util;
|
||||
|
||||
import ai.nomyo.errors.SecurityError;
|
||||
import javax.crypto.BadPaddingException;
|
||||
import javax.crypto.Cipher;
|
||||
import javax.crypto.IllegalBlockSizeException;
|
||||
|
|
@ -39,7 +40,7 @@ public final class Pass2Key {
|
|||
* @param password the password used to derive the encryption key
|
||||
* @return base64-encoded ciphertext including salt and IV
|
||||
*/
|
||||
public static String encrypt(String algorithm, String input, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException {
|
||||
public static String encrypt(String algorithm, String input, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, SecurityError {
|
||||
|
||||
byte[] salt = generateRandomBytes(SALT_LENGTH);
|
||||
SecretKey key = deriveKey(password, salt);
|
||||
|
|
@ -66,7 +67,7 @@ public final class Pass2Key {
|
|||
* @param password the password used to derive the decryption key
|
||||
* @return the decrypted plaintext
|
||||
*/
|
||||
public static String decrypt(String algorithm, String cipherText, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException {
|
||||
public static String decrypt(String algorithm, String cipherText, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, SecurityError {
|
||||
|
||||
byte[] decoded = Base64.getDecoder().decode(cipherText);
|
||||
|
||||
|
|
@ -86,13 +87,13 @@ public final class Pass2Key {
|
|||
}
|
||||
}
|
||||
|
||||
private static SecretKey deriveKey(String password, byte[] salt) {
|
||||
private static SecretKey deriveKey(String password, byte[] salt) throws SecurityError {
|
||||
try {
|
||||
SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
|
||||
KeySpec spec = new PBEKeySpec(password.toCharArray(), salt, ITERATION_COUNT, 256);
|
||||
return new SecretKeySpec(factory.generateSecret(spec).getEncoded(), "AES");
|
||||
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
|
||||
throw new RuntimeException("Could not derive key: " + e.getMessage());
|
||||
throw new SecurityError("Could not derive key: " + e.getMessage(), e);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ class SecureCompletionClientE2ETest {
|
|||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("E2E: Generate keys, save to disk, load in new client, validate")
|
||||
void e2e_fullLifecycle_generateSaveLoadValidate() {
|
||||
void e2e_fullLifecycle_generateSaveLoadValidate() throws Exception {
|
||||
// Step 1: Generate keys and save to disk
|
||||
SecureCompletionClient generateClient = new SecureCompletionClient(BASE_URL, false, true, 2);
|
||||
generateClient.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||
|
|
@ -77,7 +77,7 @@ class SecureCompletionClientE2ETest {
|
|||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("E2E: Generate plaintext keys, load, and validate")
|
||||
void e2e_plaintextKeys_generateLoadValidate() {
|
||||
void e2e_plaintextKeys_generateLoadValidate() throws Exception {
|
||||
// Generate plaintext keys (no password)
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(true, keyDir.getAbsolutePath(), null);
|
||||
|
|
@ -296,7 +296,7 @@ class SecureCompletionClientE2ETest {
|
|||
|
||||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("E2E: Encrypted key file is unreadable without password")
|
||||
@DisplayName("E2E: Encrypted key file throws SecurityError without correct password")
|
||||
void e2e_encryptedKey_unreadableWithoutPassword() throws Exception {
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||
|
|
@ -308,17 +308,19 @@ class SecureCompletionClientE2ETest {
|
|||
assertFalse(encryptedContent.contains("BEGIN PRIVATE KEY"),
|
||||
"Encrypted file should not contain PEM header");
|
||||
|
||||
// Try loading with wrong password - should not throw
|
||||
// Try loading with wrong password - should throw SecurityError
|
||||
SecureCompletionClient loadClient = new SecureCompletionClient();
|
||||
assertDoesNotThrow(() ->
|
||||
SecurityError error = assertThrows(SecurityError.class, () ->
|
||||
loadClient.loadKeys(privateKeyFile.getAbsolutePath(), null, "wrong-password"),
|
||||
"Wrong password should be handled gracefully"
|
||||
"Wrong password should throw SecurityError"
|
||||
);
|
||||
assertTrue(error.getMessage().contains("decrypt") || error.getMessage().contains("password"),
|
||||
"Error message should mention decryption or password");
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("E2E: Generate keys without saving produces in-memory keys")
|
||||
void e2e_generateKeys_noSave_producesInMemoryKeys() {
|
||||
void e2e_generateKeys_noSave_producesInMemoryKeys() throws SecurityError {
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(false);
|
||||
|
||||
|
|
|
|||
|
|
@ -28,7 +28,7 @@ class SecureCompletionClientTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("generateKeys should create 4096-bit RSA key pair")
|
||||
void generateKeys_shouldCreateValidKeyPair() {
|
||||
void generateKeys_shouldCreateValidKeyPair() throws SecurityError {
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(false);
|
||||
|
||||
|
|
@ -58,7 +58,7 @@ class SecureCompletionClientTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("generateKeys should create unique keys on each call")
|
||||
void generateKeys_shouldProduceUniqueKeys() {
|
||||
void generateKeys_shouldProduceUniqueKeys() throws SecurityError {
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(false);
|
||||
PrivateKey firstKey = client.getPrivateKey();
|
||||
|
|
@ -76,7 +76,7 @@ class SecureCompletionClientTest {
|
|||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("generateKeys with saveToFile=true should create key files")
|
||||
void generateKeys_withSaveToFile_shouldCreateKeyFiles(@TempDir Path tempDir) {
|
||||
void generateKeys_withSaveToFile_shouldCreateKeyFiles(@TempDir Path tempDir) throws SecurityError {
|
||||
File keyDir = tempDir.toFile();
|
||||
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
|
|
@ -112,7 +112,7 @@ class SecureCompletionClientTest {
|
|||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("generateKeys should not overwrite existing key files")
|
||||
void generateKeys_shouldNotOverwriteExistingKeys(@TempDir Path tempDir) {
|
||||
void generateKeys_shouldNotOverwriteExistingKeys(@TempDir Path tempDir) throws SecurityError {
|
||||
File keyDir = tempDir.toFile();
|
||||
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
|
|
@ -129,10 +129,10 @@ class SecureCompletionClientTest {
|
|||
|
||||
// ── Key Loading Tests ─────────────────────────────────────────────
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("loadKeys should load plaintext private key from file")
|
||||
void loadKeys_plaintext_shouldLoadPrivateKey(@TempDir Path tempDir) {
|
||||
void loadKeys_plaintext_shouldLoadPrivateKey(@TempDir Path tempDir) throws Exception {
|
||||
File keyDir = tempDir.toFile();
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(true, keyDir.getAbsolutePath(), null);
|
||||
|
|
@ -151,10 +151,10 @@ class SecureCompletionClientTest {
|
|||
"Loaded key should have same size as original");
|
||||
}
|
||||
|
||||
@Test
|
||||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("loadKeys should load encrypted private key with correct password")
|
||||
void loadKeys_encrypted_correctPassword_shouldLoadPrivateKey(@TempDir Path tempDir) {
|
||||
void loadKeys_encrypted_correctPassword_shouldLoadPrivateKey(@TempDir Path tempDir) throws Exception {
|
||||
File keyDir = tempDir.toFile();
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||
|
|
@ -176,30 +176,33 @@ class SecureCompletionClientTest {
|
|||
|
||||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("loadKeys should handle wrong password gracefully")
|
||||
void loadKeys_encrypted_wrongPassword_shouldHandleGracefully(@TempDir Path tempDir) {
|
||||
@DisplayName("loadKeys should throw SecurityError for wrong password")
|
||||
void loadKeys_encrypted_wrongPassword_shouldThrowSecurityError(@TempDir Path tempDir) throws SecurityError {
|
||||
File keyDir = tempDir.toFile();
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
|
||||
|
||||
SecureCompletionClient loadClient = new SecureCompletionClient();
|
||||
|
||||
assertDoesNotThrow(() ->
|
||||
SecurityError error = assertThrows(SecurityError.class, () ->
|
||||
loadClient.loadKeys(
|
||||
new File(keyDir, Constants.DEFAULT_PRIVATE_KEY_FILE).getAbsolutePath(),
|
||||
null,
|
||||
"wrong-password"
|
||||
),
|
||||
"Wrong password should not throw exception"
|
||||
"Wrong password should throw SecurityError"
|
||||
);
|
||||
|
||||
assertTrue(error.getMessage().contains("decrypt") || error.getMessage().contains("password"),
|
||||
"Error message should mention decryption or password");
|
||||
}
|
||||
|
||||
@Test
|
||||
@DisplayName("loadKeys should throw exception for non-existent file")
|
||||
void loadKeys_nonExistentFile_shouldThrowException() {
|
||||
void loadKeys_nonExistentFile_shouldThrowException() {
|
||||
SecureCompletionClient loadClient = new SecureCompletionClient();
|
||||
|
||||
RuntimeException error = assertThrows(RuntimeException.class, () ->
|
||||
SecurityError error = assertThrows(SecurityError.class, () ->
|
||||
loadClient.loadKeys("/non/existent/path/private_key.pem", null, null));
|
||||
|
||||
assertTrue(error.getMessage().contains("not found"),
|
||||
|
|
@ -210,7 +213,7 @@ class SecureCompletionClientTest {
|
|||
|
||||
@Test
|
||||
@DisplayName("validateRsaKey should accept valid 4096-bit key")
|
||||
void validateRsaKey_validKey_shouldPass() {
|
||||
void validateRsaKey_validKey_shouldPass() throws SecurityError {
|
||||
SecureCompletionClient client = new SecureCompletionClient();
|
||||
client.generateKeys(false);
|
||||
PrivateKey key = client.getPrivateKey();
|
||||
|
|
@ -264,7 +267,7 @@ class SecureCompletionClientTest {
|
|||
@Test
|
||||
@Execution(ExecutionMode.SAME_THREAD)
|
||||
@DisplayName("Full roundtrip: generate, save, load should produce same key")
|
||||
void roundtrip_generateSaveLoad_shouldProduceSameKey(@TempDir Path tempDir) {
|
||||
void roundtrip_generateSaveLoad_shouldProduceSameKey(@TempDir Path tempDir) throws Exception {
|
||||
File keyDir = tempDir.toFile();
|
||||
|
||||
// Generate and save
|
||||
|
|
@ -407,10 +410,6 @@ class SecureCompletionClientTest {
|
|||
tempClient.generateKeys(false);
|
||||
PrivateKey originalKey = tempClient.getPrivateKey();
|
||||
|
||||
String pem = "-----BEGIN PRIVATE KEY-----\n " +
|
||||
originalKey.getEncoded().length + "lines\n" +
|
||||
"-----END PRIVATE KEY-----";
|
||||
|
||||
String formattedPem = ai.nomyo.util.PEMConverter.toPEM(originalKey.getEncoded(), true);
|
||||
String pemWithWhitespace = formattedPem.replace("\n", "\n ");
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue