pub mod api; pub mod auth; pub mod config; pub mod graph_id; pub mod identity; pub mod policy; pub mod workload; pub use graph_id::GraphId; pub use identity::{AuthSource, GraphKey, ResolvedActor, Scope, TenantId}; use std::collections::{HashMap, HashSet}; use std::fs; use std::io; use std::io::Write; use std::path::PathBuf; use std::sync::Arc; use api::{ BranchCreateOutput, BranchCreateRequest, BranchDeleteOutput, BranchListOutput, BranchMergeOutput, BranchMergeRequest, ChangeOutput, ChangeRequest, CommitListOutput, CommitListQuery, ErrorCode, ErrorOutput, ExportRequest, HealthOutput, IngestOutput, IngestRequest, ReadOutput, ReadRequest, SchemaApplyOutput, SchemaApplyRequest, SchemaOutput, SnapshotQuery, ingest_output, schema_apply_output, snapshot_payload, }; pub use auth::{AWS_SECRET_ENV, EnvOrFileTokenSource, TokenSource, resolve_token_source}; use axum::body::{Body, Bytes}; use axum::extract::DefaultBodyLimit; use axum::extract::{Extension, Path, Query, Request, State}; use axum::http::StatusCode; use axum::http::header::{AUTHORIZATION, CONTENT_TYPE}; use axum::middleware::{self, Next}; use axum::response::{IntoResponse, Response}; use axum::routing::{delete, get, post}; use axum::{Json, Router}; use color_eyre::eyre::{Result, WrapErr, bail}; pub use config::{ AliasCommand, AliasConfig, CliDefaults, DEFAULT_CONFIG_FILE, OmnigraphConfig, PolicySettings, ProjectConfig, QueryDefaults, ReadOutputFormat, ServerDefaults, TableCellLayout, TargetConfig, load_config, }; use futures::stream; use omnigraph::db::{Omnigraph, ReadTarget}; use omnigraph::error::{ManifestConflictDetails, ManifestErrorKind, OmniError}; use omnigraph_compiler::json_params_to_param_map; use omnigraph_compiler::query::parser::parse_query; use omnigraph_compiler::{JsonParamMode, ParamMap}; pub use policy::{ PolicyAction, PolicyCompiler, PolicyConfig, PolicyDecision, PolicyEngine, PolicyExpectation, PolicyRequest, PolicyTestConfig, }; use serde_json::Value; use sha2::{Digest, Sha256}; use subtle::ConstantTimeEq; use tokio::net::TcpListener; use tokio::sync::mpsc; use tower_http::trace::TraceLayer; use tracing::{error, info, warn}; use tracing_subscriber::EnvFilter; use utoipa::OpenApi; use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme}; type BearerTokenHash = [u8; 32]; fn hash_bearer_token(token: &str) -> BearerTokenHash { let digest = Sha256::digest(token.as_bytes()); let mut out = [0u8; 32]; out.copy_from_slice(&digest); out } #[derive(OpenApi)] #[openapi( info( title = "Omnigraph API", description = "HTTP API for the Omnigraph graph database", ), paths( server_health, server_snapshot, server_read, server_export, server_change, server_schema_apply, server_schema_get, server_ingest, server_branch_list, server_branch_create, server_branch_delete, server_branch_merge, server_commit_list, server_commit_show, ), modifiers(&SecurityAddon), )] pub struct ApiDoc; struct SecurityAddon; impl utoipa::Modify for SecurityAddon { fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) { openapi .components .get_or_insert_with(Default::default) .add_security_scheme( "bearer_token", SecurityScheme::Http(Http::new(HttpAuthScheme::Bearer)), ); } } const DEFAULT_REQUEST_BODY_LIMIT_BYTES: usize = 1_048_576; const INGEST_REQUEST_BODY_LIMIT_BYTES: usize = 32 * 1024 * 1024; const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION"); const SERVER_SOURCE_VERSION: Option<&str> = option_env!("OMNIGRAPH_SOURCE_VERSION"); #[derive(Debug, Clone)] pub struct ServerConfig { pub uri: String, pub bind: String, pub policy_file: Option, /// Operator opt-in for fully-unauthenticated dev mode (MR-723). /// When neither bearer tokens nor a policy file are configured, /// `serve()` refuses to start unless this is true (set via /// `--unauthenticated` or `OMNIGRAPH_UNAUTHENTICATED=1`). The /// motivation is that "no tokens + no policy" looks like protection /// (no Cedar errors at boot) but is actually fully open — operators /// who set up auth and forgot the policy file would otherwise ship /// the illusion of protection. pub allow_unauthenticated: bool, } #[derive(Clone)] pub struct AppState { uri: String, /// PR 2 (MR-686): the engine is now `Arc` — no global /// write lock. Concurrent handlers call `&self` engine APIs /// directly. Per-(table, branch) write queues inside the engine /// serialize same-key writers; per-actor admission control on /// `workload` isolates noisy actors. engine: Arc, /// Per-actor admission control. See `workload::WorkloadController`. workload: Arc, bearer_tokens: Arc<[(BearerTokenHash, Arc)]>, policy_engine: Option>, } #[derive(Debug, Clone)] struct AuthenticatedActor(Arc); struct ExportStreamWriter { sender: mpsc::UnboundedSender>, } impl Write for ExportStreamWriter { fn write(&mut self, buf: &[u8]) -> io::Result { self.sender .send(Ok(Bytes::copy_from_slice(buf))) .map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "export stream closed"))?; Ok(buf.len()) } fn flush(&mut self) -> io::Result<()> { Ok(()) } } impl AuthenticatedActor { fn as_str(&self) -> &str { &self.0 } } #[derive(Debug)] pub struct ApiError { status: StatusCode, code: ErrorCode, message: String, merge_conflicts: Vec, manifest_conflict: Option, } impl AppState { pub fn new(uri: String, db: Omnigraph) -> Self { Self::new_with_bearer_tokens(uri, db, Vec::new()) } pub fn new_with_bearer_token(uri: String, db: Omnigraph, bearer_token: Option) -> Self { let bearer_tokens = normalize_bearer_token(bearer_token) .into_iter() .map(|token| ("default".to_string(), token)) .collect(); Self::new_with_bearer_tokens(uri, db, bearer_tokens) } pub fn new_with_bearer_tokens( uri: String, db: Omnigraph, bearer_tokens: Vec<(String, String)>, ) -> Self { Self::new_with_bearer_tokens_and_policy(uri, db, bearer_tokens, None) } pub fn new_with_bearer_tokens_and_policy( uri: String, db: Omnigraph, bearer_tokens: Vec<(String, String)>, policy_engine: Option, ) -> Self { let bearer_tokens: Vec<(BearerTokenHash, Arc)> = bearer_tokens .into_iter() .map(|(actor, token)| (hash_bearer_token(&token), Arc::::from(actor))) .collect(); let policy_engine: Option> = policy_engine.map(Arc::new); // MR-722 chassis: inject the policy checker into the engine so // `Omnigraph::apply_schema_as` (and PR #3's fan-out of the // remaining writers) gates at engine-layer too. HTTP-layer // `authorize_request` still fires first; the engine-layer gate // is the redundant-but-correct backstop, plus the only path // that protects SDK / embedded callers. PR #3 removes the HTTP // redundancy once we're confident the engine gate covers it. let db = if let Some(engine) = policy_engine.as_ref() { // Unsizing coercion: Arc → Arc. // Needs the explicit `as` cast — Rust 2024 doesn't infer it through // `Arc::clone`. let checker = Arc::clone(engine) as Arc; db.with_policy(checker) } else { db }; Self { uri, engine: Arc::new(db), workload: Arc::new(workload::WorkloadController::from_env()), bearer_tokens: Arc::from(bearer_tokens), policy_engine, } } /// Construct with a caller-provided [`workload::WorkloadController`]. /// Tests and benches use this to override per-actor caps without /// mutating global env vars (which is unsafe in Rust 2024 once the /// async runtime is up — `setenv` isn't thread-safe). pub fn new_with_workload( uri: String, db: Omnigraph, bearer_tokens: Vec<(String, String)>, workload: workload::WorkloadController, ) -> Self { let bearer_tokens: Vec<(BearerTokenHash, Arc)> = bearer_tokens .into_iter() .map(|(actor, token)| (hash_bearer_token(&token), Arc::::from(actor))) .collect(); Self { uri, engine: Arc::new(db), workload: Arc::new(workload), bearer_tokens: Arc::from(bearer_tokens), policy_engine: None, } } /// Install a `PolicyEngine` post-construction (MR-723). Used by /// integration tests that need to thread custom workload limits /// alongside a permit-all policy — the existing `new_with_*` and /// `new_with_workload` constructors don't compose. Production /// callers should use `open_with_bearer_tokens_and_policy` which /// installs the policy on both the HTTP state and the engine. pub fn with_policy_engine(mut self, engine: PolicyEngine) -> Self { self.policy_engine = Some(Arc::new(engine)); self } pub async fn open(uri: impl Into) -> Result { Self::open_with_bearer_token(uri, None).await } pub async fn open_with_bearer_token( uri: impl Into, bearer_token: Option, ) -> Result { let bearer_tokens = normalize_bearer_token(bearer_token) .into_iter() .map(|token| ("default".to_string(), token)) .collect(); Self::open_with_bearer_tokens(uri, bearer_tokens).await } pub async fn open_with_bearer_tokens( uri: impl Into, bearer_tokens: Vec<(String, String)>, ) -> Result { let uri = uri.into(); let db = Omnigraph::open(&uri).await?; Ok(Self::new_with_bearer_tokens(uri, db, bearer_tokens)) } pub async fn open_with_bearer_tokens_and_policy( uri: impl Into, bearer_tokens: Vec<(String, String)>, policy_file: Option<&PathBuf>, ) -> Result { let uri = uri.into(); let db = Omnigraph::open(&uri).await?; let policy_engine = match policy_file { Some(path) => Some(PolicyEngine::load(path, &uri)?), None => None, }; if policy_engine.is_some() && bearer_tokens.is_empty() { bail!("policy requires at least one configured bearer token actor"); } Ok(Self::new_with_bearer_tokens_and_policy( uri, db, bearer_tokens, policy_engine, )) } pub fn uri(&self) -> &str { &self.uri } fn requires_bearer_auth(&self) -> bool { !self.bearer_tokens.is_empty() || self.policy_engine.is_some() } fn authenticate_bearer_token(&self, provided_token: &str) -> Option> { // Hash the incoming token and compare against every stored digest in // constant time. Iterate all entries unconditionally so total work — // and therefore response timing — doesn't depend on which slot matches. let provided_hash = hash_bearer_token(provided_token); let mut matched: Option> = None; for (hash, actor) in self.bearer_tokens.iter() { if bool::from(hash.ct_eq(&provided_hash)) && matched.is_none() { matched = Some(Arc::clone(actor)); } } matched } fn policy_engine(&self) -> Option<&PolicyEngine> { self.policy_engine.as_deref() } } impl ApiError { pub fn unauthorized(message: impl Into) -> Self { Self { status: StatusCode::UNAUTHORIZED, code: ErrorCode::Unauthorized, message: message.into(), merge_conflicts: Vec::new(), manifest_conflict: None, } } pub fn forbidden(message: impl Into) -> Self { Self { status: StatusCode::FORBIDDEN, code: ErrorCode::Forbidden, message: message.into(), merge_conflicts: Vec::new(), manifest_conflict: None, } } pub fn bad_request(message: impl Into) -> Self { Self { status: StatusCode::BAD_REQUEST, code: ErrorCode::BadRequest, message: message.into(), merge_conflicts: Vec::new(), manifest_conflict: None, } } pub fn not_found(message: impl Into) -> Self { Self { status: StatusCode::NOT_FOUND, code: ErrorCode::NotFound, message: message.into(), merge_conflicts: Vec::new(), manifest_conflict: None, } } pub fn conflict(message: impl Into) -> Self { Self { status: StatusCode::CONFLICT, code: ErrorCode::Conflict, message: message.into(), merge_conflicts: Vec::new(), manifest_conflict: None, } } pub fn internal(message: impl Into) -> Self { Self { status: StatusCode::INTERNAL_SERVER_ERROR, code: ErrorCode::Internal, message: message.into(), merge_conflicts: Vec::new(), manifest_conflict: None, } } /// HTTP 429 Too Many Requests — actor exceeded their per-actor /// admission cap (count or byte budget). Clients should respect the /// `Retry-After` header. Mapped from `RejectReason::InFlightCountExceeded` /// and `RejectReason::ByteBudgetExceeded`. pub fn too_many_requests(message: impl Into) -> Self { Self { status: StatusCode::TOO_MANY_REQUESTS, code: ErrorCode::TooManyRequests, message: message.into(), merge_conflicts: Vec::new(), manifest_conflict: None, } } /// Convert a `WorkloadController` rejection into the matching /// `ApiError` variant. pub fn from_workload_reject(reject: workload::RejectReason) -> Self { match reject { workload::RejectReason::InFlightCountExceeded { .. } | workload::RejectReason::ByteBudgetExceeded { .. } => { Self::too_many_requests(reject.to_string()) } } } fn merge_conflict(conflicts: Vec) -> Self { Self { status: StatusCode::CONFLICT, code: ErrorCode::Conflict, message: summarize_merge_conflicts(&conflicts), merge_conflicts: conflicts, manifest_conflict: None, } } fn manifest_version_conflict(message: String, details: api::ManifestConflictOutput) -> Self { Self { status: StatusCode::CONFLICT, code: ErrorCode::Conflict, message, merge_conflicts: Vec::new(), manifest_conflict: Some(details), } } fn from_omni(err: OmniError) -> Self { match err { OmniError::Compiler(err) => Self::bad_request(err.to_string()), OmniError::DataFusion(message) => Self::bad_request(format!("query: {message}")), OmniError::Manifest(err) => match err.kind { ManifestErrorKind::BadRequest => Self::bad_request(err.message), ManifestErrorKind::NotFound => Self::not_found(err.message), ManifestErrorKind::Conflict => match err.details { Some(ManifestConflictDetails::ExpectedVersionMismatch { table_key, expected, actual, }) => Self::manifest_version_conflict( err.message, api::ManifestConflictOutput { table_key, expected, actual, }, ), _ => Self::conflict(err.message), }, ManifestErrorKind::Internal => Self::internal(err.message), }, OmniError::MergeConflicts(conflicts) => Self::merge_conflict( conflicts .iter() .map(api::MergeConflictOutput::from) .collect(), ), OmniError::Lance(message) => Self::internal(format!("storage: {message}")), OmniError::Io(err) => Self::internal(format!("io: {err}")), // Engine-layer policy enforcement (MR-722). All denials and // evaluation failures surface here as 403. The HTTP-layer // `authorize_request` already distinguishes 401 (missing // bearer) from 403 (policy denial), so by the time the // engine gate fires, the bearer is valid — any failure from // the engine is a policy outcome, not an auth one. OmniError::Policy(message) => Self::forbidden(message), } } } fn summarize_merge_conflicts(conflicts: &[api::MergeConflictOutput]) -> String { if conflicts.is_empty() { return "merge conflicts".to_string(); } let preview: Vec = conflicts .iter() .take(3) .map(|conflict| match conflict.row_id.as_deref() { Some(row_id) => format!( "{}:{} ({})", conflict.table_key, row_id, conflict.kind.as_str() ), None => format!("{} ({})", conflict.table_key, conflict.kind.as_str()), }) .collect(); let suffix = if conflicts.len() > preview.len() { format!("; and {} more", conflicts.len() - preview.len()) } else { String::new() }; format!("merge conflicts: {}{}", preview.join("; "), suffix) } /// Constant `Retry-After` value (seconds) emitted on 429 responses. const RETRY_AFTER_SECONDS: &str = "60"; impl IntoResponse for ApiError { fn into_response(self) -> Response { let mut headers = axum::http::HeaderMap::new(); if matches!(self.code, ErrorCode::TooManyRequests) { headers.insert( axum::http::header::RETRY_AFTER, axum::http::HeaderValue::from_static(RETRY_AFTER_SECONDS), ); } ( self.status, headers, Json(ErrorOutput { error: self.message, code: Some(self.code), merge_conflicts: self.merge_conflicts, manifest_conflict: self.manifest_conflict, }), ) .into_response() } } pub fn init_tracing() { let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info")); let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init(); } pub fn load_server_settings( config_path: Option<&PathBuf>, cli_uri: Option, cli_target: Option, cli_bind: Option, cli_allow_unauthenticated: bool, ) -> Result { let config = load_config(config_path)?; let uri = config.resolve_target_uri(cli_uri, cli_target.as_deref(), config.server_graph_name())?; let bind = cli_bind.unwrap_or_else(|| config.server_bind().to_string()); let policy_file = config.resolve_policy_file(); // Either `--unauthenticated` or `OMNIGRAPH_UNAUTHENTICATED=1` flips // this. Treat any non-empty, non-"0"/"false" string as truthy — // standard 12-factor "any value is true" reading of the env var. let env_unauth = std::env::var("OMNIGRAPH_UNAUTHENTICATED") .ok() .map(|v| { let trimmed = v.trim(); !trimmed.is_empty() && trimmed != "0" && !trimmed.eq_ignore_ascii_case("false") }) .unwrap_or(false); let allow_unauthenticated = cli_allow_unauthenticated || env_unauth; Ok(ServerConfig { uri, bind, policy_file, allow_unauthenticated, }) } /// MR-723 server runtime state, classified from the three-state matrix /// of (bearer tokens configured) × (policy file configured) at startup. /// /// * **Open** — neither tokens nor policy; requires explicit /// `allow_unauthenticated`. Effectively a "trust the network" dev /// mode. `serve()` refuses to start in this shape without the flag, /// so the only way to reach this state at runtime is via deliberate /// operator opt-in. /// * **DefaultDeny** — tokens configured but no policy file. The /// server requires a valid bearer token; once authenticated, every /// action except `Read` is denied with 403. Closes the "tokens but /// forgot the policy file" trap. /// * **PolicyEnabled** — policy file configured. Cedar evaluates every /// authenticated request. Tokens may also be configured (typical) or /// not (unusual but valid — every request fails 401 without a /// bearer, which is effectively "locked"). #[derive(Debug, Clone, Copy, Eq, PartialEq)] pub enum ServerRuntimeState { Open, DefaultDeny, PolicyEnabled, } /// Compute the [`ServerRuntimeState`] from the configured inputs. /// Pulled out as a pure function so the 3-state matrix is unit-testable /// without standing up the full server. pub fn classify_server_runtime_state( has_tokens: bool, has_policy: bool, allow_unauthenticated: bool, ) -> Result { match (has_tokens, has_policy, allow_unauthenticated) { (false, false, false) => bail!( "server has no bearer tokens and no policy file configured. This is a fully \ open server — pass `--unauthenticated` (or set OMNIGRAPH_UNAUTHENTICATED=1) \ if you actually want that, otherwise configure bearer tokens (see \ docs/user/server.md) and/or `policy.file` in omnigraph.yaml." ), (false, false, true) => Ok(ServerRuntimeState::Open), (true, false, _) => Ok(ServerRuntimeState::DefaultDeny), (_, true, _) => Ok(ServerRuntimeState::PolicyEnabled), } } pub fn build_app(state: AppState) -> Router { let protected = Router::new() .route("/snapshot", get(server_snapshot)) .route("/export", post(server_export)) .route("/read", post(server_read)) .route("/change", post(server_change)) .route("/schema", get(server_schema_get)) .route("/schema/apply", post(server_schema_apply)) .route( "/ingest", post(server_ingest).layer(DefaultBodyLimit::max(INGEST_REQUEST_BODY_LIMIT_BYTES)), ) .route( "/branches", get(server_branch_list).post(server_branch_create), ) .route("/branches/{branch}", delete(server_branch_delete)) .route("/branches/merge", post(server_branch_merge)) .route("/commits", get(server_commit_list)) .route("/commits/{commit_id}", get(server_commit_show)) .route_layer(middleware::from_fn_with_state( state.clone(), require_bearer_auth, )); Router::new() .route("/healthz", get(server_health)) .route("/openapi.json", get(server_openapi)) .merge(protected) .layer(DefaultBodyLimit::max(DEFAULT_REQUEST_BODY_LIMIT_BYTES)) .layer(TraceLayer::new_for_http()) .with_state(state) } pub async fn serve(config: ServerConfig) -> Result<()> { let token_source = resolve_token_source().await?; info!(source = token_source.name(), "loaded bearer token source"); let tokens = token_source.load().await?; let runtime_state = classify_server_runtime_state( !tokens.is_empty(), config.policy_file.is_some(), config.allow_unauthenticated, )?; match runtime_state { ServerRuntimeState::Open => warn!( "running with --unauthenticated: no bearer tokens, no policy file, all \ requests permitted. This is for local dev only — do not expose to a \ network you don't fully trust." ), ServerRuntimeState::DefaultDeny => warn!( "bearer tokens are configured but no policy file is set — running in \ default-deny mode (only `read` actions are permitted for authenticated \ actors). Configure `policy.file` in omnigraph.yaml to enable Cedar rules." ), ServerRuntimeState::PolicyEnabled => {} } let state = AppState::open_with_bearer_tokens_and_policy( config.uri.clone(), tokens, config.policy_file.as_ref(), ) .await?; let listener = TcpListener::bind(&config.bind).await?; info!(uri = %config.uri, bind = %config.bind, "serving omnigraph"); axum::serve(listener, build_app(state)) .with_graceful_shutdown(shutdown_signal()) .await?; Ok(()) } async fn shutdown_signal() { if let Err(err) = tokio::signal::ctrl_c().await { error!(error = %err, "failed to install ctrl-c handler"); return; } info!("shutdown signal received"); } #[utoipa::path( get, path = "/healthz", tag = "health", operation_id = "health", responses( (status = 200, description = "Server is healthy", body = HealthOutput), ), )] /// Liveness probe. /// /// Returns server status and version. Unauthenticated; safe to call from any /// caller. Use this to confirm the server is reachable before invoking other /// endpoints. async fn server_health() -> Json { Json(HealthOutput { status: "ok".to_string(), version: SERVER_VERSION.to_string(), source_version: SERVER_SOURCE_VERSION.map(str::to_string), }) } async fn server_openapi(State(state): State) -> Json { let mut doc = ApiDoc::openapi(); if !state.requires_bearer_auth() { strip_security(&mut doc); } Json(doc) } fn strip_security(doc: &mut utoipa::openapi::OpenApi) { if let Some(components) = doc.components.as_mut() { components.security_schemes.clear(); } for path_item in doc.paths.paths.values_mut() { for op in [ path_item.get.as_mut(), path_item.post.as_mut(), path_item.put.as_mut(), path_item.delete.as_mut(), path_item.options.as_mut(), path_item.head.as_mut(), path_item.patch.as_mut(), path_item.trace.as_mut(), ] .into_iter() .flatten() { op.security = None; } } } async fn require_bearer_auth( State(state): State, mut request: Request, next: Next, ) -> std::result::Result { if !state.requires_bearer_auth() { return Ok(next.run(request).await); } let Some(header) = request .headers() .get(AUTHORIZATION) .and_then(|value| value.to_str().ok()) else { return Err(ApiError::unauthorized("missing bearer token")); }; let Some(provided_token) = header.strip_prefix("Bearer ") else { return Err(ApiError::unauthorized("missing bearer token")); }; let Some(actor) = state.authenticate_bearer_token(provided_token) else { return Err(ApiError::unauthorized("invalid bearer token")); }; request.extensions_mut().insert(AuthenticatedActor(actor)); Ok(next.run(request).await) } fn log_policy_decision(actor_id: &str, request: &PolicyRequest, decision: &PolicyDecision) { info!( actor_id = actor_id, action = %request.action, branch = request.branch.as_deref().unwrap_or(""), target_branch = request.target_branch.as_deref().unwrap_or(""), allowed = decision.allowed, matched_rule_id = decision.matched_rule_id.as_deref().unwrap_or(""), "policy decision" ); } fn authorize_request( state: &AppState, actor: Option<&AuthenticatedActor>, mut request: PolicyRequest, ) -> std::result::Result<(), ApiError> { let Some(engine) = state.policy_engine() else { // MR-723 default-deny path. We're here when no PolicyEngine is // installed. Two startup-validated shapes can reach this: // // * **Open mode** (`--unauthenticated`): no tokens, no policy. // `require_bearer_auth` short-circuits before this is called, // but defense in depth — if a future change makes the // middleware call here for an unauthenticated request, we // want every action to remain Ok rather than 403. The // operator opted in. // * **DefaultDeny mode**: tokens configured but no policy. The // request went through bearer auth, so `actor` is Some and // identifies a known actor. Only `Read` is permitted; every // other action returns 403. This closes the "configured auth // but forgot the policy file" trap from MR-723. if actor.is_some() && request.action != PolicyAction::Read { return Err(ApiError::forbidden( "server runs in default-deny mode (bearer tokens configured but no \ policy file). Only `read` actions are permitted; configure \ `policy.file` in omnigraph.yaml to enable other actions.", )); } return Ok(()); }; let Some(actor) = actor else { return Err(ApiError::unauthorized("missing bearer token")); }; // SECURITY INVARIANT (MR-731): actor identity comes from the matched // bearer token, never from a client-supplied request header, query // parameter, or body field. This line is the single chokepoint where // the authoritative actor (resolved from the bearer match by // `require_bearer_auth`) overwrites whatever the handler put in the // PolicyRequest. Removing or weakening it lets clients spoof identity — // exactly the Supabase RLS footgun ("trusting raw_user_meta_data is // asking the attacker if they're an admin"). The principle is codified // in `docs/dev/invariants.md` Hard Invariant 11 ("clients cannot set // actor identity directly") and pinned by the regression test // `actor_id_resolves_from_bearer_token_ignoring_client_supplied_headers` // in `crates/omnigraph-server/tests/server.rs`. // // Side effect: also prevents an empty-string default at any handler // call site from ever reaching the engine as a policy subject. request.actor_id = actor.as_str().to_string(); let decision = engine .authorize(&request) .map_err(|err| ApiError::internal(format!("policy: {err}")))?; log_policy_decision(actor.as_str(), &request, &decision); if decision.allowed { Ok(()) } else { Err(ApiError::forbidden(decision.message)) } } #[utoipa::path( get, path = "/snapshot", tag = "snapshots", operation_id = "getSnapshot", params(SnapshotQuery), responses( (status = 200, description = "Database snapshot", body = api::SnapshotOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Read the current snapshot of a branch. /// /// Returns the manifest version plus per-table metadata (path, version, row /// count) for every table on the branch. Defaults to `main` when `branch` is /// omitted. Read-only. async fn server_snapshot( State(state): State, actor: Option>, Query(query): Query, ) -> std::result::Result, ApiError> { let branch = query.branch.unwrap_or_else(|| "main".to_string()); authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::Read, branch: Some(branch.clone()), target_branch: None, }, )?; let snapshot = { let db = &state.engine; db.snapshot_of(ReadTarget::branch(branch.as_str())) .await .map_err(ApiError::from_omni)? }; Ok(Json(snapshot_payload(&branch, &snapshot))) } #[utoipa::path( post, path = "/read", tag = "queries", operation_id = "read", request_body = ReadRequest, responses( (status = 200, description = "Query results", body = ReadOutput), (status = 400, description = "Bad request", body = ErrorOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Execute a GQ read query. /// /// Runs the query in `query_source` against either a branch or a frozen /// snapshot (mutually exclusive). When `query_source` defines multiple named /// queries, pick one with `query_name`. `params` is a JSON object whose keys /// match the parameters declared by the query. Returns rows as a JSON array /// plus a `columns` list. Read-only. async fn server_read( State(state): State, actor: Option>, Json(request): Json, ) -> std::result::Result, ApiError> { if request.branch.is_some() && request.snapshot.is_some() { return Err(ApiError::bad_request( "read request may specify branch or snapshot, not both", )); } let target = read_target_from_request(request.branch, request.snapshot); let policy_branch = match &target { ReadTarget::Branch(branch) => Some(branch.clone()), ReadTarget::Snapshot(_) if state.policy_engine().is_some() && actor.is_some() => { let db = &state.engine; db.resolved_branch_of(target.clone()) .await .map(|branch| branch.or_else(|| Some("main".to_string()))) .map_err(ApiError::from_omni)? } ReadTarget::Snapshot(_) => None, }; authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::Read, branch: policy_branch, target_branch: None, }, )?; let (selected_name, query_params) = select_named_query(&request.query_source, request.query_name.as_deref()) .map_err(|err| ApiError::bad_request(err.to_string()))?; let params = query_params_from_json(&query_params, request.params.as_ref()) .map_err(|err| ApiError::bad_request(err.to_string()))?; let result = { let db = &state.engine; db.query( target.clone(), &request.query_source, &selected_name, ¶ms, ) .await .map_err(ApiError::from_omni)? }; Ok(Json(api::read_output(selected_name, &target, result))) } #[utoipa::path( post, path = "/export", tag = "queries", operation_id = "export", request_body = ExportRequest, responses( (status = 200, description = "Exported data as NDJSON", content_type = "application/x-ndjson"), (status = 400, description = "Bad request", body = ErrorOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Stream the contents of a branch as NDJSON. /// /// Emits one JSON object per line (`application/x-ndjson`). Filter with /// `type_names` (node/edge type names) and/or `table_keys`; both empty /// streams the entire branch. Suitable for large exports — the response is /// streamed, not buffered. Read-only. async fn server_export( State(state): State, actor: Option>, Json(request): Json, ) -> std::result::Result { let branch = request.branch.unwrap_or_else(|| "main".to_string()); authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::Export, branch: Some(branch.clone()), target_branch: None, }, )?; let engine = Arc::clone(&state.engine); let type_names = request.type_names.clone(); let table_keys = request.table_keys.clone(); let (tx, rx) = mpsc::unbounded_channel::>(); tokio::spawn(async move { let result = { let mut writer = ExportStreamWriter { sender: tx.clone() }; engine .export_jsonl_to_writer(&branch, &type_names, &table_keys, &mut writer) .await }; if let Err(err) = result { let _ = tx.send(Err(io::Error::other(err.to_string()))); } }); let body = Body::from_stream(stream::unfold(rx, |mut rx| async move { rx.recv().await.map(|item| (item, rx)) })); Ok(( StatusCode::OK, [(CONTENT_TYPE, "application/x-ndjson; charset=utf-8")], body, ) .into_response()) } #[utoipa::path( post, path = "/change", tag = "mutations", operation_id = "change", request_body = ChangeRequest, responses( (status = 200, description = "Mutation results", body = ChangeOutput), (status = 400, description = "Bad request", body = ErrorOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), (status = 409, description = "Merge conflict", body = ErrorOutput), (status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Apply a GQ mutation to a branch. /// /// Writes to the named `branch` (defaults to `main`). Mutations are atomic /// per call and produce a new commit. Returns counts of nodes and edges /// affected. **Destructive**: on success the branch is updated; rejected /// mutations may still acquire locks briefly. Returns 409 on merge conflict. async fn server_change( State(state): State, actor: Option>, Json(request): Json, ) -> std::result::Result, ApiError> { let branch = request.branch.unwrap_or_else(|| "main".to_string()); let actor_arc = actor .as_ref() .map(|Extension(actor)| Arc::clone(&actor.0)) .unwrap_or_else(|| Arc::::from("anonymous")); let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str()); authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor_id.map(str::to_string).unwrap_or_default(), action: PolicyAction::Change, branch: Some(branch.clone()), target_branch: None, }, )?; // Per-actor admission: bound concurrent in-flight mutations and // estimated bytes per actor. Cedar runs FIRST so denied requests // don't consume admission slots. Estimate uses the request body // size as a coarse proxy; engine memory pressure can run higher. let est_bytes = request.query_source.len() as u64 + request .params .as_ref() .map(|p| p.to_string().len() as u64) .unwrap_or(0); let _admission = state .workload .try_admit(&actor_arc, est_bytes) .map_err(ApiError::from_workload_reject)?; let (selected_name, query_params) = select_named_query(&request.query_source, request.query_name.as_deref()) .map_err(|err| ApiError::bad_request(err.to_string()))?; let params = query_params_from_json(&query_params, request.params.as_ref()) .map_err(|err| ApiError::bad_request(err.to_string()))?; let result = { let db = &state.engine; db.mutate_as( &branch, &request.query_source, &selected_name, ¶ms, actor_id, ) .await .map_err(ApiError::from_omni)? }; Ok(Json(ChangeOutput { branch, query_name: selected_name, affected_nodes: result.affected_nodes, affected_edges: result.affected_edges, actor_id: actor_id.map(str::to_string), })) } #[utoipa::path( get, path = "/schema", tag = "schema", operation_id = "getSchema", responses( (status = 200, description = "Current schema source", body = SchemaOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Read the current schema source. /// /// Returns the project's schema as a single string in `.pg` source form. /// Useful for clients that want to introspect available types and tables /// before constructing GQ queries. Read-only. async fn server_schema_get( State(state): State, actor: Option>, ) -> std::result::Result, ApiError> { authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::Read, branch: None, target_branch: None, }, )?; let schema_source = { let db = &state.engine; db.schema_source().to_string() }; Ok(Json(SchemaOutput { schema_source })) } #[utoipa::path( post, path = "/schema/apply", tag = "mutations", operation_id = "applySchema", request_body = SchemaApplyRequest, responses( (status = 200, description = "Schema apply results", body = SchemaApplyOutput), (status = 400, description = "Bad request", body = ErrorOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), (status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Apply a schema migration. /// /// Diffs `schema_source` against the current schema and applies the resulting /// migration steps (add/drop type, add/drop column, etc.). **Destructive**: /// some steps drop data. Returns the list of steps applied; if `applied` is /// false the diff was unsupported and no changes were made. async fn server_schema_apply( State(state): State, actor: Option>, Json(request): Json, ) -> std::result::Result, ApiError> { let actor_arc = actor .as_ref() .map(|Extension(actor)| Arc::clone(&actor.0)) .unwrap_or_else(|| Arc::::from("anonymous")); let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str()); authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor_id.map(str::to_string).unwrap_or_default(), action: PolicyAction::SchemaApply, branch: None, target_branch: Some("main".to_string()), }, )?; let est_bytes = request.schema_source.len() as u64; let _admission = state .workload .try_admit(&actor_arc, est_bytes) .map_err(ApiError::from_workload_reject)?; let result = { let db = &state.engine; // Engine-layer policy enforcement (MR-722): pass the resolved // actor through so apply_schema_as can call enforce() with the // authoritative identity. With a policy installed in AppState, // engine-side enforcement re-checks the same decision the // HTTP-layer authorize_request just made above. PR #3 collapses // the redundancy. db.apply_schema_as( &request.schema_source, omnigraph::db::SchemaApplyOptions { allow_data_loss: request.allow_data_loss, }, actor_id, ) .await .map_err(ApiError::from_omni)? }; Ok(Json(schema_apply_output(state.uri(), result))) } #[utoipa::path( post, path = "/ingest", tag = "mutations", operation_id = "ingest", request_body = IngestRequest, responses( (status = 200, description = "Ingest results", body = IngestOutput), (status = 400, description = "Bad request", body = ErrorOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), (status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Bulk-ingest NDJSON data into a branch. /// /// `data` is NDJSON with one record per line. `mode` controls behavior on /// existing rows: `merge` upserts by id (default), `append` blindly inserts, /// `overwrite` replaces table contents. If `branch` does not exist it is /// created from `from` (defaults to `main`). **Destructive** when `mode` is /// `overwrite` or when ingest produces conflicting writes. async fn server_ingest( State(state): State, actor: Option>, Json(request): Json, ) -> std::result::Result, ApiError> { let branch = request.branch.unwrap_or_else(|| "main".to_string()); let from = request.from.unwrap_or_else(|| "main".to_string()); let mode = request.mode.unwrap_or(omnigraph::loader::LoadMode::Merge); let actor_arc = actor .as_ref() .map(|Extension(actor)| Arc::clone(&actor.0)) .unwrap_or_else(|| Arc::::from("anonymous")); let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str()); let branch_exists = { let db = &state.engine; db.branch_list() .await .map_err(ApiError::from_omni)? .into_iter() .any(|name| name == branch) }; if !branch_exists { authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor_id.map(str::to_string).unwrap_or_default(), action: PolicyAction::BranchCreate, branch: Some(from.clone()), target_branch: Some(branch.clone()), }, )?; } authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor_id.map(str::to_string).unwrap_or_default(), action: PolicyAction::Change, branch: Some(branch.clone()), target_branch: None, }, )?; let est_bytes = request.data.len() as u64; let _admission = state .workload .try_admit(&actor_arc, est_bytes) .map_err(ApiError::from_workload_reject)?; let result = { let db = &state.engine; db.ingest_as(&branch, Some(&from), &request.data, mode, actor_id) .await .map_err(ApiError::from_omni)? }; Ok(Json(ingest_output( state.uri(), &result, actor_id.map(str::to_string), ))) } #[utoipa::path( get, path = "/branches", tag = "branches", operation_id = "listBranches", responses( (status = 200, description = "List of branches", body = BranchListOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// List all branches. /// /// Returns branch names sorted alphabetically. Read-only. async fn server_branch_list( State(state): State, actor: Option>, ) -> std::result::Result, ApiError> { authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::Read, branch: None, target_branch: None, }, )?; let mut branches = { let db = &state.engine; db.branch_list().await.map_err(ApiError::from_omni)? }; branches.sort(); Ok(Json(BranchListOutput { branches })) } #[utoipa::path( post, path = "/branches", tag = "branches", operation_id = "createBranch", request_body = BranchCreateRequest, responses( (status = 200, description = "Branch created", body = BranchCreateOutput), (status = 400, description = "Bad request", body = ErrorOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), (status = 409, description = "Branch already exists", body = ErrorOutput), (status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Create a new branch. /// /// Forks `name` off of `from` (defaults to `main`). The new branch shares /// table data with its parent until it is mutated. Returns 409 if `name` /// already exists. async fn server_branch_create( State(state): State, actor: Option>, Json(request): Json, ) -> std::result::Result, ApiError> { let from = request.from.unwrap_or_else(|| "main".to_string()); let actor_arc = actor .as_ref() .map(|Extension(actor)| Arc::clone(&actor.0)) .unwrap_or_else(|| Arc::::from("anonymous")); authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::BranchCreate, branch: Some(from.clone()), target_branch: Some(request.name.clone()), }, )?; // Branch metadata only — small constant bytes estimate. The Lance // shallow-clone work is bounded by the parent's manifest size, not // the request body. let _admission = state .workload .try_admit(&actor_arc, 256) .map_err(ApiError::from_workload_reject)?; { let db = &state.engine; db.branch_create_from_as( ReadTarget::branch(&from), &request.name, actor.as_ref().map(|Extension(a)| a.as_str()), ) .await .map_err(ApiError::from_omni)?; } Ok(Json(BranchCreateOutput { uri: state.uri().to_string(), from, name: request.name, actor_id: actor.map(|Extension(actor)| actor.as_str().to_string()), })) } #[utoipa::path( delete, path = "/branches/{branch}", tag = "branches", operation_id = "deleteBranch", params( ("branch" = String, Path, description = "Branch name to delete"), ), responses( (status = 200, description = "Branch deleted", body = BranchDeleteOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), (status = 404, description = "Branch not found", body = ErrorOutput), (status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Delete a branch. /// /// **Irreversible.** Removes the branch pointer; commits remain reachable /// only if referenced by another branch. Returns 404 if the branch does not /// exist. async fn server_branch_delete( State(state): State, actor: Option>, Path(branch): Path, ) -> std::result::Result, ApiError> { let actor_arc = actor .as_ref() .map(|Extension(actor)| Arc::clone(&actor.0)) .unwrap_or_else(|| Arc::::from("anonymous")); let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str()); authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor_id.map(str::to_string).unwrap_or_default(), action: PolicyAction::BranchDelete, branch: None, target_branch: Some(branch.clone()), }, )?; // Metadata-only manifest tombstone — small constant estimate. let _admission = state .workload .try_admit(&actor_arc, 256) .map_err(ApiError::from_workload_reject)?; { let db = &state.engine; db.branch_delete_as(&branch, actor_id) .await .map_err(ApiError::from_omni)?; } Ok(Json(BranchDeleteOutput { uri: state.uri().to_string(), name: branch, actor_id: actor_id.map(str::to_string), })) } #[utoipa::path( post, path = "/branches/merge", tag = "branches", operation_id = "mergeBranches", request_body = BranchMergeRequest, responses( (status = 200, description = "Branches merged", body = BranchMergeOutput), (status = 400, description = "Bad request", body = ErrorOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), (status = 409, description = "Merge conflict", body = ErrorOutput), (status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Merge one branch into another. /// /// Merges `source` into `target` (defaults to `main`). Outcome is one of /// `already_up_to_date`, `fast_forward`, or `merged`. Returns 409 with the /// list of conflicts if the merge cannot be completed; the target is left /// unchanged in that case. **Destructive** to `target` on success. async fn server_branch_merge( State(state): State, actor: Option>, Json(request): Json, ) -> std::result::Result, ApiError> { let target = request.target.unwrap_or_else(|| "main".to_string()); let actor_arc = actor .as_ref() .map(|Extension(actor)| Arc::clone(&actor.0)) .unwrap_or_else(|| Arc::::from("anonymous")); let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str()); authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor_id.map(str::to_string).unwrap_or_default(), action: PolicyAction::BranchMerge, branch: Some(request.source.clone()), target_branch: Some(target.clone()), }, )?; // Merge body is small JSON; the heavy work is in the engine but is // bounded per-(table, branch) by the writer queue. Small constant // estimate suffices for the actor in-flight count. let _admission = state .workload .try_admit(&actor_arc, 256) .map_err(ApiError::from_workload_reject)?; let outcome = { let db = &state.engine; db.branch_merge_as(&request.source, &target, actor_id) .await .map_err(ApiError::from_omni)? }; Ok(Json(BranchMergeOutput { source: request.source, target, outcome: outcome.into(), actor_id: actor_id.map(str::to_string), })) } #[utoipa::path( get, path = "/commits", tag = "commits", operation_id = "listCommits", params(CommitListQuery), responses( (status = 200, description = "List of commits", body = CommitListOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// List commits. /// /// Filter by `branch` to get the commits on a single branch (most recent /// first); omit to list across all branches. Read-only. async fn server_commit_list( State(state): State, actor: Option>, Query(query): Query, ) -> std::result::Result, ApiError> { authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::Read, branch: query.branch.clone(), target_branch: None, }, )?; let commits = { let db = &state.engine; db.list_commits(query.branch.as_deref()) .await .map_err(ApiError::from_omni)? }; Ok(Json(CommitListOutput { commits: commits.iter().map(api::commit_output).collect(), })) } #[utoipa::path( get, path = "/commits/{commit_id}", tag = "commits", operation_id = "getCommit", params( ("commit_id" = String, Path, description = "Commit identifier"), ), responses( (status = 200, description = "Commit details", body = api::CommitOutput), (status = 401, description = "Unauthorized", body = ErrorOutput), (status = 403, description = "Forbidden", body = ErrorOutput), (status = 404, description = "Commit not found", body = ErrorOutput), ), security(("bearer_token" = [])), )] /// Get a single commit. /// /// Returns the commit's manifest version, parent commit(s), and creation /// metadata. Read-only. async fn server_commit_show( State(state): State, actor: Option>, Path(commit_id): Path, ) -> std::result::Result, ApiError> { authorize_request( &state, actor.as_ref().map(|Extension(actor)| actor), PolicyRequest { actor_id: actor .as_ref() .map(|Extension(actor)| actor.as_str().to_string()) .unwrap_or_default(), action: PolicyAction::Read, branch: None, target_branch: None, }, )?; let commit = { let db = &state.engine; db.get_commit(&commit_id) .await .map_err(ApiError::from_omni)? }; Ok(Json(api::commit_output(&commit))) } fn read_target_from_request(branch: Option, snapshot: Option) -> ReadTarget { if let Some(snapshot) = snapshot { ReadTarget::snapshot(omnigraph::db::SnapshotId::new(snapshot)) } else { ReadTarget::branch(branch.unwrap_or_else(|| "main".to_string())) } } fn select_named_query( query_source: &str, requested_name: Option<&str>, ) -> Result<(String, Vec)> { let parsed = parse_query(query_source)?; let query = if let Some(name) = requested_name { parsed .queries .into_iter() .find(|query| query.name == name) .ok_or_else(|| color_eyre::eyre::eyre!("query '{}' not found", name))? } else if parsed.queries.len() == 1 { parsed.queries.into_iter().next().unwrap() } else { bail!("query file contains multiple queries; pass --name"); }; Ok((query.name, query.params)) } fn query_params_from_json( query_params: &[omnigraph_compiler::query::ast::Param], params_json: Option<&Value>, ) -> Result { json_params_to_param_map(params_json, query_params, JsonParamMode::Standard) .map_err(|err| color_eyre::eyre::eyre!(err.to_string())) } fn normalize_bearer_token(value: Option) -> Option { value .map(|value| value.trim().to_string()) .filter(|value| !value.is_empty()) } fn normalize_bearer_actor(value: String) -> Result { let value = value.trim().to_string(); if value.is_empty() { bail!("bearer token actor names must not be blank"); } Ok(value) } fn parse_bearer_tokens_json(value: &str) -> Result> { let entries: HashMap = serde_json::from_str(value) .wrap_err("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON must be a JSON object of actor->token")?; Ok(entries.into_iter().collect()) } fn read_bearer_tokens_file(path: &str) -> Result> { let contents = fs::read_to_string(path) .wrap_err_with(|| format!("failed to read bearer tokens file at {path}"))?; parse_bearer_tokens_json(&contents) .wrap_err_with(|| format!("failed to parse bearer tokens file at {path}")) } fn validate_bearer_tokens(entries: Vec<(String, String)>) -> Result> { let mut seen_actors = HashSet::new(); let mut seen_tokens = HashSet::new(); let mut normalized = Vec::with_capacity(entries.len()); for (actor, token) in entries { let actor = normalize_bearer_actor(actor)?; let Some(token) = normalize_bearer_token(Some(token)) else { bail!("bearer token for actor '{actor}' must not be blank"); }; if !seen_actors.insert(actor.clone()) { bail!("duplicate bearer token actor '{actor}'"); } if !seen_tokens.insert(token.clone()) { bail!("duplicate bearer token value configured"); } normalized.push((actor, token)); } normalized.sort_by(|(left, _), (right, _)| left.cmp(right)); Ok(normalized) } fn server_bearer_tokens_from_env() -> Result> { let mut entries = Vec::new(); if let Some(token) = normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKEN").ok()) { entries.push(("default".to_string(), token)); } if let Some(path) = normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKENS_FILE").ok()) { entries.extend(read_bearer_tokens_file(&path)?); } else if let Some(json) = normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON").ok()) { entries.extend(parse_bearer_tokens_json(&json)?); } validate_bearer_tokens(entries) } #[cfg(test)] mod tests { use super::{ ServerConfig, ServerRuntimeState, classify_server_runtime_state, hash_bearer_token, load_server_settings, normalize_bearer_token, parse_bearer_tokens_json, serve, server_bearer_tokens_from_env, }; use serial_test::serial; use std::env; use std::fs; use tempfile::tempdir; #[test] fn hash_bearer_token_produces_32_byte_output() { let hash = hash_bearer_token("any-token"); assert_eq!(hash.len(), 32); } #[test] fn hash_bearer_token_is_deterministic() { assert_eq!( hash_bearer_token("stable-input"), hash_bearer_token("stable-input"), ); } #[test] fn hash_bearer_token_differs_for_different_inputs() { assert_ne!(hash_bearer_token("token-a"), hash_bearer_token("token-b")); } #[test] fn hash_bearer_token_matches_known_sha256_vector() { // SHA-256("abc"). If this ever fails, the hash function was swapped. let hash = hash_bearer_token("abc"); let hex: String = hash.iter().map(|b| format!("{:02x}", b)).collect(); assert_eq!( hex, "ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad" ); } #[test] fn server_settings_load_from_yaml_config() { let temp = tempdir().unwrap(); let config = temp.path().join("omnigraph.yaml"); fs::write( &config, r#" graphs: local: uri: /tmp/demo.omni server: graph: local bind: 0.0.0.0:9090 "#, ) .unwrap(); let settings = load_server_settings(Some(&config), None, None, None, false).unwrap(); assert_eq!(settings.uri, "/tmp/demo.omni"); assert_eq!(settings.bind, "0.0.0.0:9090"); } #[test] fn server_settings_cli_flags_override_yaml_config() { let temp = tempdir().unwrap(); let config = temp.path().join("omnigraph.yaml"); fs::write( &config, r#" graphs: local: uri: /tmp/demo.omni server: graph: local bind: 127.0.0.1:8080 "#, ) .unwrap(); let settings = load_server_settings( Some(&config), Some("/tmp/override.omni".to_string()), None, Some("0.0.0.0:9999".to_string()), false, ) .unwrap(); assert_eq!(settings.uri, "/tmp/override.omni"); assert_eq!(settings.bind, "0.0.0.0:9999"); } #[test] fn server_settings_can_resolve_named_target() { let temp = tempdir().unwrap(); let config = temp.path().join("omnigraph.yaml"); fs::write( &config, r#" graphs: local: uri: ./demo.omni dev: uri: http://127.0.0.1:8080 server: graph: local bind: 127.0.0.1:8080 "#, ) .unwrap(); let settings = load_server_settings(Some(&config), None, Some("dev".to_string()), None, false) .unwrap(); assert_eq!(settings.uri, "http://127.0.0.1:8080"); } #[test] fn server_settings_require_uri_from_cli_or_config() { let error = load_server_settings(None, None, None, None, false).unwrap_err(); assert!(error.to_string().contains("URI must be provided")); } #[test] fn classify_open_requires_explicit_unauthenticated_flag() { // State 1: no tokens, no policy, no flag → refuse to start. let error = classify_server_runtime_state(false, false, false).unwrap_err(); let msg = error.to_string(); assert!( msg.contains("--unauthenticated"), "expected refusal message mentioning --unauthenticated, got: {msg}" ); // Same matrix cell but with the flag set → Open mode permitted. assert_eq!( classify_server_runtime_state(false, false, true).unwrap(), ServerRuntimeState::Open ); } #[test] fn classify_tokens_without_policy_is_default_deny() { // State 2: tokens configured, no policy → DefaultDeny regardless // of the flag (the flag opts into the fully-open dev mode; it // doesn't downgrade default-deny back to open). assert_eq!( classify_server_runtime_state(true, false, false).unwrap(), ServerRuntimeState::DefaultDeny ); assert_eq!( classify_server_runtime_state(true, false, true).unwrap(), ServerRuntimeState::DefaultDeny ); } #[tokio::test] #[serial] async fn serve_refuses_to_start_in_state_1_without_unauthenticated() { // MR-723 PR A: pin the integration boundary that the classifier // is actually called by `serve()` before any side-effecting // work (Lance dataset open, TcpListener::bind). The classifier // itself is unit-tested above; this test guards the propagation // path from `classify_server_runtime_state` through serve's // `?` so a future refactor that drops the call returns red. // // Marked `#[serial]` because we have to clear all bearer-token // env vars, and another test in this module setting any of them // concurrently would corrupt the read inside `resolve_token_source`. let _guard = EnvGuard::set(&[ ("OMNIGRAPH_SERVER_BEARER_TOKEN", None), ("OMNIGRAPH_SERVER_BEARER_TOKENS_FILE", None), ("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON", None), ("OMNIGRAPH_SERVER_BEARER_TOKENS_AWS_SECRET", None), ("OMNIGRAPH_UNAUTHENTICATED", None), ]); let temp = tempdir().unwrap(); // Graph path doesn't need to exist — classifier fires before // `AppState::open_with_bearer_tokens_and_policy`. let config = ServerConfig { uri: temp .path() .join("graph.omni") .to_string_lossy() .into_owned(), bind: "127.0.0.1:0".to_string(), policy_file: None, allow_unauthenticated: false, }; let result = serve(config).await; let err = result.expect_err("serve should refuse to start in State 1 without --unauthenticated"); let msg = format!("{:?}", err); assert!( msg.contains("no bearer tokens") || msg.contains("policy file"), "expected refusal message naming the misconfiguration, got: {msg}", ); } #[test] #[serial] fn unauthenticated_env_var_classification() { // MR-723 PR A: closes the gap where the env-var read path inside // `load_server_settings` was structurally implemented but not // exercised by any test. Three properties to pin, all in one // sequential test because `cargo test` runs the mod test suite // in parallel and `OMNIGRAPH_UNAUTHENTICATED` is process-global // — interleaving with another test that sets the same env var // (concurrent classifier tests, even the bearer-token suite // sharing `EnvGuard`) corrupts the read. Sequential within one // test fn is the simplest race-free shape. let temp = tempdir().unwrap(); let config_path = temp.path().join("omnigraph.yaml"); fs::write( &config_path, r#" graphs: local: uri: /tmp/demo-unauth.omni server: graph: local "#, ) .unwrap(); // Truthy values flip Open mode on, even with CLI flag off. for value in ["1", "true", "yes", "TRUE", "anything"] { let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some(value))]); let settings = load_server_settings(Some(&config_path), None, None, None, false) .expect("settings load should succeed"); assert!( settings.allow_unauthenticated, "OMNIGRAPH_UNAUTHENTICATED={value:?} should enable Open mode", ); } // Falsy values keep refusal behavior, even with CLI flag off. for value in ["0", "false", "FALSE", ""] { let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some(value))]); let settings = load_server_settings(Some(&config_path), None, None, None, false) .expect("settings load should succeed"); assert!( !settings.allow_unauthenticated, "OMNIGRAPH_UNAUTHENTICATED={value:?} should NOT enable Open mode", ); } // Unset env var: also false. let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", None)]); let settings = load_server_settings(Some(&config_path), None, None, None, false) .expect("settings load should succeed"); assert!( !settings.allow_unauthenticated, "OMNIGRAPH_UNAUTHENTICATED unset should NOT enable Open mode", ); drop(_guard); // CLI flag wins even when env is falsy — `serve()` honors the // OR of both inputs. let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some("0"))]); let settings = load_server_settings(Some(&config_path), None, None, None, true) .expect("settings load should succeed"); assert!( settings.allow_unauthenticated, "--unauthenticated CLI flag should win even when env is falsy", ); } #[test] fn classify_policy_enabled_always_wins() { // State 3: any setup with a policy file → PolicyEnabled. The // flag doesn't matter and tokens-or-not doesn't matter (no // tokens + policy is unusual but valid — every request fails // 401 without a bearer, which is effectively "locked"). assert_eq!( classify_server_runtime_state(true, true, false).unwrap(), ServerRuntimeState::PolicyEnabled ); assert_eq!( classify_server_runtime_state(false, true, false).unwrap(), ServerRuntimeState::PolicyEnabled ); assert_eq!( classify_server_runtime_state(true, true, true).unwrap(), ServerRuntimeState::PolicyEnabled ); } #[test] fn normalize_bearer_token_trims_and_filters_blank_values() { assert_eq!(normalize_bearer_token(None), None); assert_eq!(normalize_bearer_token(Some(" ".to_string())), None); assert_eq!( normalize_bearer_token(Some(" demo-token ".to_string())).as_deref(), Some("demo-token") ); } struct EnvGuard { saved: Vec<(&'static str, Option)>, } impl EnvGuard { fn set(vars: &[(&'static str, Option<&str>)]) -> Self { let saved = vars .iter() .map(|(name, _)| (*name, env::var(name).ok())) .collect::>(); for (name, value) in vars { unsafe { match value { Some(value) => env::set_var(name, value), None => env::remove_var(name), } } } Self { saved } } } impl Drop for EnvGuard { fn drop(&mut self) { for (name, value) in self.saved.drain(..) { unsafe { match value { Some(value) => env::set_var(name, value), None => env::remove_var(name), } } } } } #[test] fn parse_bearer_tokens_json_reads_actor_token_map() { let tokens = parse_bearer_tokens_json(r#"{"alice":" token-a ","bob":"token-b"}"#).unwrap(); assert_eq!(tokens.len(), 2); assert!(tokens.contains(&("alice".to_string(), " token-a ".to_string()))); assert!(tokens.contains(&("bob".to_string(), "token-b".to_string()))); } #[test] #[serial] fn server_bearer_tokens_from_env_reads_legacy_token_and_token_file() { let temp = tempdir().unwrap(); let tokens_path = temp.path().join("tokens.json"); fs::write( &tokens_path, r#"{"team-01":"token-one","team-02":"token-two"}"#, ) .unwrap(); let _guard = EnvGuard::set(&[ ("OMNIGRAPH_SERVER_BEARER_TOKEN", Some(" legacy-token ")), ( "OMNIGRAPH_SERVER_BEARER_TOKENS_FILE", Some(tokens_path.to_str().unwrap()), ), ("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON", None), ]); let tokens = server_bearer_tokens_from_env().unwrap(); assert_eq!( tokens, vec![ ("default".to_string(), "legacy-token".to_string()), ("team-01".to_string(), "token-one".to_string()), ("team-02".to_string(), "token-two".to_string()), ] ); } }