omnigraph/crates/omnigraph-server/src/lib.rs
Ragnor Comerford 1ed3eea26d
mr-668: add GraphId newtype + Cloud-mode forward identity stubs (PR 1/10)
PR 1 of the MR-668 multi-graph server work. Pure types, no runtime
behavior changes yet.

Ships the validated identity vocabulary that the rest of the implementation
will consume:

- `GraphId(String)` — `^[a-zA-Z0-9-]{1,64}$`, leading underscore rejected
  (engine reserves every `_*` filename), reserved route names rejected
  (`policies`, `healthz`, `openapi`, `openapi.json`, `graphs`). Validation
  lives in `try_from` only; serde `Deserialize` re-runs it so JSON payloads
  cannot bypass.
- `TenantId(String)` — same regex shape as GraphId. `None` in Cluster
  mode; reserved for Cloud mode (RFC 0003) where it carries the OAuth
  `org_id` claim.
- `GraphKey { tenant_id: Option<TenantId>, graph_id }` — the registry
  HashMap key. `cluster()` constructor for the Cluster-mode default.
- `Scope` enum with `Full` variant — Cluster mode default; RFC 0004 will
  extend with OAuth scopes (`graph:read`/`write`/`admin`/`*`).
- `AuthSource` enum with `Static` variant — Cluster mode default; RFC
  0001 step 1 will add `Oidc`.
- `ResolvedActor { actor_id, tenant_id, scopes, source }` — replaces the
  upcoming refactor of `AuthenticatedActor(Arc<str>)` in PR 4a.

