mirror of
https://github.com/ModernRelay/omnigraph.git
synced 2026-06-12 01:45:14 +02:00
PR 1 of the MR-668 multi-graph server work. Pure types, no runtime
behavior changes yet.
Ships the validated identity vocabulary that the rest of the implementation
will consume:
- `GraphId(String)` — `^[a-zA-Z0-9-]{1,64}$`, leading underscore rejected
(engine reserves every `_*` filename), reserved route names rejected
(`policies`, `healthz`, `openapi`, `openapi.json`, `graphs`). Validation
lives in `try_from` only; serde `Deserialize` re-runs it so JSON payloads
cannot bypass.
- `TenantId(String)` — same regex shape as GraphId. `None` in Cluster
mode; reserved for Cloud mode (RFC 0003) where it carries the OAuth
`org_id` claim.
- `GraphKey { tenant_id: Option<TenantId>, graph_id }` — the registry
HashMap key. `cluster()` constructor for the Cluster-mode default.
- `Scope` enum with `Full` variant — Cluster mode default; RFC 0004 will
extend with OAuth scopes (`graph:read`/`write`/`admin`/`*`).
- `AuthSource` enum with `Static` variant — Cluster mode default; RFC
0001 step 1 will add `Oidc`.
- `ResolvedActor { actor_id, tenant_id, scopes, source }` — replaces the
upcoming refactor of `AuthenticatedActor(Arc<str>)` in PR 4a.
Per MR-668 design decision 13: ship the Cloud-mode forward type shapes
now (no `TokenVerifier` trait yet — that's RFC 0001 step 1) so handler
signatures stay stable across the Cluster → Cloud trajectory. `Scope`
and `AuthSource` use `#[non_exhaustive]` so future variants don't break
caller matches.
Tests: 26 new (15 graph_id + 11 identity), all passing. No regression
in the existing 36 server library tests.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
2133 lines
75 KiB
Rust
2133 lines
75 KiB
Rust
pub mod api;
|
||
pub mod auth;
|
||
pub mod config;
|
||
pub mod graph_id;
|
||
pub mod identity;
|
||
pub mod policy;
|
||
pub mod workload;
|
||
|
||
pub use graph_id::GraphId;
|
||
pub use identity::{AuthSource, GraphKey, ResolvedActor, Scope, TenantId};
|
||
|
||
use std::collections::{HashMap, HashSet};
|
||
use std::fs;
|
||
use std::io;
|
||
use std::io::Write;
|
||
use std::path::PathBuf;
|
||
use std::sync::Arc;
|
||
|
||
use api::{
|
||
BranchCreateOutput, BranchCreateRequest, BranchDeleteOutput, BranchListOutput,
|
||
BranchMergeOutput, BranchMergeRequest, ChangeOutput, ChangeRequest, CommitListOutput,
|
||
CommitListQuery, ErrorCode, ErrorOutput, ExportRequest, HealthOutput, IngestOutput,
|
||
IngestRequest, ReadOutput, ReadRequest, SchemaApplyOutput, SchemaApplyRequest, SchemaOutput,
|
||
SnapshotQuery, ingest_output, schema_apply_output, snapshot_payload,
|
||
};
|
||
pub use auth::{AWS_SECRET_ENV, EnvOrFileTokenSource, TokenSource, resolve_token_source};
|
||
use axum::body::{Body, Bytes};
|
||
use axum::extract::DefaultBodyLimit;
|
||
use axum::extract::{Extension, Path, Query, Request, State};
|
||
use axum::http::StatusCode;
|
||
use axum::http::header::{AUTHORIZATION, CONTENT_TYPE};
|
||
use axum::middleware::{self, Next};
|
||
use axum::response::{IntoResponse, Response};
|
||
use axum::routing::{delete, get, post};
|
||
use axum::{Json, Router};
|
||
use color_eyre::eyre::{Result, WrapErr, bail};
|
||
pub use config::{
|
||
AliasCommand, AliasConfig, CliDefaults, DEFAULT_CONFIG_FILE, OmnigraphConfig, PolicySettings,
|
||
ProjectConfig, QueryDefaults, ReadOutputFormat, ServerDefaults, TableCellLayout, TargetConfig,
|
||
load_config,
|
||
};
|
||
use futures::stream;
|
||
use omnigraph::db::{Omnigraph, ReadTarget};
|
||
use omnigraph::error::{ManifestConflictDetails, ManifestErrorKind, OmniError};
|
||
use omnigraph_compiler::json_params_to_param_map;
|
||
use omnigraph_compiler::query::parser::parse_query;
|
||
use omnigraph_compiler::{JsonParamMode, ParamMap};
|
||
pub use policy::{
|
||
PolicyAction, PolicyCompiler, PolicyConfig, PolicyDecision, PolicyEngine, PolicyExpectation,
|
||
PolicyRequest, PolicyTestConfig,
|
||
};
|
||
use serde_json::Value;
|
||
use sha2::{Digest, Sha256};
|
||
use subtle::ConstantTimeEq;
|
||
use tokio::net::TcpListener;
|
||
use tokio::sync::mpsc;
|
||
use tower_http::trace::TraceLayer;
|
||
use tracing::{error, info, warn};
|
||
use tracing_subscriber::EnvFilter;
|
||
use utoipa::OpenApi;
|
||
use utoipa::openapi::security::{Http, HttpAuthScheme, SecurityScheme};
|
||
|
||
type BearerTokenHash = [u8; 32];
|
||
|
||
fn hash_bearer_token(token: &str) -> BearerTokenHash {
|
||
let digest = Sha256::digest(token.as_bytes());
|
||
let mut out = [0u8; 32];
|
||
out.copy_from_slice(&digest);
|
||
out
|
||
}
|
||
|
||
#[derive(OpenApi)]
|
||
#[openapi(
|
||
info(
|
||
title = "Omnigraph API",
|
||
description = "HTTP API for the Omnigraph graph database",
|
||
),
|
||
paths(
|
||
server_health,
|
||
server_snapshot,
|
||
server_read,
|
||
server_export,
|
||
server_change,
|
||
server_schema_apply,
|
||
server_schema_get,
|
||
server_ingest,
|
||
server_branch_list,
|
||
server_branch_create,
|
||
server_branch_delete,
|
||
server_branch_merge,
|
||
server_commit_list,
|
||
server_commit_show,
|
||
),
|
||
modifiers(&SecurityAddon),
|
||
)]
|
||
pub struct ApiDoc;
|
||
|
||
struct SecurityAddon;
|
||
|
||
impl utoipa::Modify for SecurityAddon {
|
||
fn modify(&self, openapi: &mut utoipa::openapi::OpenApi) {
|
||
openapi
|
||
.components
|
||
.get_or_insert_with(Default::default)
|
||
.add_security_scheme(
|
||
"bearer_token",
|
||
SecurityScheme::Http(Http::new(HttpAuthScheme::Bearer)),
|
||
);
|
||
}
|
||
}
|
||
|
||
const DEFAULT_REQUEST_BODY_LIMIT_BYTES: usize = 1_048_576;
|
||
const INGEST_REQUEST_BODY_LIMIT_BYTES: usize = 32 * 1024 * 1024;
|
||
const SERVER_VERSION: &str = env!("CARGO_PKG_VERSION");
|
||
const SERVER_SOURCE_VERSION: Option<&str> = option_env!("OMNIGRAPH_SOURCE_VERSION");
|
||
|
||
#[derive(Debug, Clone)]
|
||
pub struct ServerConfig {
|
||
pub uri: String,
|
||
pub bind: String,
|
||
pub policy_file: Option<PathBuf>,
|
||
/// Operator opt-in for fully-unauthenticated dev mode (MR-723).
|
||
/// When neither bearer tokens nor a policy file are configured,
|
||
/// `serve()` refuses to start unless this is true (set via
|
||
/// `--unauthenticated` or `OMNIGRAPH_UNAUTHENTICATED=1`). The
|
||
/// motivation is that "no tokens + no policy" looks like protection
|
||
/// (no Cedar errors at boot) but is actually fully open — operators
|
||
/// who set up auth and forgot the policy file would otherwise ship
|
||
/// the illusion of protection.
|
||
pub allow_unauthenticated: bool,
|
||
}
|
||
|
||
#[derive(Clone)]
|
||
pub struct AppState {
|
||
uri: String,
|
||
/// PR 2 (MR-686): the engine is now `Arc<Omnigraph>` — no global
|
||
/// write lock. Concurrent handlers call `&self` engine APIs
|
||
/// directly. Per-(table, branch) write queues inside the engine
|
||
/// serialize same-key writers; per-actor admission control on
|
||
/// `workload` isolates noisy actors.
|
||
engine: Arc<Omnigraph>,
|
||
/// Per-actor admission control. See `workload::WorkloadController`.
|
||
workload: Arc<workload::WorkloadController>,
|
||
bearer_tokens: Arc<[(BearerTokenHash, Arc<str>)]>,
|
||
policy_engine: Option<Arc<PolicyEngine>>,
|
||
}
|
||
|
||
#[derive(Debug, Clone)]
|
||
struct AuthenticatedActor(Arc<str>);
|
||
|
||
struct ExportStreamWriter {
|
||
sender: mpsc::UnboundedSender<std::result::Result<Bytes, io::Error>>,
|
||
}
|
||
|
||
impl Write for ExportStreamWriter {
|
||
fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
|
||
self.sender
|
||
.send(Ok(Bytes::copy_from_slice(buf)))
|
||
.map_err(|_| io::Error::new(io::ErrorKind::BrokenPipe, "export stream closed"))?;
|
||
Ok(buf.len())
|
||
}
|
||
|
||
fn flush(&mut self) -> io::Result<()> {
|
||
Ok(())
|
||
}
|
||
}
|
||
|
||
impl AuthenticatedActor {
|
||
fn as_str(&self) -> &str {
|
||
&self.0
|
||
}
|
||
}
|
||
|
||
#[derive(Debug)]
|
||
pub struct ApiError {
|
||
status: StatusCode,
|
||
code: ErrorCode,
|
||
message: String,
|
||
merge_conflicts: Vec<api::MergeConflictOutput>,
|
||
manifest_conflict: Option<api::ManifestConflictOutput>,
|
||
}
|
||
|
||
impl AppState {
|
||
pub fn new(uri: String, db: Omnigraph) -> Self {
|
||
Self::new_with_bearer_tokens(uri, db, Vec::new())
|
||
}
|
||
|
||
pub fn new_with_bearer_token(uri: String, db: Omnigraph, bearer_token: Option<String>) -> Self {
|
||
let bearer_tokens = normalize_bearer_token(bearer_token)
|
||
.into_iter()
|
||
.map(|token| ("default".to_string(), token))
|
||
.collect();
|
||
Self::new_with_bearer_tokens(uri, db, bearer_tokens)
|
||
}
|
||
|
||
pub fn new_with_bearer_tokens(
|
||
uri: String,
|
||
db: Omnigraph,
|
||
bearer_tokens: Vec<(String, String)>,
|
||
) -> Self {
|
||
Self::new_with_bearer_tokens_and_policy(uri, db, bearer_tokens, None)
|
||
}
|
||
|
||
pub fn new_with_bearer_tokens_and_policy(
|
||
uri: String,
|
||
db: Omnigraph,
|
||
bearer_tokens: Vec<(String, String)>,
|
||
policy_engine: Option<PolicyEngine>,
|
||
) -> Self {
|
||
let bearer_tokens: Vec<(BearerTokenHash, Arc<str>)> = bearer_tokens
|
||
.into_iter()
|
||
.map(|(actor, token)| (hash_bearer_token(&token), Arc::<str>::from(actor)))
|
||
.collect();
|
||
let policy_engine: Option<Arc<PolicyEngine>> = policy_engine.map(Arc::new);
|
||
// MR-722 chassis: inject the policy checker into the engine so
|
||
// `Omnigraph::apply_schema_as` (and PR #3's fan-out of the
|
||
// remaining writers) gates at engine-layer too. HTTP-layer
|
||
// `authorize_request` still fires first; the engine-layer gate
|
||
// is the redundant-but-correct backstop, plus the only path
|
||
// that protects SDK / embedded callers. PR #3 removes the HTTP
|
||
// redundancy once we're confident the engine gate covers it.
|
||
let db = if let Some(engine) = policy_engine.as_ref() {
|
||
// Unsizing coercion: Arc<PolicyEngine> → Arc<dyn PolicyChecker>.
|
||
// Needs the explicit `as` cast — Rust 2024 doesn't infer it through
|
||
// `Arc::clone`.
|
||
let checker = Arc::clone(engine) as Arc<dyn omnigraph_policy::PolicyChecker>;
|
||
db.with_policy(checker)
|
||
} else {
|
||
db
|
||
};
|
||
Self {
|
||
uri,
|
||
engine: Arc::new(db),
|
||
workload: Arc::new(workload::WorkloadController::from_env()),
|
||
bearer_tokens: Arc::from(bearer_tokens),
|
||
policy_engine,
|
||
}
|
||
}
|
||
|
||
/// Construct with a caller-provided [`workload::WorkloadController`].
|
||
/// Tests and benches use this to override per-actor caps without
|
||
/// mutating global env vars (which is unsafe in Rust 2024 once the
|
||
/// async runtime is up — `setenv` isn't thread-safe).
|
||
pub fn new_with_workload(
|
||
uri: String,
|
||
db: Omnigraph,
|
||
bearer_tokens: Vec<(String, String)>,
|
||
workload: workload::WorkloadController,
|
||
) -> Self {
|
||
let bearer_tokens: Vec<(BearerTokenHash, Arc<str>)> = bearer_tokens
|
||
.into_iter()
|
||
.map(|(actor, token)| (hash_bearer_token(&token), Arc::<str>::from(actor)))
|
||
.collect();
|
||
Self {
|
||
uri,
|
||
engine: Arc::new(db),
|
||
workload: Arc::new(workload),
|
||
bearer_tokens: Arc::from(bearer_tokens),
|
||
policy_engine: None,
|
||
}
|
||
}
|
||
|
||
/// Install a `PolicyEngine` post-construction (MR-723). Used by
|
||
/// integration tests that need to thread custom workload limits
|
||
/// alongside a permit-all policy — the existing `new_with_*` and
|
||
/// `new_with_workload` constructors don't compose. Production
|
||
/// callers should use `open_with_bearer_tokens_and_policy` which
|
||
/// installs the policy on both the HTTP state and the engine.
|
||
pub fn with_policy_engine(mut self, engine: PolicyEngine) -> Self {
|
||
self.policy_engine = Some(Arc::new(engine));
|
||
self
|
||
}
|
||
|
||
pub async fn open(uri: impl Into<String>) -> Result<Self> {
|
||
Self::open_with_bearer_token(uri, None).await
|
||
}
|
||
|
||
pub async fn open_with_bearer_token(
|
||
uri: impl Into<String>,
|
||
bearer_token: Option<String>,
|
||
) -> Result<Self> {
|
||
let bearer_tokens = normalize_bearer_token(bearer_token)
|
||
.into_iter()
|
||
.map(|token| ("default".to_string(), token))
|
||
.collect();
|
||
Self::open_with_bearer_tokens(uri, bearer_tokens).await
|
||
}
|
||
|
||
pub async fn open_with_bearer_tokens(
|
||
uri: impl Into<String>,
|
||
bearer_tokens: Vec<(String, String)>,
|
||
) -> Result<Self> {
|
||
let uri = uri.into();
|
||
let db = Omnigraph::open(&uri).await?;
|
||
Ok(Self::new_with_bearer_tokens(uri, db, bearer_tokens))
|
||
}
|
||
|
||
pub async fn open_with_bearer_tokens_and_policy(
|
||
uri: impl Into<String>,
|
||
bearer_tokens: Vec<(String, String)>,
|
||
policy_file: Option<&PathBuf>,
|
||
) -> Result<Self> {
|
||
let uri = uri.into();
|
||
let db = Omnigraph::open(&uri).await?;
|
||
let policy_engine = match policy_file {
|
||
Some(path) => Some(PolicyEngine::load(path, &uri)?),
|
||
None => None,
|
||
};
|
||
if policy_engine.is_some() && bearer_tokens.is_empty() {
|
||
bail!("policy requires at least one configured bearer token actor");
|
||
}
|
||
Ok(Self::new_with_bearer_tokens_and_policy(
|
||
uri,
|
||
db,
|
||
bearer_tokens,
|
||
policy_engine,
|
||
))
|
||
}
|
||
|
||
pub fn uri(&self) -> &str {
|
||
&self.uri
|
||
}
|
||
|
||
fn requires_bearer_auth(&self) -> bool {
|
||
!self.bearer_tokens.is_empty() || self.policy_engine.is_some()
|
||
}
|
||
|
||
fn authenticate_bearer_token(&self, provided_token: &str) -> Option<Arc<str>> {
|
||
// Hash the incoming token and compare against every stored digest in
|
||
// constant time. Iterate all entries unconditionally so total work —
|
||
// and therefore response timing — doesn't depend on which slot matches.
|
||
let provided_hash = hash_bearer_token(provided_token);
|
||
let mut matched: Option<Arc<str>> = None;
|
||
for (hash, actor) in self.bearer_tokens.iter() {
|
||
if bool::from(hash.ct_eq(&provided_hash)) && matched.is_none() {
|
||
matched = Some(Arc::clone(actor));
|
||
}
|
||
}
|
||
matched
|
||
}
|
||
|
||
fn policy_engine(&self) -> Option<&PolicyEngine> {
|
||
self.policy_engine.as_deref()
|
||
}
|
||
}
|
||
|
||
impl ApiError {
|
||
pub fn unauthorized(message: impl Into<String>) -> Self {
|
||
Self {
|
||
status: StatusCode::UNAUTHORIZED,
|
||
code: ErrorCode::Unauthorized,
|
||
message: message.into(),
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
pub fn forbidden(message: impl Into<String>) -> Self {
|
||
Self {
|
||
status: StatusCode::FORBIDDEN,
|
||
code: ErrorCode::Forbidden,
|
||
message: message.into(),
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
pub fn bad_request(message: impl Into<String>) -> Self {
|
||
Self {
|
||
status: StatusCode::BAD_REQUEST,
|
||
code: ErrorCode::BadRequest,
|
||
message: message.into(),
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
pub fn not_found(message: impl Into<String>) -> Self {
|
||
Self {
|
||
status: StatusCode::NOT_FOUND,
|
||
code: ErrorCode::NotFound,
|
||
message: message.into(),
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
pub fn conflict(message: impl Into<String>) -> Self {
|
||
Self {
|
||
status: StatusCode::CONFLICT,
|
||
code: ErrorCode::Conflict,
|
||
message: message.into(),
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
pub fn internal(message: impl Into<String>) -> Self {
|
||
Self {
|
||
status: StatusCode::INTERNAL_SERVER_ERROR,
|
||
code: ErrorCode::Internal,
|
||
message: message.into(),
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
/// HTTP 429 Too Many Requests — actor exceeded their per-actor
|
||
/// admission cap (count or byte budget). Clients should respect the
|
||
/// `Retry-After` header. Mapped from `RejectReason::InFlightCountExceeded`
|
||
/// and `RejectReason::ByteBudgetExceeded`.
|
||
pub fn too_many_requests(message: impl Into<String>) -> Self {
|
||
Self {
|
||
status: StatusCode::TOO_MANY_REQUESTS,
|
||
code: ErrorCode::TooManyRequests,
|
||
message: message.into(),
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
/// Convert a `WorkloadController` rejection into the matching
|
||
/// `ApiError` variant.
|
||
pub fn from_workload_reject(reject: workload::RejectReason) -> Self {
|
||
match reject {
|
||
workload::RejectReason::InFlightCountExceeded { .. }
|
||
| workload::RejectReason::ByteBudgetExceeded { .. } => {
|
||
Self::too_many_requests(reject.to_string())
|
||
}
|
||
}
|
||
}
|
||
|
||
fn merge_conflict(conflicts: Vec<api::MergeConflictOutput>) -> Self {
|
||
Self {
|
||
status: StatusCode::CONFLICT,
|
||
code: ErrorCode::Conflict,
|
||
message: summarize_merge_conflicts(&conflicts),
|
||
merge_conflicts: conflicts,
|
||
manifest_conflict: None,
|
||
}
|
||
}
|
||
|
||
fn manifest_version_conflict(message: String, details: api::ManifestConflictOutput) -> Self {
|
||
Self {
|
||
status: StatusCode::CONFLICT,
|
||
code: ErrorCode::Conflict,
|
||
message,
|
||
merge_conflicts: Vec::new(),
|
||
manifest_conflict: Some(details),
|
||
}
|
||
}
|
||
|
||
fn from_omni(err: OmniError) -> Self {
|
||
match err {
|
||
OmniError::Compiler(err) => Self::bad_request(err.to_string()),
|
||
OmniError::DataFusion(message) => Self::bad_request(format!("query: {message}")),
|
||
OmniError::Manifest(err) => match err.kind {
|
||
ManifestErrorKind::BadRequest => Self::bad_request(err.message),
|
||
ManifestErrorKind::NotFound => Self::not_found(err.message),
|
||
ManifestErrorKind::Conflict => match err.details {
|
||
Some(ManifestConflictDetails::ExpectedVersionMismatch {
|
||
table_key,
|
||
expected,
|
||
actual,
|
||
}) => Self::manifest_version_conflict(
|
||
err.message,
|
||
api::ManifestConflictOutput {
|
||
table_key,
|
||
expected,
|
||
actual,
|
||
},
|
||
),
|
||
_ => Self::conflict(err.message),
|
||
},
|
||
ManifestErrorKind::Internal => Self::internal(err.message),
|
||
},
|
||
OmniError::MergeConflicts(conflicts) => Self::merge_conflict(
|
||
conflicts
|
||
.iter()
|
||
.map(api::MergeConflictOutput::from)
|
||
.collect(),
|
||
),
|
||
OmniError::Lance(message) => Self::internal(format!("storage: {message}")),
|
||
OmniError::Io(err) => Self::internal(format!("io: {err}")),
|
||
// Engine-layer policy enforcement (MR-722). All denials and
|
||
// evaluation failures surface here as 403. The HTTP-layer
|
||
// `authorize_request` already distinguishes 401 (missing
|
||
// bearer) from 403 (policy denial), so by the time the
|
||
// engine gate fires, the bearer is valid — any failure from
|
||
// the engine is a policy outcome, not an auth one.
|
||
OmniError::Policy(message) => Self::forbidden(message),
|
||
}
|
||
}
|
||
}
|
||
|
||
fn summarize_merge_conflicts(conflicts: &[api::MergeConflictOutput]) -> String {
|
||
if conflicts.is_empty() {
|
||
return "merge conflicts".to_string();
|
||
}
|
||
|
||
let preview: Vec<String> = conflicts
|
||
.iter()
|
||
.take(3)
|
||
.map(|conflict| match conflict.row_id.as_deref() {
|
||
Some(row_id) => format!(
|
||
"{}:{} ({})",
|
||
conflict.table_key,
|
||
row_id,
|
||
conflict.kind.as_str()
|
||
),
|
||
None => format!("{} ({})", conflict.table_key, conflict.kind.as_str()),
|
||
})
|
||
.collect();
|
||
|
||
let suffix = if conflicts.len() > preview.len() {
|
||
format!("; and {} more", conflicts.len() - preview.len())
|
||
} else {
|
||
String::new()
|
||
};
|
||
|
||
format!("merge conflicts: {}{}", preview.join("; "), suffix)
|
||
}
|
||
|
||
/// Constant `Retry-After` value (seconds) emitted on 429 responses.
|
||
const RETRY_AFTER_SECONDS: &str = "60";
|
||
|
||
impl IntoResponse for ApiError {
|
||
fn into_response(self) -> Response {
|
||
let mut headers = axum::http::HeaderMap::new();
|
||
if matches!(self.code, ErrorCode::TooManyRequests) {
|
||
headers.insert(
|
||
axum::http::header::RETRY_AFTER,
|
||
axum::http::HeaderValue::from_static(RETRY_AFTER_SECONDS),
|
||
);
|
||
}
|
||
(
|
||
self.status,
|
||
headers,
|
||
Json(ErrorOutput {
|
||
error: self.message,
|
||
code: Some(self.code),
|
||
merge_conflicts: self.merge_conflicts,
|
||
manifest_conflict: self.manifest_conflict,
|
||
}),
|
||
)
|
||
.into_response()
|
||
}
|
||
}
|
||
|
||
pub fn init_tracing() {
|
||
let filter = EnvFilter::try_from_default_env().unwrap_or_else(|_| EnvFilter::new("info"));
|
||
let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
|
||
}
|
||
|
||
pub fn load_server_settings(
|
||
config_path: Option<&PathBuf>,
|
||
cli_uri: Option<String>,
|
||
cli_target: Option<String>,
|
||
cli_bind: Option<String>,
|
||
cli_allow_unauthenticated: bool,
|
||
) -> Result<ServerConfig> {
|
||
let config = load_config(config_path)?;
|
||
let uri =
|
||
config.resolve_target_uri(cli_uri, cli_target.as_deref(), config.server_graph_name())?;
|
||
let bind = cli_bind.unwrap_or_else(|| config.server_bind().to_string());
|
||
let policy_file = config.resolve_policy_file();
|
||
// Either `--unauthenticated` or `OMNIGRAPH_UNAUTHENTICATED=1` flips
|
||
// this. Treat any non-empty, non-"0"/"false" string as truthy —
|
||
// standard 12-factor "any value is true" reading of the env var.
|
||
let env_unauth = std::env::var("OMNIGRAPH_UNAUTHENTICATED")
|
||
.ok()
|
||
.map(|v| {
|
||
let trimmed = v.trim();
|
||
!trimmed.is_empty() && trimmed != "0" && !trimmed.eq_ignore_ascii_case("false")
|
||
})
|
||
.unwrap_or(false);
|
||
let allow_unauthenticated = cli_allow_unauthenticated || env_unauth;
|
||
|
||
Ok(ServerConfig {
|
||
uri,
|
||
bind,
|
||
policy_file,
|
||
allow_unauthenticated,
|
||
})
|
||
}
|
||
|
||
/// MR-723 server runtime state, classified from the three-state matrix
|
||
/// of (bearer tokens configured) × (policy file configured) at startup.
|
||
///
|
||
/// * **Open** — neither tokens nor policy; requires explicit
|
||
/// `allow_unauthenticated`. Effectively a "trust the network" dev
|
||
/// mode. `serve()` refuses to start in this shape without the flag,
|
||
/// so the only way to reach this state at runtime is via deliberate
|
||
/// operator opt-in.
|
||
/// * **DefaultDeny** — tokens configured but no policy file. The
|
||
/// server requires a valid bearer token; once authenticated, every
|
||
/// action except `Read` is denied with 403. Closes the "tokens but
|
||
/// forgot the policy file" trap.
|
||
/// * **PolicyEnabled** — policy file configured. Cedar evaluates every
|
||
/// authenticated request. Tokens may also be configured (typical) or
|
||
/// not (unusual but valid — every request fails 401 without a
|
||
/// bearer, which is effectively "locked").
|
||
#[derive(Debug, Clone, Copy, Eq, PartialEq)]
|
||
pub enum ServerRuntimeState {
|
||
Open,
|
||
DefaultDeny,
|
||
PolicyEnabled,
|
||
}
|
||
|
||
/// Compute the [`ServerRuntimeState`] from the configured inputs.
|
||
/// Pulled out as a pure function so the 3-state matrix is unit-testable
|
||
/// without standing up the full server.
|
||
pub fn classify_server_runtime_state(
|
||
has_tokens: bool,
|
||
has_policy: bool,
|
||
allow_unauthenticated: bool,
|
||
) -> Result<ServerRuntimeState> {
|
||
match (has_tokens, has_policy, allow_unauthenticated) {
|
||
(false, false, false) => bail!(
|
||
"server has no bearer tokens and no policy file configured. This is a fully \
|
||
open server — pass `--unauthenticated` (or set OMNIGRAPH_UNAUTHENTICATED=1) \
|
||
if you actually want that, otherwise configure bearer tokens (see \
|
||
docs/user/server.md) and/or `policy.file` in omnigraph.yaml."
|
||
),
|
||
(false, false, true) => Ok(ServerRuntimeState::Open),
|
||
(true, false, _) => Ok(ServerRuntimeState::DefaultDeny),
|
||
(_, true, _) => Ok(ServerRuntimeState::PolicyEnabled),
|
||
}
|
||
}
|
||
|
||
pub fn build_app(state: AppState) -> Router {
|
||
let protected = Router::new()
|
||
.route("/snapshot", get(server_snapshot))
|
||
.route("/export", post(server_export))
|
||
.route("/read", post(server_read))
|
||
.route("/change", post(server_change))
|
||
.route("/schema", get(server_schema_get))
|
||
.route("/schema/apply", post(server_schema_apply))
|
||
.route(
|
||
"/ingest",
|
||
post(server_ingest).layer(DefaultBodyLimit::max(INGEST_REQUEST_BODY_LIMIT_BYTES)),
|
||
)
|
||
.route(
|
||
"/branches",
|
||
get(server_branch_list).post(server_branch_create),
|
||
)
|
||
.route("/branches/{branch}", delete(server_branch_delete))
|
||
.route("/branches/merge", post(server_branch_merge))
|
||
.route("/commits", get(server_commit_list))
|
||
.route("/commits/{commit_id}", get(server_commit_show))
|
||
.route_layer(middleware::from_fn_with_state(
|
||
state.clone(),
|
||
require_bearer_auth,
|
||
));
|
||
|
||
Router::new()
|
||
.route("/healthz", get(server_health))
|
||
.route("/openapi.json", get(server_openapi))
|
||
.merge(protected)
|
||
.layer(DefaultBodyLimit::max(DEFAULT_REQUEST_BODY_LIMIT_BYTES))
|
||
.layer(TraceLayer::new_for_http())
|
||
.with_state(state)
|
||
}
|
||
|
||
pub async fn serve(config: ServerConfig) -> Result<()> {
|
||
let token_source = resolve_token_source().await?;
|
||
info!(source = token_source.name(), "loaded bearer token source");
|
||
let tokens = token_source.load().await?;
|
||
let runtime_state = classify_server_runtime_state(
|
||
!tokens.is_empty(),
|
||
config.policy_file.is_some(),
|
||
config.allow_unauthenticated,
|
||
)?;
|
||
match runtime_state {
|
||
ServerRuntimeState::Open => warn!(
|
||
"running with --unauthenticated: no bearer tokens, no policy file, all \
|
||
requests permitted. This is for local dev only — do not expose to a \
|
||
network you don't fully trust."
|
||
),
|
||
ServerRuntimeState::DefaultDeny => warn!(
|
||
"bearer tokens are configured but no policy file is set — running in \
|
||
default-deny mode (only `read` actions are permitted for authenticated \
|
||
actors). Configure `policy.file` in omnigraph.yaml to enable Cedar rules."
|
||
),
|
||
ServerRuntimeState::PolicyEnabled => {}
|
||
}
|
||
let state = AppState::open_with_bearer_tokens_and_policy(
|
||
config.uri.clone(),
|
||
tokens,
|
||
config.policy_file.as_ref(),
|
||
)
|
||
.await?;
|
||
let listener = TcpListener::bind(&config.bind).await?;
|
||
info!(uri = %config.uri, bind = %config.bind, "serving omnigraph");
|
||
axum::serve(listener, build_app(state))
|
||
.with_graceful_shutdown(shutdown_signal())
|
||
.await?;
|
||
Ok(())
|
||
}
|
||
|
||
async fn shutdown_signal() {
|
||
if let Err(err) = tokio::signal::ctrl_c().await {
|
||
error!(error = %err, "failed to install ctrl-c handler");
|
||
return;
|
||
}
|
||
info!("shutdown signal received");
|
||
}
|
||
|
||
#[utoipa::path(
|
||
get,
|
||
path = "/healthz",
|
||
tag = "health",
|
||
operation_id = "health",
|
||
responses(
|
||
(status = 200, description = "Server is healthy", body = HealthOutput),
|
||
),
|
||
)]
|
||
/// Liveness probe.
|
||
///
|
||
/// Returns server status and version. Unauthenticated; safe to call from any
|
||
/// caller. Use this to confirm the server is reachable before invoking other
|
||
/// endpoints.
|
||
async fn server_health() -> Json<HealthOutput> {
|
||
Json(HealthOutput {
|
||
status: "ok".to_string(),
|
||
version: SERVER_VERSION.to_string(),
|
||
source_version: SERVER_SOURCE_VERSION.map(str::to_string),
|
||
})
|
||
}
|
||
|
||
async fn server_openapi(State(state): State<AppState>) -> Json<utoipa::openapi::OpenApi> {
|
||
let mut doc = ApiDoc::openapi();
|
||
if !state.requires_bearer_auth() {
|
||
strip_security(&mut doc);
|
||
}
|
||
Json(doc)
|
||
}
|
||
|
||
fn strip_security(doc: &mut utoipa::openapi::OpenApi) {
|
||
if let Some(components) = doc.components.as_mut() {
|
||
components.security_schemes.clear();
|
||
}
|
||
for path_item in doc.paths.paths.values_mut() {
|
||
for op in [
|
||
path_item.get.as_mut(),
|
||
path_item.post.as_mut(),
|
||
path_item.put.as_mut(),
|
||
path_item.delete.as_mut(),
|
||
path_item.options.as_mut(),
|
||
path_item.head.as_mut(),
|
||
path_item.patch.as_mut(),
|
||
path_item.trace.as_mut(),
|
||
]
|
||
.into_iter()
|
||
.flatten()
|
||
{
|
||
op.security = None;
|
||
}
|
||
}
|
||
}
|
||
|
||
async fn require_bearer_auth(
|
||
State(state): State<AppState>,
|
||
mut request: Request,
|
||
next: Next,
|
||
) -> std::result::Result<Response, ApiError> {
|
||
if !state.requires_bearer_auth() {
|
||
return Ok(next.run(request).await);
|
||
}
|
||
|
||
let Some(header) = request
|
||
.headers()
|
||
.get(AUTHORIZATION)
|
||
.and_then(|value| value.to_str().ok())
|
||
else {
|
||
return Err(ApiError::unauthorized("missing bearer token"));
|
||
};
|
||
|
||
let Some(provided_token) = header.strip_prefix("Bearer ") else {
|
||
return Err(ApiError::unauthorized("missing bearer token"));
|
||
};
|
||
|
||
let Some(actor) = state.authenticate_bearer_token(provided_token) else {
|
||
return Err(ApiError::unauthorized("invalid bearer token"));
|
||
};
|
||
request.extensions_mut().insert(AuthenticatedActor(actor));
|
||
|
||
Ok(next.run(request).await)
|
||
}
|
||
|
||
fn log_policy_decision(actor_id: &str, request: &PolicyRequest, decision: &PolicyDecision) {
|
||
info!(
|
||
actor_id = actor_id,
|
||
action = %request.action,
|
||
branch = request.branch.as_deref().unwrap_or(""),
|
||
target_branch = request.target_branch.as_deref().unwrap_or(""),
|
||
allowed = decision.allowed,
|
||
matched_rule_id = decision.matched_rule_id.as_deref().unwrap_or(""),
|
||
"policy decision"
|
||
);
|
||
}
|
||
|
||
fn authorize_request(
|
||
state: &AppState,
|
||
actor: Option<&AuthenticatedActor>,
|
||
mut request: PolicyRequest,
|
||
) -> std::result::Result<(), ApiError> {
|
||
let Some(engine) = state.policy_engine() else {
|
||
// MR-723 default-deny path. We're here when no PolicyEngine is
|
||
// installed. Two startup-validated shapes can reach this:
|
||
//
|
||
// * **Open mode** (`--unauthenticated`): no tokens, no policy.
|
||
// `require_bearer_auth` short-circuits before this is called,
|
||
// but defense in depth — if a future change makes the
|
||
// middleware call here for an unauthenticated request, we
|
||
// want every action to remain Ok rather than 403. The
|
||
// operator opted in.
|
||
// * **DefaultDeny mode**: tokens configured but no policy. The
|
||
// request went through bearer auth, so `actor` is Some and
|
||
// identifies a known actor. Only `Read` is permitted; every
|
||
// other action returns 403. This closes the "configured auth
|
||
// but forgot the policy file" trap from MR-723.
|
||
if actor.is_some() && request.action != PolicyAction::Read {
|
||
return Err(ApiError::forbidden(
|
||
"server runs in default-deny mode (bearer tokens configured but no \
|
||
policy file). Only `read` actions are permitted; configure \
|
||
`policy.file` in omnigraph.yaml to enable other actions.",
|
||
));
|
||
}
|
||
return Ok(());
|
||
};
|
||
let Some(actor) = actor else {
|
||
return Err(ApiError::unauthorized("missing bearer token"));
|
||
};
|
||
// SECURITY INVARIANT (MR-731): actor identity comes from the matched
|
||
// bearer token, never from a client-supplied request header, query
|
||
// parameter, or body field. This line is the single chokepoint where
|
||
// the authoritative actor (resolved from the bearer match by
|
||
// `require_bearer_auth`) overwrites whatever the handler put in the
|
||
// PolicyRequest. Removing or weakening it lets clients spoof identity —
|
||
// exactly the Supabase RLS footgun ("trusting raw_user_meta_data is
|
||
// asking the attacker if they're an admin"). The principle is codified
|
||
// in `docs/dev/invariants.md` Hard Invariant 11 ("clients cannot set
|
||
// actor identity directly") and pinned by the regression test
|
||
// `actor_id_resolves_from_bearer_token_ignoring_client_supplied_headers`
|
||
// in `crates/omnigraph-server/tests/server.rs`.
|
||
//
|
||
// Side effect: also prevents an empty-string default at any handler
|
||
// call site from ever reaching the engine as a policy subject.
|
||
request.actor_id = actor.as_str().to_string();
|
||
let decision = engine
|
||
.authorize(&request)
|
||
.map_err(|err| ApiError::internal(format!("policy: {err}")))?;
|
||
log_policy_decision(actor.as_str(), &request, &decision);
|
||
if decision.allowed {
|
||
Ok(())
|
||
} else {
|
||
Err(ApiError::forbidden(decision.message))
|
||
}
|
||
}
|
||
|
||
#[utoipa::path(
|
||
get,
|
||
path = "/snapshot",
|
||
tag = "snapshots",
|
||
operation_id = "getSnapshot",
|
||
params(SnapshotQuery),
|
||
responses(
|
||
(status = 200, description = "Database snapshot", body = api::SnapshotOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Read the current snapshot of a branch.
|
||
///
|
||
/// Returns the manifest version plus per-table metadata (path, version, row
|
||
/// count) for every table on the branch. Defaults to `main` when `branch` is
|
||
/// omitted. Read-only.
|
||
async fn server_snapshot(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Query(query): Query<SnapshotQuery>,
|
||
) -> std::result::Result<Json<api::SnapshotOutput>, ApiError> {
|
||
let branch = query.branch.unwrap_or_else(|| "main".to_string());
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::Read,
|
||
branch: Some(branch.clone()),
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let snapshot = {
|
||
let db = &state.engine;
|
||
db.snapshot_of(ReadTarget::branch(branch.as_str()))
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
Ok(Json(snapshot_payload(&branch, &snapshot)))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
post,
|
||
path = "/read",
|
||
tag = "queries",
|
||
operation_id = "read",
|
||
request_body = ReadRequest,
|
||
responses(
|
||
(status = 200, description = "Query results", body = ReadOutput),
|
||
(status = 400, description = "Bad request", body = ErrorOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Execute a GQ read query.
|
||
///
|
||
/// Runs the query in `query_source` against either a branch or a frozen
|
||
/// snapshot (mutually exclusive). When `query_source` defines multiple named
|
||
/// queries, pick one with `query_name`. `params` is a JSON object whose keys
|
||
/// match the parameters declared by the query. Returns rows as a JSON array
|
||
/// plus a `columns` list. Read-only.
|
||
async fn server_read(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Json(request): Json<ReadRequest>,
|
||
) -> std::result::Result<Json<ReadOutput>, ApiError> {
|
||
if request.branch.is_some() && request.snapshot.is_some() {
|
||
return Err(ApiError::bad_request(
|
||
"read request may specify branch or snapshot, not both",
|
||
));
|
||
}
|
||
|
||
let target = read_target_from_request(request.branch, request.snapshot);
|
||
let policy_branch = match &target {
|
||
ReadTarget::Branch(branch) => Some(branch.clone()),
|
||
ReadTarget::Snapshot(_) if state.policy_engine().is_some() && actor.is_some() => {
|
||
let db = &state.engine;
|
||
db.resolved_branch_of(target.clone())
|
||
.await
|
||
.map(|branch| branch.or_else(|| Some("main".to_string())))
|
||
.map_err(ApiError::from_omni)?
|
||
}
|
||
ReadTarget::Snapshot(_) => None,
|
||
};
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::Read,
|
||
branch: policy_branch,
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let (selected_name, query_params) =
|
||
select_named_query(&request.query_source, request.query_name.as_deref())
|
||
.map_err(|err| ApiError::bad_request(err.to_string()))?;
|
||
let params = query_params_from_json(&query_params, request.params.as_ref())
|
||
.map_err(|err| ApiError::bad_request(err.to_string()))?;
|
||
|
||
let result = {
|
||
let db = &state.engine;
|
||
db.query(
|
||
target.clone(),
|
||
&request.query_source,
|
||
&selected_name,
|
||
¶ms,
|
||
)
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
Ok(Json(api::read_output(selected_name, &target, result)))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
post,
|
||
path = "/export",
|
||
tag = "queries",
|
||
operation_id = "export",
|
||
request_body = ExportRequest,
|
||
responses(
|
||
(status = 200, description = "Exported data as NDJSON", content_type = "application/x-ndjson"),
|
||
(status = 400, description = "Bad request", body = ErrorOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Stream the contents of a branch as NDJSON.
|
||
///
|
||
/// Emits one JSON object per line (`application/x-ndjson`). Filter with
|
||
/// `type_names` (node/edge type names) and/or `table_keys`; both empty
|
||
/// streams the entire branch. Suitable for large exports — the response is
|
||
/// streamed, not buffered. Read-only.
|
||
async fn server_export(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Json(request): Json<ExportRequest>,
|
||
) -> std::result::Result<Response, ApiError> {
|
||
let branch = request.branch.unwrap_or_else(|| "main".to_string());
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::Export,
|
||
branch: Some(branch.clone()),
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let engine = Arc::clone(&state.engine);
|
||
let type_names = request.type_names.clone();
|
||
let table_keys = request.table_keys.clone();
|
||
let (tx, rx) = mpsc::unbounded_channel::<std::result::Result<Bytes, io::Error>>();
|
||
tokio::spawn(async move {
|
||
let result = {
|
||
let mut writer = ExportStreamWriter { sender: tx.clone() };
|
||
engine
|
||
.export_jsonl_to_writer(&branch, &type_names, &table_keys, &mut writer)
|
||
.await
|
||
};
|
||
if let Err(err) = result {
|
||
let _ = tx.send(Err(io::Error::other(err.to_string())));
|
||
}
|
||
});
|
||
let body = Body::from_stream(stream::unfold(rx, |mut rx| async move {
|
||
rx.recv().await.map(|item| (item, rx))
|
||
}));
|
||
Ok((
|
||
StatusCode::OK,
|
||
[(CONTENT_TYPE, "application/x-ndjson; charset=utf-8")],
|
||
body,
|
||
)
|
||
.into_response())
|
||
}
|
||
|
||
#[utoipa::path(
|
||
post,
|
||
path = "/change",
|
||
tag = "mutations",
|
||
operation_id = "change",
|
||
request_body = ChangeRequest,
|
||
responses(
|
||
(status = 200, description = "Mutation results", body = ChangeOutput),
|
||
(status = 400, description = "Bad request", body = ErrorOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
(status = 409, description = "Merge conflict", body = ErrorOutput),
|
||
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Apply a GQ mutation to a branch.
|
||
///
|
||
/// Writes to the named `branch` (defaults to `main`). Mutations are atomic
|
||
/// per call and produce a new commit. Returns counts of nodes and edges
|
||
/// affected. **Destructive**: on success the branch is updated; rejected
|
||
/// mutations may still acquire locks briefly. Returns 409 on merge conflict.
|
||
async fn server_change(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Json(request): Json<ChangeRequest>,
|
||
) -> std::result::Result<Json<ChangeOutput>, ApiError> {
|
||
let branch = request.branch.unwrap_or_else(|| "main".to_string());
|
||
let actor_arc = actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| Arc::clone(&actor.0))
|
||
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
|
||
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
|
||
action: PolicyAction::Change,
|
||
branch: Some(branch.clone()),
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
// Per-actor admission: bound concurrent in-flight mutations and
|
||
// estimated bytes per actor. Cedar runs FIRST so denied requests
|
||
// don't consume admission slots. Estimate uses the request body
|
||
// size as a coarse proxy; engine memory pressure can run higher.
|
||
let est_bytes = request.query_source.len() as u64
|
||
+ request
|
||
.params
|
||
.as_ref()
|
||
.map(|p| p.to_string().len() as u64)
|
||
.unwrap_or(0);
|
||
let _admission = state
|
||
.workload
|
||
.try_admit(&actor_arc, est_bytes)
|
||
.map_err(ApiError::from_workload_reject)?;
|
||
let (selected_name, query_params) =
|
||
select_named_query(&request.query_source, request.query_name.as_deref())
|
||
.map_err(|err| ApiError::bad_request(err.to_string()))?;
|
||
let params = query_params_from_json(&query_params, request.params.as_ref())
|
||
.map_err(|err| ApiError::bad_request(err.to_string()))?;
|
||
|
||
let result = {
|
||
let db = &state.engine;
|
||
db.mutate_as(
|
||
&branch,
|
||
&request.query_source,
|
||
&selected_name,
|
||
¶ms,
|
||
actor_id,
|
||
)
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
Ok(Json(ChangeOutput {
|
||
branch,
|
||
query_name: selected_name,
|
||
affected_nodes: result.affected_nodes,
|
||
affected_edges: result.affected_edges,
|
||
actor_id: actor_id.map(str::to_string),
|
||
}))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
get,
|
||
path = "/schema",
|
||
tag = "schema",
|
||
operation_id = "getSchema",
|
||
responses(
|
||
(status = 200, description = "Current schema source", body = SchemaOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Read the current schema source.
|
||
///
|
||
/// Returns the project's schema as a single string in `.pg` source form.
|
||
/// Useful for clients that want to introspect available types and tables
|
||
/// before constructing GQ queries. Read-only.
|
||
async fn server_schema_get(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
) -> std::result::Result<Json<SchemaOutput>, ApiError> {
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::Read,
|
||
branch: None,
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let schema_source = {
|
||
let db = &state.engine;
|
||
db.schema_source().to_string()
|
||
};
|
||
Ok(Json(SchemaOutput { schema_source }))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
post,
|
||
path = "/schema/apply",
|
||
tag = "mutations",
|
||
operation_id = "applySchema",
|
||
request_body = SchemaApplyRequest,
|
||
responses(
|
||
(status = 200, description = "Schema apply results", body = SchemaApplyOutput),
|
||
(status = 400, description = "Bad request", body = ErrorOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Apply a schema migration.
|
||
///
|
||
/// Diffs `schema_source` against the current schema and applies the resulting
|
||
/// migration steps (add/drop type, add/drop column, etc.). **Destructive**:
|
||
/// some steps drop data. Returns the list of steps applied; if `applied` is
|
||
/// false the diff was unsupported and no changes were made.
|
||
async fn server_schema_apply(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Json(request): Json<SchemaApplyRequest>,
|
||
) -> std::result::Result<Json<SchemaApplyOutput>, ApiError> {
|
||
let actor_arc = actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| Arc::clone(&actor.0))
|
||
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
|
||
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
|
||
action: PolicyAction::SchemaApply,
|
||
branch: None,
|
||
target_branch: Some("main".to_string()),
|
||
},
|
||
)?;
|
||
let est_bytes = request.schema_source.len() as u64;
|
||
let _admission = state
|
||
.workload
|
||
.try_admit(&actor_arc, est_bytes)
|
||
.map_err(ApiError::from_workload_reject)?;
|
||
let result = {
|
||
let db = &state.engine;
|
||
// Engine-layer policy enforcement (MR-722): pass the resolved
|
||
// actor through so apply_schema_as can call enforce() with the
|
||
// authoritative identity. With a policy installed in AppState,
|
||
// engine-side enforcement re-checks the same decision the
|
||
// HTTP-layer authorize_request just made above. PR #3 collapses
|
||
// the redundancy.
|
||
db.apply_schema_as(
|
||
&request.schema_source,
|
||
omnigraph::db::SchemaApplyOptions {
|
||
allow_data_loss: request.allow_data_loss,
|
||
},
|
||
actor_id,
|
||
)
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
Ok(Json(schema_apply_output(state.uri(), result)))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
post,
|
||
path = "/ingest",
|
||
tag = "mutations",
|
||
operation_id = "ingest",
|
||
request_body = IngestRequest,
|
||
responses(
|
||
(status = 200, description = "Ingest results", body = IngestOutput),
|
||
(status = 400, description = "Bad request", body = ErrorOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Bulk-ingest NDJSON data into a branch.
|
||
///
|
||
/// `data` is NDJSON with one record per line. `mode` controls behavior on
|
||
/// existing rows: `merge` upserts by id (default), `append` blindly inserts,
|
||
/// `overwrite` replaces table contents. If `branch` does not exist it is
|
||
/// created from `from` (defaults to `main`). **Destructive** when `mode` is
|
||
/// `overwrite` or when ingest produces conflicting writes.
|
||
async fn server_ingest(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Json(request): Json<IngestRequest>,
|
||
) -> std::result::Result<Json<IngestOutput>, ApiError> {
|
||
let branch = request.branch.unwrap_or_else(|| "main".to_string());
|
||
let from = request.from.unwrap_or_else(|| "main".to_string());
|
||
let mode = request.mode.unwrap_or(omnigraph::loader::LoadMode::Merge);
|
||
let actor_arc = actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| Arc::clone(&actor.0))
|
||
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
|
||
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
|
||
|
||
let branch_exists = {
|
||
let db = &state.engine;
|
||
db.branch_list()
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
.into_iter()
|
||
.any(|name| name == branch)
|
||
};
|
||
|
||
if !branch_exists {
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
|
||
action: PolicyAction::BranchCreate,
|
||
branch: Some(from.clone()),
|
||
target_branch: Some(branch.clone()),
|
||
},
|
||
)?;
|
||
}
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
|
||
action: PolicyAction::Change,
|
||
branch: Some(branch.clone()),
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let est_bytes = request.data.len() as u64;
|
||
let _admission = state
|
||
.workload
|
||
.try_admit(&actor_arc, est_bytes)
|
||
.map_err(ApiError::from_workload_reject)?;
|
||
|
||
let result = {
|
||
let db = &state.engine;
|
||
db.ingest_as(&branch, Some(&from), &request.data, mode, actor_id)
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
|
||
Ok(Json(ingest_output(
|
||
state.uri(),
|
||
&result,
|
||
actor_id.map(str::to_string),
|
||
)))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
get,
|
||
path = "/branches",
|
||
tag = "branches",
|
||
operation_id = "listBranches",
|
||
responses(
|
||
(status = 200, description = "List of branches", body = BranchListOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// List all branches.
|
||
///
|
||
/// Returns branch names sorted alphabetically. Read-only.
|
||
async fn server_branch_list(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
) -> std::result::Result<Json<BranchListOutput>, ApiError> {
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::Read,
|
||
branch: None,
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let mut branches = {
|
||
let db = &state.engine;
|
||
db.branch_list().await.map_err(ApiError::from_omni)?
|
||
};
|
||
branches.sort();
|
||
Ok(Json(BranchListOutput { branches }))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
post,
|
||
path = "/branches",
|
||
tag = "branches",
|
||
operation_id = "createBranch",
|
||
request_body = BranchCreateRequest,
|
||
responses(
|
||
(status = 200, description = "Branch created", body = BranchCreateOutput),
|
||
(status = 400, description = "Bad request", body = ErrorOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
(status = 409, description = "Branch already exists", body = ErrorOutput),
|
||
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Create a new branch.
|
||
///
|
||
/// Forks `name` off of `from` (defaults to `main`). The new branch shares
|
||
/// table data with its parent until it is mutated. Returns 409 if `name`
|
||
/// already exists.
|
||
async fn server_branch_create(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Json(request): Json<BranchCreateRequest>,
|
||
) -> std::result::Result<Json<BranchCreateOutput>, ApiError> {
|
||
let from = request.from.unwrap_or_else(|| "main".to_string());
|
||
let actor_arc = actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| Arc::clone(&actor.0))
|
||
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::BranchCreate,
|
||
branch: Some(from.clone()),
|
||
target_branch: Some(request.name.clone()),
|
||
},
|
||
)?;
|
||
// Branch metadata only — small constant bytes estimate. The Lance
|
||
// shallow-clone work is bounded by the parent's manifest size, not
|
||
// the request body.
|
||
let _admission = state
|
||
.workload
|
||
.try_admit(&actor_arc, 256)
|
||
.map_err(ApiError::from_workload_reject)?;
|
||
{
|
||
let db = &state.engine;
|
||
db.branch_create_from_as(
|
||
ReadTarget::branch(&from),
|
||
&request.name,
|
||
actor.as_ref().map(|Extension(a)| a.as_str()),
|
||
)
|
||
.await
|
||
.map_err(ApiError::from_omni)?;
|
||
}
|
||
Ok(Json(BranchCreateOutput {
|
||
uri: state.uri().to_string(),
|
||
from,
|
||
name: request.name,
|
||
actor_id: actor.map(|Extension(actor)| actor.as_str().to_string()),
|
||
}))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
delete,
|
||
path = "/branches/{branch}",
|
||
tag = "branches",
|
||
operation_id = "deleteBranch",
|
||
params(
|
||
("branch" = String, Path, description = "Branch name to delete"),
|
||
),
|
||
responses(
|
||
(status = 200, description = "Branch deleted", body = BranchDeleteOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
(status = 404, description = "Branch not found", body = ErrorOutput),
|
||
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Delete a branch.
|
||
///
|
||
/// **Irreversible.** Removes the branch pointer; commits remain reachable
|
||
/// only if referenced by another branch. Returns 404 if the branch does not
|
||
/// exist.
|
||
async fn server_branch_delete(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Path(branch): Path<String>,
|
||
) -> std::result::Result<Json<BranchDeleteOutput>, ApiError> {
|
||
let actor_arc = actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| Arc::clone(&actor.0))
|
||
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
|
||
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
|
||
action: PolicyAction::BranchDelete,
|
||
branch: None,
|
||
target_branch: Some(branch.clone()),
|
||
},
|
||
)?;
|
||
// Metadata-only manifest tombstone — small constant estimate.
|
||
let _admission = state
|
||
.workload
|
||
.try_admit(&actor_arc, 256)
|
||
.map_err(ApiError::from_workload_reject)?;
|
||
{
|
||
let db = &state.engine;
|
||
db.branch_delete_as(&branch, actor_id)
|
||
.await
|
||
.map_err(ApiError::from_omni)?;
|
||
}
|
||
Ok(Json(BranchDeleteOutput {
|
||
uri: state.uri().to_string(),
|
||
name: branch,
|
||
actor_id: actor_id.map(str::to_string),
|
||
}))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
post,
|
||
path = "/branches/merge",
|
||
tag = "branches",
|
||
operation_id = "mergeBranches",
|
||
request_body = BranchMergeRequest,
|
||
responses(
|
||
(status = 200, description = "Branches merged", body = BranchMergeOutput),
|
||
(status = 400, description = "Bad request", body = ErrorOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
(status = 409, description = "Merge conflict", body = ErrorOutput),
|
||
(status = 429, description = "Per-actor admission cap exceeded; honor `Retry-After` header", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Merge one branch into another.
|
||
///
|
||
/// Merges `source` into `target` (defaults to `main`). Outcome is one of
|
||
/// `already_up_to_date`, `fast_forward`, or `merged`. Returns 409 with the
|
||
/// list of conflicts if the merge cannot be completed; the target is left
|
||
/// unchanged in that case. **Destructive** to `target` on success.
|
||
async fn server_branch_merge(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Json(request): Json<BranchMergeRequest>,
|
||
) -> std::result::Result<Json<BranchMergeOutput>, ApiError> {
|
||
let target = request.target.unwrap_or_else(|| "main".to_string());
|
||
let actor_arc = actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| Arc::clone(&actor.0))
|
||
.unwrap_or_else(|| Arc::<str>::from("anonymous"));
|
||
let actor_id = actor.as_ref().map(|Extension(actor)| actor.as_str());
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor_id.map(str::to_string).unwrap_or_default(),
|
||
action: PolicyAction::BranchMerge,
|
||
branch: Some(request.source.clone()),
|
||
target_branch: Some(target.clone()),
|
||
},
|
||
)?;
|
||
// Merge body is small JSON; the heavy work is in the engine but is
|
||
// bounded per-(table, branch) by the writer queue. Small constant
|
||
// estimate suffices for the actor in-flight count.
|
||
let _admission = state
|
||
.workload
|
||
.try_admit(&actor_arc, 256)
|
||
.map_err(ApiError::from_workload_reject)?;
|
||
let outcome = {
|
||
let db = &state.engine;
|
||
db.branch_merge_as(&request.source, &target, actor_id)
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
Ok(Json(BranchMergeOutput {
|
||
source: request.source,
|
||
target,
|
||
outcome: outcome.into(),
|
||
actor_id: actor_id.map(str::to_string),
|
||
}))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
get,
|
||
path = "/commits",
|
||
tag = "commits",
|
||
operation_id = "listCommits",
|
||
params(CommitListQuery),
|
||
responses(
|
||
(status = 200, description = "List of commits", body = CommitListOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// List commits.
|
||
///
|
||
/// Filter by `branch` to get the commits on a single branch (most recent
|
||
/// first); omit to list across all branches. Read-only.
|
||
async fn server_commit_list(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Query(query): Query<CommitListQuery>,
|
||
) -> std::result::Result<Json<CommitListOutput>, ApiError> {
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::Read,
|
||
branch: query.branch.clone(),
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let commits = {
|
||
let db = &state.engine;
|
||
db.list_commits(query.branch.as_deref())
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
Ok(Json(CommitListOutput {
|
||
commits: commits.iter().map(api::commit_output).collect(),
|
||
}))
|
||
}
|
||
|
||
#[utoipa::path(
|
||
get,
|
||
path = "/commits/{commit_id}",
|
||
tag = "commits",
|
||
operation_id = "getCommit",
|
||
params(
|
||
("commit_id" = String, Path, description = "Commit identifier"),
|
||
),
|
||
responses(
|
||
(status = 200, description = "Commit details", body = api::CommitOutput),
|
||
(status = 401, description = "Unauthorized", body = ErrorOutput),
|
||
(status = 403, description = "Forbidden", body = ErrorOutput),
|
||
(status = 404, description = "Commit not found", body = ErrorOutput),
|
||
),
|
||
security(("bearer_token" = [])),
|
||
)]
|
||
/// Get a single commit.
|
||
///
|
||
/// Returns the commit's manifest version, parent commit(s), and creation
|
||
/// metadata. Read-only.
|
||
async fn server_commit_show(
|
||
State(state): State<AppState>,
|
||
actor: Option<Extension<AuthenticatedActor>>,
|
||
Path(commit_id): Path<String>,
|
||
) -> std::result::Result<Json<api::CommitOutput>, ApiError> {
|
||
authorize_request(
|
||
&state,
|
||
actor.as_ref().map(|Extension(actor)| actor),
|
||
PolicyRequest {
|
||
actor_id: actor
|
||
.as_ref()
|
||
.map(|Extension(actor)| actor.as_str().to_string())
|
||
.unwrap_or_default(),
|
||
action: PolicyAction::Read,
|
||
branch: None,
|
||
target_branch: None,
|
||
},
|
||
)?;
|
||
let commit = {
|
||
let db = &state.engine;
|
||
db.get_commit(&commit_id)
|
||
.await
|
||
.map_err(ApiError::from_omni)?
|
||
};
|
||
Ok(Json(api::commit_output(&commit)))
|
||
}
|
||
|
||
fn read_target_from_request(branch: Option<String>, snapshot: Option<String>) -> ReadTarget {
|
||
if let Some(snapshot) = snapshot {
|
||
ReadTarget::snapshot(omnigraph::db::SnapshotId::new(snapshot))
|
||
} else {
|
||
ReadTarget::branch(branch.unwrap_or_else(|| "main".to_string()))
|
||
}
|
||
}
|
||
|
||
fn select_named_query(
|
||
query_source: &str,
|
||
requested_name: Option<&str>,
|
||
) -> Result<(String, Vec<omnigraph_compiler::query::ast::Param>)> {
|
||
let parsed = parse_query(query_source)?;
|
||
let query = if let Some(name) = requested_name {
|
||
parsed
|
||
.queries
|
||
.into_iter()
|
||
.find(|query| query.name == name)
|
||
.ok_or_else(|| color_eyre::eyre::eyre!("query '{}' not found", name))?
|
||
} else if parsed.queries.len() == 1 {
|
||
parsed.queries.into_iter().next().unwrap()
|
||
} else {
|
||
bail!("query file contains multiple queries; pass --name");
|
||
};
|
||
|
||
Ok((query.name, query.params))
|
||
}
|
||
|
||
fn query_params_from_json(
|
||
query_params: &[omnigraph_compiler::query::ast::Param],
|
||
params_json: Option<&Value>,
|
||
) -> Result<ParamMap> {
|
||
json_params_to_param_map(params_json, query_params, JsonParamMode::Standard)
|
||
.map_err(|err| color_eyre::eyre::eyre!(err.to_string()))
|
||
}
|
||
|
||
fn normalize_bearer_token(value: Option<String>) -> Option<String> {
|
||
value
|
||
.map(|value| value.trim().to_string())
|
||
.filter(|value| !value.is_empty())
|
||
}
|
||
|
||
fn normalize_bearer_actor(value: String) -> Result<String> {
|
||
let value = value.trim().to_string();
|
||
if value.is_empty() {
|
||
bail!("bearer token actor names must not be blank");
|
||
}
|
||
Ok(value)
|
||
}
|
||
|
||
fn parse_bearer_tokens_json(value: &str) -> Result<Vec<(String, String)>> {
|
||
let entries: HashMap<String, String> = serde_json::from_str(value)
|
||
.wrap_err("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON must be a JSON object of actor->token")?;
|
||
Ok(entries.into_iter().collect())
|
||
}
|
||
|
||
fn read_bearer_tokens_file(path: &str) -> Result<Vec<(String, String)>> {
|
||
let contents = fs::read_to_string(path)
|
||
.wrap_err_with(|| format!("failed to read bearer tokens file at {path}"))?;
|
||
parse_bearer_tokens_json(&contents)
|
||
.wrap_err_with(|| format!("failed to parse bearer tokens file at {path}"))
|
||
}
|
||
|
||
fn validate_bearer_tokens(entries: Vec<(String, String)>) -> Result<Vec<(String, String)>> {
|
||
let mut seen_actors = HashSet::new();
|
||
let mut seen_tokens = HashSet::new();
|
||
let mut normalized = Vec::with_capacity(entries.len());
|
||
|
||
for (actor, token) in entries {
|
||
let actor = normalize_bearer_actor(actor)?;
|
||
let Some(token) = normalize_bearer_token(Some(token)) else {
|
||
bail!("bearer token for actor '{actor}' must not be blank");
|
||
};
|
||
if !seen_actors.insert(actor.clone()) {
|
||
bail!("duplicate bearer token actor '{actor}'");
|
||
}
|
||
if !seen_tokens.insert(token.clone()) {
|
||
bail!("duplicate bearer token value configured");
|
||
}
|
||
normalized.push((actor, token));
|
||
}
|
||
|
||
normalized.sort_by(|(left, _), (right, _)| left.cmp(right));
|
||
Ok(normalized)
|
||
}
|
||
|
||
fn server_bearer_tokens_from_env() -> Result<Vec<(String, String)>> {
|
||
let mut entries = Vec::new();
|
||
|
||
if let Some(token) = normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKEN").ok())
|
||
{
|
||
entries.push(("default".to_string(), token));
|
||
}
|
||
|
||
if let Some(path) =
|
||
normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKENS_FILE").ok())
|
||
{
|
||
entries.extend(read_bearer_tokens_file(&path)?);
|
||
} else if let Some(json) =
|
||
normalize_bearer_token(std::env::var("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON").ok())
|
||
{
|
||
entries.extend(parse_bearer_tokens_json(&json)?);
|
||
}
|
||
|
||
validate_bearer_tokens(entries)
|
||
}
|
||
|
||
#[cfg(test)]
|
||
mod tests {
|
||
use super::{
|
||
ServerConfig, ServerRuntimeState, classify_server_runtime_state, hash_bearer_token,
|
||
load_server_settings, normalize_bearer_token, parse_bearer_tokens_json, serve,
|
||
server_bearer_tokens_from_env,
|
||
};
|
||
use serial_test::serial;
|
||
use std::env;
|
||
use std::fs;
|
||
use tempfile::tempdir;
|
||
|
||
#[test]
|
||
fn hash_bearer_token_produces_32_byte_output() {
|
||
let hash = hash_bearer_token("any-token");
|
||
assert_eq!(hash.len(), 32);
|
||
}
|
||
|
||
#[test]
|
||
fn hash_bearer_token_is_deterministic() {
|
||
assert_eq!(
|
||
hash_bearer_token("stable-input"),
|
||
hash_bearer_token("stable-input"),
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn hash_bearer_token_differs_for_different_inputs() {
|
||
assert_ne!(hash_bearer_token("token-a"), hash_bearer_token("token-b"));
|
||
}
|
||
|
||
#[test]
|
||
fn hash_bearer_token_matches_known_sha256_vector() {
|
||
// SHA-256("abc"). If this ever fails, the hash function was swapped.
|
||
let hash = hash_bearer_token("abc");
|
||
let hex: String = hash.iter().map(|b| format!("{:02x}", b)).collect();
|
||
assert_eq!(
|
||
hex,
|
||
"ba7816bf8f01cfea414140de5dae2223b00361a396177a9cb410ff61f20015ad"
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn server_settings_load_from_yaml_config() {
|
||
let temp = tempdir().unwrap();
|
||
let config = temp.path().join("omnigraph.yaml");
|
||
fs::write(
|
||
&config,
|
||
r#"
|
||
graphs:
|
||
local:
|
||
uri: /tmp/demo.omni
|
||
server:
|
||
graph: local
|
||
bind: 0.0.0.0:9090
|
||
"#,
|
||
)
|
||
.unwrap();
|
||
|
||
let settings = load_server_settings(Some(&config), None, None, None, false).unwrap();
|
||
assert_eq!(settings.uri, "/tmp/demo.omni");
|
||
assert_eq!(settings.bind, "0.0.0.0:9090");
|
||
}
|
||
|
||
#[test]
|
||
fn server_settings_cli_flags_override_yaml_config() {
|
||
let temp = tempdir().unwrap();
|
||
let config = temp.path().join("omnigraph.yaml");
|
||
fs::write(
|
||
&config,
|
||
r#"
|
||
graphs:
|
||
local:
|
||
uri: /tmp/demo.omni
|
||
server:
|
||
graph: local
|
||
bind: 127.0.0.1:8080
|
||
"#,
|
||
)
|
||
.unwrap();
|
||
|
||
let settings = load_server_settings(
|
||
Some(&config),
|
||
Some("/tmp/override.omni".to_string()),
|
||
None,
|
||
Some("0.0.0.0:9999".to_string()),
|
||
false,
|
||
)
|
||
.unwrap();
|
||
assert_eq!(settings.uri, "/tmp/override.omni");
|
||
assert_eq!(settings.bind, "0.0.0.0:9999");
|
||
}
|
||
|
||
#[test]
|
||
fn server_settings_can_resolve_named_target() {
|
||
let temp = tempdir().unwrap();
|
||
let config = temp.path().join("omnigraph.yaml");
|
||
fs::write(
|
||
&config,
|
||
r#"
|
||
graphs:
|
||
local:
|
||
uri: ./demo.omni
|
||
dev:
|
||
uri: http://127.0.0.1:8080
|
||
server:
|
||
graph: local
|
||
bind: 127.0.0.1:8080
|
||
"#,
|
||
)
|
||
.unwrap();
|
||
|
||
let settings =
|
||
load_server_settings(Some(&config), None, Some("dev".to_string()), None, false)
|
||
.unwrap();
|
||
assert_eq!(settings.uri, "http://127.0.0.1:8080");
|
||
}
|
||
|
||
#[test]
|
||
fn server_settings_require_uri_from_cli_or_config() {
|
||
let error = load_server_settings(None, None, None, None, false).unwrap_err();
|
||
assert!(error.to_string().contains("URI must be provided"));
|
||
}
|
||
|
||
#[test]
|
||
fn classify_open_requires_explicit_unauthenticated_flag() {
|
||
// State 1: no tokens, no policy, no flag → refuse to start.
|
||
let error = classify_server_runtime_state(false, false, false).unwrap_err();
|
||
let msg = error.to_string();
|
||
assert!(
|
||
msg.contains("--unauthenticated"),
|
||
"expected refusal message mentioning --unauthenticated, got: {msg}"
|
||
);
|
||
|
||
// Same matrix cell but with the flag set → Open mode permitted.
|
||
assert_eq!(
|
||
classify_server_runtime_state(false, false, true).unwrap(),
|
||
ServerRuntimeState::Open
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn classify_tokens_without_policy_is_default_deny() {
|
||
// State 2: tokens configured, no policy → DefaultDeny regardless
|
||
// of the flag (the flag opts into the fully-open dev mode; it
|
||
// doesn't downgrade default-deny back to open).
|
||
assert_eq!(
|
||
classify_server_runtime_state(true, false, false).unwrap(),
|
||
ServerRuntimeState::DefaultDeny
|
||
);
|
||
assert_eq!(
|
||
classify_server_runtime_state(true, false, true).unwrap(),
|
||
ServerRuntimeState::DefaultDeny
|
||
);
|
||
}
|
||
|
||
#[tokio::test]
|
||
#[serial]
|
||
async fn serve_refuses_to_start_in_state_1_without_unauthenticated() {
|
||
// MR-723 PR A: pin the integration boundary that the classifier
|
||
// is actually called by `serve()` before any side-effecting
|
||
// work (Lance dataset open, TcpListener::bind). The classifier
|
||
// itself is unit-tested above; this test guards the propagation
|
||
// path from `classify_server_runtime_state` through serve's
|
||
// `?` so a future refactor that drops the call returns red.
|
||
//
|
||
// Marked `#[serial]` because we have to clear all bearer-token
|
||
// env vars, and another test in this module setting any of them
|
||
// concurrently would corrupt the read inside `resolve_token_source`.
|
||
let _guard = EnvGuard::set(&[
|
||
("OMNIGRAPH_SERVER_BEARER_TOKEN", None),
|
||
("OMNIGRAPH_SERVER_BEARER_TOKENS_FILE", None),
|
||
("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON", None),
|
||
("OMNIGRAPH_SERVER_BEARER_TOKENS_AWS_SECRET", None),
|
||
("OMNIGRAPH_UNAUTHENTICATED", None),
|
||
]);
|
||
let temp = tempdir().unwrap();
|
||
// Graph path doesn't need to exist — classifier fires before
|
||
// `AppState::open_with_bearer_tokens_and_policy`.
|
||
let config = ServerConfig {
|
||
uri: temp
|
||
.path()
|
||
.join("graph.omni")
|
||
.to_string_lossy()
|
||
.into_owned(),
|
||
bind: "127.0.0.1:0".to_string(),
|
||
policy_file: None,
|
||
allow_unauthenticated: false,
|
||
};
|
||
let result = serve(config).await;
|
||
let err =
|
||
result.expect_err("serve should refuse to start in State 1 without --unauthenticated");
|
||
let msg = format!("{:?}", err);
|
||
assert!(
|
||
msg.contains("no bearer tokens") || msg.contains("policy file"),
|
||
"expected refusal message naming the misconfiguration, got: {msg}",
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
#[serial]
|
||
fn unauthenticated_env_var_classification() {
|
||
// MR-723 PR A: closes the gap where the env-var read path inside
|
||
// `load_server_settings` was structurally implemented but not
|
||
// exercised by any test. Three properties to pin, all in one
|
||
// sequential test because `cargo test` runs the mod test suite
|
||
// in parallel and `OMNIGRAPH_UNAUTHENTICATED` is process-global
|
||
// — interleaving with another test that sets the same env var
|
||
// (concurrent classifier tests, even the bearer-token suite
|
||
// sharing `EnvGuard`) corrupts the read. Sequential within one
|
||
// test fn is the simplest race-free shape.
|
||
let temp = tempdir().unwrap();
|
||
let config_path = temp.path().join("omnigraph.yaml");
|
||
fs::write(
|
||
&config_path,
|
||
r#"
|
||
graphs:
|
||
local:
|
||
uri: /tmp/demo-unauth.omni
|
||
server:
|
||
graph: local
|
||
"#,
|
||
)
|
||
.unwrap();
|
||
|
||
// Truthy values flip Open mode on, even with CLI flag off.
|
||
for value in ["1", "true", "yes", "TRUE", "anything"] {
|
||
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some(value))]);
|
||
let settings = load_server_settings(Some(&config_path), None, None, None, false)
|
||
.expect("settings load should succeed");
|
||
assert!(
|
||
settings.allow_unauthenticated,
|
||
"OMNIGRAPH_UNAUTHENTICATED={value:?} should enable Open mode",
|
||
);
|
||
}
|
||
|
||
// Falsy values keep refusal behavior, even with CLI flag off.
|
||
for value in ["0", "false", "FALSE", ""] {
|
||
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some(value))]);
|
||
let settings = load_server_settings(Some(&config_path), None, None, None, false)
|
||
.expect("settings load should succeed");
|
||
assert!(
|
||
!settings.allow_unauthenticated,
|
||
"OMNIGRAPH_UNAUTHENTICATED={value:?} should NOT enable Open mode",
|
||
);
|
||
}
|
||
|
||
// Unset env var: also false.
|
||
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", None)]);
|
||
let settings = load_server_settings(Some(&config_path), None, None, None, false)
|
||
.expect("settings load should succeed");
|
||
assert!(
|
||
!settings.allow_unauthenticated,
|
||
"OMNIGRAPH_UNAUTHENTICATED unset should NOT enable Open mode",
|
||
);
|
||
drop(_guard);
|
||
|
||
// CLI flag wins even when env is falsy — `serve()` honors the
|
||
// OR of both inputs.
|
||
let _guard = EnvGuard::set(&[("OMNIGRAPH_UNAUTHENTICATED", Some("0"))]);
|
||
let settings = load_server_settings(Some(&config_path), None, None, None, true)
|
||
.expect("settings load should succeed");
|
||
assert!(
|
||
settings.allow_unauthenticated,
|
||
"--unauthenticated CLI flag should win even when env is falsy",
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn classify_policy_enabled_always_wins() {
|
||
// State 3: any setup with a policy file → PolicyEnabled. The
|
||
// flag doesn't matter and tokens-or-not doesn't matter (no
|
||
// tokens + policy is unusual but valid — every request fails
|
||
// 401 without a bearer, which is effectively "locked").
|
||
assert_eq!(
|
||
classify_server_runtime_state(true, true, false).unwrap(),
|
||
ServerRuntimeState::PolicyEnabled
|
||
);
|
||
assert_eq!(
|
||
classify_server_runtime_state(false, true, false).unwrap(),
|
||
ServerRuntimeState::PolicyEnabled
|
||
);
|
||
assert_eq!(
|
||
classify_server_runtime_state(true, true, true).unwrap(),
|
||
ServerRuntimeState::PolicyEnabled
|
||
);
|
||
}
|
||
|
||
#[test]
|
||
fn normalize_bearer_token_trims_and_filters_blank_values() {
|
||
assert_eq!(normalize_bearer_token(None), None);
|
||
assert_eq!(normalize_bearer_token(Some(" ".to_string())), None);
|
||
assert_eq!(
|
||
normalize_bearer_token(Some(" demo-token ".to_string())).as_deref(),
|
||
Some("demo-token")
|
||
);
|
||
}
|
||
|
||
struct EnvGuard {
|
||
saved: Vec<(&'static str, Option<String>)>,
|
||
}
|
||
|
||
impl EnvGuard {
|
||
fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
|
||
let saved = vars
|
||
.iter()
|
||
.map(|(name, _)| (*name, env::var(name).ok()))
|
||
.collect::<Vec<_>>();
|
||
for (name, value) in vars {
|
||
unsafe {
|
||
match value {
|
||
Some(value) => env::set_var(name, value),
|
||
None => env::remove_var(name),
|
||
}
|
||
}
|
||
}
|
||
Self { saved }
|
||
}
|
||
}
|
||
|
||
impl Drop for EnvGuard {
|
||
fn drop(&mut self) {
|
||
for (name, value) in self.saved.drain(..) {
|
||
unsafe {
|
||
match value {
|
||
Some(value) => env::set_var(name, value),
|
||
None => env::remove_var(name),
|
||
}
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
#[test]
|
||
fn parse_bearer_tokens_json_reads_actor_token_map() {
|
||
let tokens = parse_bearer_tokens_json(r#"{"alice":" token-a ","bob":"token-b"}"#).unwrap();
|
||
assert_eq!(tokens.len(), 2);
|
||
assert!(tokens.contains(&("alice".to_string(), " token-a ".to_string())));
|
||
assert!(tokens.contains(&("bob".to_string(), "token-b".to_string())));
|
||
}
|
||
|
||
#[test]
|
||
#[serial]
|
||
fn server_bearer_tokens_from_env_reads_legacy_token_and_token_file() {
|
||
let temp = tempdir().unwrap();
|
||
let tokens_path = temp.path().join("tokens.json");
|
||
fs::write(
|
||
&tokens_path,
|
||
r#"{"team-01":"token-one","team-02":"token-two"}"#,
|
||
)
|
||
.unwrap();
|
||
|
||
let _guard = EnvGuard::set(&[
|
||
("OMNIGRAPH_SERVER_BEARER_TOKEN", Some(" legacy-token ")),
|
||
(
|
||
"OMNIGRAPH_SERVER_BEARER_TOKENS_FILE",
|
||
Some(tokens_path.to_str().unwrap()),
|
||
),
|
||
("OMNIGRAPH_SERVER_BEARER_TOKENS_JSON", None),
|
||
]);
|
||
|
||
let tokens = server_bearer_tokens_from_env().unwrap();
|
||
assert_eq!(
|
||
tokens,
|
||
vec![
|
||
("default".to_string(), "legacy-token".to_string()),
|
||
("team-01".to_string(), "token-one".to_string()),
|
||
("team-02".to_string(), "token-two".to_string()),
|
||
]
|
||
);
|
||
}
|
||
}
|