diff --git a/Cargo.lock b/Cargo.lock index 2419e9f..92279d2 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -4855,6 +4855,7 @@ version = "0.7.0" dependencies = [ "omnigraph-compiler", "omnigraph-engine", + "regex", "serde", "serde_json", "utoipa", @@ -4964,6 +4965,20 @@ dependencies = [ "url", ] +[[package]] +name = "omnigraph-mcp" +version = "0.7.0" +dependencies = [ + "async-trait", + "axum", + "http 1.4.0", + "rmcp", + "serde_json", + "tokio", + "tower", + "tower-http", +] + [[package]] name = "omnigraph-policy" version = "0.7.0" @@ -4990,12 +5005,14 @@ dependencies = [ "color-eyre", "dashmap", "futures", + "http 1.4.0", "lance", "lance-index", "omnigraph-api-types", "omnigraph-cluster", "omnigraph-compiler", "omnigraph-engine", + "omnigraph-mcp", "omnigraph-policy", "regex", "serde", @@ -5336,6 +5353,12 @@ version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +[[package]] +name = "pastey" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2ee67f1008b1ba2321834326597b8e186293b049a023cdef258527550b9935b4" + [[package]] name = "path_abs" version = "0.5.1" @@ -6329,6 +6352,35 @@ dependencies = [ "windows-sys 0.52.0", ] +[[package]] +name = "rmcp" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0810a9f717d9828f475fe1f629f4c305c8464b7f496c3a854b58d29e65f4058e" +dependencies = [ + "async-trait", + "bytes", + "chrono", + "futures", + "http 1.4.0", + "http-body 1.0.1", + "http-body-util", + "pastey", + "pin-project-lite", + "rand 0.10.1", + "schemars 1.2.1", + "serde", + "serde_json", + "sse-stream", + "thiserror", + "tokio", + "tokio-stream", + "tokio-util", + "tower-service", + "tracing", + "uuid", +] + [[package]] name = "roaring" version = "0.11.4" @@ -6602,12 +6654,26 @@ version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2b42f36aa1cd011945615b92222f6bf73c599a102a300334cd7f8dbeec726cc" dependencies = [ + "chrono", "dyn-clone", "ref-cast", + "schemars_derive", "serde", "serde_json", ] +[[package]] +name = "schemars_derive" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d115b50f4aaeea07e79c1912f645c7513d81715d0420f8bc77a18c6260b307f" +dependencies = [ + "proc-macro2", + "quote", + "serde_derive_internals", + "syn 2.0.117", +] + [[package]] name = "scoped-tls" version = "1.0.1" @@ -6712,6 +6778,17 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "serde_derive_internals" +version = "0.29.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "18d26a20a969b9e3fdf2fc2d9f21eda6c40e2de84c9408bb5d3b05d499aae711" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.117", +] + [[package]] name = "serde_json" version = "1.0.149" @@ -7057,6 +7134,19 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "sse-stream" +version = "0.2.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3962b63f038885f15bce2c6e02c0e7925c072f1ac86bb60fd44c5c6b762fb72" +dependencies = [ + "bytes", + "futures-util", + "http-body 1.0.1", + "http-body-util", + "pin-project-lite", +] + [[package]] name = "stable_deref_trait" version = "1.2.1" diff --git a/Cargo.toml b/Cargo.toml index c442242..fed8e49 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -8,6 +8,7 @@ members = [ "crates/omnigraph-cluster", "crates/omnigraph-policy", "crates/omnigraph-server", + "crates/omnigraph-mcp", ] default-members = [ "crates/omnigraph", diff --git a/crates/omnigraph-api-types/Cargo.toml b/crates/omnigraph-api-types/Cargo.toml index d69d4fe..840aaa3 100644 --- a/crates/omnigraph-api-types/Cargo.toml +++ b/crates/omnigraph-api-types/Cargo.toml @@ -14,3 +14,7 @@ omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.7.0" } serde = { workspace = true } serde_json = { workspace = true } utoipa = { workspace = true } + +[dev-dependencies] +# Faithful `pattern` enforcement in the schema/coercer equivalence test. +regex = { workspace = true } diff --git a/crates/omnigraph-api-types/src/lib.rs b/crates/omnigraph-api-types/src/lib.rs index 2814602..b26e46c 100644 --- a/crates/omnigraph-api-types/src/lib.rs +++ b/crates/omnigraph-api-types/src/lib.rs @@ -11,7 +11,7 @@ use omnigraph_compiler::query::ast::Param; use omnigraph_compiler::result::QueryResult; use omnigraph_compiler::types::{PropType, ScalarType}; use serde::{Deserialize, Serialize}; -use serde_json::Value; +use serde_json::{Value, json}; use utoipa::{IntoParams, ToSchema}; /// Shadow enum for documenting [`LoadMode`] in the OpenAPI schema. @@ -459,6 +459,68 @@ pub fn param_descriptor(param: &Param) -> ParamDescriptor { } } +/// JSON Schema (2020-12) for a scalar param kind. **Superset of the engine +/// coercer** (`omnigraph_compiler::coerce_param_typed`, Standard mode): a +/// too-narrow schema would make a strict client reject inputs the engine +/// accepts; a too-wide one reaches the coercer and surfaces as an `isError` +/// tool result for model self-correction (SEP-1303). Locked to the coercer by +/// `tests/schema_equivalence.rs`. Exhaustive + wildcard-free: adding a +/// `ParamKind` is a compile error until its arm (and corpus row) exist. +fn scalar_schema(kind: ParamKind) -> Value { + match kind { + ParamKind::String => json!({ "type": "string" }), + ParamKind::Bool => json!({ "type": "boolean" }), + // Standard-mode integer coercion accepts a JSON number OR a numeric + // string (i64/u64 lose precision past 2^53 as a JSON number), so the + // schema accepts both; range/sign are the coercer's to enforce. + ParamKind::Int | ParamKind::BigInt => json!({ + "anyOf": [ { "type": "integer" }, { "type": "string", "pattern": r"^-?\d+$" } ] + }), + ParamKind::Float => json!({ "type": "number" }), + // Date/DateTime/Blob coerce from any string; `format` is an advisory + // annotation (non-asserting in 2020-12), so the schema accepts exactly + // what the coercer does while still hinting the shape to clients. + ParamKind::Date => json!({ "type": "string", "format": "date" }), + ParamKind::DateTime => json!({ "type": "string", "format": "date-time" }), + ParamKind::Blob => json!({ "type": "string", "format": "uri" }), + ParamKind::Vector | ParamKind::List => { + unreachable!("composite kinds are handled in param_json_schema") + } + } +} + +/// The JSON Schema (2020-12) for a stored-query parameter — the single mapping +/// both the OpenAPI catalog and the MCP tool projection consume, applying the +/// nullable rule uniformly. See [`scalar_schema`] for the superset contract. +pub fn param_json_schema(p: &ParamDescriptor) -> Value { + let base = match p.kind { + ParamKind::Vector => { + let mut schema = json!({ "type": "array", "items": { "type": "number" } }); + if let Some(dim) = p.vector_dim { + schema["minItems"] = json!(dim); + schema["maxItems"] = json!(dim); + } + schema + } + ParamKind::List => { + let item = p + .item_kind + .map(scalar_schema) + .unwrap_or_else(|| json!({ "type": "string" })); + json!({ "type": "array", "items": item }) + } + scalar => scalar_schema(scalar), + }; + // The coercer accepts explicit `null` for a nullable param (and its + // omission); a strict client would reject `null` against the bare scalar. + // Allow null at the schema level for nullable params. + if p.nullable { + json!({ "anyOf": [ base, { "type": "null" } ] }) + } else { + base + } +} + #[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)] pub struct SchemaApplyRequest { diff --git a/crates/omnigraph-api-types/tests/schema_equivalence.rs b/crates/omnigraph-api-types/tests/schema_equivalence.rs new file mode 100644 index 0000000..b65a013 --- /dev/null +++ b/crates/omnigraph-api-types/tests/schema_equivalence.rs @@ -0,0 +1,149 @@ +//! The lock that makes `param_json_schema` correct *by construction*: for a +//! fixed corpus, the generated JSON Schema must accept **at least** every value +//! the engine coercer (`coerce_param_typed`, Standard mode — the mode the +//! stored-query invoke path uses) accepts. A schema narrower than the coercer +//! would make a strict client reject inputs the engine would have taken; a +//! wider one is fine (it reaches the coercer and surfaces as `isError`, +//! SEP-1303). Drift in either direction-of-acceptance turns this red. +//! +//! Keyed by **engine type-name string** (`"I32"`, `"U64"`, `"Vector(3)"`, +//! `"[I32]"`), not `ParamKind`, because the coercer keys on the type-name and +//! `ParamKind` is a lossy bucket (`I32`/`U32` → `Int`). The descriptor is built +//! through the real projection (`param_descriptor`). + +use omnigraph_api_types::{param_descriptor, param_json_schema}; +use omnigraph_compiler::query::ast::Param; +use omnigraph_compiler::{JsonParamMode, coerce_param_typed}; +use serde_json::{Value, json}; + +/// A faithful validator for the closed schema vocabulary `param_json_schema` +/// emits: `type` (string/integer/number/boolean/array/null), `anyOf`, `items`, +/// `minItems`/`maxItems`, and `pattern` (enforced via `regex`). It interprets +/// the *emitted* schema JSON — not a hardcoded copy — so it tracks generator +/// changes. A new construct the generator emits would need a new arm here. +fn schema_accepts(schema: &Value, value: &Value) -> bool { + if let Some(any_of) = schema.get("anyOf").and_then(Value::as_array) { + return any_of.iter().any(|s| schema_accepts(s, value)); + } + match schema.get("type").and_then(Value::as_str) { + Some("string") => { + let Some(s) = value.as_str() else { return false }; + match schema.get("pattern").and_then(Value::as_str) { + Some(pat) => regex::Regex::new(pat).unwrap().is_match(s), + None => true, + } + } + Some("integer") => value.as_i64().is_some() || value.as_u64().is_some(), + Some("number") => value.is_number(), + Some("boolean") => value.is_boolean(), + Some("null") => value.is_null(), + Some("array") => { + let Some(arr) = value.as_array() else { return false }; + if let Some(min) = schema.get("minItems").and_then(Value::as_u64) { + if (arr.len() as u64) < min { + return false; + } + } + if let Some(max) = schema.get("maxItems").and_then(Value::as_u64) { + if (arr.len() as u64) > max { + return false; + } + } + match schema.get("items") { + Some(items) => arr.iter().all(|el| schema_accepts(items, el)), + None => true, + } + } + other => panic!("schema_accepts: unhandled schema {other:?} in {schema}"), + } +} + +fn descriptor(type_name: &str, nullable: bool) -> omnigraph_api_types::ParamDescriptor { + param_descriptor(&Param { + name: "p".to_string(), + type_name: type_name.to_string(), + nullable, + }) +} + +/// Every engine scalar/composite type-name a stored-query param can declare. +const TYPE_NAMES: &[&str] = &[ + "String", "Bool", "I32", "I64", "U32", "U64", "F32", "F64", "Date", "DateTime", "Blob", + "Vector(3)", "[I32]", "[String]", +]; + +/// A broad value bag thrown at every type-name; the coercer decides which it +/// accepts, and the schema must accept at least those. +fn corpus() -> Vec { + vec![ + json!("hello"), + json!("AAEC"), // base64-looking → still just a string (Blob) + json!("2024-01-02"), // date string + json!("2024-01-02T03:04:05Z"), // datetime string + json!("og://blob/abc"), // blob URI + json!(5), + json!(-5), + json!(0), + json!(5_000_000_000i64), // fits i64/u64, exceeds i32 + json!(99_999_999_999u64), + json!("5"), + json!("-5"), + json!("9999999999999999999999"), // exceeds i64 even as string + json!(1.5), + json!(true), + json!([1.0, 2.0, 3.0]), + json!([1, 2, 3]), + json!([1, 2]), + json!(["a", "b"]), + json!({ "k": 1 }), + ] +} + +#[test] +fn schema_is_a_superset_of_the_coercer() { + for &type_name in TYPE_NAMES { + let schema = param_json_schema(&descriptor(type_name, false)); + for value in corpus() { + // The coercer is the authority; null is handled by the parent + // (`json_params_to_param_map`), not `coerce_param_typed`, so skip it + // here — the null rule is pinned separately below. + if value.is_null() { + continue; + } + if coerce_param_typed("p", &value, type_name, JsonParamMode::Standard).is_ok() { + assert!( + schema_accepts(&schema, &value), + "type {type_name}: coercer accepts {value} but schema {schema} rejects it" + ); + } + } + } +} + +#[test] +fn nullable_rule_matches_the_parent_coercer() { + // The parent coercer accepts explicit `null` iff the param is nullable. + // `param_json_schema` must mirror that at the schema level. + for &type_name in TYPE_NAMES { + let nullable = param_json_schema(&descriptor(type_name, true)); + let non_nullable = param_json_schema(&descriptor(type_name, false)); + assert!( + schema_accepts(&nullable, &Value::Null), + "type {type_name}: nullable schema must accept null ({nullable})" + ); + assert!( + !schema_accepts(&non_nullable, &Value::Null), + "type {type_name}: non-nullable schema must reject null ({non_nullable})" + ); + } +} + +#[test] +fn vector_dim_bounds_are_present_or_omitted() { + let with_dim = param_json_schema(&descriptor("Vector(4)", false)); + assert_eq!(with_dim["minItems"], json!(4)); + assert_eq!(with_dim["maxItems"], json!(4)); + // A four-element array validates; three or five do not. + assert!(schema_accepts(&with_dim, &json!([1.0, 2.0, 3.0, 4.0]))); + assert!(!schema_accepts(&with_dim, &json!([1.0, 2.0, 3.0]))); +} diff --git a/crates/omnigraph-compiler/src/lib.rs b/crates/omnigraph-compiler/src/lib.rs index 4f85c08..8b121e4 100644 --- a/crates/omnigraph-compiler/src/lib.rs +++ b/crates/omnigraph-compiler/src/lib.rs @@ -27,7 +27,7 @@ pub use query::lint::{ lint_query_file, }; pub use query_input::{ - JsonParamMode, RunInputError, RunInputResult, ToParam, find_named_query, + JsonParamMode, RunInputError, RunInputResult, ToParam, coerce_param_typed, find_named_query, json_params_to_param_map, }; pub use result::{MutationExecResult, MutationResult, QueryResult, RunResult}; diff --git a/crates/omnigraph-compiler/src/query_input.rs b/crates/omnigraph-compiler/src/query_input.rs index b85decf..5d48e53 100644 --- a/crates/omnigraph-compiler/src/query_input.rs +++ b/crates/omnigraph-compiler/src/query_input.rs @@ -322,6 +322,23 @@ pub fn json_params_to_param_map( Ok(map) } +/// Coerce one JSON value to a typed [`Literal`] by the engine's input +/// contract — the single authority for what a param accepts. Exposed so the +/// shared `param_json_schema` projection (in `omnigraph-api-types`) can be +/// locked to this coercer by an equivalence test: the JSON Schema a client +/// validates against must accept at least what this accepts, or a strict +/// client would reject inputs the engine would have taken. Does **not** apply +/// the nullable rule — explicit `null` is handled by [`json_params_to_param_map`], +/// not here. +pub fn coerce_param_typed( + key: &str, + value: &Value, + type_name: &str, + mode: JsonParamMode, +) -> RunInputResult { + json_value_to_literal_typed(key, value, type_name, mode) +} + fn json_value_to_literal_typed( key: &str, value: &Value, diff --git a/crates/omnigraph-mcp/Cargo.toml b/crates/omnigraph-mcp/Cargo.toml new file mode 100644 index 0000000..34bfb4b --- /dev/null +++ b/crates/omnigraph-mcp/Cargo.toml @@ -0,0 +1,28 @@ +[package] +name = "omnigraph-mcp" +version = "0.7.0" +edition = "2024" +description = "MCP (Model Context Protocol) Streamable-HTTP transport and backend seam for Omnigraph. Contains the rmcp dependency and defines the McpBackend trait the server implements; names no omnigraph engine/server type, so the dependency edge is server → mcp." +license = "MIT" +repository = "https://github.com/ModernRelay/omnigraph" +homepage = "https://github.com/ModernRelay/omnigraph" +documentation = "https://docs.rs/omnigraph-mcp" + +[dependencies] +# rmcp is contained to this crate. `server` + `transport-streamable-http-server` +# give the StreamableHttpService tower wiring. Do NOT enable rmcp's `local` +# feature — it cfg's the tower wiring out (transport mod is gated on +# `not(feature = "local")`). +rmcp = { version = "1.7", default-features = false, features = ["server", "transport-streamable-http-server"] } +axum = { workspace = true } +http = "1" +# `limit` adds RequestBodyLimitLayer; features are additive with the workspace's +# `trace`. rmcp reads the body directly (no axum extractor), so axum's +# DefaultBodyLimit does not bound /mcp — this layer is the real bound. +tower-http = { workspace = true, features = ["limit"] } +tokio = { workspace = true } +async-trait = { workspace = true } +serde_json = { workspace = true } + +[dev-dependencies] +tower = { workspace = true } diff --git a/crates/omnigraph-mcp/src/lib.rs b/crates/omnigraph-mcp/src/lib.rs new file mode 100644 index 0000000..0946a7c --- /dev/null +++ b/crates/omnigraph-mcp/src/lib.rs @@ -0,0 +1,73 @@ +//! MCP (Model Context Protocol) server surface for Omnigraph, served over +//! **stateless Streamable HTTP**. +//! +//! This crate owns the `rmcp` dependency and the transport wiring. It defines a +//! single seam — the [`McpBackend`] trait — that the server crate implements. +//! The crate **never names an omnigraph type**: the backend reads its own types +//! (resolved actor, graph handle, …) out of `parts.extensions`, so the +//! dependency edge is `server → mcp` (never the reverse — a `mcp → server` edge +//! would cycle the binary at `server-bin → omnigraph-mcp → server-lib`). +//! +//! The transport is **stateless JSON over a single `/mcp` POST**: no SSE stream, +//! no `Mcp-Session-Id`, every request independent. See [`transport`]. + +use async_trait::async_trait; + +mod service; +pub mod transport; + +pub use transport::{McpHostPolicy, OriginPolicy, mcp_router}; + +// rmcp model types re-exported so the server speaks rmcp via `omnigraph_mcp::…` +// and carries no direct rmcp dependency. +pub use rmcp::ErrorData as McpError; // JSON-RPC error: invalid_params=-32602, internal_error=-32603 +pub use rmcp::model::{ + CallToolResult, Content, Extensions, Implementation, RawResource, ReadResourceResult, Resource, + ResourceContents, ServerCapabilities, ServerInfo, Tool, ToolAnnotations, +}; + +/// A JSON object — the shape of tool arguments and JSON Schema documents. +/// Identical to `rmcp::model::JsonObject` (`serde_json::Map`). +pub type JsonObject = serde_json::Map; + +/// The seam the server fills. One implementor (`OmnigraphMcpBackend`); the boxed +/// future from `#[async_trait]` is negligible at MCP QPS. +/// +/// **The list seam is non-paginated by contract.** `list_tools`/`list_resources` +/// return the *full* set, so the service always emits `nextCursor: null`. The +/// catalog is bounded by construction (a fixed set of built-ins; large +/// stored-query catalogs collapse to a discovery + execute meta-tool pair rather +/// than leaning on `tools/list` paging). The `Vec` return type *is* that +/// contract; a future paging need is a signature change, not a doc promise. +/// +/// Each method receives the request's [`http::request::Parts`]; the backend reads +/// its own injected extensions (`parts.extensions.get::()`) — the decoupling +/// mechanism that keeps this crate free of omnigraph types and auth-method +/// agnostic. +#[async_trait] +pub trait McpBackend: Clone + Send + Sync + 'static { + /// Server identity + advertised capabilities (`initialize` response). + fn server_info(&self) -> ServerInfo; + + /// The full, Cedar-filtered tool set for this request's actor + graph. + async fn list_tools(&self, parts: &http::request::Parts) -> Result, McpError>; + + /// Dispatch a tool call. The authoritative authorization gate. + async fn call_tool( + &self, + parts: &http::request::Parts, + name: &str, + args: JsonObject, + ) -> Result; + + /// The full, Cedar-filtered resource set for this request's actor + graph. + async fn list_resources(&self, parts: &http::request::Parts) + -> Result, McpError>; + + /// Read one resource by URI. + async fn read_resource( + &self, + parts: &http::request::Parts, + uri: &str, + ) -> Result; +} diff --git a/crates/omnigraph-mcp/src/service.rs b/crates/omnigraph-mcp/src/service.rs new file mode 100644 index 0000000..26297f4 --- /dev/null +++ b/crates/omnigraph-mcp/src/service.rs @@ -0,0 +1,80 @@ +//! `McpService` — the rmcp `ServerHandler` adapter. Pulls the request's +//! `http::request::Parts` out of the context once and delegates each method to +//! the [`McpBackend`]. Maps the backend's non-paginated `Vec` returns to +//! rmcp's `List*Result` with `next_cursor: None`. + +use rmcp::ServerHandler; +use rmcp::ErrorData as McpError; +use rmcp::model::{ + CallToolRequestParams, CallToolResult, ListResourcesResult, ListToolsResult, + PaginatedRequestParams, ReadResourceRequestParams, ReadResourceResult, ServerInfo, +}; +use rmcp::service::{RequestContext, RoleServer}; + +use crate::McpBackend; + +#[derive(Clone)] +pub(crate) struct McpService { + backend: B, +} + +impl McpService { + pub(crate) fn new(backend: B) -> Self { + Self { backend } + } + + /// The HTTP `Parts` injected by `StreamableHttpService` into the request + /// context extensions (`tower.rs` does `request.into_parts()` then + /// `req.request.extensions_mut().insert(part)`). Absent only on an internal + /// wiring error, not a client-reachable path. + fn parts(ctx: &RequestContext) -> Result<&http::request::Parts, McpError> { + ctx.extensions + .get::() + .ok_or_else(|| McpError::internal_error("request parts missing from MCP context", None)) + } +} + +impl ServerHandler for McpService { + fn get_info(&self) -> ServerInfo { + self.backend.server_info() + } + + async fn list_tools( + &self, + _request: Option, + context: RequestContext, + ) -> Result { + let parts = Self::parts(&context)?; + let tools = self.backend.list_tools(parts).await?; + Ok(ListToolsResult::with_all_items(tools)) + } + + async fn call_tool( + &self, + request: CallToolRequestParams, + context: RequestContext, + ) -> Result { + let parts = Self::parts(&context)?; + let args = request.arguments.unwrap_or_default(); + self.backend.call_tool(parts, &request.name, args).await + } + + async fn list_resources( + &self, + _request: Option, + context: RequestContext, + ) -> Result { + let parts = Self::parts(&context)?; + let resources = self.backend.list_resources(parts).await?; + Ok(ListResourcesResult::with_all_items(resources)) + } + + async fn read_resource( + &self, + request: ReadResourceRequestParams, + context: RequestContext, + ) -> Result { + let parts = Self::parts(&context)?; + self.backend.read_resource(parts, &request.uri).await + } +} diff --git a/crates/omnigraph-mcp/src/transport.rs b/crates/omnigraph-mcp/src/transport.rs new file mode 100644 index 0000000..5d9a7a0 --- /dev/null +++ b/crates/omnigraph-mcp/src/transport.rs @@ -0,0 +1,158 @@ +//! Stateless Streamable-HTTP transport for the MCP surface. +//! +//! One endpoint: `POST /mcp` returns a single `application/json` object; no SSE, +//! no session id (`NeverSessionManager` + `stateful_mode = false`). rmcp gives, +//! for free in stateless mode: `GET`/`DELETE → 405 + Allow: POST`, a disallowed +//! `Host → 403`, and an unsupported `MCP-Protocol-Version → 400` on +//! **non-`initialize`** requests. `initialize` is exempt by design — it +//! negotiates the version in its JSON-RPC body (`protocolVersion`), not the HTTP +//! header, so a bogus header there is ignored (absent ⇒ rmcp's default version). +//! +//! The one thing rmcp does **not** give is fail-closed Origin: it validates +//! `Origin` only when `allowed_origins` is non-empty (an empty list is +//! *fail-open*). [`origin_guard`] closes that — a present, disallowed `Origin` +//! is `403` regardless — and [`McpHostPolicy`] has no "absent ⇒ skip" state, so a +//! remote deployment cannot accidentally run fail-open. + +use std::net::SocketAddr; +use std::sync::Arc; + +use axum::extract::{Request, State}; +use axum::middleware::{Next, from_fn_with_state}; +use axum::response::{IntoResponse, Response}; +use rmcp::transport::streamable_http_server::{ + StreamableHttpServerConfig, StreamableHttpService, session::never::NeverSessionManager, +}; + +use crate::McpBackend; +use crate::service::McpService; + +/// Browser-`Origin` posture as a **total** choice — there is no `None ⇒ skip` +/// state to leak into a fail-open default. Every deployment lands in exactly one +/// arm, chosen once by [`McpHostPolicy::from_bind`]. +#[derive(Debug, Clone)] +pub enum OriginPolicy { + /// Browser clients from these origins; any OTHER present `Origin` → `403`. + Allow(Vec), + /// No browser clients expected; ANY present `Origin` → `403`. Non-browser + /// MCP clients (the launch tier) send no `Origin` and pass. The remote + /// default. + DenyBrowsers, + /// Explicit opt-out (loopback dev / trusted network) — never the remote + /// default. + Unchecked, +} + +/// Host + Origin posture, derived together from the deployment. The struct has +/// no skip-by-absence state, and [`from_bind`](Self::from_bind) is the only +/// constructor, so a fail-open policy is unrepresentable. +#[derive(Debug, Clone)] +pub struct McpHostPolicy { + /// `None` ⇒ accept any `Host` (DNS-rebinding defense relaxed for a known + /// public bind; bearer is the real control there). + pub allowed_hosts: Option>, + /// Total — no `Option`. + pub origin: OriginPolicy, +} + +impl McpHostPolicy { + /// The only constructor. Host and Origin posture are derived together from + /// the bind + config, **fail-closed**: a remote bind with no configured + /// origins is `DenyBrowsers` (a present `Origin` is rejected), NOT "skip". + pub fn from_bind(bind: &SocketAddr, public_hosts: &[String], browser_origins: &[String]) -> Self { + let loopback = bind.ip().is_loopback(); + Self { + allowed_hosts: if loopback { + // A loopback bind accepts every loopback Host form, not just the + // stack it bound: the Host header is independent of the socket + // (in-process tests, reverse proxies, dual-stack `localhost` + // resolution), so a `127.0.0.1`-bound server must still accept a + // `[::1]` Host and vice-versa. This mirrors rmcp's own default + // loopback set; deriving the list from `bind.ip()` alone dropped + // the sibling-stack literal and 403'd legitimate loopback clients. + Some(vec!["127.0.0.1".into(), "::1".into(), "localhost".into()]) + } else if public_hosts.is_empty() { + None + } else { + Some(public_hosts.to_vec()) + }, + origin: if !browser_origins.is_empty() { + OriginPolicy::Allow(browser_origins.to_vec()) + } else if loopback { + OriginPolicy::Unchecked + } else { + OriginPolicy::DenyBrowsers + }, + } + } +} + +/// Fail-closed Origin enforcement, run BEFORE rmcp so it is independent of +/// rmcp's empty-`allowed_origins` fail-open semantics. A *present* `Origin` that +/// the policy disallows → `403`; an *absent* `Origin` always passes (non-browser +/// MCP clients send none); `Unchecked` is a no-op. +async fn origin_guard(State(origin): State, request: Request, next: Next) -> Response { + let header = request + .headers() + .get(http::header::ORIGIN) + .and_then(|v| v.to_str().ok()); + let allowed = match header { + None => true, + Some(o) => match &origin { + OriginPolicy::Unchecked => true, + OriginPolicy::Allow(list) => list.iter().any(|a| a == o), + OriginPolicy::DenyBrowsers => false, + }, + }; + if allowed { + next.run(request).await + } else { + (http::StatusCode::FORBIDDEN, "Forbidden: Origin not allowed").into_response() + } +} + +/// Build the `/mcp` router for a backend. The returned router carries its own +/// Origin guard and body-limit layer; merge (not `.route`) it into the +/// per-graph group so the body limit does not leak onto sibling routes. +/// +/// Generic over the router state `S`: the `/mcp` route is a `route_service` +/// with no state-bearing extractors, so it composes with any caller's state +/// type (e.g. the server merges it into a `Router` before +/// `.with_state`). A standalone caller pins `S = ()` via the return-type +/// annotation. +pub fn mcp_router(backend: B, body_limit: usize, hosts: McpHostPolicy) -> axum::Router +where + B: McpBackend, + S: Clone + Send + Sync + 'static, +{ + // `StreamableHttpServerConfig` is `#[non_exhaustive]`; its Default is + // stateful_mode=true, json_response=false, allowed_hosts=loopback. Build + // from Default and flip via the with_* setters for a remote stateless JSON + // server. + let mut config = StreamableHttpServerConfig::default() + .with_stateful_mode(false) + .with_json_response(true); + config = match &hosts.allowed_hosts { + Some(list) => config.with_allowed_hosts(list.clone()), + None => config.disable_allowed_hosts(), + }; + // `Allow` also configures rmcp as defense-in-depth; `DenyBrowsers` cannot be + // expressed to rmcp (empty list ⇒ rmcp skips), so `origin_guard` is the + // fail-closed authority. + if let OriginPolicy::Allow(origins) = &hosts.origin { + config = config.with_allowed_origins(origins.clone()); + } + + // service_factory returns Result; NeverSessionManager pairs + // with stateless mode (rejects every session op). + let svc = StreamableHttpService::new( + move || Ok(McpService::new(backend.clone())), + Arc::new(NeverSessionManager::default()), + config, + ); + + axum::Router::::new() + .route_service("/mcp", svc) + .layer(from_fn_with_state(hosts.origin, origin_guard)) + .layer(tower_http::limit::RequestBodyLimitLayer::new(body_limit)) +} diff --git a/crates/omnigraph-mcp/tests/standalone.rs b/crates/omnigraph-mcp/tests/standalone.rs new file mode 100644 index 0000000..af53449 --- /dev/null +++ b/crates/omnigraph-mcp/tests/standalone.rs @@ -0,0 +1,247 @@ +//! The `omnigraph-mcp` crate stands alone: a trivial `McpBackend` drives the +//! real transport with no omnigraph dependency. Also the **rmcp surface guard** +//! — the first smoke check on any rmcp version bump (mirrors the engine's +//! `lance_surface_guards.rs`). + +use std::sync::Arc; + +use async_trait::async_trait; +use axum::body::{Body, to_bytes}; +use axum::http::{Request, StatusCode, header}; +use omnigraph_mcp::{ + CallToolResult, Content, Implementation, McpBackend, McpError, McpHostPolicy, OriginPolicy, + ReadResourceResult, Resource, ServerCapabilities, ServerInfo, Tool, mcp_router, +}; +use serde_json::{Value, json}; +use tower::ServiceExt; + +#[derive(Clone)] +struct Dummy; + +#[async_trait] +impl McpBackend for Dummy { + fn server_info(&self) -> ServerInfo { + ServerInfo::new( + ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + ) + .with_server_info(Implementation::new("omnigraph-test", "0.0.0")) + } + + async fn list_tools(&self, _parts: &http::request::Parts) -> Result, McpError> { + let schema = json!({ "type": "object", "properties": {}, "additionalProperties": false }); + let schema = schema.as_object().unwrap().clone(); + Ok(vec![Tool::new( + "graph_health", + "Liveness probe.", + Arc::new(schema), + )]) + } + + async fn call_tool( + &self, + _parts: &http::request::Parts, + name: &str, + _args: omnigraph_mcp::JsonObject, + ) -> Result { + Ok(CallToolResult::success(vec![Content::text(format!( + "called {name}" + ))])) + } + + async fn list_resources( + &self, + _parts: &http::request::Parts, + ) -> Result, McpError> { + Ok(vec![]) + } + + async fn read_resource( + &self, + _parts: &http::request::Parts, + _uri: &str, + ) -> Result { + Err(McpError::invalid_params("no resources", None)) + } +} + +fn loopback_router() -> axum::Router { + let policy = McpHostPolicy::from_bind(&"127.0.0.1:0".parse().unwrap(), &[], &[]); + mcp_router(Dummy, 1 << 20, policy) +} + +fn mcp_post(body: Value) -> Request { + Request::builder() + .method("POST") + .uri("/mcp") + .header(header::HOST, "localhost") + .header(header::CONTENT_TYPE, "application/json") + .header(header::ACCEPT, "application/json, text/event-stream") + .body(Body::from(serde_json::to_vec(&body).unwrap())) + .unwrap() +} + +async fn json_body(resp: axum::response::Response) -> Value { + let bytes = to_bytes(resp.into_body(), usize::MAX).await.unwrap(); + serde_json::from_slice(&bytes).unwrap() +} + +#[tokio::test] +async fn initialize_advertises_tools_and_resources() { + let resp = loopback_router() + .oneshot(mcp_post(json!({ + "jsonrpc": "2.0", + "id": 1, + "method": "initialize", + "params": { + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": { "name": "test", "version": "0" } + } + }))) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let v = json_body(resp).await; + assert_eq!(v["result"]["serverInfo"]["name"], "omnigraph-test"); + assert!(v["result"]["capabilities"]["tools"].is_object()); + assert!(v["result"]["capabilities"]["resources"].is_object()); +} + +#[tokio::test] +async fn tools_list_returns_full_set_with_no_next_cursor() { + let resp = loopback_router() + .oneshot(mcp_post(json!({ + "jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {} + }))) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::OK); + let v = json_body(resp).await; + let tools = v["result"]["tools"].as_array().unwrap(); + assert_eq!(tools.len(), 1); + assert_eq!(tools[0]["name"], "graph_health"); + // Non-paginated by contract: nextCursor is absent (None ⇒ omitted). + assert!(v["result"]["nextCursor"].is_null()); +} + +#[tokio::test] +async fn get_is_method_not_allowed() { + let resp = loopback_router() + .oneshot( + Request::builder() + .method("GET") + .uri("/mcp") + .header(header::HOST, "localhost") + .body(Body::empty()) + .unwrap(), + ) + .await + .unwrap(); + assert_eq!(resp.status(), StatusCode::METHOD_NOT_ALLOWED); + assert_eq!(resp.headers()[header::ALLOW], "POST"); +} + +#[tokio::test] +async fn deny_browsers_rejects_present_origin_but_allows_absent() { + // A remote-shaped policy: any present Origin is forbidden; absent passes. + let policy = McpHostPolicy { + allowed_hosts: None, + origin: OriginPolicy::DenyBrowsers, + }; + let router: axum::Router = mcp_router(Dummy, 1 << 20, policy); + + let init = json!({ + "jsonrpc": "2.0", "id": 1, "method": "initialize", + "params": { "protocolVersion": "2025-11-25", "capabilities": {}, + "clientInfo": { "name": "t", "version": "0" } } + }); + + // Present, disallowed Origin → 403 (origin_guard, not rmcp's empty-list path). + let mut with_origin = mcp_post(init.clone()); + with_origin + .headers_mut() + .insert(header::ORIGIN, "https://evil.example".parse().unwrap()); + let resp = router.clone().oneshot(with_origin).await.unwrap(); + assert_eq!(resp.status(), StatusCode::FORBIDDEN); + + // Absent Origin → 200 (non-browser MCP clients send none). + let resp = router.oneshot(mcp_post(init)).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK); +} + +#[tokio::test] +async fn loopback_bind_allows_all_loopback_host_forms() { + // A loopback bind must accept *every* loopback Host form — 127.0.0.1, [::1], + // and localhost — regardless of which stack it bound. The Host header is + // independent of the socket (in-process, reverse proxies, dual-stack + // `localhost`), so a `127.0.0.1`-bound server must still accept a `[::1]` + // Host. (Matches rmcp's default loopback set; deriving the list from the + // bound IP alone dropped the sibling-stack literal and 403'd the client.) + for bind in ["127.0.0.1:8080", "[::1]:8080"] { + let policy = McpHostPolicy::from_bind(&bind.parse().unwrap(), &[], &[]); + let hosts = policy.allowed_hosts.clone().unwrap(); + for expected in ["127.0.0.1", "::1", "localhost"] { + assert!( + hosts.iter().any(|h| h == expected), + "bind {bind}: loopback allowlist missing {expected}: {hosts:?}" + ); + } + // e2e: the IPv6 loopback Host is accepted even on the IPv4 bind. + let router: axum::Router = mcp_router(Dummy, 1 << 20, policy); + let mut req = mcp_post(json!({ + "jsonrpc": "2.0", "id": 1, "method": "initialize", + "params": { "protocolVersion": "2025-11-25", "capabilities": {}, + "clientInfo": { "name": "t", "version": "0" } } + })); + req.headers_mut().insert(header::HOST, "[::1]:8080".parse().unwrap()); + let resp = router.oneshot(req).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK, "bind {bind}: IPv6 loopback Host rejected"); + } +} + +#[tokio::test] +async fn unsupported_protocol_version_header_is_400_except_on_initialize() { + // rmcp validates `MCP-Protocol-Version` on non-`initialize` requests only: + // `initialize` negotiates the version in its JSON-RPC body, so a bogus header + // there is ignored (200), while the same header on a follow-up request is a + // 400. Pins the real contract (the transport doc-comment notes this). + // `initialize` + bogus version header → 200 (header not validated on init). + let mut init = mcp_post(json!({ + "jsonrpc": "2.0", "id": 1, "method": "initialize", + "params": { "protocolVersion": "2025-11-25", "capabilities": {}, + "clientInfo": { "name": "t", "version": "0" } } + })); + init.headers_mut() + .insert("mcp-protocol-version", "1900-01-01".parse().unwrap()); + let resp = loopback_router().oneshot(init).await.unwrap(); + assert_eq!(resp.status(), StatusCode::OK, "initialize must not validate the version header"); + + // A follow-up request (`tools/list`) + bogus version header → 400. + let mut list = mcp_post(json!({ "jsonrpc": "2.0", "id": 2, "method": "tools/list", "params": {} })); + list.headers_mut() + .insert("mcp-protocol-version", "1900-01-01".parse().unwrap()); + let resp = loopback_router().oneshot(list).await.unwrap(); + assert_eq!(resp.status(), StatusCode::BAD_REQUEST, "non-init bogus version must be 400"); +} + +/// rmcp surface guard — pins the API shapes the transport relies on. Turns red +/// (compile error) on an rmcp bump that renames/moves any of these. Compile-only. +#[allow(dead_code)] +fn _rmcp_surface_guard() { + use rmcp::transport::streamable_http_server::{ + StreamableHttpServerConfig, session::never::NeverSessionManager, + }; + let _config = StreamableHttpServerConfig::default() + .with_stateful_mode(false) + .with_json_response(true) + .disable_allowed_hosts() + .with_allowed_origins(["https://app.example".to_string()]); + let _session = NeverSessionManager::default(); + // The Parts passthrough: rmcp's RequestContext extensions hold the HTTP parts. + fn _reads_parts(ctx: &rmcp::service::RequestContext) { + let _parts: Option<&http::request::Parts> = ctx.extensions.get::(); + } +} diff --git a/crates/omnigraph-policy/src/lib.rs b/crates/omnigraph-policy/src/lib.rs index 46b380a..3d834b2 100644 --- a/crates/omnigraph-policy/src/lib.rs +++ b/crates/omnigraph-policy/src/lib.rs @@ -6,7 +6,7 @@ use std::str::FromStr; use cedar_policy::{ Authorizer, Context, Decision, Entities, Entity, EntityId, EntityTypeName, EntityUid, Policy, - PolicyId, PolicySet, Request, Schema, ValidationMode, Validator, + PolicyId, PolicySet, Request, Response, Schema, ValidationMode, Validator, }; use clap::ValueEnum; use color_eyre::eyre::{Result, bail, eyre}; @@ -506,6 +506,80 @@ impl PolicyEngine { /// the "server-authoritative actor identity" invariant — clients /// supplying a `PolicyRequest` cannot smuggle identity through the /// same struct that carries the requested action. + /// Evaluate one Cedar request for `actor_id` performing `action` in the + /// given branch `context`, returning the raw decision response. The single + /// Cedar entry point shared by [`Self::authorize`] (per-call gate) and + /// [`Self::permits_on_any_branch`] (list-time capability probe) — there is + /// no second matching implementation that could drift from the policy set. + fn evaluate(&self, actor_id: &str, action: PolicyAction, context: serde_json::Value) -> Result { + let principal = entity_uid("Actor", actor_id)?; + let action_uid = entity_uid("Action", action.as_str())?; + // Pick the resource entity based on the action's `resource_kind`. + // Server-scoped actions (`graph_list`) bind to + // `Omnigraph::Server::"root"`; per-graph actions bind to + // `Omnigraph::Graph::""`. + let resource = match action.resource_kind() { + PolicyResourceKind::Server => entity_uid("Server", SERVER_RESOURCE_ID)?, + PolicyResourceKind::Graph => entity_uid("Graph", &self.graph_id)?, + }; + let context = Context::from_json_value(context, Some((&self.schema, &action_uid)))?; + let cedar_request = Request::new(principal, action_uid, resource, context, Some(&self.schema))?; + let response = + Authorizer::new().is_authorized(&cedar_request, &self.policies, &self.entities); + let errors = response + .diagnostics() + .errors() + .map(|err| err.to_string()) + .collect::>(); + if !errors.is_empty() { + bail!("policy evaluation failed:\n{}", errors.join("\n")); + } + Ok(response) + } + + /// List-time capability probe: could `action` be permitted for `actor_id` on + /// **any** branch the per-call gate might be invoked with? Enumerates the + /// branch-shape space (omitted / protected / unprotected, on whichever of + /// `branch`/`target_branch` the action scopes on) through [`Self::evaluate`] + /// and returns true if any is allowed. + /// + /// This makes `tools/list` a faithful *relaxation* of the per-call gate: it + /// never hides a tool the caller could invoke on some branch (over-showing + /// is safe — the per-call gate is authoritative), while still hiding a tool + /// the actor has no grant for. It deliberately does not fabricate a single + /// branch name (which under a "write unprotected branches" policy answers + /// the wrong question — a `change` request with no branch, or on protected + /// `main`, is denied, yet the actor can write feature branches). + pub fn permits_on_any_branch(&self, actor_id: &str, action: PolicyAction) -> Result { + if !self.known_actors.contains(actor_id) { + return Ok(false); + } + // The compiled branch/target scope conditions depend only on + // (`has_*`, `*_is_protected`), so these shapes span every distinguishable + // request. `branch`/`target_branch` string values are unread by any rule. + let shapes: Vec = if action.uses_branch_scope() { + vec![ + branch_context(false, "", false, false, "", false), + branch_context(true, "p", true, false, "", false), + branch_context(true, "u", false, false, "", false), + ] + } else if action.uses_target_branch_scope() { + vec![ + branch_context(false, "", false, true, "p", true), + branch_context(false, "", false, true, "u", false), + ] + } else { + // Graph-scoped action (no branch dimension): one evaluation suffices. + vec![branch_context(false, "", false, false, "", false)] + }; + for context in shapes { + if matches!(self.evaluate(actor_id, action, context)?.decision(), Decision::Allow) { + return Ok(true); + } + } + Ok(false) + } + pub fn authorize(&self, actor_id: &str, request: &PolicyRequest) -> Result { if !self.known_actors.contains(actor_id) { return Ok(self.deny( @@ -517,36 +591,15 @@ impl PolicyEngine { )); } - let principal = entity_uid("Actor", actor_id)?; - let action = entity_uid("Action", request.action.as_str())?; - // Pick the resource entity based on the action's `resource_kind`. - // Server-scoped actions (`graph_list`) bind to - // `Omnigraph::Server::"root"`; per-graph actions bind to - // `Omnigraph::Graph::""`. - let resource = match request.action.resource_kind() { - PolicyResourceKind::Server => entity_uid("Server", SERVER_RESOURCE_ID)?, - PolicyResourceKind::Graph => entity_uid("Graph", &self.graph_id)?, - }; - let context_value = json!({ - "has_branch": request.branch.is_some(), - "branch": request.branch.clone().unwrap_or_default(), - "has_target_branch": request.target_branch.is_some(), - "target_branch": request.target_branch.clone().unwrap_or_default(), - "branch_is_protected": request.branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)), - "target_branch_is_protected": request.target_branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)), - }); - let context = Context::from_json_value(context_value, Some((&self.schema, &action)))?; - let cedar_request = Request::new(principal, action, resource, context, Some(&self.schema))?; - let response = - Authorizer::new().is_authorized(&cedar_request, &self.policies, &self.entities); - let errors = response - .diagnostics() - .errors() - .map(|err| err.to_string()) - .collect::>(); - if !errors.is_empty() { - bail!("policy evaluation failed:\n{}", errors.join("\n")); - } + let context = branch_context( + request.branch.is_some(), + request.branch.as_deref().unwrap_or_default(), + request.branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)), + request.target_branch.is_some(), + request.target_branch.as_deref().unwrap_or_default(), + request.target_branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)), + ); + let response = self.evaluate(actor_id, request.action, context)?; let matched_rule_id = response .diagnostics() @@ -790,6 +843,28 @@ fn compile_policy_source(rule: &PolicyRule, action: &PolicyAction, graph_id: &st ) } +/// Build the Cedar request context from the branch/target dimensions. The +/// single shape used by both the per-call gate ([`PolicyEngine::authorize`]) +/// and the list-time capability probe ([`PolicyEngine::permits_on_any_branch`]), +/// so both feed the policy set an identically-structured context. +fn branch_context( + has_branch: bool, + branch: &str, + branch_is_protected: bool, + has_target_branch: bool, + target_branch: &str, + target_branch_is_protected: bool, +) -> serde_json::Value { + json!({ + "has_branch": has_branch, + "branch": branch, + "has_target_branch": has_target_branch, + "target_branch": target_branch, + "branch_is_protected": branch_is_protected, + "target_branch_is_protected": target_branch_is_protected, + }) +} + fn branch_scope_condition(scope: PolicyBranchScope) -> String { match scope { PolicyBranchScope::Any => "true".to_string(), @@ -1200,6 +1275,65 @@ rules: assert_eq!(admin.matched_rule_id.as_deref(), Some("admins-promote")); } + #[test] + fn permits_on_any_branch_reports_capability_not_a_fabricated_branch() { + // The canonical "protected main, write unprotected branches" policy. + let policy: PolicyConfig = serde_yaml::from_str( + r#" +version: 1 +groups: + writers: [act-writer] + readers: [act-reader] +protected_branches: [main] +rules: + - id: writers-change-unprotected + allow: + actors: { group: writers } + actions: [change] + branch_scope: unprotected + - id: readers-read + allow: + actors: { group: readers } + actions: [read] + branch_scope: any +"#, + ) + .unwrap(); + let engine = PolicyCompiler::compile(&policy, "graph").unwrap(); + + // The writer can `change` *some* branch (unprotected), so the list-time + // capability probe is true — even though the per-call gate denies on + // protected `main` and on a branchless request. This is exactly the + // divergence a fabricated-`main` probe got wrong. + assert!(engine.permits_on_any_branch("act-writer", PolicyAction::Change).unwrap()); + assert!( + !engine + .authorize( + "act-writer", + &PolicyRequest { action: PolicyAction::Change, branch: Some("main".to_string()), target_branch: None }, + ) + .unwrap() + .allowed + ); + assert!( + engine + .authorize( + "act-writer", + &PolicyRequest { action: PolicyAction::Change, branch: Some("feature".to_string()), target_branch: None }, + ) + .unwrap() + .allowed + ); + + // A reader has no `change` grant on any branch → capability is false, + // so the relaxation still hides write tools from read-only actors. + assert!(!engine.permits_on_any_branch("act-reader", PolicyAction::Change).unwrap()); + assert!(engine.permits_on_any_branch("act-reader", PolicyAction::Read).unwrap()); + + // Unknown actor → false (never panics, never allows). + assert!(!engine.permits_on_any_branch("act-ghost", PolicyAction::Read).unwrap()); + } + #[test] fn policy_tests_enforce_expected_outcomes() { let policy: PolicyConfig = serde_yaml::from_str( diff --git a/crates/omnigraph-server/Cargo.toml b/crates/omnigraph-server/Cargo.toml index a6a0717..aa48875 100644 --- a/crates/omnigraph-server/Cargo.toml +++ b/crates/omnigraph-server/Cargo.toml @@ -24,7 +24,11 @@ omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.7.0" } omnigraph-policy = { path = "../omnigraph-policy", version = "0.7.0" } omnigraph-api-types = { path = "../omnigraph-api-types", version = "0.7.0" } omnigraph-cluster = { path = "../omnigraph-cluster", version = "0.7.0" } +# The MCP surface. rmcp is contained to omnigraph-mcp — the server carries NO +# direct rmcp dependency (verify: `cargo tree -p omnigraph-server -e normal | grep rmcp`). +omnigraph-mcp = { path = "../omnigraph-mcp", version = "0.7.0" } axum = { workspace = true } +http = "1" clap = { workspace = true } color-eyre = { workspace = true } serde = { workspace = true } diff --git a/crates/omnigraph-server/src/handlers.rs b/crates/omnigraph-server/src/handlers.rs index 7de38d2..c2c0263 100644 --- a/crates/omnigraph-server/src/handlers.rs +++ b/crates/omnigraph-server/src/handlers.rs @@ -426,6 +426,37 @@ pub(crate) fn authorize_request( } } +/// List-time capability probe: could `action` be permitted on *some* branch? +/// Mirrors [`authorize`]'s no-policy handling (open mode allows per-graph +/// actions; default-deny allows only `Read`; server-scoped actions are closed), +/// and otherwise delegates to [`PolicyEngine::permits_on_any_branch`]. Used to +/// filter argument-scoped tools in `tools/list` as a relaxation of the per-call +/// gate — so a tool callable on some branch is never hidden, while one the +/// actor has no grant for stays hidden. +pub(crate) fn authorize_any_branch( + actor: Option<&ResolvedActor>, + policy: Option<&PolicyEngine>, + action: PolicyAction, +) -> std::result::Result { + let Some(engine) = policy else { + if action.resource_kind() == PolicyResourceKind::Server { + return Ok(false); + } + // Default-deny mode (tokens configured, no policy): only Read; Open mode + // (no tokens): all per-graph actions. Matches `authorize` exactly. + if actor.is_some() && action != PolicyAction::Read { + return Ok(false); + } + return Ok(true); + }; + let Some(actor) = actor else { + return Err(ApiError::unauthorized("missing bearer token")); + }; + engine + .permits_on_any_branch(actor.actor_id.as_ref(), action) + .map_err(|err| ApiError::internal(format!("policy: {err}"))) +} + #[utoipa::path( get, path = "/snapshot", @@ -1216,7 +1247,7 @@ pub(crate) async fn server_schema_apply( /// Shared body for `POST /load` (canonical) and `POST /ingest` (deprecated): /// branch-exists / fork-if-`from` check, Cedar authorization, admission, the /// bulk `load_as`, and the `IngestOutput` mapping. -async fn run_ingest( +pub(crate) async fn run_ingest( state: AppState, handle: Arc, actor: Option<&ResolvedActor>, diff --git a/crates/omnigraph-server/src/lib.rs b/crates/omnigraph-server/src/lib.rs index 5451b05..80b24e6 100644 --- a/crates/omnigraph-server/src/lib.rs +++ b/crates/omnigraph-server/src/lib.rs @@ -1,5 +1,6 @@ pub mod api; mod handlers; +mod mcp; mod settings; pub use settings::{load_server_settings, classify_server_runtime_state, ServerRuntimeState}; use settings::*; @@ -257,6 +258,18 @@ pub struct AppState { /// resource. Loaded from the cluster-scoped policy binding when /// configured. Per-graph policies live on each `GraphHandle.policy`. server_policy: Option>, + /// MCP host/Origin policy inputs. Default (`None` bind + empty lists) + /// yields a loopback-safe `Unchecked` policy — correct for in-process + /// tests that never bind a socket. `serve()` overrides `mcp_bind` from + /// `listener.local_addr()` so a public bind is fail-closed + /// (`DenyBrowsers`), not silently `Unchecked` (the silent-fail-open + /// guard — see `omnigraph_mcp::McpHostPolicy::from_bind`). `public_hosts` + /// / `browser_origins` are reserved for future cluster/CLI config (empty + /// today: a public bind disables Host-allowlisting and rejects browser + /// Origins until configured). + mcp_bind: Option, + mcp_public_hosts: Vec, + mcp_browser_origins: Vec, } struct ExportStreamWriter { @@ -531,6 +544,9 @@ impl AppState { workload, bearer_tokens, server_policy: None, + mcp_bind: None, + mcp_public_hosts: Vec::new(), + mcp_browser_origins: Vec::new(), } } @@ -557,6 +573,9 @@ impl AppState { workload: Arc::new(workload), bearer_tokens, server_policy: server_policy.map(Arc::new), + mcp_bind: None, + mcp_public_hosts: Vec::new(), + mcp_browser_origins: Vec::new(), }) } @@ -567,6 +586,34 @@ impl AppState { &self.routing } + /// Install the MCP host/Origin policy inputs from the bound socket. + /// `serve()` calls this after `TcpListener::bind` (reading + /// `local_addr()` — the authoritative bound address, which resolves + /// `0.0.0.0`/hostname binds) and before `build_app`, so the derived + /// policy is fail-closed on a public bind. Tests that build an app + /// without a socket skip this and get the loopback-safe default. + pub fn with_mcp_host_inputs( + mut self, + bind: std::net::SocketAddr, + public_hosts: Vec, + browser_origins: Vec, + ) -> Self { + self.mcp_bind = Some(bind); + self.mcp_public_hosts = public_hosts; + self.mcp_browser_origins = browser_origins; + self + } + + /// Derive the MCP host/Origin policy from the stored inputs through the + /// single fail-closed constructor. A `None` bind defaults to loopback + /// (`Unchecked`), correct for in-process tests. + pub(crate) fn mcp_host_policy(&self) -> omnigraph_mcp::McpHostPolicy { + let bind = self + .mcp_bind + .unwrap_or_else(|| std::net::SocketAddr::from(([127, 0, 0, 1], 0))); + omnigraph_mcp::McpHostPolicy::from_bind(&bind, &self.mcp_public_hosts, &self.mcp_browser_origins) + } + fn requires_bearer_auth(&self) -> bool { if !self.bearer_tokens.is_empty() { return true; @@ -605,6 +652,20 @@ fn hash_bearer_tokens(bearer_tokens: Vec<(String, String)>) -> Arc<[(BearerToken } impl ApiError { + /// HTTP status this error maps to — identical to what `IntoResponse` + /// emits (`self.status`). Used by the MCP `classify` mapper to split + /// semantic 4xx (→ `isError` tool result) from operational 5xx + /// (→ JSON-RPC protocol error). + pub(crate) fn status_code(&self) -> StatusCode { + self.status + } + + /// The human-readable message — identical to the `error` field + /// `IntoResponse` puts in the body (`self.message`). + pub(crate) fn message_str(&self) -> &str { + &self.message + } + pub fn unauthorized(message: impl Into) -> Self { Self { status: StatusCode::UNAUTHORIZED, @@ -926,6 +987,12 @@ pub fn build_app(state: AppState) -> Router { .route("/branches/merge", post(server_branch_merge)) .route("/commits", get(server_commit_list)) .route("/commits/{commit_id}", get(server_commit_show)) + // The MCP surface → POST /graphs/{graph_id}/mcp. Merged (not `.route`) + // so its own tower-http body-limit + Origin-guard layers stay scoped to + // /mcp and don't leak onto the REST routes. The two route_layers below + // (bearer + handle) wrap it, so rmcp sees a request whose extensions + // already carry ResolvedActor + Arc. + .merge(mcp::mcp_router(state.clone())) .route_layer(middleware::from_fn_with_state( state.clone(), resolve_graph_handle, @@ -1018,6 +1085,15 @@ pub async fn serve(config: ServerConfig) -> Result<()> { }; let listener = TcpListener::bind(&bind).await?; + // Derive the MCP host/Origin policy from the ACTUAL bound address (not the + // configured `bind` string — `0.0.0.0`/hostname binds resolve only after + // bind). A public bind ⇒ fail-closed `DenyBrowsers`; loopback ⇒ `Unchecked`. + // `public_hosts`/`browser_origins` are empty until cluster/CLI config wires + // them (a public bind then disables Host-allowlisting, with bearer the + // control). Missing this reorder would silently leave a public bind on the + // loopback default — the fail-open class `McpHostPolicy` exists to close. + let local_addr = listener.local_addr()?; + let state = state.with_mcp_host_inputs(local_addr, Vec::new(), Vec::new()); axum::serve(listener, build_app(state)) .with_graceful_shutdown(shutdown_signal()) .await?; diff --git a/crates/omnigraph-server/src/mcp.rs b/crates/omnigraph-server/src/mcp.rs new file mode 100644 index 0000000..1e94357 --- /dev/null +++ b/crates/omnigraph-server/src/mcp.rs @@ -0,0 +1,1101 @@ +//! The server's MCP backend: projects built-in operations and the per-graph +//! stored-query registry as MCP tools + resources, each Cedar-gated by the same +//! `authorize` path the REST routes use. Adds **no business logic** — every +//! tool delegates to the same engine/handler functions the routes call. +//! +//! The `rmcp` transport lives in `omnigraph-mcp`; this module fills its +//! `McpBackend` seam. It reads the request's resolved actor + graph handle out +//! of `parts.extensions` (injected by `require_bearer_auth` / +//! `resolve_graph_handle` before rmcp sees the request). + +use std::sync::Arc; + +use async_trait::async_trait; +use serde_json::{Value, json}; + +use omnigraph::db::{ReadTarget, SchemaApplyOptions}; +use omnigraph_mcp::{ + CallToolResult, Content, Implementation, JsonObject, McpBackend, McpError, RawResource, + ReadResourceResult, Resource, ResourceContents, ServerCapabilities, ServerInfo, Tool, + ToolAnnotations, +}; + +use crate::handlers::{ + Authz, authorize, authorize_any_branch, authorize_request, run_ingest, run_mutate, run_query, +}; +use crate::queries::{QueryRegistry, StoredQuery}; +use crate::{ + ApiError, AppState, GraphHandle, INGEST_REQUEST_BODY_LIMIT_BYTES, PolicyAction, PolicyRequest, + ResolvedActor, SERVER_VERSION, api, validate_registry_against_catalog, +}; + +const SCHEMA_URI: &str = "omnigraph://schema"; +const BRANCHES_URI: &str = "omnigraph://branches"; + +/// `auto` projection flips from one-tool-per-query to the discovery+execute +/// meta pair at or above this many exposed queries (model accuracy degrades as +/// a single client's tool count climbs past a few dozen). +const STORED_QUERY_AUTO_THRESHOLD: usize = 24; +const STORED_QUERY_LIST_TOOL: &str = "stored_query_list"; +const STORED_QUERY_RUN_TOOL: &str = "stored_query_run"; + +/// The closed set of built-in tool names — reserved graph-wide. Folded into the +/// stored-query registry's uniqueness check (`QueryRegistry::from_specs`) so a +/// stored query that shadows a built-in fails loudly at load instead of being +/// silently un-served. Kept in sync with `Builtin::ALL` by +/// `builtin_tool_names_match_enum`. +pub(crate) const BUILTIN_TOOL_NAMES: &[&str] = &[ + "graph_health", + "graph_query", + "graph_snapshot", + "schema_get", + "branch_list", + "commit_list", + "commit_get", + "graph_mutate", + "graph_load", + "branch_create", + "branch_delete", + "branch_merge", + "schema_apply", +]; + +/// The server's thin wrapper over `omnigraph_mcp::mcp_router`: derives the +/// fail-closed host policy from the bound socket and passes the 32 MiB body +/// limit (`/load` parity). Merged into `per_graph_protected` in `build_app`. +pub(crate) fn mcp_router(state: AppState) -> axum::Router { + let host_policy = state.mcp_host_policy(); + omnigraph_mcp::mcp_router::<_, AppState>( + OmnigraphMcpBackend::new(state), + INGEST_REQUEST_BODY_LIMIT_BYTES, + host_policy, + ) +} + +#[derive(Clone)] +pub(crate) struct OmnigraphMcpBackend { + state: AppState, +} + +impl OmnigraphMcpBackend { + pub(crate) fn new(state: AppState) -> Self { + Self { state } + } + + /// Pull the resolved actor (absent in `--unauthenticated` mode) and the + /// graph handle out of the request extensions. The handle is always present + /// on a `/graphs/{id}/mcp` route (injected by `resolve_graph_handle`); its + /// absence is an internal wiring error. + fn ctx<'a>( + &self, + parts: &'a http::request::Parts, + ) -> Result<(Option<&'a ResolvedActor>, &'a Arc), McpError> { + let actor = parts.extensions.get::(); + let handle = parts + .extensions + .get::>() + .ok_or_else(|| McpError::internal_error("graph handle missing from request extensions", None))?; + Ok((actor, handle)) + } +} + +#[async_trait] +impl McpBackend for OmnigraphMcpBackend { + fn server_info(&self) -> ServerInfo { + ServerInfo::new( + ServerCapabilities::builder() + .enable_tools() + .enable_resources() + .build(), + ) + .with_server_info(Implementation::new("omnigraph", SERVER_VERSION)) + .with_instructions( + "Omnigraph graph database. Every tool operates on a single graph — the one in \ + the URL path — so the graph id never appears in tool arguments or output.", + ) + } + + async fn list_tools(&self, parts: &http::request::Parts) -> Result, McpError> { + let (actor, handle) = self.ctx(parts)?; + let mut tools = Vec::new(); + for builtin in Builtin::ALL { + // Visibility is a *relaxation* of the per-call gate: a built-in whose + // authorization depends on a caller-chosen branch is shown iff the + // actor could invoke it on *some* branch (`authorize_any_branch`), + // never hidden because a fabricated branch happened to be denied. The + // per-call gate inside `call` stays authoritative. + let visible = match builtin.list_action() { + None => true, + Some(action) => { + authorize_any_branch(actor, handle.policy.as_deref(), action).map_err(api_to_mcp)? + } + }; + if visible { + tools.push(builtin.descriptor()); + } + } + // Stored queries are graph-scoped behind one coarse `invoke_query` gate + // (the catalog can't be probed without it). Projection mode is chosen + // from this graph's exposed-query count. + if let Some(registry) = handle.queries.as_deref() { + let can_invoke = allowed(actor, handle, invoke_query_request())?; + if can_invoke { + match stored_mode(registry) { + StoredMode::PerQuery => { + for query in registry.exposed() { + tools.push(stored_query_tool(query)); + } + } + StoredMode::Meta => { + tools.push(stored_query_list_descriptor()); + tools.push(stored_query_run_descriptor()); + } + } + } + } + Ok(tools) + } + + async fn call_tool( + &self, + parts: &http::request::Parts, + name: &str, + args: JsonObject, + ) -> Result { + let (actor, handle) = self.ctx(parts)?; + + // 1. Built-in. Its names are a fixed public set, so a denial surfaces as + // `isError` via `classify` (no masking needed). + if let Some(builtin) = Builtin::by_name(name) { + return classify(builtin.call(&self.state, actor, handle, args).await); + } + + // 2 & 3. Stored-query paths (the meta pair and per-query tools) share the + // coarse `invoke_query` outer gate, deny-masked as an unknown tool so + // the stored-query catalog (whose names reveal business logic) can't + // be probed. The inner Read/Change gate runs in run_query/run_mutate. + // `resolve_stored_tool` is the single membership test: a name is + // callable iff `tools/list` would have advertised it (same mode + + // `expose` filter), so call and list cannot diverge — the meta pair is + // callable only in `Meta` mode and per-query tools only in `PerQuery`. + if let Some(dispatch) = handle + .queries + .as_deref() + .and_then(|registry| resolve_stored_tool(registry, name)) + { + match authorize(actor, handle.policy.as_deref(), invoke_query_request()).map_err(api_to_mcp)? { + Authz::Allowed => {} + Authz::Denied(_) => return Err(unknown_tool(name)), + } + return classify(self.dispatch_stored(actor, handle, name, dispatch, args).await); + } + + Err(unknown_tool(name)) + } + + async fn list_resources(&self, parts: &http::request::Parts) -> Result, McpError> { + let (actor, handle) = self.ctx(parts)?; + if allowed(actor, handle, read_request(None))? { + Ok(vec![ + resource(SCHEMA_URI, "schema", "This graph's schema as .pg source."), + resource(BRANCHES_URI, "branches", "Branch names on this graph (JSON)."), + ]) + } else { + Ok(Vec::new()) + } + } + + async fn read_resource( + &self, + parts: &http::request::Parts, + uri: &str, + ) -> Result { + let (actor, handle) = self.ctx(parts)?; + if uri != SCHEMA_URI && uri != BRANCHES_URI { + return Err(McpError::invalid_params(format!("unknown resource: {uri}"), None)); + } + // Read-gated; a denial masks identically to an unknown URI so the + // resource set can't be probed. Operational failures stay 5xx. + match authorize(actor, handle.policy.as_deref(), read_request(None)).map_err(api_to_mcp)? { + Authz::Allowed => {} + Authz::Denied(_) => { + return Err(McpError::invalid_params(format!("unknown resource: {uri}"), None)); + } + } + let contents = if uri == SCHEMA_URI { + ResourceContents::text(handle.engine.schema_source().to_string(), SCHEMA_URI) + } else { + let mut branches = handle + .engine + .branch_list() + .await + .map_err(|e| McpError::internal_error(e.to_string(), None))?; + branches.sort(); + let json = serde_json::to_string(&api::BranchListOutput { branches }) + .map_err(|e| McpError::internal_error(e.to_string(), None))?; + ResourceContents::text(json, BRANCHES_URI) + }; + Ok(ReadResourceResult::new(vec![contents])) + } +} + +impl OmnigraphMcpBackend { + /// Dispatch a stored-query tool call (already resolved by + /// [`resolve_stored_tool`] and past the `invoke_query` gate). + /// `stored_query_list` enumerates the exposed catalog; `stored_query_run` + /// and the per-query tools invoke a query through `run_query`/`run_mutate` + /// (the inner Read/Change gate). Both invoke paths resolve **exposed-only**, + /// so a query hidden from the tool list is unreachable even by name. + async fn dispatch_stored( + &self, + actor: Option<&ResolvedActor>, + handle: &Arc, + name: &str, + dispatch: StoredDispatch, + args: JsonObject, + ) -> Result { + let registry = handle + .queries + .as_deref() + .ok_or_else(|| ApiError::not_found("no stored queries for this graph"))?; + + let stored = match dispatch { + StoredDispatch::List => return stored_query_list(registry, &args), + StoredDispatch::Run => { + let query_name = req_str(&args, "name")?; + registry + .exposed_by_name(&query_name) + .ok_or_else(|| ApiError::not_found("stored query not found"))? + } + // Per-query tool: resolve by effective tool name, exposed-only. + StoredDispatch::PerQuery => registry + .exposed() + .find(|q| q.effective_tool_name() == name) + .ok_or_else(|| ApiError::not_found("stored query not found"))?, + }; + + self.invoke_stored( + actor, + handle, + stored, + opt_obj(&args, "params")?, + opt_str(&args, "branch")?, + opt_str(&args, "snapshot")?, + ) + .await + } + + async fn invoke_stored( + &self, + actor: Option<&ResolvedActor>, + handle: &Arc, + stored: &StoredQuery, + params: Option, + branch: Option, + snapshot: Option, + ) -> Result { + if stored.is_mutation() { + if snapshot.is_some() { + return Err(ApiError::bad_request( + "a mutation cannot target a snapshot", + )); + } + let output = run_mutate( + self.state.clone(), + Arc::clone(handle), + actor, + &stored.source, + Some(&stored.name), + params.as_ref(), + branch.unwrap_or_else(|| "main".to_string()), + ) + .await?; + json_result(&api::InvokeStoredQueryResponse::Change(output)) + } else { + let (selected, target, result) = run_query( + Arc::clone(handle), + actor, + &stored.source, + Some(&stored.name), + params.as_ref(), + branch, + snapshot, + true, + ) + .await?; + json_result(&api::InvokeStoredQueryResponse::Read(api::read_output( + selected, &target, result, + ))) + } + } +} + +// ===== shared helpers ===== + +/// Map a handler `ApiError` into an MCP tool outcome (SEP-1303): a semantic +/// 4xx (bad params, validation, 404/409) becomes an `isError` tool *result* +/// fed back to the model; an operational 5xx becomes a JSON-RPC protocol error. +fn classify(result: Result) -> Result { + match result { + Ok(out) => Ok(out), + Err(e) if e.status_code().is_client_error() => { + Ok(CallToolResult::error(vec![Content::text(e.message_str().to_owned())])) + } + Err(e) => Err(McpError::internal_error(e.message_str().to_owned(), None)), + } +} + +/// Map an *authorization* `ApiError` (operational only — 401/500) onto an MCP +/// protocol error. Used on paths where a denial is consumed as `Authz`. +fn api_to_mcp(e: ApiError) -> McpError { + if e.status_code().is_client_error() { + McpError::invalid_params(e.message_str().to_owned(), None) + } else { + McpError::internal_error(e.message_str().to_owned(), None) + } +} + +fn unknown_tool(name: &str) -> McpError { + McpError::invalid_params(format!("unknown tool: {name}"), None) +} + +/// Argument-independent Cedar check used to filter `tools/list` / +/// `resources/list`. `call_tool` is authoritative; over-showing a branch-scoped +/// grant is the safe direction. +fn allowed( + actor: Option<&ResolvedActor>, + handle: &Arc, + request: PolicyRequest, +) -> Result { + let decision = authorize(actor, handle.policy.as_deref(), request).map_err(api_to_mcp)?; + Ok(matches!(decision, Authz::Allowed)) +} + +fn read_request(branch: Option) -> PolicyRequest { + PolicyRequest { action: PolicyAction::Read, branch, target_branch: None } +} + +fn invoke_query_request() -> PolicyRequest { + // Graph-scoped: no branch dimension. + PolicyRequest { action: PolicyAction::InvokeQuery, branch: None, target_branch: None } +} + +fn actor_id<'a>(actor: Option<&'a ResolvedActor>) -> Option<&'a str> { + actor.map(|a| a.actor_id.as_ref()) +} + +fn actor_arc(actor: Option<&ResolvedActor>) -> Arc { + actor + .map(|a| Arc::clone(&a.actor_id)) + .unwrap_or_else(|| Arc::::from("anonymous")) +} + +/// Serialize a DTO into a tool result with **structured output** — +/// `structuredContent` (typed returns for code-mode runtimes) plus a text +/// mirror for clients that don't parse it. `outputSchema` declaration is a +/// tracked follow-up (R2: utoipa `ToSchema` → JSON-Schema-2020-12 fidelity). +fn json_result(value: &T) -> Result { + let value = serde_json::to_value(value) + .map_err(|e| ApiError::internal(format!("serialize tool result: {e}")))?; + Ok(CallToolResult::structured(value)) +} + +fn opt_str(args: &JsonObject, key: &str) -> Result, ApiError> { + match args.get(key) { + None | Some(Value::Null) => Ok(None), + Some(Value::String(s)) => Ok(Some(s.clone())), + Some(_) => Err(ApiError::bad_request(format!("'{key}' must be a string"))), + } +} + +fn req_str(args: &JsonObject, key: &str) -> Result { + opt_str(args, key)?.ok_or_else(|| ApiError::bad_request(format!("'{key}' is required"))) +} + +fn opt_obj(args: &JsonObject, key: &str) -> Result, ApiError> { + match args.get(key) { + None | Some(Value::Null) => Ok(None), + Some(v @ Value::Object(_)) => Ok(Some(v.clone())), + Some(_) => Err(ApiError::bad_request(format!("'{key}' must be an object"))), + } +} + +fn from_args(args: JsonObject, what: &str) -> Result { + serde_json::from_value(Value::Object(args)) + .map_err(|e| ApiError::bad_request(format!("invalid {what}: {e}"))) +} + +fn read_only_annotations() -> ToolAnnotations { + ToolAnnotations::new().read_only(true).open_world(false) +} + +fn write_annotations(destructive: bool) -> ToolAnnotations { + ToolAnnotations::new() + .read_only(false) + .destructive(destructive) + .open_world(false) +} + +fn resource(uri: &str, name: &str, description: &str) -> Resource { + let raw = RawResource { + uri: uri.to_string(), + name: name.to_string(), + title: None, + description: Some(description.to_string()), + mime_type: Some("text/plain".to_string()), + size: None, + icons: None, + meta: None, + }; + Resource::new(raw, None) +} + +fn schema(value: Value) -> Arc { + Arc::new(value.as_object().cloned().unwrap_or_default()) +} + +// ===== stored-query projection ===== + +enum StoredMode { + PerQuery, + Meta, +} + +fn stored_mode(registry: &QueryRegistry) -> StoredMode { + let exposed = registry.exposed().count(); + if exposed < STORED_QUERY_AUTO_THRESHOLD { + StoredMode::PerQuery + } else { + StoredMode::Meta + } +} + +/// Which stored-query tool a `call_tool` name resolves to. The membership test +/// for the stored surface: it returns `Some` **only** for a name `tools/list` +/// would have advertised under the same projection, so call and list cannot +/// diverge. +enum StoredDispatch { + /// `stored_query_list` — Meta projection only. + List, + /// `stored_query_run` — Meta projection only. + Run, + /// A per-query tool (PerQuery projection only); resolved exposed-only. + PerQuery, +} + +/// Map a `call_tool` name to its stored dispatch, honoring the projection mode +/// and `expose`. In `Meta` mode only the meta pair resolves (per-query names → +/// `None`); in `PerQuery` mode only an exposed per-query tool name resolves (the +/// meta names → `None`). The exact inverse of what [`list_tools`] advertises. +fn resolve_stored_tool(registry: &QueryRegistry, name: &str) -> Option { + match stored_mode(registry) { + StoredMode::Meta => match name { + STORED_QUERY_LIST_TOOL => Some(StoredDispatch::List), + STORED_QUERY_RUN_TOOL => Some(StoredDispatch::Run), + _ => None, + }, + StoredMode::PerQuery => registry + .exposed() + .any(|q| q.effective_tool_name() == name) + .then_some(StoredDispatch::PerQuery), + } +} + +/// Per-query input schema: query params nested under `params` (so a param named +/// `branch`/`snapshot` can't collide with the knobs), plus `branch` and — +/// for reads only — `snapshot`. Each param's schema is the shared +/// `param_json_schema` (locked to the engine coercer). +fn stored_query_input_schema(stored: &StoredQuery) -> Value { + let mut props = serde_json::Map::new(); + let mut required = Vec::new(); + for param in &stored.decl.params { + let descriptor = api::param_descriptor(param); + props.insert(param.name.clone(), api::param_json_schema(&descriptor)); + if !param.nullable { + required.push(Value::String(param.name.clone())); + } + } + let mut params_schema = json!({ + "type": "object", + "properties": Value::Object(props), + "additionalProperties": false + }); + if !required.is_empty() { + params_schema["required"] = Value::Array(required); + } + let mut top = json!({ + "type": "object", + "properties": { "params": params_schema, "branch": { "type": "string" } }, + "additionalProperties": false + }); + if !stored.is_mutation() { + top["properties"]["snapshot"] = json!({ "type": "string" }); + } + top +} + +fn stored_query_tool(stored: &StoredQuery) -> Tool { + let description = stored + .decl + .description + .clone() + .unwrap_or_else(|| format!("Stored query '{}'.", stored.name)); + let annotations = if stored.is_mutation() { + write_annotations(true) + } else { + read_only_annotations() + }; + Tool::new( + stored.effective_tool_name().to_string(), + description, + schema(stored_query_input_schema(stored)), + ) + .with_annotations(annotations) +} + +fn stored_query_list_descriptor() -> Tool { + Tool::new( + STORED_QUERY_LIST_TOOL, + "List this graph's stored queries (names, descriptions, params). Use this to \ + discover queries, then run one with stored_query_run.", + schema(json!({ + "type": "object", + "properties": { + "filter": { "type": "string", "description": "Case-insensitive substring over name/description." }, + "detail_level": { "type": "string", "enum": ["summary", "full"], "description": "`full` includes typed params (default summary)." } + }, + "additionalProperties": false + })), + ) + .with_annotations(read_only_annotations()) +} + +fn stored_query_run_descriptor() -> Tool { + Tool::new( + STORED_QUERY_RUN_TOOL, + "Run a stored query by name. A read returns rows; a mutation applies a write.", + schema(json!({ + "type": "object", + "properties": { + "name": { "type": "string", "description": "Stored query name (from stored_query_list)." }, + "params": { "type": "object", "description": "Query parameters." }, + "branch": { "type": "string", "description": "Branch (default main)." }, + "snapshot": { "type": "string", "description": "Snapshot id (reads only; exclusive with branch)." } + }, + "required": ["name"], + "additionalProperties": false + })), + ) + .with_annotations( + // Mixed read/write population — annotate conservatively as a writer so + // clients prompt for confirmation; the inner gate enforces per query. + write_annotations(true), + ) +} + +fn stored_query_list(registry: &QueryRegistry, args: &JsonObject) -> Result { + let filter = opt_str(args, "filter")?.map(|f| f.to_lowercase()); + let full = matches!(opt_str(args, "detail_level")?.as_deref(), Some("full")); + let mut entries = Vec::new(); + for query in registry.exposed() { + let entry = api::query_catalog_entry(query); + if let Some(filter) = &filter { + let hay = format!( + "{} {}", + entry.name.to_lowercase(), + entry.description.as_deref().unwrap_or("").to_lowercase() + ); + if !hay.contains(filter) { + continue; + } + } + if full { + entries.push(serde_json::to_value(&entry).unwrap_or(Value::Null)); + } else { + entries.push(json!({ + "name": entry.name, + "tool_name": entry.tool_name, + "description": entry.description, + "mutation": entry.mutation, + })); + } + } + json_result(&json!({ "queries": entries })) +} + +// ===== built-in tools ===== + +/// Built-in operational tools. One variant per tool; `descriptor`/`list_gate`/ +/// `call` are match arms (lower liability than a `dyn` zoo). +#[derive(Clone, Copy)] +enum Builtin { + GraphHealth, + GraphQuery, + GraphSnapshot, + SchemaGet, + BranchList, + CommitList, + CommitGet, + GraphMutate, + GraphLoad, + BranchCreate, + BranchDelete, + BranchMerge, + SchemaApply, +} + +impl Builtin { + const ALL: [Builtin; 13] = [ + Builtin::GraphHealth, + Builtin::GraphQuery, + Builtin::GraphSnapshot, + Builtin::SchemaGet, + Builtin::BranchList, + Builtin::CommitList, + Builtin::CommitGet, + Builtin::GraphMutate, + Builtin::GraphLoad, + Builtin::BranchCreate, + Builtin::BranchDelete, + Builtin::BranchMerge, + Builtin::SchemaApply, + ]; + + fn name(self) -> &'static str { + match self { + Builtin::GraphHealth => "graph_health", + Builtin::GraphQuery => "graph_query", + Builtin::GraphSnapshot => "graph_snapshot", + Builtin::SchemaGet => "schema_get", + Builtin::BranchList => "branch_list", + Builtin::CommitList => "commit_list", + Builtin::CommitGet => "commit_get", + Builtin::GraphMutate => "graph_mutate", + Builtin::GraphLoad => "graph_load", + Builtin::BranchCreate => "branch_create", + Builtin::BranchDelete => "branch_delete", + Builtin::BranchMerge => "branch_merge", + Builtin::SchemaApply => "schema_apply", + } + } + + fn by_name(name: &str) -> Option { + Builtin::ALL.into_iter().find(|b| b.name() == name) + } + + /// The action whose grant governs whether this tool is shown in + /// `tools/list`. `None` ⇒ always visible (`graph_health` is liveness, + /// ungated). The branch dimension is intentionally absent: listing probes + /// "can this actor perform `action` on *any* branch?" via + /// `authorize_any_branch`, a relaxation of the per-call gate. A read-only + /// actor (no write grant on any branch) has writers hidden; an actor who can + /// write unprotected branches sees `graph_mutate` even when `main` is + /// protected. The authoritative per-call gate runs inside `call`. + fn list_action(self) -> Option { + match self { + Builtin::GraphHealth => None, + Builtin::GraphQuery + | Builtin::GraphSnapshot + | Builtin::SchemaGet + | Builtin::BranchList + | Builtin::CommitList + | Builtin::CommitGet => Some(PolicyAction::Read), + Builtin::GraphMutate | Builtin::GraphLoad => Some(PolicyAction::Change), + Builtin::BranchCreate => Some(PolicyAction::BranchCreate), + Builtin::BranchDelete => Some(PolicyAction::BranchDelete), + Builtin::BranchMerge => Some(PolicyAction::BranchMerge), + Builtin::SchemaApply => Some(PolicyAction::SchemaApply), + } + } + + fn descriptor(self) -> Tool { + let (description, input, annotations): (&'static str, Value, ToolAnnotations) = match self { + Builtin::GraphHealth => ( + "Liveness/identity probe for this graph. No arguments.", + json!({ "type": "object", "properties": {}, "additionalProperties": false }), + read_only_annotations(), + ), + Builtin::GraphQuery => ( + "Run an ad-hoc read-only GQ query against this graph. Mutations are rejected.", + json!({ + "type": "object", + "properties": { + "query": { "type": "string", "description": "GQ query source." }, + "name": { "type": "string", "description": "Select one query by name when the source defines several." }, + "params": { "type": "object", "description": "Query parameters." }, + "branch": { "type": "string", "description": "Branch to read (default main; exclusive with snapshot)." }, + "snapshot": { "type": "string", "description": "Snapshot id to read (exclusive with branch)." } + }, + "required": ["query"], + "additionalProperties": false + }), + read_only_annotations(), + ), + Builtin::GraphSnapshot => ( + "Read the current snapshot (manifest version + per-table metadata) of a branch.", + json!({ + "type": "object", + "properties": { "branch": { "type": "string", "description": "Branch (default main)." } }, + "additionalProperties": false + }), + read_only_annotations(), + ), + Builtin::SchemaGet => ( + "Get this graph's schema as .pg source.", + json!({ "type": "object", "properties": {}, "additionalProperties": false }), + read_only_annotations(), + ), + Builtin::BranchList => ( + "List all branch names (sorted).", + json!({ "type": "object", "properties": {}, "additionalProperties": false }), + read_only_annotations(), + ), + Builtin::CommitList => ( + "List commits, optionally scoped to a branch.", + json!({ + "type": "object", + "properties": { "branch": { "type": "string", "description": "Restrict to a branch's history." } }, + "additionalProperties": false + }), + read_only_annotations(), + ), + Builtin::CommitGet => ( + "Get a single commit by id.", + json!({ + "type": "object", + "properties": { "commit_id": { "type": "string" } }, + "required": ["commit_id"], + "additionalProperties": false + }), + read_only_annotations(), + ), + Builtin::GraphMutate => ( + "Run an ad-hoc GQ mutation (insert/update/delete) against a branch.", + json!({ + "type": "object", + "properties": { + "query": { "type": "string", "description": "GQ mutation source." }, + "name": { "type": "string", "description": "Select one query by name when the source defines several." }, + "params": { "type": "object", "description": "Query parameters." }, + "branch": { "type": "string", "description": "Branch to write (default main)." } + }, + "required": ["query"], + "additionalProperties": false + }), + write_annotations(true), + ), + Builtin::GraphLoad => ( + "Bulk-load NDJSON into a branch. Without `from`, the branch must exist (a \ + missing branch is an error, never an implicit fork).", + json!({ + "type": "object", + "properties": { + "data": { "type": "string", "description": "NDJSON, one record per line." }, + "branch": { "type": "string", "description": "Target branch (default main)." }, + "from": { "type": "string", "description": "Parent to fork `branch` from if it doesn't exist." }, + "mode": { "type": "string", "enum": ["merge", "append", "overwrite"], "description": "On existing rows (default merge)." } + }, + "required": ["data"], + "additionalProperties": false + }), + write_annotations(true), + ), + Builtin::BranchCreate => ( + "Create a branch by forking `from` (default main).", + json!({ + "type": "object", + "properties": { + "name": { "type": "string", "description": "New branch name." }, + "from": { "type": "string", "description": "Parent branch (default main)." } + }, + "required": ["name"], + "additionalProperties": false + }), + // Additive — not destructive. + write_annotations(false), + ), + Builtin::BranchDelete => ( + "Delete a branch.", + json!({ + "type": "object", + "properties": { "branch": { "type": "string" } }, + "required": ["branch"], + "additionalProperties": false + }), + write_annotations(true), + ), + Builtin::BranchMerge => ( + "Merge `source` into `target` (default main).", + json!({ + "type": "object", + "properties": { + "source": { "type": "string" }, + "target": { "type": "string", "description": "Default main." } + }, + "required": ["source"], + "additionalProperties": false + }), + write_annotations(true), + ), + Builtin::SchemaApply => ( + "Apply a schema migration (.pg source). Disabled (409) on cluster-backed \ + serving — use `omnigraph cluster apply` and restart.", + json!({ + "type": "object", + "properties": { + "schema_source": { "type": "string", "description": "Target schema as .pg source." }, + "allow_data_loss": { "type": "boolean", "description": "Permit data-dropping steps (default false)." } + }, + "required": ["schema_source"], + "additionalProperties": false + }), + write_annotations(true), + ), + }; + Tool::new(self.name(), description, schema(input)).with_annotations(annotations) + } + + async fn call( + self, + state: &AppState, + actor: Option<&ResolvedActor>, + handle: &Arc, + args: JsonObject, + ) -> Result { + match self { + Builtin::GraphHealth => json_result(&json!({ + "graph_id": handle.key.graph_id.as_str(), + "status": "ok", + "version": SERVER_VERSION, + })), + Builtin::GraphQuery => { + // run_query self-authorizes Read (per its real target branch). + let (selected, target, result) = run_query( + Arc::clone(handle), + actor, + &req_str(&args, "query")?, + opt_str(&args, "name")?.as_deref(), + opt_obj(&args, "params")?.as_ref(), + opt_str(&args, "branch")?, + opt_str(&args, "snapshot")?, + true, + ) + .await?; + json_result(&api::read_output(selected, &target, result)) + } + Builtin::GraphSnapshot => { + let branch = opt_str(&args, "branch")?.unwrap_or_else(|| "main".to_string()); + authorize_request(actor, handle.policy.as_deref(), read_request(Some(branch.clone())))?; + let snapshot = handle + .engine + .snapshot_of(ReadTarget::branch(branch.as_str())) + .await + .map_err(ApiError::from_omni)?; + json_result(&api::snapshot_payload(&branch, &snapshot)) + } + Builtin::SchemaGet => { + authorize_request(actor, handle.policy.as_deref(), read_request(None))?; + json_result(&api::SchemaOutput { + schema_source: handle.engine.schema_source().to_string(), + }) + } + Builtin::BranchList => { + authorize_request(actor, handle.policy.as_deref(), read_request(None))?; + let mut branches = handle.engine.branch_list().await.map_err(ApiError::from_omni)?; + branches.sort(); + json_result(&api::BranchListOutput { branches }) + } + Builtin::CommitList => { + let branch = opt_str(&args, "branch")?; + authorize_request(actor, handle.policy.as_deref(), read_request(branch.clone()))?; + let commits = handle + .engine + .list_commits(branch.as_deref()) + .await + .map_err(ApiError::from_omni)?; + json_result(&api::CommitListOutput { + commits: commits.iter().map(api::commit_output).collect(), + }) + } + Builtin::CommitGet => { + let commit_id = req_str(&args, "commit_id")?; + authorize_request(actor, handle.policy.as_deref(), read_request(None))?; + let commit = handle + .engine + .get_commit(&commit_id) + .await + .map_err(ApiError::from_omni)?; + json_result(&api::commit_output(&commit)) + } + Builtin::GraphMutate => { + // run_mutate self-authorizes Change + admission. + let branch = opt_str(&args, "branch")?.unwrap_or_else(|| "main".to_string()); + let output = run_mutate( + state.clone(), + Arc::clone(handle), + actor, + &req_str(&args, "query")?, + opt_str(&args, "name")?.as_deref(), + opt_obj(&args, "params")?.as_ref(), + branch, + ) + .await?; + json_result(&output) + } + Builtin::GraphLoad => { + // run_ingest self-authorizes (BranchCreate iff `from`, then + // Change) + admission, and 404s a missing branch with no `from`. + let request: api::IngestRequest = from_args(args, "load request")?; + let output = run_ingest(state.clone(), Arc::clone(handle), actor, request).await?; + json_result(&output) + } + Builtin::BranchCreate => { + let request: api::BranchCreateRequest = from_args(args, "branch-create request")?; + let from = request.from.unwrap_or_else(|| "main".to_string()); + authorize_request( + actor, + handle.policy.as_deref(), + PolicyRequest { + action: PolicyAction::BranchCreate, + branch: Some(from.clone()), + target_branch: Some(request.name.clone()), + }, + )?; + let _admission = state + .workload + .try_admit(&actor_arc(actor), 256) + .map_err(ApiError::from_workload_reject)?; + handle + .engine + .branch_create_from_as(ReadTarget::branch(&from), &request.name, actor_id(actor)) + .await + .map_err(ApiError::from_omni)?; + json_result(&api::BranchCreateOutput { + uri: handle.uri.clone(), + from, + name: request.name, + actor_id: actor_id(actor).map(str::to_string), + }) + } + Builtin::BranchDelete => { + let branch = req_str(&args, "branch")?; + authorize_request( + actor, + handle.policy.as_deref(), + PolicyRequest { + action: PolicyAction::BranchDelete, + branch: None, + target_branch: Some(branch.clone()), + }, + )?; + let _admission = state + .workload + .try_admit(&actor_arc(actor), 256) + .map_err(ApiError::from_workload_reject)?; + handle + .engine + .branch_delete_as(&branch, actor_id(actor)) + .await + .map_err(ApiError::from_omni)?; + json_result(&api::BranchDeleteOutput { + uri: handle.uri.clone(), + name: branch, + actor_id: actor_id(actor).map(str::to_string), + }) + } + Builtin::BranchMerge => { + let request: api::BranchMergeRequest = from_args(args, "branch-merge request")?; + let target = request.target.unwrap_or_else(|| "main".to_string()); + authorize_request( + actor, + handle.policy.as_deref(), + PolicyRequest { + action: PolicyAction::BranchMerge, + branch: Some(request.source.clone()), + target_branch: Some(target.clone()), + }, + )?; + let _admission = state + .workload + .try_admit(&actor_arc(actor), 256) + .map_err(ApiError::from_workload_reject)?; + let outcome = handle + .engine + .branch_merge_as(&request.source, &target, actor_id(actor)) + .await + .map_err(ApiError::from_omni)?; + json_result(&api::BranchMergeOutput { + source: request.source, + target, + outcome: outcome.into(), + actor_id: actor_id(actor).map(str::to_string), + }) + } + Builtin::SchemaApply => { + let request: api::SchemaApplyRequest = from_args(args, "schema-apply request")?; + authorize_request( + actor, + handle.policy.as_deref(), + PolicyRequest { + action: PolicyAction::SchemaApply, + branch: None, + target_branch: Some("main".to_string()), + }, + )?; + // Disable on cluster-backed serving AFTER the Cedar gate, so an + // unauthorized actor gets 403, not a topology-disclosing 409. + if state.routing().config_path.is_some() { + return Err(ApiError::conflict( + "server-side schema apply is disabled for cluster-backed serving; \ + update the cluster config, run `omnigraph cluster apply`, and restart \ + the server.", + )); + } + let _admission = state + .workload + .try_admit(&actor_arc(actor), request.schema_source.len() as u64) + .map_err(ApiError::from_workload_reject)?; + let registry = handle.queries.as_deref(); + let label = handle.key.graph_id.as_str().to_string(); + let result = handle + .engine + .apply_schema_as_with_catalog_check( + &request.schema_source, + SchemaApplyOptions { allow_data_loss: request.allow_data_loss }, + actor_id(actor), + |catalog| { + if let Some(registry) = registry { + validate_registry_against_catalog(registry, catalog, &label)?; + } + Ok(()) + }, + ) + .await + .map_err(ApiError::from_omni)?; + if result.applied { + let engine = Arc::clone(&handle.engine); + tokio::spawn(async move { + if let Err(err) = engine.ensure_indices().await { + tracing::warn!(error = %err, "post-apply ensure_indices failed"); + } + }); + } + json_result(&api::schema_apply_output(handle.uri.as_str(), result)) + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn builtin_tool_names_match_enum() { + let from_enum: Vec<&str> = Builtin::ALL.iter().map(|b| b.name()).collect(); + assert_eq!( + from_enum, BUILTIN_TOOL_NAMES, + "BUILTIN_TOOL_NAMES must stay in sync with Builtin::ALL (the collision-check reserves these)" + ); + } +} diff --git a/crates/omnigraph-server/src/queries.rs b/crates/omnigraph-server/src/queries.rs index 09d2491..085800d 100644 --- a/crates/omnigraph-server/src/queries.rs +++ b/crates/omnigraph-server/src/queries.rs @@ -95,6 +95,11 @@ impl std::fmt::Display for LoadError { } } +/// Sentinel "winner" used to seed the collision check with built-in tool +/// names. Not a valid query symbol (`<`/`>` are not identifier characters), so +/// it can never collide with a real query name. +const BUILTIN_OWNER: &str = ""; + impl QueryRegistry { /// Build a registry from in-memory specs: parse each source, select /// the declaration whose symbol equals the manifest key, and assert @@ -147,14 +152,24 @@ impl QueryRegistry { // before it is moved into `Self`. { let mut claimed: BTreeMap<&str, &str> = BTreeMap::new(); + // Built-in MCP tool names are reserved graph-wide. A stored query + // that shadows one would silently never be served (built-ins win at + // dispatch) — the deny-list forbids silent drops, so seed them here + // and fail loudly at load instead. + for builtin in crate::mcp::BUILTIN_TOOL_NAMES { + claimed.insert(builtin, BUILTIN_OWNER); + } for query in by_name.values().filter(|q| q.expose) { let tool = query.effective_tool_name(); if let Some(winner) = claimed.insert(tool, &query.name) { + let message = if winner == BUILTIN_OWNER { + format!("MCP tool name '{tool}' is reserved by a built-in tool") + } else { + format!("MCP tool name '{tool}' already claimed by exposed query '{winner}'") + }; errors.push(LoadError { query: Some(query.name.clone()), - message: format!( - "MCP tool name '{tool}' already claimed by exposed query '{winner}'" - ), + message, }); } } @@ -167,14 +182,35 @@ impl QueryRegistry { } } + /// Resolve by symbol name, **ignoring `expose`**. The raw catalog accessor + /// for HTTP/service callers (`expose:false` queries are deliberately + /// HTTP-callable; see [`StoredQuery::expose`]). The MCP backend must NOT use + /// this — it resolves through [`Self::exposed_by_name`] so the agent surface + /// can never reach a query hidden from the tool list. pub fn lookup(&self, name: &str) -> Option<&StoredQuery> { self.by_name.get(name) } + /// Iterate the full catalog, **ignoring `expose`** (HTTP/service surface). pub fn iter(&self) -> impl Iterator { self.by_name.values() } + /// The MCP-reachable catalog: exactly the exposed queries. The single + /// `expose` chokepoint for the agent surface — `tools/list`, the + /// `stored_query_list` tool, and per-query tool dispatch all funnel through + /// it, so they cannot drift on which queries an agent may see or run. + pub fn exposed(&self) -> impl Iterator { + self.by_name.values().filter(|q| q.expose) + } + + /// Resolve by symbol name, **exposed-only** — the MCP `stored_query_run` + /// resolver. An unexposed query is unreachable by name through this path + /// even to a caller that knows the name (the agent surface honors `expose`). + pub fn exposed_by_name(&self, name: &str) -> Option<&StoredQuery> { + self.by_name.get(name).filter(|q| q.expose) + } + pub fn is_empty(&self) -> bool { self.by_name.is_empty() } diff --git a/crates/omnigraph-server/tests/mcp.rs b/crates/omnigraph-server/tests/mcp.rs new file mode 100644 index 0000000..960f813 --- /dev/null +++ b/crates/omnigraph-server/tests/mcp.rs @@ -0,0 +1,618 @@ +//! Black-box tests for the MCP surface (`POST /graphs/{id}/mcp`), driven over +//! `build_app` with in-process tower `oneshot`. Phase 2 covers the read tools, +//! resources, protocol conformance, Cedar-filtered listing, and the server-side +//! Origin fail-closed wiring. (Crate-level transport conformance — 405, the +//! rmcp surface guard — lives in `omnigraph-mcp/tests/standalone.rs`.) + +mod support; + +use axum::Router; +use axum::body::Body; +use axum::http::{Method, Request, StatusCode}; +use omnigraph_server::queries::{QueryRegistry, RegistrySpec}; +use omnigraph_server::{AppState, build_app}; +use serde_json::{Value, json}; +use support::{ + FIND_PERSON_GQ, INVOKE_POLICY_YAML, app_for_loaded_graph_with_auth_tokens, + app_for_loaded_graph_with_auth_tokens_and_policy, app_with_stored_queries, g, graph_path, + init_loaded_graph, json_response, +}; + +/// Build a JSON-RPC POST to `/graphs/default/mcp`. Sets the `Accept` (both +/// JSON + SSE, as rmcp requires) and `Host` (loopback policy allows it) headers, +/// and an optional bearer token. +fn mcp_request(token: Option<&str>, body: Value) -> Request { + let mut builder = Request::builder() + .uri(g("/mcp")) + .method(Method::POST) + .header("host", "localhost") + .header("content-type", "application/json") + .header("accept", "application/json, text/event-stream"); + if let Some(token) = token { + builder = builder.header("authorization", format!("Bearer {token}")); + } + builder + .body(Body::from(serde_json::to_vec(&body).unwrap())) + .unwrap() +} + +fn rpc(id: i64, method: &str, params: Value) -> Value { + json!({ "jsonrpc": "2.0", "id": id, "method": method, "params": params }) +} + +#[tokio::test] +async fn initialize_advertises_tools_and_resources() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc( + 1, + "initialize", + json!({ + "protocolVersion": "2025-11-25", + "capabilities": {}, + "clientInfo": { "name": "test", "version": "0" } + }), + ), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_eq!(v["result"]["serverInfo"]["name"], "omnigraph"); + assert!(v["result"]["capabilities"]["tools"].is_object()); + assert!(v["result"]["capabilities"]["resources"].is_object()); +} + +fn tool_names(list_result: &Value) -> Vec { + list_result["result"]["tools"] + .as_array() + .unwrap() + .iter() + .map(|t| t["name"].as_str().unwrap().to_string()) + .collect() +} + +#[tokio::test] +async fn tools_list_returns_builtins_with_no_cursor() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = + json_response(&app, mcp_request(Some("tok"), rpc(2, "tools/list", json!({})))).await; + assert_eq!(status, StatusCode::OK); + let names = tool_names(&v); + for expected in [ + "graph_health", + "graph_query", + "graph_snapshot", + "schema_get", + "branch_list", + "commit_list", + "commit_get", + ] { + assert!(names.contains(&expected.to_string()), "missing tool {expected} in {names:?}"); + } + // Non-paginated by contract. + assert!(v["result"]["nextCursor"].is_null()); +} + +#[tokio::test] +async fn graph_health_returns_ok() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = json_response( + &app, + mcp_request(Some("tok"), rpc(3, "tools/call", json!({ "name": "graph_health", "arguments": {} }))), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_ne!(v["result"]["isError"], json!(true)); + let text = v["result"]["content"][0]["text"].as_str().unwrap(); + assert!(text.contains("\"status\":\"ok\""), "health payload: {text}"); +} + +#[tokio::test] +async fn graph_query_runs_a_read() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc( + 4, + "tools/call", + json!({ + "name": "graph_query", + "arguments": { "query": "query all() { match { $p: Person } return { $p.name } }" } + }), + ), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_ne!(v["result"]["isError"], json!(true), "unexpected isError: {v}"); + // ReadOutput carries a row_count; the text mirror is the serialized DTO. + let text = v["result"]["content"][0]["text"].as_str().unwrap(); + assert!(text.contains("row_count"), "read output: {text}"); +} + +#[tokio::test] +async fn malformed_query_is_iserror_not_protocol_error() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc( + 5, + "tools/call", + json!({ "name": "graph_query", "arguments": { "query": "this is not gq" } }), + ), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + // A bad query is a semantic (4xx) failure → isError tool result, not a + // JSON-RPC protocol error (SEP-1303). + assert_eq!(v["result"]["isError"], json!(true), "expected isError, got {v}"); + assert!(v["error"].is_null()); +} + +#[tokio::test] +async fn unknown_tool_is_invalid_params() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = json_response( + &app, + mcp_request(Some("tok"), rpc(6, "tools/call", json!({ "name": "no_such_tool", "arguments": {} }))), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_eq!(v["error"]["code"], json!(-32602)); +} + +const READER_ONLY_POLICY: &str = r#" +version: 1 +groups: + readers: [act-reader] +protected_branches: [main] +rules: + - id: readers-read + allow: + actors: { group: readers } + actions: [read] + branch_scope: any +"#; + +#[tokio::test] +async fn cedar_filters_listing_and_gates_calls() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens_and_policy( + &[("act-reader", "tok-r"), ("act-none", "tok-n")], + READER_ONLY_POLICY, + ) + .await; + + // The reader sees the Read-gated tools. + let (_s, reader) = + json_response(&app, mcp_request(Some("tok-r"), rpc(1, "tools/list", json!({})))).await; + let reader_names = tool_names(&reader); + assert!(reader_names.contains(&"graph_query".to_string())); + assert!(reader_names.contains(&"schema_get".to_string())); + + // act-none has no rules → Read denied → only the ungated graph_health shows. + let (_s, none) = + json_response(&app, mcp_request(Some("tok-n"), rpc(2, "tools/list", json!({})))).await; + let none_names = tool_names(&none); + assert_eq!(none_names, vec!["graph_health".to_string()], "denied actor saw {none_names:?}"); + + // And a denied call surfaces isError (the read gate inside the delegate). + let (status, v) = json_response( + &app, + mcp_request(Some("tok-n"), rpc(3, "tools/call", json!({ "name": "schema_get", "arguments": {} }))), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_eq!(v["result"]["isError"], json!(true), "expected denied schema_get to isError: {v}"); +} + +#[tokio::test] +async fn resource_read_returns_schema() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc(1, "resources/read", json!({ "uri": "omnigraph://schema" })), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + let text = v["result"]["contents"][0]["text"].as_str().unwrap(); + assert!(text.contains("node Person"), "schema resource: {text}"); +} + +/// Server-side wiring of the fail-closed Origin policy: a non-loopback bind +/// yields `DenyBrowsers`, so a present `Origin` is `403` while an absent one +/// passes. (The policy logic itself is unit-tested in omnigraph-mcp.) +async fn app_with_public_bind() -> (tempfile::TempDir, Router) { + let temp = init_loaded_graph().await; + let graph = graph_path(temp.path()); + let state = AppState::open(graph.to_string_lossy().to_string()) + .await + .unwrap() + .with_mcp_host_inputs("203.0.113.1:8080".parse().unwrap(), Vec::new(), Vec::new()); + (temp, build_app(state)) +} + +#[tokio::test] +async fn public_bind_rejects_present_origin() { + let (_t, app) = app_with_public_bind().await; + let init = rpc( + 1, + "initialize", + json!({ "protocolVersion": "2025-11-25", "capabilities": {}, + "clientInfo": { "name": "t", "version": "0" } }), + ); + + // Present, forged Origin → 403 (origin_guard). + let mut with_origin = mcp_request(None, init.clone()); + with_origin + .headers_mut() + .insert("origin", "https://evil.example".parse().unwrap()); + // A non-loopback bind also disables Host-allowlisting (allowed_hosts None), + // so the Host header is irrelevant here. + let resp = { + use tower::ServiceExt; + app.clone().oneshot(with_origin).await.unwrap() + }; + assert_eq!(resp.status(), StatusCode::FORBIDDEN); + + // Absent Origin → request proceeds (200). + let (status, _v) = json_response(&app, mcp_request(None, init)).await; + assert_eq!(status, StatusCode::OK); +} + +// ===== Phase 3: write tools, stored queries, structured output ===== + +#[tokio::test] +async fn graph_query_emits_structured_content() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (_s, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc( + 1, + "tools/call", + json!({ + "name": "graph_query", + "arguments": { "query": "query all() { match { $p: Person } return { $p.name } }" } + }), + ), + ), + ) + .await; + // Structured output: structuredContent present (never null) + text mirror. + assert!(v["result"]["structuredContent"].is_object(), "no structuredContent: {v}"); + assert!(v["result"]["structuredContent"]["row_count"].is_number()); + assert!(v["result"]["content"][0]["text"].is_string()); +} + +#[tokio::test] +async fn graph_mutate_writes_end_to_end() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc( + 1, + "tools/call", + json!({ + "name": "graph_mutate", + "arguments": { + "query": "query ins($name: String, $age: I32) { insert Person { name: $name, age: $age } }", + "params": { "name": "McpWrite", "age": 41 }, + "branch": "main" + } + }), + ), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_ne!(v["result"]["isError"], json!(true), "mutate failed: {v}"); + assert!( + v["result"]["structuredContent"]["affected_nodes"].as_u64().unwrap_or(0) >= 1, + "expected an inserted node: {v}" + ); +} + +#[tokio::test] +async fn graph_load_missing_branch_then_fork() { + let (_t, app) = app_for_loaded_graph_with_auth_tokens(&[("act", "tok")]).await; + let line = r#"{"type":"Person","data":{"name":"McpLoaded","age":7}}"#; + + // Missing branch + no `from` → 404 → isError (never an implicit fork). + let (_s, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc(1, "tools/call", json!({ "name": "graph_load", "arguments": { "data": line, "branch": "nope" } })), + ), + ) + .await; + assert_eq!(v["result"]["isError"], json!(true), "expected 404 isError: {v}"); + + // With `from` → forks the branch and loads. + let (_s, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc( + 2, + "tools/call", + json!({ "name": "graph_load", "arguments": { "data": line, "branch": "feature", "from": "main" } }), + ), + ), + ) + .await; + assert_ne!(v["result"]["isError"], json!(true), "fork-and-load failed: {v}"); +} + +#[tokio::test] +async fn stored_query_projects_as_a_tool_and_runs() { + // 1 exposed query → per_query mode → it appears as its own tool. + let (_t, app) = app_with_stored_queries( + &[("find_person", FIND_PERSON_GQ, true)], + &[("act-invoke", "tok")], + INVOKE_POLICY_YAML, + ) + .await; + + let (_s, list) = + json_response(&app, mcp_request(Some("tok"), rpc(1, "tools/list", json!({})))).await; + assert!( + tool_names(&list).contains(&"find_person".to_string()), + "stored query not projected: {:?}", + tool_names(&list) + ); + + // And it runs (params nested under `params`). + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc( + 2, + "tools/call", + json!({ "name": "find_person", "arguments": { "params": { "name": "Nobody" } } }), + ), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_ne!(v["result"]["isError"], json!(true), "stored query failed: {v}"); + assert!(v["result"]["structuredContent"]["row_count"].is_number()); +} + +#[tokio::test] +async fn stored_query_invoke_denied_masks_as_unknown_tool() { + // act-noinvoke has `read` but not `invoke_query` → the outer gate denies and + // the stored tool masks byte-identically to an unknown tool. + let (_t, app) = app_with_stored_queries( + &[("find_person", FIND_PERSON_GQ, true)], + &[("act-noinvoke", "tok")], + INVOKE_POLICY_YAML, + ) + .await; + let (_s, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc(1, "tools/call", json!({ "name": "find_person", "arguments": { "params": {} } })), + ), + ) + .await; + assert_eq!(v["error"]["code"], json!(-32602)); + assert_eq!( + v["error"]["message"].as_str().unwrap(), + "unknown tool: find_person", + "denied stored query must mask as unknown" + ); +} + +#[tokio::test] +async fn large_catalog_uses_meta_projection() { + // At/above the auto threshold (24 exposed queries) the projection collapses + // to the discovery + execute meta pair instead of N typed tools. + let sources: Vec<(String, String)> = (0..25) + .map(|i| { + let name = format!("q{i}"); + let src = format!("query {name}() {{ match {{ $p: Person }} return {{ $p.name }} }}"); + (name, src) + }) + .collect(); + let specs: Vec<(&str, &str, bool)> = sources + .iter() + .map(|(n, s)| (n.as_str(), s.as_str(), true)) + .collect(); + let (_t, app) = + app_with_stored_queries(&specs, &[("act-invoke", "tok")], INVOKE_POLICY_YAML).await; + + let (_s, list) = + json_response(&app, mcp_request(Some("tok"), rpc(1, "tools/list", json!({})))).await; + let names = tool_names(&list); + assert!(names.contains(&"stored_query_list".to_string()), "{names:?}"); + assert!(names.contains(&"stored_query_run".to_string()), "{names:?}"); + assert!(!names.contains(&"q5".to_string()), "meta mode must not list per-query tools: {names:?}"); + + // stored_query_run executes one by name. + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc(2, "tools/call", json!({ "name": "stored_query_run", "arguments": { "name": "q5" } })), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_ne!(v["result"]["isError"], json!(true), "stored_query_run failed: {v}"); +} + +const PROTECTED_MAIN_WRITE_BRANCHES_POLICY: &str = r#" +version: 1 +groups: + writers: [act-writer] + readers: [act-reader] +protected_branches: [main] +rules: + - id: writers-read + allow: + actors: { group: writers } + actions: [read] + branch_scope: any + - id: writers-change-unprotected + allow: + actors: { group: writers } + actions: [change] + branch_scope: unprotected + - id: readers-read + allow: + actors: { group: readers } + actions: [read] + branch_scope: any +"#; + +#[tokio::test] +async fn write_tool_listed_when_only_unprotected_writes_allowed() { + // The canonical workflow policy: protected `main`, writable feature branches. + // `graph_mutate`/`graph_load` must be advertised to an actor who can change + // unprotected branches — the per-call gate is authoritative and would allow + // graph_mutate(branch="feature"). Listing probes the action capability on + // *any* branch, not a fabricated `main` (which is protected → denied). A + // read-only actor must still NOT see the write tools. + let (_t, app) = app_for_loaded_graph_with_auth_tokens_and_policy( + &[("act-writer", "tok-w"), ("act-reader", "tok-r")], + PROTECTED_MAIN_WRITE_BRANCHES_POLICY, + ) + .await; + + let (_s, w) = + json_response(&app, mcp_request(Some("tok-w"), rpc(1, "tools/list", json!({})))).await; + let w_names = tool_names(&w); + assert!( + w_names.contains(&"graph_mutate".to_string()), + "graph_mutate hidden from an unprotected-branch writer (under-show): {w_names:?}" + ); + assert!(w_names.contains(&"graph_load".to_string()), "graph_load hidden: {w_names:?}"); + + let (_s, r) = + json_response(&app, mcp_request(Some("tok-r"), rpc(2, "tools/list", json!({})))).await; + let r_names = tool_names(&r); + assert!( + !r_names.contains(&"graph_mutate".to_string()), + "graph_mutate shown to a read-only actor (over-show regression): {r_names:?}" + ); + assert!( + r_names.contains(&"graph_query".to_string()), + "reader should still see read tools: {r_names:?}" + ); +} + +#[tokio::test] +async fn per_query_mode_does_not_expose_meta_tools() { + // Below the auto threshold the projection is per-query, so the discovery + + // execute meta pair was never advertised. It must not be callable either — + // `call_tool` resolves a stored tool through the same projection `tools/list` + // renders, so list and call cannot diverge. + let (_t, app) = app_with_stored_queries( + &[("find_person", FIND_PERSON_GQ, true)], + &[("act-invoke", "tok")], + INVOKE_POLICY_YAML, + ) + .await; + + for tool in ["stored_query_run", "stored_query_list"] { + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc(1, "tools/call", json!({ "name": tool, "arguments": { "name": "find_person" } })), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_eq!(v["error"]["code"], json!(-32602), "{tool} must be unknown in per_query mode: {v}"); + assert_eq!( + v["error"]["message"].as_str().unwrap(), + format!("unknown tool: {tool}"), + "{tool} must mask as unknown when the projection didn't advertise it" + ); + } +} + +#[tokio::test] +async fn stored_query_run_cannot_reach_unexposed_query() { + // Meta projection (24 exposed) plus one unexposed `hidden`. `stored_query_run` + // must not resolve the unexposed query even to a caller that knows its name — + // the agent surface honors `expose`, like every other stored-query path. + // (`expose:false` stays HTTP/service-callable; this is the MCP boundary only.) + let exposed: Vec<(String, String)> = (0..24) + .map(|i| { + let name = format!("q{i}"); + let src = format!("query {name}() {{ match {{ $p: Person }} return {{ $p.name }} }}"); + (name, src) + }) + .collect(); + let hidden_src = "query hidden() { match { $p: Person } return { $p.name } }"; + let mut specs: Vec<(&str, &str, bool)> = + exposed.iter().map(|(n, s)| (n.as_str(), s.as_str(), true)).collect(); + specs.push(("hidden", hidden_src, false)); + + let (_t, app) = + app_with_stored_queries(&specs, &[("act-invoke", "tok")], INVOKE_POLICY_YAML).await; + + // Confirm the meta projection is in force (so stored_query_run exists), and + // that the unexposed query is not discoverable via stored_query_list. + let (_s, list) = + json_response(&app, mcp_request(Some("tok"), rpc(1, "tools/list", json!({})))).await; + assert!(tool_names(&list).contains(&"stored_query_run".to_string()), "{:?}", tool_names(&list)); + let (_s, listed) = json_response( + &app, + mcp_request(Some("tok"), rpc(2, "tools/call", json!({ "name": "stored_query_list", "arguments": {} }))), + ) + .await; + let catalog = listed["result"]["structuredContent"]["queries"].as_array().unwrap(); + assert!( + catalog.iter().all(|q| q["name"] != json!("hidden")), + "unexposed query leaked into stored_query_list: {listed}" + ); + + // Running the unexposed query by name → not found (isError), never executed. + let (status, v) = json_response( + &app, + mcp_request( + Some("tok"), + rpc(3, "tools/call", json!({ "name": "stored_query_run", "arguments": { "name": "hidden" } })), + ), + ) + .await; + assert_eq!(status, StatusCode::OK); + assert_eq!(v["result"]["isError"], json!(true), "unexposed query must not run via stored_query_run: {v}"); +} + +#[test] +fn stored_query_shadowing_a_builtin_is_a_load_error() { + // A stored query whose tool name collides with a built-in must fail loudly + // at registry load, never be silently un-served. + let result = QueryRegistry::from_specs(vec![RegistrySpec { + name: "graph_query".to_string(), + source: "query graph_query() { match { $p: Person } return { $p.name } }".to_string(), + expose: true, + tool_name: None, + }]); + let errors = result.expect_err("expected a collision error"); + assert!( + errors.iter().any(|e| e.message.contains("reserved by a built-in")), + "expected built-in reservation error, got {errors:?}" + ); +} diff --git a/docs/dev/rfc-003-mcp-server-surface.md b/docs/dev/rfc-003-mcp-server-surface.md index a9fef9b..201e8e7 100644 --- a/docs/dev/rfc-003-mcp-server-surface.md +++ b/docs/dev/rfc-003-mcp-server-surface.md @@ -234,7 +234,12 @@ impl McpHostPolicy { pub fn from_bind(bind: &SocketAddr, public_hosts: &[String], browser_origins: &[String]) -> Self { let loopback = bind.ip().is_loopback(); Self { - allowed_hosts: if loopback { Some(vec!["127.0.0.1".into(), "localhost".into()]) } + // Loopback bind ⇒ the full loopback Host set (both stacks + the + // hostname alias), matching rmcp's default `["localhost","127.0.0.1","::1"]`. + // The Host header is independent of the bound socket (in-process, + // proxies, dual-stack localhost), so a 127-bound server must still + // accept a `[::1]` Host — deriving the list from `bind.ip()` alone 403'd it. + allowed_hosts: if loopback { Some(vec!["127.0.0.1".into(), "::1".into(), "localhost".into()]) } else if public_hosts.is_empty() { None } else { Some(public_hosts.to_vec()) }, origin: if !browser_origins.is_empty() { OriginPolicy::Allow(browser_origins.to_vec()) } else if loopback { OriginPolicy::Unchecked } // local dev convenience only @@ -597,11 +602,22 @@ Represent built-ins as a `Builtin` enum (one variant per tool; `descriptor` / `g `call` as match arms) — lower liability than ~13 unit structs + `dyn`. Stored-query tools are a sibling populator over `handle.queries`. -**`list_tools` / `list_resources` are Cedar-filtered** by running the *same* -authorization the call path runs, with **default args (branch `main`)** — not a -`branch: None` probe (which matches no `branch_scope` rule and would hide tools the -actor can call on a scoped branch). Over-showing a branch-scoped grant is the safe -direction; `call_tool` is the authoritative gate. +**`list_tools` / `list_resources` are Cedar-filtered as a *relaxation* of the +call-path gate** — listing never hides a tool the caller could invoke on some +branch (over-showing is the safe direction; `call_tool` is authoritative). A +built-in whose authorization depends on a caller-chosen branch (`graph_mutate`, +`graph_load`, `branch_*`) is shown iff `authorize_any_branch` → +`PolicyEngine::permits_on_any_branch(actor, action)` is true: that probes the +branch-shape space (omitted / protected / unprotected) through the same Cedar +authorizer and returns true if *any* shape is allowed. A fixed-branch probe is +wrong here — both a fabricated `main` (denied under "protect `main`, write +unprotected branches", the canonical workflow) and a `branch: None` probe +(matches no `branch_scope` rule) under-show `graph_mutate` to an actor who can +write feature branches. The stored-query surface gets the same list/call +agreement structurally: `resolve_stored_tool` is the single membership test, so +the meta pair is callable only in `meta` mode and `stored_query_run` resolves +**exposed-only** (an `expose:false` query is unreachable by name on the agent +surface, though it stays HTTP/service-callable). ## 11. Dispatch reuse + error classification