Per MR-668 design decision 13: ship the Cloud-mode forward type shapes
now (no `TokenVerifier` trait yet — that's RFC 0001 step 1) so handler
signatures stay stable across the Cluster → Cloud trajectory. `Scope`
and `AuthSource` use `#[non_exhaustive]` so future variants don't break
caller matches.

Tests: 26 new (15 graph_id + 11 identity), all passing. No regression
in the existing 36 server library tests.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2026-05-25 18:51:49 +02:00

2133 lines
75 KiB
Rust
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

pub mod api;
pub mod auth;
pub mod config;
pub mod graph_id;
pub mod identity;
pub mod policy;
pub mod workload;
pub use graph_id::GraphId;
pub use identity::{AuthSource, GraphKey, ResolvedActor, Scope, TenantId};
use std::collections::{HashMap, HashSet};
use std::fs;
use std::io;
use std::io::Write;
use std::path::PathBuf;
use std::sync::Arc;
use api::{
BranchCreateOutput, BranchCreateRequest, BranchDeleteOutput, BranchListOutput,
BranchMergeOutput, BranchMergeRequest, ChangeOutput, ChangeRequest, CommitListOutput,
CommitListQuery, ErrorCode, ErrorOutput, ExportRequest, HealthOutput, IngestOutput,
IngestRequest, ReadOutput, ReadRequest, SchemaApplyOutput, SchemaApplyRequest, SchemaOutput,
SnapshotQuery, ingest_output, schema_apply_output, snapshot_payload,
};
pub use auth::{AWS_SECRET_ENV, EnvOrFileTokenSource, TokenSource, resolve_token_source};
use axum::body::{Body, Bytes};
use axum::extract::DefaultBodyLimit;
use axum::extract::{Extension, Path, Query, Request, State};
use axum::http::StatusCode;
use axum::http::header::{AUTHORIZATION, CONTENT_TYPE};
use axum::middleware::{self, Next};
use axum::response::{IntoResponse, Response};
use axum::routing::{delete, get, post};
use axum::{Json, Router};
use color_eyre::eyre::{Result, WrapErr, bail};
pub use config::{
AliasCommand, AliasConfig, CliDefaults, DEFAULT_CONFIG_FILE, OmnigraphConfig, PolicySettings,
ProjectConfig, QueryDefaults, ReadOutputFormat, ServerDefaults, TableCellLayout, TargetConfig,
load_config,
};
use futures::stream;
use omnigraph::db::{Omnigraph, ReadTarget};
use omnigraph::error::{ManifestConflictDetails, ManifestErrorKind, OmniError};
use omnigraph_compiler::json_params_to_param_map;
use omnigraph_compiler::query::parser::parse_query;
use omnigraph_compiler::{JsonParamMode, ParamMap};
pub use policy::{
PolicyAction, PolicyCompiler, PolicyConfig, PolicyDecision, PolicyEngine, PolicyExpectation,
PolicyRequest, PolicyTestConfig,
};
use serde_json::Value;
use sha2::{Digest, Sha256};
use subtle::ConstantTimeEq;
use tokio::net::TcpListener;
use tokio::sync::mpsc;
use tower_http::trace::TraceLayer;
use tracing::{error, info, warn};
use tracing_subscriber::EnvFilter;
use utoipa::OpenApi;
use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
type BearerTokenHash = [u8; 32];
fn hash_bearer_token(token: &str) -> BearerTokenHash {
let digest = Sha256::digest(token.as_bytes());
let mut out = [0u8; 32];
out.copy_from_slice(&digest);
out
}
#[derive(OpenApi)]
#[openapi(
info(
title = "Omnigraph API",
description = "HTTP API for the Omnigraph graph database",
),
paths(
server_health,
server_snapshot,
server_read,
server_export,
server_change,
server_schema_apply,
server_schema_get,
server_ingest,
server_branch_list,
server_branch_create,
server_branch_delete,
server_branch_merge,
server_commit_list,
server_commit_show,
),
modifiers(&SecurityAddon),
)]
pub struct ApiDoc;
struct SecurityAddon;
impl utoipa::Modify for SecurityAddon {
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
openapi
.components
.get_or_insert_with(Default::default)
.add_security_scheme(
"bearer_token",
SecurityScheme::Http(Http::new(HttpAuthScheme::Bearer)),
);
}
}
const DEFAULT_REQUEST_BODY_LIMIT_BYTES: usize = 1_048_576;
const INGEST_REQUEST_BODY_LIMIT_BYTES: usize = 32 * 1024 * 1024;
const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
const SERVER_SOURCE_VERSION: Option<&str> = option_env!("OMNIGRAPH_SOURCE_VERSION");
#[derive(Debug, Clone)]
pub struct ServerConfig {
pub uri: String,
pub bind: String,
pub policy_file: Option<PathBuf>,
/// Operator opt-in for fully-unauthenticated dev mode (MR-723).
/// When neither bearer tokens nor a policy file are configured,
/// `serve()` refuses to start unless this is true (set via
/// `--unauthenticated` or `OMNIGRAPH_UNAUTHENTICATED=1`). The
/// motivation is that "no tokens + no policy" looks like protection
/// (no Cedar errors at boot) but is actually fully open — operators
/// who set up auth and forgot the policy file would otherwise ship
/// the illusion of protection.
pub allow_unauthenticated: bool,
}
#[derive(Clone)]
pub struct AppState {
uri: String,
/// PR 2 (MR-686): the engine is now `Arc<Omnigraph>` — no global
/// write lock. Concurrent handlers call `&self` engine APIs
/// directly. Per-(table, branch) write queues inside the engine
/// serialize same-key writers; per-actor admission control on
/// `workload` isolates noisy actors.
engine: Arc<Omnigraph>,
/// Per-actor admission control. See `workload::WorkloadController`.
workload: Arc<workload::WorkloadController>,
bearer_tokens: Arc<[(BearerTokenHash, Arc<str>)]>,
policy_engine: Option<Arc<PolicyEngine>>,
}
#[derive(Debug, Clone)]
struct AuthenticatedActor(Arc<str>);
struct ExportStreamWriter {
sender: mpsc::UnboundedSender<std::result::Result<Bytes, io::Error>>,
}
impl Write for ExportStreamWriter {
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
self.sender
.send(Ok(Bytes::copy_from_slice(buf)))
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "export stream closed"))?;
Ok(buf.len())
}
fn flush(&mut self) -> io::Result<()> {
Ok(())
}
}
impl AuthenticatedActor {
fn as_str(&self) -> &str {
&self.0
}
}
#[derive(Debug)]
pub struct ApiError {
status: StatusCode,
code: ErrorCode,
message: String,
merge_conflicts: Vec<api::MergeConflictOutput>,
manifest_conflict: Option<api::ManifestConflictOutput>,
}
impl AppState {
pub fn new(uri: String, db: Omnigraph) -> Self {
Self::new_with_bearer_tokens(uri, db, Vec::new())
}
pub fn new_with_bearer_token(uri: String, db: Omnigraph, bearer_token: Option<String>) -> Self {
let bearer_tokens = normalize_bearer_token(bearer_token)
.into_iter()
.map(|token| ("default".to_string(), token))
.collect();
Self::new_with_bearer_tokens(uri, db, bearer_tokens)
}
pub fn new_with_bearer_tokens(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
) -> Self {
Self::new_with_bearer_tokens_and_policy(uri, db, bearer_tokens, None)
}
pub fn new_with_bearer_tokens_and_policy(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
policy_engine: Option<PolicyEngine>,
) -> Self {
let bearer_tokens: Vec<(BearerTokenHash, Arc<str>)> = bearer_tokens
.into_iter()
.map(|(actor, token)| (hash_bearer_token(&token), Arc::<str>::from(actor)))
.collect();
let policy_engine: Option<Arc<PolicyEngine>> = policy_engine.map(Arc::new);
// MR-722 chassis: inject the policy checker into the engine so
// `Omnigraph::apply_schema_as` (and PR #3's fan-out of the
// remaining writers) gates at engine-layer too. HTTP-layer
// `authorize_request` still fires first; the engine-layer gate
// is the redundant-but-correct backstop, plus the only path
// that protects SDK / embedded callers. PR #3 removes the HTTP
// redundancy once we're confident the engine gate covers it.
let db = if let Some(engine) = policy_engine.as_ref() {
// Unsizing coercion: Arc<PolicyEngine> → Arc<dyn PolicyChecker>.
// Needs the explicit `as` cast — Rust 2024 doesn't infer it through
// `Arc::clone`.
let checker = Arc::clone(engine) as Arc<dyn omnigraph_policy::PolicyChecker>;
db.with_policy(checker)
} else {
db
};
Self {
uri,
engine: Arc::new(db),
workload: Arc::new(workload::WorkloadController::from_env()),
bearer_tokens: Arc::from(bearer_tokens),
policy_engine,
}
}
/// Construct with a caller-provided [`workload::WorkloadController`].
/// Tests and benches use this to override per-actor caps without
/// mutating global env vars (which is unsafe in Rust 2024 once the
/// async runtime is up — `setenv` isn't thread-safe).
pub fn new_with_workload(
uri: String,
db: Omnigraph,
bearer_tokens: Vec<(String, String)>,
workload: workload::WorkloadController,
) -> Self {
let bearer_tokens: Vec<(BearerTokenHash, Arc<str>)> = bearer_tokens
.into_iter()
.map(|(actor, token)| (hash_bearer_token(&token), Arc::<str>::from(actor)))
.collect();
Self {
uri,
engine: Arc::new(db),
workload: Arc::new(workload),
bearer_tokens: Arc::from(bearer_tokens),
policy_engine: None,
}
}
/// Install a `PolicyEngine` post-construction (MR-723). Used by
/// integration tests that need to thread custom workload limits
/// alongside a permit-all policy — the existing `new_with_*` and
/// `new_with_workload` constructors don't compose. Production
/// callers should use `open_with_bearer_tokens_and_policy` which
/// installs the policy on both the HTTP state and the engine.
pub fn with_policy_engine(mut self, engine: PolicyEngine) -> Self {
self.policy_engine = Some(Arc::new(engine));
self
}
pub async fn open(uri: impl Into<String>) -> Result<Self> {
Self::open_with_bearer_token(uri, None).await
}
pub async fn open_with_bearer_token(
uri: impl Into<String>,
bearer_token: Option<String>,
) -> Result<Self> {
let bearer_tokens = normalize_bearer_token(bearer_token)
.into_iter()
.map(|token| ("default".to_string(), token))
.collect();
Self::open_with_bearer_tokens(uri, bearer_tokens).await
}
pub async fn open_with_bearer_tokens(
uri: impl Into<String>,
bearer_tokens: Vec<(String, String)>,
) -> Result<Self> {
let uri = uri.into();
let db = Omnigraph::open(&uri).await?;
Ok(Self::new_with_bearer_tokens(uri, db, bearer_tokens))
}
pub async fn open_with_bearer_tokens_and_policy(
uri: impl Into<String>,
bearer_tokens: Vec<(String, String)>,
policy_file: Option<&PathBuf>,
) -> Result<Self> {
let uri = uri.into();
let db = Omnigraph::open(&uri).await?;
let policy_engine = match policy_file {
Some(path) => Some(PolicyEngine::load(path, &uri)?),
None => None,
};
if policy_engine.is_some() && bearer_tokens.is_empty() {
bail!("policy requires at least one configured bearer token actor");
}
Ok(Self::new_with_bearer_tokens_and_policy(
uri,
db,
bearer_tokens,
policy_engine,
))
}
pub fn uri(&self) -> &str {
&self.uri
}
fn requires_bearer_auth(&self) -> bool {
!self.bearer_tokens.is_empty() || self.policy_engine.is_some()
}
fn authenticate_bearer_token(&self, provided_token: &str) -> Option<Arc<str>> {
// Hash the incoming token and compare against every stored digest in
// constant time. Iterate all entries unconditionally so total work —
// and therefore response timing — doesn't depend on which slot matches.
let provided_hash = hash_bearer_token(provided_token);
let mut matched: Option<Arc<str>> = None;
for (hash, actor) in self.bearer_tokens.iter() {
if bool::from(hash.ct_eq(&provided_hash)) && matched.is_none() {
matched = Some(Arc::clone(actor));
}
}
matched
}
fn policy_engine(&self) -> Option<&PolicyEngine> {
self.policy_engine.as_deref()
}
}
impl ApiError {
pub fn unauthorized(message: impl Into<String>) -> Self {
Self {
status: StatusCode::UNAUTHORIZED,
code: ErrorCode::Unauthorized,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn forbidden(message: impl Into<String>) -> Self {
Self {
status: StatusCode::FORBIDDEN,
code: ErrorCode::Forbidden,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn bad_request(message: impl Into<String>) -> Self {
Self {
status: StatusCode::BAD_REQUEST,
code: ErrorCode::BadRequest,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn not_found(message: impl Into<String>) -> Self {
Self {
status: StatusCode::NOT_FOUND,
code: ErrorCode::NotFound,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn conflict(message: impl Into<String>) -> Self {
Self {
status: StatusCode::CONFLICT,
code: ErrorCode::Conflict,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
pub fn internal(message: impl Into<String>) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
code: ErrorCode::Internal,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
/// HTTP 429 Too Many Requests — actor exceeded their per-actor
/// admission cap (count or byte budget). Clients should respect the
/// `Retry-After` header. Mapped from `RejectReason::InFlightCountExceeded`
/// and `RejectReason::ByteBudgetExceeded`.
pub fn too_many_requests(message: impl Into<String>) -> Self {
Self {
status: StatusCode::TOO_MANY_REQUESTS,
code: ErrorCode::TooManyRequests,
message: message.into(),
merge_conflicts: Vec::new(),
manifest_conflict: None,
}
}
/// Convert a `WorkloadController` rejection into the matching
/// `ApiError` variant.
pub fn from_workload_reject(reject: workload::RejectReason) -> Self {
match reject {
workload::RejectReason::InFlightCountExceeded { .. }
| workload::RejectReason::ByteBudgetExceeded { .. } => {
Self::too_many_requests(reject.to_string())
}
}
}
fn merge_conflict(conflicts: Vec<api::MergeConflictOutput>) -> Self {
Self {
status: StatusCode::CONFLICT,
code: ErrorCode::Conflict,
message: summarize_merge_conflicts(&conflicts),
merge_conflicts: conflicts,
manifest_conflict: None,
}
}
fn manifest_version_conflict(message: String, details: api::ManifestConflictOutput) -> Self {
Self {
status: StatusCode::CONFLICT,
code: ErrorCode::Conflict,
message,
merge_conflicts: Vec::new(),
manifest_conflict: Some(details),
}
}
fn from_omni(err: OmniError) -> Self {
match err {
OmniError::Compiler(err) => Self::bad_request(err.to_string()),
OmniError::DataFusion(message) => Self::bad_request(format!("query: {message}")),
OmniError::Manifest(err) => match err.kind {
ManifestErrorKind::BadRequest => Self::bad_request(err.message),
ManifestErrorKind::NotFound => Self::not_found(err.message),
ManifestErrorKind::Conflict => match err.details {
Some(ManifestConflictDetails::ExpectedVersionMismatch {
table_key,
expected,
actual,
}) => Self::manifest_version_conflict(
err.message,
api::ManifestConflictOutput {
table_key,
expected,
actual,
},
),
_ => Self::conflict(err.message),
},
ManifestErrorKind::Internal => Self::internal(err.message),
},
OmniError::MergeConflicts(conflicts) => Self::merge_conflict(
conflicts
.iter()
.map(api::MergeConflictOutput::from)
.collect(),
),
OmniError::Lance(message) => Self::internal(format!("storage: {message}")),
OmniError::Io(err) => Self::internal(format!("io: {err}")),
// Engine-layer policy enforcement (MR-722). All denials and
// evaluation failures surface here as 403. The HTTP-layer
// `authorize_request` already distinguishes 401 (missing
// bearer) from 403 (policy denial), so by the time the
// engine gate fires, the bearer is valid — any failure from
// the engine is a policy outcome, not an auth one.
OmniError::Policy(message) => Self::forbidden(message),
}
}
}
fn summarize_merge_conflicts(conflicts: &[api::MergeConflictOutput]) -> String {
if conflicts.is_empty() {
return "merge conflicts".to_string();
}
let preview: Vec<String> = conflicts
.iter()
.take(3)
.map(|conflict| match conflict.row_id.as_deref() {
Some(row_id) => format!(
"{}:{} ({})",
conflict.table_key,
row_id,
conflict.kind.as_str()
),
None => format!("{} ({})", conflict.table_key, conflict.kind.as_str()),
})
.collect();
let suffix = if conflicts.len() > preview.len() {
format!("; and {} more", conflicts.len() - preview.len())
} else {
String::new()
};
format!("merge conflicts: {}{}", preview.join("; "), suffix)
}
/// Constant `Retry-After` value (seconds) emitted on 429 responses.
const RETRY_AFTER_SECONDS: &str = "60";
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let mut headers = axum::http::HeaderMap::new();
if matches!(self.code, ErrorCode::TooManyRequests) {
headers.insert(
axum::http::header::RETRY_AFTER,
axum::http::HeaderValue::from_static(RETRY_AFTER_SECONDS),
);
}
(
self.status,
headers,
Json(ErrorOutput {
error: self.message,
code: Some(self.code),
merge_conflicts: self.merge_conflicts,
manifest_conflict: self.manifest_conflict,
}),
)
.into_response()
}
}
pub fn init_tracing() {
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
}
pub fn load_server_settings(
config_path: Option<&PathBuf>,
cli_uri: Option<String>,
cli_target: Option<String>,
cli_bind: Option<String>,
cli_allow_unauthenticated: bool,
) -> Result<ServerConfig> {
let config = load_config(config_path)?;
let uri =
config.resolve_target_uri(cli_uri, cli_target.as_deref(), config.server_graph_name())?;
let bind = cli_bind.unwrap_or_else(|| config.server_bind().to_string());
let policy_file = config.resolve_policy_file();
// Either `--unauthenticated` or `OMNIGRAPH_UNAUTHENTICATED=1` flips
// this. Treat any non-empty, non-"0"/"false" string as truthy —
// standard 12-factor "any value is true" reading of the env var.
let env_unauth = std::env::var("OMNIGRAPH_UNAUTHENTICATED")
.ok()
.map(|v| {
let trimmed = v.trim();
!trimmed.is_empty() && trimmed != "0" && !trimmed.eq_ignore_ascii_case("false")
})
.unwrap_or(false);
let allow_unauthenticated = cli_allow_unauthenticated || env_unauth;
Ok(ServerConfig {
uri,
bind,
policy_file,
allow_unauthenticated,
})
}
/// MR-723 server runtime state, classified from the three-state matrix
/// of (bearer tokens configured) × (policy file configured) at startup.
///
/// * **Open** — neither tokens nor policy; requires explicit
/// `allow_unauthenticated`. Effectively a "trust the network" dev
/// mode. `serve()` refuses to start in this shape without the flag,
/// so the only way to reach this state at runtime is via deliberate
/// operator opt-in.
/// * **DefaultDeny** — tokens configured but no policy file. The
/// server requires a valid bearer token; once authenticated, every
/// action except `Read` is denied with 403. Closes the "tokens but
/// forgot the policy file" trap.
/// * **PolicyEnabled** — policy file configured. Cedar evaluates every
/// authenticated request. Tokens may also be configured (typical) or
/// not (unusual but valid — every request fails 401 without a
/// bearer, which is effectively "locked").
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
pub enum ServerRuntimeState {
Open,
DefaultDeny,
PolicyEnabled,
}
/// Compute the [`ServerRuntimeState`] from the configured inputs.
/// Pulled out as a pure function so the 3-state matrix is unit-testable
/// without standing up the full server.
pub fn classify_server_runtime_state(
has_tokens: bool,
has_policy: bool,
allow_unauthenticated: bool,
) -> Result<ServerRuntimeState> {
match (has_tokens, has_policy, allow_unauthenticated) {
(false, false, false) => bail!(
"server has no bearer tokens and no policy file configured. This is a fully \
open server — pass `--unauthenticated` (or set OMNIGRAPH_UNAUTHENTICATED=1) \
if you actually want that, otherwise configure bearer tokens (see \
docs/user/server.md) and/or `policy.file` in omnigraph.yaml."
),
(false, false, true) => Ok(ServerRuntimeState::Open),
(true, false, _) => Ok(ServerRuntimeState::DefaultDeny),
(_, true, _) => Ok(ServerRuntimeState::PolicyEnabled),
}
}
pub fn build_app(state: AppState) -> Router {
let protected = Router::new()
.route("/snapshot", get(server_snapshot))
.route("/export", post(server_export))
.route("/read", post(server_read))
.route("/change", post(server_change))
.route("/schema", get(server_schema_get))
.route("/schema/apply", post(server_schema_apply))
.route(
"/ingest",
post(server_ingest).layer(DefaultBodyLimit::max(INGEST_REQUEST_BODY_LIMIT_BYTES)),
)
.route(
"/branches",
get(server_branch_list).post(server_branch_create),
)
.route("/branches/{branch}", delete(server_branch_delete))
.route("/branches/merge", post(server_branch_merge))
.route("/commits", get(server_commit_list))
.route("/commits/{commit_id}", get(server_commit_show))
.route_layer(middleware::from_fn_with_state(
state.clone(),
require_bearer_auth,
));
Router::new()
.route("/healthz", get(server_health))
.route("/openapi.json", get(server_openapi))
.merge(protected)
.layer(DefaultBodyLimit::max(DEFAULT_REQUEST_BODY_LIMIT_BYTES))
.layer(TraceLayer::new_for_http())
.with_state(state)
}
pub async fn serve(config: ServerConfig) -> Result<()> {
let token_source = resolve_token_source().await?;
info!(source = token_source.name(), "loaded bearer token source");
let tokens = token_source.load().await?;
let runtime_state = classify_server_runtime_state(
!tokens.is_empty(),
config.policy_file.is_some(),
config.allow_unauthenticated,
)?;
match runtime_state {
ServerRuntimeState::Open => warn!(
"running with --unauthenticated: no bearer tokens, no policy file, all \
requests permitted. This is for local dev only — do not expose to a \
network you don't fully trust."
),
ServerRuntimeState::DefaultDeny => warn!(
"bearer tokens are configured but no policy file is set — running in \
default-deny mode (only `read` actions are permitted for authenticated \
actors). Configure `policy.file` in omnigraph.yaml to enable Cedar rules."
),
ServerRuntimeState::PolicyEnabled => {}
}
let state = AppState::open_with_bearer_tokens_and_policy(
config.uri.clone(),
tokens,
config.policy_file.as_ref(),
)
.await?;
let listener = TcpListener::bind(&config.bind).await?;
info!(uri = %config.uri, bind = %config.bind, "serving omnigraph");
axum::serve(listener, build_app(state))
.with_graceful_shutdown(shutdown_signal())
.await?;
Ok(())
}
async fn shutdown_signal() {
if let Err(err) = tokio::signal::ctrl_c().await {
error!(error = %err, "failed to install ctrl-c handler");
return;
}
info!("shutdown signal received");
}
#[utoipa::path(
get,
path = "/healthz",
tag = "health",
operation_id = "health",
responses(
(status = 200, description = "Server is healthy", body = HealthOutput),
),
)]
/// Liveness probe.
///
/// Returns server status and version. Unauthenticated; safe to call from any
/// caller. Use this to confirm the server is reachable before invoking other
/// endpoints.
async fn server_health() -> Json<HealthOutput> {
Json(HealthOutput {
status: "ok".to_string(),
version: SERVER_VERSION.to_string(),
source_version: SERVER_SOURCE_VERSION.map(str::to_string),
})
}
async fn server_openapi(State(state): State<AppState>) -> Json<utoipa::openapi::OpenApi> {
let mut doc = ApiDoc::openapi();
if !state.requires_bearer_auth() {
strip_security(&mut doc);
}
Json(doc)
}
fn strip_security(doc: &mut utoipa::openapi::OpenApi) {
if let Some(components) = doc.components.as_mut() {
components.security_schemes.clear();
}
for path_item in doc.paths.paths.values_mut() {
for op in [
path_item.get.as_mut(),
path_item.post.as_mut(),
path_item.put.as_mut(),
path_item.delete.as_mut(),
path_item.options.as_mut(),
path_item.head.as_mut(),
path_item.patch.as_mut(),
path_item.trace.as_mut(),
]
.into_iter()
.flatten()
{
op.security = None;
}
}
}
async fn require_bearer_auth(
State(state): State<AppState>,
mut request: Request,
next: Next,
) -> std::result::Result<Response, ApiError> {
if !state.requires_bearer_auth() {
return Ok(next.run(request).await);
}
let Some(header) = request
.headers()
.get(AUTHORIZATION)
.and_then(|value| value.to_str().ok())
else {
return Err(ApiError::unauthorized("missing bearer token"));
};
let Some(provided_token) = header.strip_prefix("Bearer ") else {
return Err(ApiError::unauthorized("missing bearer token"));
};
let Some(actor) = state.authenticate_bearer_token(provided_token) else {
return Err(ApiError::unauthorized("invalid bearer token"));
};
request.extensions_mut().insert(AuthenticatedActor(actor));
Ok(next.run(request).await)
}
fn log_policy_decision(actor_id: &str, request: &PolicyRequest, decision: &PolicyDecision) {
info!(
actor_id = actor_id,
action = %request.action,
branch = request.branch.as_deref().unwrap_or(""),
target_branch = request.target_branch.as_deref().unwrap_or(""),
allowed = decision.allowed,
matched_rule_id = decision.matched_rule_id.as_deref().unwrap_or(""),
"policy decision"
);
}
fn authorize_request(
state: &AppState,
actor: Option<&AuthenticatedActor>,
mut request: PolicyRequest,
) -> std::result::Result<(), ApiError> {
let Some(engine) = state.policy_engine() else {
// MR-723 default-deny path. We're here when no PolicyEngine is
// installed. Two startup-validated shapes can reach this:
//
// * **Open mode** (`--unauthenticated`): no tokens, no policy.
// `require_bearer_auth` short-circuits before this is called,
// but defense in depth — if a future change makes the
// middleware call here for an unauthenticated request, we
// want every action to remain Ok rather than 403. The
// operator opted in.
// * **DefaultDeny mode**: tokens configured but no policy. The
// request went through bearer auth, so `actor` is Some and
// identifies a known actor. Only `Read` is permitted; every
// other action returns 403. This closes the "configured auth
// but forgot the policy file" trap from MR-723.
if actor.is_some() && request.action != PolicyAction::Read {
return Err(ApiError::forbidden(
"server runs in default-deny mode (bearer tokens configured but no \
policy file). Only `read` actions are permitted; configure \
`policy.file` in omnigraph.yaml to enable other actions.",
));
}
return Ok(());
};
let Some(actor) = actor else {
return Err(ApiError::unauthorized("missing bearer token"));
};
// SECURITY INVARIANT (MR-731): actor identity comes from the matched
// bearer token, never from a client-supplied request header, query
// parameter, or body field. This line is the single chokepoint where
// the authoritative actor (resolved from the bearer match by
// `require_bearer_auth`) overwrites whatever the handler put in the
// PolicyRequest. Removing or weakening it lets clients spoof identity —
// exactly the Supabase RLS footgun ("trusting raw_user_meta_data is
// asking the attacker if they're an admin"). The principle is codified
// in `docs/dev/invariants.md` Hard Invariant 11 ("clients cannot set
// actor identity directly") and pinned by the regression test
// `actor_id_resolves_from_bearer_token_ignoring_client_supplied_headers`
// in `crates/omnigraph-server/tests/server.rs`.
//
// Side effect: also prevents an empty-string default at any handler
// call site from ever reaching the engine as a policy subject.
request.actor_id = actor.as_str().to_string();
let decision = engine
.authorize(&request)
.map_err(|err| ApiError::internal(format!("policy: {err}")))?;
log_policy_decision(actor.as_str(), &request, &decision);
if decision.allowed {
Ok(())
} else {
Err(ApiError::forbidden(decision.message))
}
}
#[utoipa::path(
get,
path = "/snapshot",
tag = "snapshots",
operation_id = "getSnapshot",
params(SnapshotQuery),
responses(
(status = 200, description = "Database snapshot", body = api::SnapshotOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Read the current snapshot of a branch.
///
/// Returns the manifest version plus per-table metadata (path, version, row
/// count) for every table on the branch. Defaults to `main` when `branch` is
/// omitted. Read-only.
async fn server_snapshot(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Query(query): Query<SnapshotQuery>,
) -> std::result::Result<Json<api::SnapshotOutput>, ApiError> {
let branch = query.branch.unwrap_or_else(|| "main".to_string());
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::Read,
branch: Some(branch.clone()),
target_branch: None,
},
)?;
let snapshot = {
let db = &state.engine;
db.snapshot_of(ReadTarget::branch(branch.as_str()))
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(snapshot_payload(&branch, &snapshot)))
}
#[utoipa::path(
post,
path = "/read",
tag = "queries",
operation_id = "read",
request_body = ReadRequest,
responses(
(status = 200, description = "Query results", body = ReadOutput),
(status = 400, description = "Bad request", body = ErrorOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Execute a GQ read query.
///
/// Runs the query in `query_source` against either a branch or a frozen
/// snapshot (mutually exclusive). When `query_source` defines multiple named
/// queries, pick one with `query_name`. `params` is a JSON object whose keys
/// match the parameters declared by the query. Returns rows as a JSON array
/// plus a `columns` list. Read-only.
async fn server_read(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Json(request): Json<ReadRequest>,
) -> std::result::Result<Json<ReadOutput>, ApiError> {
if request.branch.is_some() && request.snapshot.is_some() {
return Err(ApiError::bad_request(
"read request may specify branch or snapshot, not both",
));
}
let target = read_target_from_request(request.branch, request.snapshot);
let policy_branch = match &target {
ReadTarget::Branch(branch) => Some(branch.clone()),
ReadTarget::Snapshot(_) if state.policy_engine().is_some() && actor.is_some() => {
let db = &state.engine;
db.resolved_branch_of(target.clone())
.await
.map(|branch| branch.or_else(|| Some("main".to_string())))
.map_err(ApiError::from_omni)?
}
ReadTarget::Snapshot(_) => None,
};
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::Read,
branch: policy_branch,
target_branch: None,
},
)?;
let (selected_name, query_params) =
select_named_query(&request.query_source, request.query_name.as_deref())
.map_err(|err| ApiError::bad_request(err.to_string()))?;
let params = query_params_from_json(&query_params, request.params.as_ref())
.map_err(|err| ApiError::bad_request(err.to_string()))?;
let result = {
let db = &state.engine;
db.query(
target.clone(),
&request.query_source,
&selected_name,
&params,
)
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(api::read_output(selected_name, &target, result)))
}
#[utoipa::path(
post,
path = "/export",
tag = "queries",
operation_id = "export",
request_body = ExportRequest,
responses(
(status = 200, description = "Exported data as NDJSON", content_type = "application/x-ndjson"),
(status = 400, description = "Bad request", body = ErrorOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Stream the contents of a branch as NDJSON.
///
/// Emits one JSON object per line (`application/x-ndjson`). Filter with
/// `type_names` (node/edge type names) and/or `table_keys`; both empty
/// streams the entire branch. Suitable for large exports — the response is
/// streamed, not buffered. Read-only.
async fn server_export(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Json(request): Json<ExportRequest>,
) -> std::result::Result<Response, ApiError> {
let branch = request.branch.unwrap_or_else(|| "main".to_string());
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::Export,
branch: Some(branch.clone()),
target_branch: None,
},
)?;
let engine = Arc::clone(&state.engine);
let type_names = request.type_names.clone();
let table_keys = request.table_keys.clone();
let (tx, rx) = mpsc::unbounded_channel::<std::result::Result<Bytes, io::Error>>();
tokio::spawn(async move {
let result = {
let mut writer = ExportStreamWriter { sender: tx.clone() };
engine
.export_jsonl_to_writer(&branch, &type_names, &table_keys, &mut writer)
.await
};
if let Err(err) = result {
let _ = tx.send(Err(io::Error::other(err.to_string())));
}
});
let body = Body::from_stream(stream::unfold(rx, |mut rx| async move {
rx.recv().await.map(|item| (item, rx))
}));
Ok((
StatusCode::OK,
[(CONTENT_TYPE, "application/x-ndjson; charset=utf-8")],
body,
)
.into_response())
}
#[utoipa::path(
post,
path = "/change",
tag = "mutations",
operation_id = "change",
request_body = ChangeRequest,
responses(
(status = 200, description = "Mutation results", body = ChangeOutput),
(status = 400, description = "Bad request", body = ErrorOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
(status = 409, description = "Merge conflict", body = ErrorOutput),
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Apply a GQ mutation to a branch.
///
/// Writes to the named `branch` (defaults to `main`). Mutations are atomic
/// per call and produce a new commit. Returns counts of nodes and edges
/// affected. **Destructive**: on success the branch is updated; rejected
/// mutations may still acquire locks briefly. Returns 409 on merge conflict.
async fn server_change(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Json(request): Json<ChangeRequest>,
) -> std::result::Result<Json<ChangeOutput>, ApiError> {
let branch = request.branch.unwrap_or_else(|| "main".to_string());
let actor_arc = actor
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.0))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
action: PolicyAction::Change,
branch: Some(branch.clone()),
target_branch: None,
},
)?;
// Per-actor admission: bound concurrent in-flight mutations and
// estimated bytes per actor. Cedar runs FIRST so denied requests
// don't consume admission slots. Estimate uses the request body
// size as a coarse proxy; engine memory pressure can run higher.
let est_bytes = request.query_source.len() as u64
+ request
.params
.as_ref()
.map(|p| p.to_string().len() as u64)
.unwrap_or(0);
let _admission = state
.workload
.try_admit(&actor_arc, est_bytes)
.map_err(ApiError::from_workload_reject)?;
let (selected_name, query_params) =
select_named_query(&request.query_source, request.query_name.as_deref())
.map_err(|err| ApiError::bad_request(err.to_string()))?;
let params = query_params_from_json(&query_params, request.params.as_ref())
.map_err(|err| ApiError::bad_request(err.to_string()))?;
let result = {
let db = &state.engine;
db.mutate_as(
&branch,
&request.query_source,
&selected_name,
&params,
actor_id,
)
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(ChangeOutput {
branch,
query_name: selected_name,
affected_nodes: result.affected_nodes,
affected_edges: result.affected_edges,
actor_id: actor_id.map(str::to_string),
}))
}
#[utoipa::path(
get,
path = "/schema",
tag = "schema",
operation_id = "getSchema",
responses(
(status = 200, description = "Current schema source", body = SchemaOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Read the current schema source.
///
/// Returns the project's schema as a single string in `.pg` source form.
/// Useful for clients that want to introspect available types and tables
/// before constructing GQ queries. Read-only.
async fn server_schema_get(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
) -> std::result::Result<Json<SchemaOutput>, ApiError> {
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::Read,
branch: None,
target_branch: None,
},
)?;
let schema_source = {
let db = &state.engine;
db.schema_source().to_string()
};
Ok(Json(SchemaOutput { schema_source }))
}
#[utoipa::path(
post,
path = "/schema/apply",
tag = "mutations",
operation_id = "applySchema",
request_body = SchemaApplyRequest,
responses(
(status = 200, description = "Schema apply results", body = SchemaApplyOutput),
(status = 400, description = "Bad request", body = ErrorOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Apply a schema migration.
///
/// Diffs `schema_source` against the current schema and applies the resulting
/// migration steps (add/drop type, add/drop column, etc.). **Destructive**:
/// some steps drop data. Returns the list of steps applied; if `applied` is
/// false the diff was unsupported and no changes were made.
async fn server_schema_apply(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Json(request): Json<SchemaApplyRequest>,
) -> std::result::Result<Json<SchemaApplyOutput>, ApiError> {
let actor_arc = actor
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.0))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
action: PolicyAction::SchemaApply,
branch: None,
target_branch: Some("main".to_string()),
},
)?;
let est_bytes = request.schema_source.len() as u64;
let _admission = state
.workload
.try_admit(&actor_arc, est_bytes)
.map_err(ApiError::from_workload_reject)?;
let result = {
let db = &state.engine;
// Engine-layer policy enforcement (MR-722): pass the resolved
// actor through so apply_schema_as can call enforce() with the
// authoritative identity. With a policy installed in AppState,
// engine-side enforcement re-checks the same decision the
// HTTP-layer authorize_request just made above. PR #3 collapses
// the redundancy.
db.apply_schema_as(
&request.schema_source,
omnigraph::db::SchemaApplyOptions {
allow_data_loss: request.allow_data_loss,
},
actor_id,
)
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(schema_apply_output(state.uri(), result)))
}
#[utoipa::path(
post,
path = "/ingest",
tag = "mutations",
operation_id = "ingest",
request_body = IngestRequest,
responses(
(status = 200, description = "Ingest results", body = IngestOutput),
(status = 400, description = "Bad request", body = ErrorOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Bulk-ingest NDJSON data into a branch.
///
/// `data` is NDJSON with one record per line. `mode` controls behavior on
/// existing rows: `merge` upserts by id (default), `append` blindly inserts,
/// `overwrite` replaces table contents. If `branch` does not exist it is
/// created from `from` (defaults to `main`). **Destructive** when `mode` is
/// `overwrite` or when ingest produces conflicting writes.
async fn server_ingest(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Json(request): Json<IngestRequest>,
) -> std::result::Result<Json<IngestOutput>, ApiError> {
let branch = request.branch.unwrap_or_else(|| "main".to_string());
let from = request.from.unwrap_or_else(|| "main".to_string());
let mode = request.mode.unwrap_or(omnigraph::loader::LoadMode::Merge);
let actor_arc = actor
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.0))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
let branch_exists = {
let db = &state.engine;
db.branch_list()
.await
.map_err(ApiError::from_omni)?
.into_iter()
.any(|name| name == branch)
};
if !branch_exists {
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
action: PolicyAction::BranchCreate,
branch: Some(from.clone()),
target_branch: Some(branch.clone()),
},
)?;
}
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
action: PolicyAction::Change,
branch: Some(branch.clone()),
target_branch: None,
},
)?;
let est_bytes = request.data.len() as u64;
let _admission = state
.workload
.try_admit(&actor_arc, est_bytes)
.map_err(ApiError::from_workload_reject)?;
let result = {
let db = &state.engine;
db.ingest_as(&branch, Some(&from), &request.data, mode, actor_id)
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(ingest_output(
state.uri(),
&result,
actor_id.map(str::to_string),
)))
}
#[utoipa::path(
get,
path = "/branches",
tag = "branches",
operation_id = "listBranches",
responses(
(status = 200, description = "List of branches", body = BranchListOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// List all branches.
///
/// Returns branch names sorted alphabetically. Read-only.
async fn server_branch_list(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
) -> std::result::Result<Json<BranchListOutput>, ApiError> {
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::Read,
branch: None,
target_branch: None,
},
)?;
let mut branches = {
let db = &state.engine;
db.branch_list().await.map_err(ApiError::from_omni)?
};
branches.sort();
Ok(Json(BranchListOutput { branches }))
}
#[utoipa::path(
post,
path = "/branches",
tag = "branches",
operation_id = "createBranch",
request_body = BranchCreateRequest,
responses(
(status = 200, description = "Branch created", body = BranchCreateOutput),
(status = 400, description = "Bad request", body = ErrorOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
(status = 409, description = "Branch already exists", body = ErrorOutput),
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Create a new branch.
///
/// Forks `name` off of `from` (defaults to `main`). The new branch shares
/// table data with its parent until it is mutated. Returns 409 if `name`
/// already exists.
async fn server_branch_create(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Json(request): Json<BranchCreateRequest>,
) -> std::result::Result<Json<BranchCreateOutput>, ApiError> {
let from = request.from.unwrap_or_else(|| "main".to_string());
let actor_arc = actor
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.0))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::BranchCreate,
branch: Some(from.clone()),
target_branch: Some(request.name.clone()),
},
)?;
// Branch metadata only — small constant bytes estimate. The Lance
// shallow-clone work is bounded by the parent's manifest size, not
// the request body.
let _admission = state
.workload
.try_admit(&actor_arc, 256)
.map_err(ApiError::from_workload_reject)?;
{
let db = &state.engine;
db.branch_create_from_as(
ReadTarget::branch(&from),
&request.name,
actor.as_ref().map(|Extension(a)| a.as_str()),
)
.await
.map_err(ApiError::from_omni)?;
}
Ok(Json(BranchCreateOutput {
uri: state.uri().to_string(),
from,
name: request.name,
actor_id: actor.map(|Extension(actor)| actor.as_str().to_string()),
}))
}
#[utoipa::path(
delete,
path = "/branches/{branch}",
tag = "branches",
operation_id = "deleteBranch",
params(
("branch" = String, Path, description = "Branch name to delete"),
),
responses(
(status = 200, description = "Branch deleted", body = BranchDeleteOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
(status = 404, description = "Branch not found", body = ErrorOutput),
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Delete a branch.
///
/// **Irreversible.** Removes the branch pointer; commits remain reachable
/// only if referenced by another branch. Returns 404 if the branch does not
/// exist.
async fn server_branch_delete(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Path(branch): Path<String>,
) -> std::result::Result<Json<BranchDeleteOutput>, ApiError> {
let actor_arc = actor
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.0))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
action: PolicyAction::BranchDelete,
branch: None,
target_branch: Some(branch.clone()),
},
)?;
// Metadata-only manifest tombstone — small constant estimate.
let _admission = state
.workload
.try_admit(&actor_arc, 256)
.map_err(ApiError::from_workload_reject)?;
{
let db = &state.engine;
db.branch_delete_as(&branch, actor_id)
.await
.map_err(ApiError::from_omni)?;
}
Ok(Json(BranchDeleteOutput {
uri: state.uri().to_string(),
name: branch,
actor_id: actor_id.map(str::to_string),
}))
}
#[utoipa::path(
post,
path = "/branches/merge",
tag = "branches",
operation_id = "mergeBranches",
request_body = BranchMergeRequest,
responses(
(status = 200, description = "Branches merged", body = BranchMergeOutput),
(status = 400, description = "Bad request", body = ErrorOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
(status = 409, description = "Merge conflict", body = ErrorOutput),
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Merge one branch into another.
///
/// Merges `source` into `target` (defaults to `main`). Outcome is one of
/// `already_up_to_date`, `fast_forward`, or `merged`. Returns 409 with the
/// list of conflicts if the merge cannot be completed; the target is left
/// unchanged in that case. **Destructive** to `target` on success.
async fn server_branch_merge(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Json(request): Json<BranchMergeRequest>,
) -> std::result::Result<Json<BranchMergeOutput>, ApiError> {
let target = request.target.unwrap_or_else(|| "main".to_string());
let actor_arc = actor
.as_ref()
.map(|Extension(actor)| Arc::clone(&actor.0))
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
action: PolicyAction::BranchMerge,
branch: Some(request.source.clone()),
target_branch: Some(target.clone()),
},
)?;
// Merge body is small JSON; the heavy work is in the engine but is
// bounded per-(table, branch) by the writer queue. Small constant
// estimate suffices for the actor in-flight count.
let _admission = state
.workload
.try_admit(&actor_arc, 256)
.map_err(ApiError::from_workload_reject)?;
let outcome = {
let db = &state.engine;
db.branch_merge_as(&request.source, &target, actor_id)
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(BranchMergeOutput {
source: request.source,
target,
outcome: outcome.into(),
actor_id: actor_id.map(str::to_string),
}))
}
#[utoipa::path(
get,
path = "/commits",
tag = "commits",
operation_id = "listCommits",
params(CommitListQuery),
responses(
(status = 200, description = "List of commits", body = CommitListOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// List commits.
///
/// Filter by `branch` to get the commits on a single branch (most recent
/// first); omit to list across all branches. Read-only.
async fn server_commit_list(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Query(query): Query<CommitListQuery>,
) -> std::result::Result<Json<CommitListOutput>, ApiError> {
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::Read,
branch: query.branch.clone(),
target_branch: None,
},
)?;
let commits = {
let db = &state.engine;
db.list_commits(query.branch.as_deref())
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(CommitListOutput {
commits: commits.iter().map(api::commit_output).collect(),
}))
}
#[utoipa::path(
get,
path = "/commits/{commit_id}",
tag = "commits",
operation_id = "getCommit",
params(
("commit_id" = String, Path, description = "Commit identifier"),
),
responses(
(status = 200, description = "Commit details", body = api::CommitOutput),
(status = 401, description = "Unauthorized", body = ErrorOutput),
(status = 403, description = "Forbidden", body = ErrorOutput),
(status = 404, description = "Commit not found", body = ErrorOutput),
),
security(("bearer_token" = [])),
)]
/// Get a single commit.
///
/// Returns the commit's manifest version, parent commit(s), and creation
/// metadata. Read-only.
async fn server_commit_show(
State(state): State<AppState>,
actor: Option<Extension<AuthenticatedActor>>,
Path(commit_id): Path<String>,
) -> std::result::Result<Json<api::CommitOutput>, ApiError> {
authorize_request(
&state,
actor.as_ref().map(|Extension(actor)| actor),
PolicyRequest {
actor_id: actor
.as_ref()
.map(|Extension(actor)| actor.as_str().to_string())
.unwrap_or_default(),
action: PolicyAction::Read,
branch: None,
target_branch: None,
},
)?;
let commit = {
let db = &state.engine;
db.get_commit(&commit_id)
.await
.map_err(ApiError::from_omni)?
};
Ok(Json(api::commit_output(&commit)))
}
fn read_target_from_request(branch: Option<String>, snapshot: Option<String>) -> ReadTarget {
if let Some(snapshot) = snapshot {
ReadTarget::snapshot(omnigraph::db::SnapshotId::new(snapshot))
} else {
ReadTarget::branch(branch.unwrap_or_else(|| "main".to_string()))
}
}
fn select_named_query(
query_source: &str,
requested_name: Option<&str>,
) -> Result<(String, Vec<omnigraph_compiler::query::ast::Param>)> {
let parsed = parse_query(query_source)?;
let query = if let Some(name) = requested_name {
parsed
.queries
.into_iter()
.find(|query| query.name == name)
.ok_or_else(|| color_eyre::eyre::eyre!("query '{}' not found", name))?
} else if parsed.queries.len() == 1 {
parsed.queries.into_iter().next().unwrap()
} else {
bail!("query file contains multiple queries; pass --name");
};
Ok((query.name, query.params))
}
fn query_params_from_json(
query_params: &[omnigraph_compiler::query::ast::Param],
params_json: Option<&Value>,
) -> Result<ParamMap> {
json_params_to_param_map(params_json, query_params, JsonParamMode::Standard)
.map_err(|err| color_eyre::eyre::eyre!(err.to_string()))
}
fn normalize_bearer_token(value: Option<String>) -> Option<String> {
value
.map(|value| value.trim().to_string())
.filter(|value| !value.is_empty())
}
fn normalize_bearer_actor(value: String) -> Result<String> {
let value = value.trim().to_string();
if value.is_empty() {
bail!("bearer token actor names must not be blank");
}
Ok(value)
}
fn parse_bearer_tokens_json(value: &str) -> Result<Vec<(String, String)>> {
let entries: HashMap<String, String> = serde_json::from_str(value)
.wrap_err("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON must be a JSON object of actor->token")?;
Ok(entries.into_iter().collect())
}
fn read_bearer_tokens_file(path: &str) -> Result<Vec<(String, String)>> {
let contents = fs::read_to_string(path)
.wrap_err_with(|| format!("failed to read bearer tokens file at {path}"))?;
parse_bearer_tokens_json(&contents)
.wrap_err_with(|| format!("failed to parse bearer tokens file at {path}"))
}
fn validate_bearer_tokens(entries: Vec<(String, String)>) -> Result<Vec<(String, String)>> {
let mut seen_actors = HashSet::new();
let mut seen_tokens = HashSet::new();
let mut normalized = Vec::with_capacity(entries.len());
for (actor, token) in entries {
let actor = normalize_bearer_actor(actor)?;
let Some(token) = normalize_bearer_token(Some(token)) else {
bail!("bearer token for actor '{actor}' must not be blank");
};
if !seen_actors.insert(actor.clone()) {
bail!("duplicate bearer token actor '{actor}'");
}
if !seen_tokens.insert(token.clone()) {
bail!("duplicate bearer token value configured");
}
normalized.push((actor, token));
}
normalized.sort_by(|(left, _), (right, _)| left.cmp(right));
Ok(normalized)
}
fn server_bearer_tokens_from_env() -> Result<Vec<(String, String)>> {
let mut entries = Vec::new();
if let Some(token) = normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKEN").ok())
{
entries.push(("default".to_string(), token));
}
if let Some(path) =
normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKENS_FILE").ok())
{
entries.extend(read_bearer_tokens_file(&path)?);
} else if let Some(json) =
normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON").ok())
{
entries.extend(parse_bearer_tokens_json(&json)?);
}
validate_bearer_tokens(entries)
}
#[cfg(test)]
mod tests {
use super::{
ServerConfig, ServerRuntimeState, classify_server_runtime_state, hash_bearer_token,
load_server_settings, normalize_bearer_token, parse_bearer_tokens_json, serve,
server_bearer_tokens_from_env,
};
use serial_test::serial;
use std::env;
use std::fs;
use tempfile::tempdir;
#[test]
fn hash_bearer_token_produces_32_byte_output() {
let hash = hash_bearer_token("any-token");
assert_eq!(hash.len(), 32);
}
#[test]
fn hash_bearer_token_is_deterministic() {
assert_eq!(
hash_bearer_token("stable-input"),
hash_bearer_token("stable-input"),
);
}
#[test]
fn hash_bearer_token_differs_for_different_inputs() {
assert_ne!(hash_bearer_token("token-a"), hash_bearer_token("token-b"));
}
#[test]
fn hash_bearer_token_matches_known_sha256_vector() {
// SHA-256("abc"). If this ever fails, the hash function was swapped.
let hash = hash_bearer_token("abc");
let hex: String = hash.iter().map(|b| format!("{:02x}", b)).collect();
assert_eq!(
hex,
"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
);
}
#[test]
fn server_settings_load_from_yaml_config() {
let temp = tempdir().unwrap();
let config = temp.path().join("omnigraph.yaml");
fs::write(
&config,
r#"
graphs:
local:
uri: /tmp/demo.omni
server:
graph: local
bind: 0.0.0.0:9090
"#,
)
.unwrap();
let settings = load_server_settings(Some(&config), None, None, None, false).unwrap();
assert_eq!(settings.uri, "/tmp/demo.omni");
assert_eq!(settings.bind, "0.0.0.0:9090");
}
#[test]
fn server_settings_cli_flags_override_yaml_config() {
let temp = tempdir().unwrap();
let config = temp.path().join("omnigraph.yaml");
fs::write(
&config,
r#"
graphs:
local:
uri: /tmp/demo.omni
server:
graph: local
bind: 127.0.0.1:8080
"#,
)
.unwrap();
let settings = load_server_settings(
Some(&config),
Some("/tmp/override.omni".to_string()),
None,
Some("0.0.0.0:9999".to_string()),
false,
)
.unwrap();
assert_eq!(settings.uri, "/tmp/override.omni");
assert_eq!(settings.bind, "0.0.0.0:9999");
}
#[test]
fn server_settings_can_resolve_named_target() {
let temp = tempdir().unwrap();
let config = temp.path().join("omnigraph.yaml");
fs::write(
&config,
r#"
graphs:
local:
uri: ./demo.omni
dev:
uri: http://127.0.0.1:8080
server:
graph: local
bind: 127.0.0.1:8080
"#,
)
.unwrap();
let settings =
load_server_settings(Some(&config), None, Some("dev".to_string()), None, false)
.unwrap();
assert_eq!(settings.uri, "http://127.0.0.1:8080");
}
#[test]
fn server_settings_require_uri_from_cli_or_config() {
let error = load_server_settings(None, None, None, None, false).unwrap_err();
assert!(error.to_string().contains("URI must be provided"));
}
#[test]
fn classify_open_requires_explicit_unauthenticated_flag() {
// State 1: no tokens, no policy, no flag → refuse to start.
let error = classify_server_runtime_state(false, false, false).unwrap_err();
let msg = error.to_string();
assert!(
msg.contains("--unauthenticated"),
"expected refusal message mentioning --unauthenticated, got: {msg}"
);
// Same matrix cell but with the flag set → Open mode permitted.
assert_eq!(
classify_server_runtime_state(false, false, true).unwrap(),
ServerRuntimeState::Open
);
}
#[test]
fn classify_tokens_without_policy_is_default_deny() {
// State 2: tokens configured, no policy → DefaultDeny regardless
// of the flag (the flag opts into the fully-open dev mode; it
// doesn't downgrade default-deny back to open).
assert_eq!(
classify_server_runtime_state(true, false, false).unwrap(),
ServerRuntimeState::DefaultDeny
);
assert_eq!(
classify_server_runtime_state(true, false, true).unwrap(),
ServerRuntimeState::DefaultDeny
);
}
#[tokio::test]
#[serial]
async fn serve_refuses_to_start_in_state_1_without_unauthenticated() {
// MR-723 PR A: pin the integration boundary that the classifier
// is actually called by `serve()` before any side-effecting
// work (Lance dataset open, TcpListener::bind). The classifier
// itself is unit-tested above; this test guards the propagation
// path from `classify_server_runtime_state` through serve's
// `?` so a future refactor that drops the call returns red.
//
// Marked `#[serial]` because we have to clear all bearer-token
// env vars, and another test in this module setting any of them
// concurrently would corrupt the read inside `resolve_token_source`.
let _guard = EnvGuard::set(&[
("OMNIGRAPH_SERVER_BEARER_TOKEN", None),
("OMNIGRAPH_SERVER_BEARER_TOKENS_FILE", None),
("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON", None),
("OMNIGRAPH_SERVER_BEARER_TOKENS_AWS_SECRET", None),
("OMNIGRAPH_UNAUTHENTICATED", None),
]);
let temp = tempdir().unwrap();
// Graph path doesn't need to exist — classifier fires before
// `AppState::open_with_bearer_tokens_and_policy`.
let config = ServerConfig {
uri: temp
.path()
.join("graph.omni")
.to_string_lossy()
.into_owned(),
bind: "127.0.0.1:0".to_string(),
policy_file: None,
allow_unauthenticated: false,
};
let result = serve(config).await;
let err =
result.expect_err("serve should refuse to start in State 1 without --unauthenticated");
let msg = format!("{:?}", err);
assert!(
msg.contains("no bearer tokens") || msg.contains("policy file"),
"expected refusal message naming the misconfiguration, got: {msg}",
);
}
#[test]
#[serial]
fn unauthenticated_env_var_classification() {
// MR-723 PR A: closes the gap where the env-var read path inside
// `load_server_settings` was structurally implemented but not
// exercised by any test. Three properties to pin, all in one
// sequential test because `cargo test` runs the mod test suite
// in parallel and `OMNIGRAPH_UNAUTHENTICATED` is process-global
// — interleaving with another test that sets the same env var
// (concurrent classifier tests, even the bearer-token suite
// sharing `EnvGuard`) corrupts the read. Sequential within one
// test fn is the simplest race-free shape.
let temp = tempdir().unwrap();
let config_path = temp.path().join("omnigraph.yaml");
fs::write(
&config_path,
r#"
graphs:
local:
uri: /tmp/demo-unauth.omni
server:
graph: local
"#,
)
.unwrap();
// Truthy values flip Open mode on, even with CLI flag off.
for value in ["1", "true", "yes", "TRUE", "anything"] {
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some(value))]);
let settings = load_server_settings(Some(&config_path), None, None, None, false)
.expect("settings load should succeed");
assert!(
settings.allow_unauthenticated,
"OMNIGRAPH_UNAUTHENTICATED={value:?} should enable Open mode",
);
}
// Falsy values keep refusal behavior, even with CLI flag off.
for value in ["0", "false", "FALSE", ""] {
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some(value))]);
let settings = load_server_settings(Some(&config_path), None, None, None, false)
.expect("settings load should succeed");
assert!(
!settings.allow_unauthenticated,
"OMNIGRAPH_UNAUTHENTICATED={value:?} should NOT enable Open mode",
);
}
// Unset env var: also false.
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", None)]);
let settings = load_server_settings(Some(&config_path), None, None, None, false)
.expect("settings load should succeed");
assert!(
!settings.allow_unauthenticated,
"OMNIGRAPH_UNAUTHENTICATED unset should NOT enable Open mode",
);
drop(_guard);
// CLI flag wins even when env is falsy — `serve()` honors the
// OR of both inputs.
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some("0"))]);
let settings = load_server_settings(Some(&config_path), None, None, None, true)
.expect("settings load should succeed");
assert!(
settings.allow_unauthenticated,
"--unauthenticated CLI flag should win even when env is falsy",
);
}
#[test]
fn classify_policy_enabled_always_wins() {
// State 3: any setup with a policy file → PolicyEnabled. The
// flag doesn't matter and tokens-or-not doesn't matter (no
// tokens + policy is unusual but valid — every request fails
// 401 without a bearer, which is effectively "locked").
assert_eq!(
classify_server_runtime_state(true, true, false).unwrap(),
ServerRuntimeState::PolicyEnabled
);
assert_eq!(
classify_server_runtime_state(false, true, false).unwrap(),
ServerRuntimeState::PolicyEnabled
);
assert_eq!(
classify_server_runtime_state(true, true, true).unwrap(),
ServerRuntimeState::PolicyEnabled
);
}
#[test]
fn normalize_bearer_token_trims_and_filters_blank_values() {
assert_eq!(normalize_bearer_token(None), None);
assert_eq!(normalize_bearer_token(Some(" ".to_string())), None);
assert_eq!(
normalize_bearer_token(Some(" demo-token ".to_string())).as_deref(),
Some("demo-token")
);
}
struct EnvGuard {
saved: Vec<(&'static str, Option<String>)>,
}
impl EnvGuard {
fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
let saved = vars
.iter()
.map(|(name, _)| (*name, env::var(name).ok()))
.collect::<Vec<_>>();
for (name, value) in vars {
unsafe {
match value {
Some(value) => env::set_var(name, value),
None => env::remove_var(name),
}
}
}
Self { saved }
}
}
impl Drop for EnvGuard {
fn drop(&mut self) {
for (name, value) in self.saved.drain(..) {
unsafe {
match value {
Some(value) => env::set_var(name, value),
None => env::remove_var(name),
}
}
}
}
}
#[test]
fn parse_bearer_tokens_json_reads_actor_token_map() {
let tokens = parse_bearer_tokens_json(r#"{"alice":" token-a ","bob":"token-b"}"#).unwrap();
assert_eq!(tokens.len(), 2);
assert!(tokens.contains(&("alice".to_string(), " token-a ".to_string())));
assert!(tokens.contains(&("bob".to_string(), "token-b".to_string())));
}
#[test]
#[serial]
fn server_bearer_tokens_from_env_reads_legacy_token_and_token_file() {
let temp = tempdir().unwrap();
let tokens_path = temp.path().join("tokens.json");
fs::write(
&tokens_path,
r#"{"team-01":"token-one","team-02":"token-two"}"#,
)
.unwrap();
let _guard = EnvGuard::set(&[
("OMNIGRAPH_SERVER_BEARER_TOKEN", Some(" legacy-token ")),
(
"OMNIGRAPH_SERVER_BEARER_TOKENS_FILE",
Some(tokens_path.to_str().unwrap()),
),
("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON", None),
]);
let tokens = server_bearer_tokens_from_env().unwrap();
assert_eq!(
tokens,
vec![
("default".to_string(), "legacy-token".to_string()),
("team-01".to_string(), "token-one".to_string()),
("team-02".to_string(), "token-two".to_string()),
]
);
}
}