Finish functionality

This commit is contained in:
Oracle 2026-04-26 18:21:05 +02:00
parent b6af1c9792
commit 89d5282b0f
Signed by: Oracle
SSH key fingerprint: SHA256:x4/RtnjUyuHkdvmwNDsWSfcfF1V5PNr3OpriZqOvCX8
9 changed files with 583 additions and 133 deletions

View file

@ -1,6 +1,7 @@
package ai.nomyo;
import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.annotations.SerializedName;
import lombok.Getter;
import lombok.Setter;
@ -12,6 +13,8 @@ import lombok.Setter;
@Getter
public class EncryptedRequest {
private static final Gson GSON = new GsonBuilder().create();
// Getters and Setters
@SerializedName("version")
private String version;
@ -55,7 +58,7 @@ public class EncryptedRequest {
}
}
public String toJson() {
return new Gson().toJson(this);
public String toJson() {
return GSON.toJson(this);
}
}

View file

@ -1,5 +1,10 @@
package ai.nomyo;
import ai.nomyo.errors.APIConnectionError;
import ai.nomyo.errors.SecurityError;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ExecutionException;
/**
@ -8,17 +13,22 @@ import java.util.concurrent.ExecutionException;
public class Main {
static void main() {
SecureCompletionClient secureCompletionClient = new SecureCompletionClient();
//secureCompletionClient.generateKeys(true, "client_keys", "pokemon");
//secureCompletionClient.loadKeys("client_keys/private_key.pem", "pokemon");
SecureChatCompletion secureChatCompletion = new SecureChatCompletion( Constants.DEFAULT_BASE_URL, "NOMYO_AI_E2EE_INFERENCE");
List<Map<String, Object>> messages = List.of(
Map.of("role", "user", "content", "Hello! How are you today?")
);
Map<String, Object> kwargs = Map.of(
"security_tier", "standard",
"temperature", 0.7
);
try {
System.out.println(secureCompletionClient.fetchServerPublicKey().get());
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
}
var response = secureChatCompletion.create(
"Qwen/Qwen3-0.6B",
messages,
kwargs);
System.out.println(response.toString());
}
}

View file

@ -3,8 +3,12 @@ package ai.nomyo;
import ai.nomyo.errors.*;
import lombok.Getter;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.concurrent.ExecutionException;
/**
* High-level OpenAI-compatible entrypoint with automatic hybrid encryption (AES-256-GCM + RSA-4096).
@ -23,6 +27,14 @@ public class SecureChatCompletion {
this(Constants.DEFAULT_BASE_URL, false, null, true, null, Constants.DEFAULT_MAX_RETRIES);
}
public SecureChatCompletion(String baseUrl, String apiKey) {
this(baseUrl, false, apiKey, true, null, Constants.DEFAULT_MAX_RETRIES);
}
public SecureChatCompletion(String baseUrl) {
this(baseUrl, false, null, true, null, Constants.DEFAULT_MAX_RETRIES);
}
/**
* @param baseUrl NOMYO Router base URL (HTTPS enforced unless {@code allowHttp})
* @param allowHttp permit {@code http://} URLs (development only)
@ -57,15 +69,85 @@ public class SecureChatCompletion {
* @throws ServiceUnavailableError HTTP 503
* @throws APIError other errors
*/
@SuppressWarnings("unchecked")
public Map<String, Object> create(String model, List<Map<String, Object>> messages, Map<String, Object> kwargs) {
// Build payload from model, messages, and kwargs
// Validate stream is false
// Validate securityTier if provided
// Use per-call api_key override if provided, else instance apiKey
// Create temp client if baseUrl override provided
// Validate required parameters
if (model == null || model.isEmpty()) {
throw new IllegalArgumentException("model is required");
}
if (messages == null || messages.isEmpty()) {
throw new IllegalArgumentException("messages is required and cannot be empty");
}
// Build payload
Map<String, Object> payload = new HashMap<>();
payload.put("model", model);
payload.put("messages", messages);
// Add kwargs
if (kwargs != null) {
// Check for stream parameter
if (kwargs.containsKey("stream")) {
Object streamValue = kwargs.get("stream");
boolean stream = streamValue instanceof Boolean ? (Boolean) streamValue : Boolean.parseBoolean(streamValue.toString());
if (stream) {
throw new IllegalArgumentException("Streaming is not supported");
}
}
// Check for security_tier
if (kwargs.containsKey("security_tier")) {
Object tier = kwargs.get("security_tier");
if (tier != null && !Constants.VALID_SECURITY_TIERS.contains(tier.toString())) {
throw new IllegalArgumentException(
"Invalid security_tier: '" + tier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS);
}
}
payload.putAll(kwargs);
}
// Determine API key (per-call override or instance key)
String apiKey = this.apiKey;
if (kwargs != null && kwargs.containsKey("api_key")) {
Object key = kwargs.get("api_key");
if (key != null) {
apiKey = key.toString();
}
}
// Determine security tier
String securityTier = null;
if (kwargs != null && kwargs.containsKey("security_tier")) {
securityTier = kwargs.get("security_tier").toString();
}
// Generate payload ID
String payloadId = UUID.randomUUID().toString();
// Send secure request
// Return decrypted response map
return null;
try {
return client.sendSecureRequest(payload, payloadId, apiKey, securityTier).get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException("Request interrupted", e);
} catch (ExecutionException e) {
Throwable cause = e.getCause();
switch (cause) {
case SecurityError securityError -> throw new RuntimeException(cause);
case InvalidRequestError invalidRequestError -> throw new RuntimeException(cause);
case AuthenticationError authenticationError -> throw new RuntimeException(cause);
case ForbiddenError forbiddenError -> throw new RuntimeException(cause);
case RateLimitError rateLimitError -> throw new RuntimeException(cause);
case ServerError serverError -> throw new RuntimeException(cause);
case ServiceUnavailableError serviceUnavailableError -> throw new RuntimeException(cause);
case APIError apiError -> throw new RuntimeException(cause);
case APIConnectionError apiConnectionError -> throw new RuntimeException(cause);
case SecureCompletionClient.ValueError valueError -> throw new IllegalArgumentException(cause);
default ->
throw new RuntimeException("Request failed: " + cause.getMessage(), cause);
}
}
}
/**

View file

@ -4,6 +4,7 @@ import ai.nomyo.errors.*;
import ai.nomyo.util.PEMConverter;
import ai.nomyo.util.Pass2Key;
import lombok.Getter;
import lombok.Setter;
import javax.crypto.*;
import javax.crypto.spec.GCMParameterSpec;
@ -27,6 +28,10 @@ import java.util.*;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutionException;
import com.google.gson.Gson;
import com.google.gson.JsonObject;
import com.google.gson.JsonParser;
import java.util.concurrent.locks.ReentrantLock;
/**
@ -113,7 +118,7 @@ public class SecureCompletionClient {
/**
* Generates a 4096-bit RSA key pair (exponent 65537). Saves to disk if {@code saveToFile}.
*/
public void generateKeys(boolean saveToFile, String keyDir, String password) {
public void generateKeys(boolean saveToFile, String keyDir, String password) throws SecurityError {
try {
KeyPairGenerator generator = KeyPairGenerator.getInstance("RSA");
generator.initialize(new RSAKeyGenParameterSpec(Constants.RSA_KEY_SIZE, BigInteger.valueOf(Constants.RSA_PUBLIC_EXPONENT)));
@ -145,8 +150,8 @@ public class SecureCompletionClient {
try {
privatePem = Pass2Key.encrypt("AES/GCM/NoPadding", privatePem, password);
} catch (NoSuchPaddingException | IllegalBlockSizeException | BadPaddingException |
InvalidKeyException e) {
throw new RuntimeException(e);
InvalidKeyException | SecurityError e) {
throw new SecurityError("Failed to encrypt private key with password: " + e.getMessage(), e);
}
}
writer.write(privatePem);
@ -167,10 +172,12 @@ public class SecureCompletionClient {
this.privateKey = pair.getPrivate();
this.publicPemKey = publicPem;
} catch (SecurityError e) {
throw e;
} catch (NoSuchAlgorithmException e) {
throw new RuntimeException("RSA not available: " + e.getMessage(), e);
throw new SecurityError("RSA algorithm not available: " + e.getMessage(), e);
} catch (InvalidAlgorithmParameterException e) {
throw new RuntimeException(e);
throw new SecurityError("Invalid RSA key generation parameters: " + e.getMessage(), e);
} catch (IOException e) {
throw new RuntimeException("Failed to save keys: " + e.getMessage(), e);
}
@ -179,7 +186,7 @@ public class SecureCompletionClient {
/**
* Generates a 4096-bit RSA key pair and saves to the default directory.
*/
public void generateKeys(boolean saveToFile) {
public void generateKeys(boolean saveToFile) throws SecurityError {
generateKeys(saveToFile, Constants.DEFAULT_KEY_DIR, null);
}
@ -190,11 +197,12 @@ public class SecureCompletionClient {
* @param privateKeyPath private key PEM path
* @param publicPemKeyPath optional public key PEM path
* @param password optional password for encrypted private key
* @throws SecurityError if key file not found, unreadable, or decryption fails
*/
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) {
public void loadKeys(String privateKeyPath, String publicPemKeyPath, String password) throws SecurityError {
Path keyPath = Path.of(privateKeyPath);
if (!Files.exists(keyPath)) {
throw new RuntimeException("Private key file not found: " + privateKeyPath);
throw new SecurityError("Private key file not found: " + privateKeyPath);
}
String keyContent;
@ -202,35 +210,36 @@ public class SecureCompletionClient {
try {
keyContent = readFileContent(privateKeyPath);
} catch (IOException e) {
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
}
try {
keyContent = Pass2Key.decrypt("AES/GCM/NoPadding", keyContent, password);
} catch (NoSuchPaddingException | NoSuchAlgorithmException | BadPaddingException |
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException e) {
System.out.println("Wrong password!");
return;
IllegalBlockSizeException | InvalidAlgorithmParameterException | InvalidKeyException | SecurityError e) {
throw new SecurityError("Failed to decrypt private key with provided password: " + e.getMessage(), e);
}
} else {
try {
keyContent = readFileContent(privateKeyPath);
} catch (IOException e) {
throw new RuntimeException("Failed to read private key file: " + e.getMessage(), e);
throw new SecurityError("Failed to read private key file: " + e.getMessage(), e);
}
}
try {
this.privateKey = Pass2Key.convertStringToPrivateKey(keyContent);
} catch (Exception e) {
throw new RuntimeException("Failed to load private key: " + e.getMessage(), e);
throw new SecurityError("Failed to load private key: " + e.getMessage(), e);
}
}
/**
* Loads RSA private key from disk, deriving public key.
*
* @throws SecurityError if key file not found, unreadable, or decryption fails
*/
public void loadKeys(String privateKeyPath, String password) {
public void loadKeys(String privateKeyPath, String password) throws SecurityError {
loadKeys(privateKeyPath, null, password);
}
@ -250,12 +259,12 @@ public class SecureCompletionClient {
URI url;
try {
url = new URI(this.routerUrl + "/pki/public_key");
url = new URI(this.routerUrl + Constants.PKI_PUBLIC_KEY_PATH);
} catch (URISyntaxException e) {
return CompletableFuture.failedFuture(new CompletionException("Invalid URI: " + e.getMessage(), e));
return CompletableFuture.failedFuture(new CompletionException(new APIConnectionError("Invalid URI: " + e.getMessage(), e)));
}
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.of(60, ChronoUnit.SECONDS)).GET().build();
HttpRequest request = HttpRequest.newBuilder(url).timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS)).GET().build();
return this.httpClient.sendAsync(request, HttpResponse.BodyHandlers.ofString()).thenApply(response -> {
if (response.statusCode() != 200) {
@ -264,7 +273,7 @@ public class SecureCompletionClient {
return response.body();
}).thenApply(body -> {
if (!PEMConverter.validatePEM(body)) {
throw new CompletionException(new InvalidKeyException("PEM key had invalid format"));
throw new CompletionException(new SecurityError("Server returned invalid PEM key format (possible MITM attack)"));
}
return body;
});
@ -278,7 +287,37 @@ public class SecureCompletionClient {
* @throws SecurityError if encryption fails or keys not loaded
*/
public CompletableFuture<byte[]> encryptPayload(Map<String, Object> payload) {
throw new UnsupportedOperationException("Not yet implemented");
return CompletableFuture.supplyAsync(() -> {
try {
ensureKeys(null);
} catch (SecurityError e) {
throw new CompletionException(new SecurityError("Failed to ensure keys are initialized: " + e.getMessage(), e));
}
if (this.privateKey == null) {
throw new CompletionException(new SecurityError("Private key not available for encryption"));
}
// Generate AES key
KeyGenerator keyGen;
try {
keyGen = KeyGenerator.getInstance("AES");
keyGen.init(Constants.AES_KEY_SIZE * 8);
} catch (NoSuchAlgorithmException e) {
throw new CompletionException(new SecurityError("AES key generation not available: " + e.getMessage(), e));
}
Key aesKey = keyGen.generateKey();
// Serialize payload to JSON
Gson gson = new Gson();
String payloadJson = gson.toJson(payload);
byte[] payloadBytes = payloadJson.getBytes(StandardCharsets.UTF_8);
// Encrypt
return doEncrypt(payloadBytes, aesKey).join();
});
}
/**
@ -287,16 +326,16 @@ public class SecureCompletionClient {
public CompletableFuture<byte[]> doEncrypt(byte[] payloadBytes, Key aesKey) {
return CompletableFuture.supplyAsync(() -> {
SecureRandom random = new SecureRandom();
byte[] nonce = new byte[12];
byte[] nonce = new byte[Constants.GCM_NONCE_SIZE];
random.nextBytes(nonce);
Cipher cipher = null;
try {
cipher = Cipher.getInstance("AES/GCM/NoPadding");
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(128, nonce));
cipher.init(Cipher.ENCRYPT_MODE, new SecretKeySpec(aesKey.getEncoded(), "AES"), new GCMParameterSpec(Constants.GCM_TAG_SIZE * Byte.SIZE, nonce));
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidAlgorithmParameterException |
InvalidKeyException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("AES-GCM cipher initialization failed: " + e.getMessage(), e));
}
byte[] ciphertext;
@ -304,66 +343,63 @@ public class SecureCompletionClient {
try {
ciphertext = cipher.doFinal(payloadBytes);
} catch (IllegalBlockSizeException | BadPaddingException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("AES-GCM encryption failed: " + e.getMessage(), e));
}
String serverPEM;
try {
serverPEM = fetchServerPublicKey().get();
} catch (InterruptedException | ExecutionException e) {
throw new RuntimeException(e);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new RuntimeException(new SecurityError("Encryption interrupted while fetching server public key", e));
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof SecurityError) {
throw new RuntimeException((SecurityError) cause);
}
throw new RuntimeException(new SecurityError("Failed to fetch server public key: " + cause.getMessage(), cause));
}
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM).getBytes());
X509EncodedKeySpec keySpec = new X509EncodedKeySpec(PEMConverter.fromPEM(serverPEM));
PublicKey serverPublicKey;
try {
serverPublicKey = KeyFactory.getInstance("RSA").generatePublic(keySpec);
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("RSA key factory failed to parse server public key: " + e.getMessage(), e));
}
Cipher rsa;
byte[] enryptedAESKey = aesKey.getEncoded();
try {
rsa = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsa.init(Cipher.ENCRYPT_MODE, serverPublicKey);
} catch (NoSuchAlgorithmException | NoSuchPaddingException | InvalidKeyException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("RSA-OAEP cipher initialization failed: " + e.getMessage(), e));
}
byte[] encryptedAESKey;
try {
rsa.doFinal(enryptedAESKey);
encryptedAESKey = rsa.doFinal(aesKey.getEncoded());
} catch (IllegalBlockSizeException | BadPaddingException e) {
throw new RuntimeException(e);
throw new RuntimeException(new SecurityError("RSA-OAEP key wrapping failed: " + e.getMessage(), e));
}
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - (128 / Byte.SIZE), ciphertext.length);
byte[] tag = Arrays.copyOfRange(ciphertext, ciphertext.length - Constants.GCM_TAG_SIZE, ciphertext.length);
EncryptedRequest request = new EncryptedRequest();
request.setVersion("1.0");
request.setAlgorithm("hybrid-aes256-rsa4096");
request.setVersion(Constants.PROTOCOL_VERSION);
request.setAlgorithm(Constants.HYBRID_ALGORITHM);
request.setEncryptedPayload(new EncryptedRequest.EncryptedPayload(Base64.getEncoder().encodeToString(ciphertext), Base64.getEncoder().encodeToString(nonce), Base64.getEncoder().encodeToString(tag)));
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(enryptedAESKey));
request.setKeyAlgorithm("RSA-OAEP-SHA256");
request.setPayloadAlgorithm("AES-256-GCM");
request.setEncryptedAESKey(Base64.getEncoder().encodeToString(encryptedAESKey));
request.setKeyAlgorithm(Constants.KEY_WRAP_ALGORITHM);
request.setPayloadAlgorithm(Constants.PAYLOAD_ALGORITHM);
return request.toJson().getBytes(StandardCharsets.UTF_8);
});
}
/**
* Decrypts server response.
*/
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
throw new UnsupportedOperationException("Not yet implemented");
}
/**
* encrypt POST {routerUrl}/v1/chat/secure_completion retry decrypt return.
* <p>Headers: Content-Type=octet-stream, X-Payload-ID, X-Public-Key, Authorization (Bearer), X-Security-Tier.
@ -388,7 +424,210 @@ public class SecureCompletionClient {
* @throws APIError other errors
*/
public CompletableFuture<Map<String, Object>> sendSecureRequest(Map<String, Object> payload, String payloadId, String apiKey, String securityTier) {
throw new UnsupportedOperationException("Not yet implemented");
return CompletableFuture.supplyAsync(() -> {
// Validate security tier if provided
if (securityTier != null && !Constants.VALID_SECURITY_TIERS.contains(securityTier)) {
throw new CompletionException(new ValueError(
"Invalid security_tier: '" + securityTier + "'. Must be one of: " + Constants.VALID_SECURITY_TIERS));
}
try {
ensureKeys(null);
} catch (SecurityError e) {
throw new CompletionException(new SecurityError("Failed to ensure keys: " + e.getMessage(), e));
}
// Step 1: Encrypt payload
byte[] encryptedPayload;
try {
encryptedPayload = encryptPayload(payload).get();
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new CompletionException(new APIConnectionError("Encryption interrupted"));
} catch (ExecutionException e) {
Throwable cause = e.getCause();
if (cause instanceof SecurityError) {
throw new CompletionException((SecurityError) cause);
}
throw new CompletionException(new SecurityError("Encryption failed: " + cause.getMessage(), cause));
}
// Step 2: Prepare headers
URI url;
try {
url = new URI(this.routerUrl + Constants.SECURE_COMPLETION_PATH);
} catch (URISyntaxException e) {
throw new CompletionException(new APIConnectionError("Invalid URL: " + e.getMessage(), e));
}
HttpRequest.Builder requestBuilder = HttpRequest.newBuilder(url)
.timeout(Duration.ofSeconds(Constants.DEFAULT_TIMEOUT_SECONDS))
.header("Content-Type", Constants.CONTENT_TYPE_OCTET_STREAM)
.header(Constants.HEADER_PAYLOAD_ID, payloadId)
.header(Constants.HEADER_PUBLIC_KEY, urlEncodePublicKey(this.publicPemKey))
.POST(HttpRequest.BodyPublishers.ofByteArray(encryptedPayload));
if (apiKey != null && !apiKey.isEmpty()) {
requestBuilder.header("Authorization", Constants.AUTHORIZATION_BEARER_PREFIX + apiKey);
}
if (securityTier != null) {
requestBuilder.header(Constants.HEADER_SECURITY_TIER, securityTier);
}
HttpRequest request = requestBuilder.build();
// Step 3: Send request with retry
Exception lastExc = new APIConnectionError("Request failed");
int retryableCodes = 0;
for (int attempt = 0; attempt <= this.maxRetries; attempt++) {
if (attempt > 0) {
long delay = (long) Math.pow(2, attempt - 1);
try {
Thread.sleep(delay * 1000);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new CompletionException(new APIConnectionError("Retry interrupted"));
}
}
try {
HttpResponse<byte[]> response = this.httpClient.send(request, HttpResponse.BodyHandlers.ofByteArray());
int statusCode = response.statusCode();
if (statusCode == 200) {
// Step 4: Decrypt response
return decryptResponse(response.body(), payloadId).get();
} else if (statusCode == 400) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Unknown error";
throw new CompletionException(new InvalidRequestError("Bad request: " + detail, 400, errorDetails));
} else if (statusCode == 401) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Invalid API key or authentication failed";
throw new CompletionException(new AuthenticationError(detail, 401, errorDetails));
} else if (statusCode == 403) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Model not allowed for the requested security tier";
throw new CompletionException(new ForbiddenError("Forbidden: " + detail, 403, errorDetails));
} else if (statusCode == 404) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) errorDetails = parsed;
} catch (Exception ignored) {
}
String detail = errorDetails.containsKey("detail") ? errorDetails.get("detail").toString() : "Secure inference not enabled";
throw new CompletionException(new APIError("Endpoint not found: " + detail, 404, errorDetails));
} else if (Constants.RETRYABLE_STATUS_CODES.contains(statusCode)) {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> error = Map.of();
String detailMsg = "unknown";
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) {
error = parsed;
if (error.containsKey("detail")) detailMsg = error.get("detail").toString();
}
} catch (Exception ignored) {
}
if (statusCode == 429) {
lastExc = new RateLimitError("Rate limit exceeded: " + detailMsg, 429, error);
} else if (statusCode == 500) {
lastExc = new ServerError("Server error: " + detailMsg, 500, error);
} else if (statusCode == 503) {
lastExc = new ServiceUnavailableError("Service unavailable: " + detailMsg, 503, error);
} else {
lastExc = new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, error);
}
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
} else {
String body = new String(response.body(), StandardCharsets.UTF_8);
Map<String, Object> errorDetails = Map.of();
String detailMsg = "unknown";
try {
@SuppressWarnings("unchecked")
Map<String, Object> parsed = (Map<String, Object>) new Gson().fromJson(body, Object.class);
if (parsed != null) {
errorDetails = parsed;
if (errorDetails.containsKey("detail")) detailMsg = errorDetails.get("detail").toString();
}
} catch (Exception ignored) {
}
throw new CompletionException(new APIError("Unexpected status code: " + statusCode + " " + detailMsg, statusCode, errorDetails));
}
} catch (CompletionException e) {
Throwable cause = e.getCause();
if (cause instanceof InvalidRequestError || cause instanceof AuthenticationError ||
cause instanceof ForbiddenError || cause instanceof SecurityError) {
throw e;
}
if (cause instanceof RateLimitError || cause instanceof ServerError || cause instanceof ServiceUnavailableError) {
lastExc = (Exception) cause;
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
}
if (cause instanceof APIError apiError) {
if (Constants.RETRYABLE_STATUS_CODES.contains(apiError.getStatusCode())) {
lastExc = apiError;
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
}
throw e;
}
lastExc = new APIConnectionError("Failed to connect to router: " + e.getMessage());
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
} catch (InterruptedException e) {
Thread.currentThread().interrupt();
throw new CompletionException(new APIConnectionError("Request interrupted"));
} catch (ExecutionException e) {
throw new CompletionException(new APIConnectionError("Request failed: " + e.getMessage(), e));
} catch (IOException e) {
lastExc = new APIConnectionError("IO error: " + e.getMessage(), e);
if (attempt < this.maxRetries) {
continue;
}
throw new CompletionException(lastExc);
}
}
throw new CompletionException(lastExc);
});
}
/**
@ -410,7 +649,7 @@ public class SecureCompletionClient {
*
* @param keyDir key directory or {@code null} for ephemeral
*/
public void ensureKeys(String keyDir) {
public void ensureKeys(String keyDir) throws SecurityError {
if (keysInitialized) return;
keyInitLock.lock();
try {
@ -458,21 +697,19 @@ public class SecureCompletionClient {
* Maps HTTP status code to exception (200null).
*/
public Exception mapHttpStatus(int statusCode, String responseBody) {
String message = responseBody != null ? responseBody : "no body";
Map<String, Object> errorDetails = responseBody != null ? Map.of("response", responseBody) : Map.of();
return switch (statusCode) {
case 200 -> null;
case 400 ->
new InvalidRequestError("Invalid request: " + (responseBody != null ? responseBody : "no body"));
case 401 ->
new AuthenticationError("Authentication failed: " + (responseBody != null ? responseBody : "no body"));
case 403 -> new ForbiddenError("Access forbidden: " + (responseBody != null ? responseBody : "no body"));
case 404 -> new APIError("Not found: " + (responseBody != null ? responseBody : "no body"));
case 429 -> new RateLimitError("Rate limit exceeded: " + (responseBody != null ? responseBody : "no body"));
case 500 -> new ServerError("Internal server error: " + (responseBody != null ? responseBody : "no body"));
case 503 ->
new ServiceUnavailableError("Service unavailable: " + (responseBody != null ? responseBody : "no body"));
case 502, 504 -> new APIError("Gateway error: " + (responseBody != null ? responseBody : "no body"));
default ->
new APIError("Unexpected status " + statusCode + ": " + (responseBody != null ? responseBody : "no body"));
case 400 -> new InvalidRequestError("Invalid request: " + message, statusCode, errorDetails);
case 401 -> new AuthenticationError("Authentication failed: " + message, statusCode, errorDetails);
case 403 -> new ForbiddenError("Access forbidden: " + message, statusCode, errorDetails);
case 404 -> new APIError("Not found: " + message, statusCode, errorDetails);
case 429 -> new RateLimitError("Rate limit exceeded: " + message, statusCode, errorDetails);
case 500 -> new ServerError("Internal server error: " + message, statusCode, errorDetails);
case 503 -> new ServiceUnavailableError("Service unavailable: " + message, statusCode, errorDetails);
case 502, 504 -> new APIError("Gateway error: " + message, statusCode, errorDetails);
default -> new APIError("Unexpected status " + statusCode + ": " + message, statusCode, errorDetails);
};
}
@ -484,9 +721,138 @@ public class SecureCompletionClient {
}
/**
* Delegates to resource cleanup (stub).
* Closes the HTTP client and clears keys from memory.
*/
public void close() {
throw new UnsupportedOperationException("Not yet implemented");
this.httpClient.close();
this.privateKey = null;
this.publicPemKey = null;
this.keysInitialized = false;
}
/**
* Decrypts server response.
*/
public CompletableFuture<Map<String, Object>> decryptResponse(byte[] encryptedResponse, String payloadId) {
return CompletableFuture.supplyAsync(() -> {
if (encryptedResponse == null || encryptedResponse.length == 0) {
throw new CompletionException(new ValueError("Empty encrypted response"));
}
String jsonResponse;
try {
jsonResponse = new String(encryptedResponse, StandardCharsets.UTF_8);
} catch (Exception e) {
throw new CompletionException(new ValueError("Invalid encrypted package format: not valid UTF-8"));
}
Gson gson = new Gson();
JsonObject packageJson;
try {
packageJson = JsonParser.parseString(jsonResponse).getAsJsonObject();
} catch (Exception e) {
throw new CompletionException(new ValueError("Invalid encrypted package format: malformed JSON"));
}
// Validate required fields
String[] requiredFields = {"version", "algorithm", "encrypted_payload", "encrypted_aes_key"};
for (String field : requiredFields) {
if (!packageJson.has(field)) {
throw new CompletionException(new ValueError("Missing required fields in encrypted package: " + field));
}
}
// Validate version and algorithm
String version = packageJson.get("version").getAsString();
String algorithm = packageJson.get("algorithm").getAsString();
if (!Constants.PROTOCOL_VERSION.equals(version)) {
throw new CompletionException(new ValueError(
"Unsupported protocol version: '" + version + "'. Expected: '" + Constants.PROTOCOL_VERSION + "'"));
}
if (!Constants.HYBRID_ALGORITHM.equals(algorithm)) {
throw new CompletionException(new ValueError(
"Unsupported encryption algorithm: '" + algorithm + "'. Expected: '" + Constants.HYBRID_ALGORITHM + "'"));
}
// Validate encrypted_payload structure
JsonObject encryptedPayload = packageJson.get("encrypted_payload").getAsJsonObject();
String[] payloadRequired = {"ciphertext", "nonce", "tag"};
for (String field : payloadRequired) {
if (!encryptedPayload.has(field)) {
throw new CompletionException(new ValueError("Missing fields in encrypted_payload: " + field));
}
}
// Guard: private key must be initialized
if (this.privateKey == null) {
throw new CompletionException(new SecurityError("Private key not initialized. Call generateKeys() or loadKeys() first."));
}
try {
// Decrypt AES key with private key
byte[] encryptedAESKey = Base64.getDecoder().decode(packageJson.get("encrypted_aes_key").getAsString());
Cipher rsaCipher = Cipher.getInstance("RSA/ECB/OAEPWithSHA-256AndMGF1Padding");
rsaCipher.init(Cipher.DECRYPT_MODE, this.privateKey);
byte[] aesKeyBytes = rsaCipher.doFinal(encryptedAESKey);
SecretKeySpec aesKey = new SecretKeySpec(aesKeyBytes, "AES");
// Decrypt payload
byte[] ciphertext = Base64.getDecoder().decode(encryptedPayload.get("ciphertext").getAsString());
byte[] nonce = Base64.getDecoder().decode(encryptedPayload.get("nonce").getAsString());
byte[] tag = Base64.getDecoder().decode(encryptedPayload.get("tag").getAsString());
Cipher aesCipher = Cipher.getInstance("AES/GCM/NoPadding");
aesCipher.init(Cipher.DECRYPT_MODE, aesKey, new GCMParameterSpec(Constants.GCM_TAG_SIZE * 8, nonce));
// Combine ciphertext (without tag) and tag for decryption
byte[] ciphertextWithTag = new byte[ciphertext.length + tag.length];
System.arraycopy(ciphertext, 0, ciphertextWithTag, 0, ciphertext.length);
System.arraycopy(tag, 0, ciphertextWithTag, ciphertext.length, tag.length);
byte[] plaintextBytes = aesCipher.doFinal(ciphertextWithTag);
// Parse JSON response
Map<String, Object> response;
try {
Object parsed = gson.fromJson(new String(plaintextBytes, StandardCharsets.UTF_8), Object.class);
@SuppressWarnings("unchecked")
Map<String, Object> resultMap = (Map<String, Object>) parsed;
response = resultMap != null ? resultMap : new HashMap<>();
} catch (Exception e) {
throw new CompletionException(new ValueError("Decrypted response is not valid JSON: " + e.getMessage()));
}
// Add metadata
if (!response.containsKey("_metadata")) {
response.put("_metadata", new HashMap<String, Object>());
}
@SuppressWarnings("unchecked")
Map<String, Object> metadata = (Map<String, Object>) response.get("_metadata");
metadata.put("payload_id", payloadId);
metadata.put("processed_at", packageJson.has("processed_at") ? packageJson.get("processed_at").getAsString() : null);
metadata.put("is_encrypted", true);
metadata.put("encryption_algorithm", Constants.HYBRID_ALGORITHM);
return response;
} catch (Exception e) {
throw new CompletionException(new SecurityError("Decryption failed: integrity check or authentication failed"));
}
});
}
/**
* Error class for invalid argument/value errors (maps to Python ValueError).
*/
public static class ValueError extends Exception {
public ValueError(String message) {
super(message);
}
public ValueError(String message, Throwable cause) {
super(message, cause);
}
}
}

View file

@ -54,22 +54,6 @@ public final class SecureMemory {
return secureByteArray(data, true);
}
/**
* @deprecated Use {@link #secureByteArray(byte[])} instead.
*/
@Deprecated
public static SecureBuffer secureBytes(byte[] data, boolean lock) {
return new SecureBuffer(data, lock);
}
/**
* @deprecated Use {@link #secureByteArray(byte[])} instead.
*/
@Deprecated
public static SecureBuffer secureBytes(byte[] data) {
return secureBytes(data, true);
}
/**
* Returns protection capabilities: enabled, protection_level, has_memory_locking, has_secure_zeroing, supports_full_protection, page_size.
*/

View file

@ -1,6 +1,5 @@
package ai.nomyo.util;
import java.util.Arrays;
import java.util.Base64;
/**
@ -25,10 +24,14 @@ public class PEMConverter {
return publicKeyFormatted.toString();
}
public static String fromPEM(String pem) {
pem = pem.replaceAll("^-----BEGIN\\s+PRIVATE\\s+KEY-----|^------END\\s+PUBLIC\\s+KEY-----\n", "");
public static byte[] fromPEM(String pem) {
pem = pem.replace("-----BEGIN PRIVATE KEY-----", "")
.replace("-----BEGIN PUBLIC KEY-----", "")
.replace("-----END PRIVATE KEY-----", "")
.replace("-----END PUBLIC KEY-----", "")
.replaceAll("\\s+", "");
return Arrays.toString(Base64.getDecoder().decode(pem));
return Base64.getDecoder().decode(pem);
}
public static boolean validatePEM(String keyIn) {

View file

@ -1,5 +1,6 @@
package ai.nomyo.util;
import ai.nomyo.errors.SecurityError;
import javax.crypto.BadPaddingException;
import javax.crypto.Cipher;
import javax.crypto.IllegalBlockSizeException;
@ -39,7 +40,7 @@ public final class Pass2Key {
* @param password the password used to derive the encryption key
* @return base64-encoded ciphertext including salt and IV
*/
public static String encrypt(String algorithm, String input, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException {
public static String encrypt(String algorithm, String input, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, SecurityError {
byte[] salt = generateRandomBytes(SALT_LENGTH);
SecretKey key = deriveKey(password, salt);
@ -66,7 +67,7 @@ public final class Pass2Key {
* @param password the password used to derive the decryption key
* @return the decrypted plaintext
*/
public static String decrypt(String algorithm, String cipherText, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException {
public static String decrypt(String algorithm, String cipherText, String password) throws NoSuchPaddingException, NoSuchAlgorithmException, InvalidAlgorithmParameterException, InvalidKeyException, BadPaddingException, IllegalBlockSizeException, SecurityError {
byte[] decoded = Base64.getDecoder().decode(cipherText);
@ -86,13 +87,13 @@ public final class Pass2Key {
}
}
private static SecretKey deriveKey(String password, byte[] salt) {
private static SecretKey deriveKey(String password, byte[] salt) throws SecurityError {
try {
SecretKeyFactory factory = SecretKeyFactory.getInstance("PBKDF2WithHmacSHA256");
KeySpec spec = new PBEKeySpec(password.toCharArray(), salt, ITERATION_COUNT, 256);
return new SecretKeySpec(factory.generateSecret(spec).getEncoded(), "AES");
} catch (InvalidKeySpecException | NoSuchAlgorithmException e) {
throw new RuntimeException("Could not derive key: " + e.getMessage());
throw new SecurityError("Could not derive key: " + e.getMessage(), e);
}
}

View file

@ -36,7 +36,7 @@ class SecureCompletionClientE2ETest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("E2E: Generate keys, save to disk, load in new client, validate")
void e2e_fullLifecycle_generateSaveLoadValidate() {
void e2e_fullLifecycle_generateSaveLoadValidate() throws Exception {
// Step 1: Generate keys and save to disk
SecureCompletionClient generateClient = new SecureCompletionClient(BASE_URL, false, true, 2);
generateClient.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
@ -77,7 +77,7 @@ class SecureCompletionClientE2ETest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("E2E: Generate plaintext keys, load, and validate")
void e2e_plaintextKeys_generateLoadValidate() {
void e2e_plaintextKeys_generateLoadValidate() throws Exception {
// Generate plaintext keys (no password)
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(true, keyDir.getAbsolutePath(), null);
@ -296,7 +296,7 @@ class SecureCompletionClientE2ETest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("E2E: Encrypted key file is unreadable without password")
@DisplayName("E2E: Encrypted key file throws SecurityError without correct password")
void e2e_encryptedKey_unreadableWithoutPassword() throws Exception {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
@ -308,17 +308,19 @@ class SecureCompletionClientE2ETest {
assertFalse(encryptedContent.contains("BEGIN PRIVATE KEY"),
"Encrypted file should not contain PEM header");
// Try loading with wrong password - should not throw
// Try loading with wrong password - should throw SecurityError
SecureCompletionClient loadClient = new SecureCompletionClient();
assertDoesNotThrow(() ->
SecurityError error = assertThrows(SecurityError.class, () ->
loadClient.loadKeys(privateKeyFile.getAbsolutePath(), null, "wrong-password"),
"Wrong password should be handled gracefully"
"Wrong password should throw SecurityError"
);
assertTrue(error.getMessage().contains("decrypt") || error.getMessage().contains("password"),
"Error message should mention decryption or password");
}
@Test
@DisplayName("E2E: Generate keys without saving produces in-memory keys")
void e2e_generateKeys_noSave_producesInMemoryKeys() {
void e2e_generateKeys_noSave_producesInMemoryKeys() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);

View file

@ -28,7 +28,7 @@ class SecureCompletionClientTest {
@Test
@DisplayName("generateKeys should create 4096-bit RSA key pair")
void generateKeys_shouldCreateValidKeyPair() {
void generateKeys_shouldCreateValidKeyPair() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
@ -58,7 +58,7 @@ class SecureCompletionClientTest {
@Test
@DisplayName("generateKeys should create unique keys on each call")
void generateKeys_shouldProduceUniqueKeys() {
void generateKeys_shouldProduceUniqueKeys() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
PrivateKey firstKey = client.getPrivateKey();
@ -76,7 +76,7 @@ class SecureCompletionClientTest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("generateKeys with saveToFile=true should create key files")
void generateKeys_withSaveToFile_shouldCreateKeyFiles(@TempDir Path tempDir) {
void generateKeys_withSaveToFile_shouldCreateKeyFiles(@TempDir Path tempDir) throws SecurityError {
File keyDir = tempDir.toFile();
SecureCompletionClient client = new SecureCompletionClient();
@ -112,7 +112,7 @@ class SecureCompletionClientTest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("generateKeys should not overwrite existing key files")
void generateKeys_shouldNotOverwriteExistingKeys(@TempDir Path tempDir) {
void generateKeys_shouldNotOverwriteExistingKeys(@TempDir Path tempDir) throws SecurityError {
File keyDir = tempDir.toFile();
SecureCompletionClient client = new SecureCompletionClient();
@ -129,10 +129,10 @@ class SecureCompletionClientTest {
// Key Loading Tests
@Test
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("loadKeys should load plaintext private key from file")
void loadKeys_plaintext_shouldLoadPrivateKey(@TempDir Path tempDir) {
void loadKeys_plaintext_shouldLoadPrivateKey(@TempDir Path tempDir) throws Exception {
File keyDir = tempDir.toFile();
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(true, keyDir.getAbsolutePath(), null);
@ -151,10 +151,10 @@ class SecureCompletionClientTest {
"Loaded key should have same size as original");
}
@Test
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("loadKeys should load encrypted private key with correct password")
void loadKeys_encrypted_correctPassword_shouldLoadPrivateKey(@TempDir Path tempDir) {
void loadKeys_encrypted_correctPassword_shouldLoadPrivateKey(@TempDir Path tempDir) throws Exception {
File keyDir = tempDir.toFile();
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
@ -176,30 +176,33 @@ class SecureCompletionClientTest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("loadKeys should handle wrong password gracefully")
void loadKeys_encrypted_wrongPassword_shouldHandleGracefully(@TempDir Path tempDir) {
@DisplayName("loadKeys should throw SecurityError for wrong password")
void loadKeys_encrypted_wrongPassword_shouldThrowSecurityError(@TempDir Path tempDir) throws SecurityError {
File keyDir = tempDir.toFile();
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(true, keyDir.getAbsolutePath(), TEST_PASSWORD);
SecureCompletionClient loadClient = new SecureCompletionClient();
assertDoesNotThrow(() ->
SecurityError error = assertThrows(SecurityError.class, () ->
loadClient.loadKeys(
new File(keyDir, Constants.DEFAULT_PRIVATE_KEY_FILE).getAbsolutePath(),
null,
"wrong-password"
),
"Wrong password should not throw exception"
"Wrong password should throw SecurityError"
);
assertTrue(error.getMessage().contains("decrypt") || error.getMessage().contains("password"),
"Error message should mention decryption or password");
}
@Test
@DisplayName("loadKeys should throw exception for non-existent file")
void loadKeys_nonExistentFile_shouldThrowException() {
void loadKeys_nonExistentFile_shouldThrowException() {
SecureCompletionClient loadClient = new SecureCompletionClient();
RuntimeException error = assertThrows(RuntimeException.class, () ->
SecurityError error = assertThrows(SecurityError.class, () ->
loadClient.loadKeys("/non/existent/path/private_key.pem", null, null));
assertTrue(error.getMessage().contains("not found"),
@ -210,7 +213,7 @@ class SecureCompletionClientTest {
@Test
@DisplayName("validateRsaKey should accept valid 4096-bit key")
void validateRsaKey_validKey_shouldPass() {
void validateRsaKey_validKey_shouldPass() throws SecurityError {
SecureCompletionClient client = new SecureCompletionClient();
client.generateKeys(false);
PrivateKey key = client.getPrivateKey();
@ -264,7 +267,7 @@ class SecureCompletionClientTest {
@Test
@Execution(ExecutionMode.SAME_THREAD)
@DisplayName("Full roundtrip: generate, save, load should produce same key")
void roundtrip_generateSaveLoad_shouldProduceSameKey(@TempDir Path tempDir) {
void roundtrip_generateSaveLoad_shouldProduceSameKey(@TempDir Path tempDir) throws Exception {
File keyDir = tempDir.toFile();
// Generate and save
@ -407,10 +410,6 @@ class SecureCompletionClientTest {
tempClient.generateKeys(false);
PrivateKey originalKey = tempClient.getPrivateKey();
String pem = "-----BEGIN PRIVATE KEY-----\n " +
originalKey.getEncoded().length + "lines\n" +
"-----END PRIVATE KEY-----";
String formattedPem = ai.nomyo.util.PEMConverter.toPEM(originalKey.getEncoded(), true);
String pemWithWhitespace = formattedPem.replace("\n", "\n ");