mirror of
https://github.com/samvallad33/vestige.git
synced 2026-05-08 15:22:37 +02:00
feat: add MCP Streamable HTTP transport with Bearer auth
Adds a second transport layer alongside stdio — Streamable HTTP on port 3928. Enables Claude.ai, remote clients, and web integrations to connect to Vestige over HTTP with per-session McpServer instances. - POST /mcp (JSON-RPC) + DELETE /mcp (session cleanup) - Bearer token auth with constant-time comparison (subtle crate) - Auto-generated UUID v4 token persisted with 0o600 permissions - Per-session McpServer instances with 30-min idle reaper - 100 max sessions, 50 concurrency limit, 256KB body limit - --http-port flag + VESTIGE_HTTP_PORT / VESTIGE_HTTP_BIND env vars - Module exports moved from binary to lib.rs for reusability - vestige CLI gains `serve` subcommand via shared lib Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
070889ef26
commit
816b577f69
11 changed files with 849 additions and 170 deletions
|
|
@ -4,6 +4,7 @@
|
|||
|
||||
use std::io::{BufWriter, Write};
|
||||
use std::path::PathBuf;
|
||||
use std::sync::Arc;
|
||||
|
||||
use chrono::{NaiveDate, Utc};
|
||||
use clap::{Parser, Subcommand};
|
||||
|
|
@ -109,6 +110,19 @@ enum Commands {
|
|||
#[arg(long)]
|
||||
source: Option<String>,
|
||||
},
|
||||
|
||||
/// Start standalone HTTP MCP server (no stdio, for remote access)
|
||||
Serve {
|
||||
/// HTTP transport port
|
||||
#[arg(long, default_value = "3928")]
|
||||
port: u16,
|
||||
/// Also start the dashboard
|
||||
#[arg(long)]
|
||||
dashboard: bool,
|
||||
/// Dashboard port
|
||||
#[arg(long, default_value = "3927")]
|
||||
dashboard_port: u16,
|
||||
},
|
||||
}
|
||||
|
||||
fn main() -> anyhow::Result<()> {
|
||||
|
|
@ -139,6 +153,11 @@ fn main() -> anyhow::Result<()> {
|
|||
node_type,
|
||||
source,
|
||||
} => run_ingest(content, tags, node_type, source),
|
||||
Commands::Serve {
|
||||
port,
|
||||
dashboard,
|
||||
dashboard_port,
|
||||
} => run_serve(port, dashboard, dashboard_port),
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -967,6 +986,82 @@ fn run_dashboard(port: u16, open_browser: bool) -> anyhow::Result<()> {
|
|||
})
|
||||
}
|
||||
|
||||
/// Start standalone HTTP MCP server (no stdio transport)
|
||||
fn run_serve(port: u16, with_dashboard: bool, dashboard_port: u16) -> anyhow::Result<()> {
|
||||
use vestige_mcp::cognitive::CognitiveEngine;
|
||||
|
||||
println!("{}", "=== Vestige HTTP Server ===".cyan().bold());
|
||||
println!();
|
||||
|
||||
let storage = Storage::new(None)?;
|
||||
|
||||
#[cfg(feature = "embeddings")]
|
||||
{
|
||||
if let Err(e) = storage.init_embeddings() {
|
||||
println!(
|
||||
" {} Embeddings unavailable: {} (search will use keyword-only)",
|
||||
"!".yellow(),
|
||||
e
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
let storage = Arc::new(storage);
|
||||
|
||||
let rt = tokio::runtime::Runtime::new()?;
|
||||
rt.block_on(async move {
|
||||
let cognitive = Arc::new(tokio::sync::Mutex::new(CognitiveEngine::new()));
|
||||
{
|
||||
let mut cog = cognitive.lock().await;
|
||||
cog.hydrate(&storage);
|
||||
}
|
||||
|
||||
let (event_tx, _) =
|
||||
tokio::sync::broadcast::channel::<vestige_mcp::dashboard::events::VestigeEvent>(1024);
|
||||
|
||||
// Optionally start dashboard
|
||||
if with_dashboard {
|
||||
let ds = Arc::clone(&storage);
|
||||
let dc = Arc::clone(&cognitive);
|
||||
let dtx = event_tx.clone();
|
||||
tokio::spawn(async move {
|
||||
match vestige_mcp::dashboard::start_background_with_event_tx(ds, Some(dc), dtx, dashboard_port).await {
|
||||
Ok(_) => println!(" {} Dashboard: http://127.0.0.1:{}", ">".cyan(), dashboard_port),
|
||||
Err(e) => eprintln!(" {} Dashboard failed: {}", "!".yellow(), e),
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
// Get auth token
|
||||
let token = vestige_mcp::protocol::auth::get_or_create_auth_token()
|
||||
.map_err(|e| anyhow::anyhow!("Failed to create auth token: {}", e))?;
|
||||
|
||||
let bind = std::env::var("VESTIGE_HTTP_BIND").unwrap_or_else(|_| "127.0.0.1".to_string());
|
||||
println!(" {} HTTP transport: http://{}:{}/mcp", ">".cyan(), bind, port);
|
||||
println!(" {} Auth token: {}...", ">".cyan(), &token[..8]);
|
||||
println!();
|
||||
println!("{}", "Press Ctrl+C to stop.".dimmed());
|
||||
|
||||
// Start HTTP transport (blocks on the server, no stdio)
|
||||
vestige_mcp::protocol::http::start_http_transport(
|
||||
Arc::clone(&storage),
|
||||
Arc::clone(&cognitive),
|
||||
event_tx,
|
||||
token,
|
||||
port,
|
||||
)
|
||||
.await
|
||||
.map_err(|e| anyhow::anyhow!("HTTP transport failed: {}", e))?;
|
||||
|
||||
// Keep the process alive (the HTTP server runs in a spawned task)
|
||||
tokio::signal::ctrl_c().await.ok();
|
||||
println!();
|
||||
println!("{}", "Shutting down...".dimmed());
|
||||
|
||||
Ok(())
|
||||
})
|
||||
}
|
||||
|
||||
/// Truncate a string for display (UTF-8 safe)
|
||||
fn truncate(s: &str, max_chars: usize) -> String {
|
||||
let s = s.replace('\n', " ");
|
||||
|
|
|
|||
|
|
@ -4,3 +4,7 @@
|
|||
|
||||
pub mod cognitive;
|
||||
pub mod dashboard;
|
||||
pub mod protocol;
|
||||
pub mod resources;
|
||||
pub mod server;
|
||||
pub mod tools;
|
||||
|
|
|
|||
|
|
@ -27,12 +27,9 @@
|
|||
//! - Reconsolidation (memories editable on retrieval)
|
||||
//! - Memory Chains (reasoning paths)
|
||||
|
||||
// cognitive is exported from lib.rs for dashboard access
|
||||
use vestige_mcp::cognitive;
|
||||
mod protocol;
|
||||
mod resources;
|
||||
mod server;
|
||||
mod tools;
|
||||
use vestige_mcp::protocol;
|
||||
use vestige_mcp::server;
|
||||
|
||||
use std::io;
|
||||
use std::path::PathBuf;
|
||||
|
|
@ -44,15 +41,24 @@ use tracing_subscriber::EnvFilter;
|
|||
// Use vestige-core for the cognitive science engine
|
||||
use vestige_core::Storage;
|
||||
|
||||
use crate::protocol::stdio::StdioTransport;
|
||||
use crate::server::McpServer;
|
||||
use protocol::stdio::StdioTransport;
|
||||
use server::McpServer;
|
||||
|
||||
/// Parse command-line arguments and return the optional data directory path.
|
||||
/// Returns `None` for the path if no `--data-dir` was specified.
|
||||
/// Parsed CLI configuration.
|
||||
struct Config {
|
||||
data_dir: Option<PathBuf>,
|
||||
http_port: u16,
|
||||
}
|
||||
|
||||
/// Parse command-line arguments into a `Config`.
|
||||
/// Exits the process if `--help` or `--version` is requested.
|
||||
fn parse_args() -> Option<PathBuf> {
|
||||
fn parse_args() -> Config {
|
||||
let args: Vec<String> = std::env::args().collect();
|
||||
let mut data_dir: Option<PathBuf> = None;
|
||||
let mut http_port: u16 = std::env::var("VESTIGE_HTTP_PORT")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or(3928);
|
||||
let mut i = 1;
|
||||
|
||||
while i < args.len() {
|
||||
|
|
@ -69,13 +75,18 @@ fn parse_args() -> Option<PathBuf> {
|
|||
println!(" -h, --help Print help information");
|
||||
println!(" -V, --version Print version information");
|
||||
println!(" --data-dir <PATH> Custom data directory");
|
||||
println!(" --http-port <PORT> HTTP transport port (default: 3928)");
|
||||
println!();
|
||||
println!("ENVIRONMENT:");
|
||||
println!(" RUST_LOG Log level filter (e.g., debug, info, warn, error)");
|
||||
println!(" RUST_LOG Log level filter (e.g., debug, info, warn, error)");
|
||||
println!(" VESTIGE_AUTH_TOKEN Override the bearer token for HTTP transport");
|
||||
println!(" VESTIGE_HTTP_PORT HTTP transport port (default: 3928)");
|
||||
println!(" VESTIGE_DASHBOARD_PORT Dashboard port (default: 3927)");
|
||||
println!();
|
||||
println!("EXAMPLES:");
|
||||
println!(" vestige-mcp");
|
||||
println!(" vestige-mcp --data-dir /custom/path");
|
||||
println!(" vestige-mcp --http-port 8080");
|
||||
println!(" RUST_LOG=debug vestige-mcp");
|
||||
std::process::exit(0);
|
||||
}
|
||||
|
|
@ -102,6 +113,31 @@ fn parse_args() -> Option<PathBuf> {
|
|||
}
|
||||
data_dir = Some(PathBuf::from(path));
|
||||
}
|
||||
"--http-port" => {
|
||||
i += 1;
|
||||
if i >= args.len() {
|
||||
eprintln!("error: --http-port requires a port number");
|
||||
eprintln!("Usage: vestige-mcp --http-port <PORT>");
|
||||
std::process::exit(1);
|
||||
}
|
||||
http_port = match args[i].parse() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
eprintln!("error: invalid port number '{}'", args[i]);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
}
|
||||
arg if arg.starts_with("--http-port=") => {
|
||||
let val = arg.strip_prefix("--http-port=").unwrap_or("");
|
||||
http_port = match val.parse() {
|
||||
Ok(p) => p,
|
||||
Err(_) => {
|
||||
eprintln!("error: invalid port number '{}'", val);
|
||||
std::process::exit(1);
|
||||
}
|
||||
};
|
||||
}
|
||||
arg => {
|
||||
eprintln!("error: unknown argument '{}'", arg);
|
||||
eprintln!("Usage: vestige-mcp [OPTIONS]");
|
||||
|
|
@ -112,13 +148,13 @@ fn parse_args() -> Option<PathBuf> {
|
|||
i += 1;
|
||||
}
|
||||
|
||||
data_dir
|
||||
Config { data_dir, http_port }
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() {
|
||||
// Parse CLI arguments first (before logging init, so --help/--version work cleanly)
|
||||
let data_dir = parse_args();
|
||||
let config = parse_args();
|
||||
|
||||
// Initialize logging to stderr (stdout is for JSON-RPC)
|
||||
tracing_subscriber::fmt()
|
||||
|
|
@ -134,7 +170,7 @@ async fn main() {
|
|||
info!("Vestige MCP Server v{} starting...", env!("CARGO_PKG_VERSION"));
|
||||
|
||||
// Initialize storage with optional custom data directory
|
||||
let storage = match Storage::new(data_dir) {
|
||||
let storage = match Storage::new(config.data_dir) {
|
||||
Ok(s) => {
|
||||
info!("Storage initialized successfully");
|
||||
|
||||
|
|
@ -260,6 +296,38 @@ async fn main() {
|
|||
});
|
||||
}
|
||||
|
||||
// Start HTTP MCP transport (Streamable HTTP for Claude.ai / remote clients)
|
||||
{
|
||||
let http_storage = Arc::clone(&storage);
|
||||
let http_cognitive = Arc::clone(&cognitive);
|
||||
let http_event_tx = event_tx.clone();
|
||||
let http_port = config.http_port;
|
||||
|
||||
match protocol::auth::get_or_create_auth_token() {
|
||||
Ok(token) => {
|
||||
let bind = std::env::var("VESTIGE_HTTP_BIND").unwrap_or_else(|_| "127.0.0.1".to_string());
|
||||
eprintln!("Vestige HTTP transport: http://{}:{}/mcp", bind, http_port);
|
||||
eprintln!("Auth token: {}...", &token[..8]);
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = protocol::http::start_http_transport(
|
||||
http_storage,
|
||||
http_cognitive,
|
||||
http_event_tx,
|
||||
token,
|
||||
http_port,
|
||||
)
|
||||
.await
|
||||
{
|
||||
warn!("HTTP transport failed to start: {}", e);
|
||||
}
|
||||
});
|
||||
}
|
||||
Err(e) => {
|
||||
warn!("Could not create auth token, HTTP transport disabled: {}", e);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Load cross-encoder reranker in the background (downloads ~150MB on first run)
|
||||
#[cfg(feature = "embeddings")]
|
||||
{
|
||||
|
|
|
|||
104
crates/vestige-mcp/src/protocol/auth.rs
Normal file
104
crates/vestige-mcp/src/protocol/auth.rs
Normal file
|
|
@ -0,0 +1,104 @@
|
|||
//! Bearer token authentication for the HTTP transport.
|
||||
//!
|
||||
//! Token priority:
|
||||
//! 1. `VESTIGE_AUTH_TOKEN` env var (override)
|
||||
//! 2. Read from `<data_dir>/auth_token` file
|
||||
//! 3. Generate `uuid::Uuid::new_v4()`, write to file with 0o600 permissions
|
||||
//!
|
||||
//! Security: The token file is created with restricted permissions from the
|
||||
//! start (via OpenOptionsExt on Unix) to prevent a TOCTOU race where another
|
||||
//! process could read the token before permissions are set.
|
||||
|
||||
use std::fs;
|
||||
use std::path::PathBuf;
|
||||
|
||||
use directories::ProjectDirs;
|
||||
use tracing::{info, warn};
|
||||
|
||||
/// Minimum recommended token length when provided via env var.
|
||||
const MIN_TOKEN_LENGTH: usize = 32;
|
||||
|
||||
/// Return the auth token file path inside the Vestige data directory.
|
||||
fn token_path() -> Result<PathBuf, Box<dyn std::error::Error>> {
|
||||
let dirs = ProjectDirs::from("com", "vestige", "core")
|
||||
.ok_or("could not determine project directories")?;
|
||||
Ok(dirs.data_dir().join("auth_token"))
|
||||
}
|
||||
|
||||
/// Get (or create) the bearer token used for HTTP transport authentication.
|
||||
///
|
||||
/// Priority:
|
||||
/// 1. `VESTIGE_AUTH_TOKEN` environment variable
|
||||
/// 2. Existing `auth_token` file in the data directory
|
||||
/// 3. Newly generated UUID v4, persisted to file
|
||||
pub fn get_or_create_auth_token() -> Result<String, Box<dyn std::error::Error>> {
|
||||
// 1. Env var override
|
||||
if let Ok(token) = std::env::var("VESTIGE_AUTH_TOKEN") {
|
||||
let token = token.trim().to_string();
|
||||
if !token.is_empty() {
|
||||
if token.len() < MIN_TOKEN_LENGTH {
|
||||
warn!(
|
||||
"VESTIGE_AUTH_TOKEN is only {} chars (recommended >= {}). \
|
||||
Short tokens are vulnerable to brute-force attacks.",
|
||||
token.len(),
|
||||
MIN_TOKEN_LENGTH
|
||||
);
|
||||
}
|
||||
info!("Using auth token from VESTIGE_AUTH_TOKEN env var");
|
||||
return Ok(token);
|
||||
}
|
||||
}
|
||||
|
||||
let path = token_path()?;
|
||||
|
||||
// 2. Read existing file
|
||||
if path.exists() {
|
||||
let token = fs::read_to_string(&path)?.trim().to_string();
|
||||
if !token.is_empty() {
|
||||
info!("Using auth token from {}", path.display());
|
||||
return Ok(token);
|
||||
}
|
||||
}
|
||||
|
||||
// 3. Generate new token and persist
|
||||
let token = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
// Ensure parent directory exists with restricted permissions
|
||||
if let Some(parent) = path.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
|
||||
// Restrict parent directory permissions on Unix (owner only)
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::os::unix::fs::PermissionsExt;
|
||||
let _ = fs::set_permissions(parent, fs::Permissions::from_mode(0o700));
|
||||
}
|
||||
}
|
||||
|
||||
// Write token file with restricted permissions from the start.
|
||||
// On Unix, we use OpenOptionsExt to set mode 0o600 at creation time,
|
||||
// avoiding the TOCTOU race of write-then-chmod.
|
||||
#[cfg(unix)]
|
||||
{
|
||||
use std::io::Write;
|
||||
use std::os::unix::fs::OpenOptionsExt;
|
||||
|
||||
let mut file = fs::OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.truncate(true)
|
||||
.mode(0o600) // Owner read/write only — set at creation, no race window
|
||||
.open(&path)?;
|
||||
file.write_all(token.as_bytes())?;
|
||||
file.sync_all()?;
|
||||
}
|
||||
|
||||
// On non-Unix (Windows), fall back to regular write (Windows ACLs are different)
|
||||
#[cfg(not(unix))]
|
||||
{
|
||||
fs::write(&path, &token)?;
|
||||
}
|
||||
|
||||
info!("Generated new auth token at {}", path.display());
|
||||
Ok(token)
|
||||
}
|
||||
308
crates/vestige-mcp/src/protocol/http.rs
Normal file
308
crates/vestige-mcp/src/protocol/http.rs
Normal file
|
|
@ -0,0 +1,308 @@
|
|||
//! Streamable HTTP transport for MCP.
|
||||
//!
|
||||
//! Implements the MCP Streamable HTTP transport specification:
|
||||
//! - `POST /mcp` — JSON-RPC endpoint (initialize, tools/call, etc.)
|
||||
//! - `DELETE /mcp` — session cleanup
|
||||
//!
|
||||
//! Each client gets a per-session `McpServer` instance (owns `initialized` state).
|
||||
//! Shared state (Storage, CognitiveEngine, event bus) is shared across sessions.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::time::{Duration, Instant};
|
||||
|
||||
use axum::extract::{DefaultBodyLimit, State};
|
||||
use axum::http::{HeaderMap, StatusCode};
|
||||
use axum::response::IntoResponse;
|
||||
use axum::routing::{delete, post};
|
||||
use axum::{Json, Router};
|
||||
use subtle::ConstantTimeEq;
|
||||
use tokio::sync::{broadcast, Mutex, RwLock};
|
||||
use tower::ServiceBuilder;
|
||||
use tower::limit::ConcurrencyLimitLayer;
|
||||
use tower_http::cors::CorsLayer;
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::cognitive::CognitiveEngine;
|
||||
use crate::protocol::types::JsonRpcRequest;
|
||||
use crate::server::McpServer;
|
||||
use vestige_core::Storage;
|
||||
use crate::dashboard::events::VestigeEvent;
|
||||
|
||||
/// Maximum concurrent sessions.
|
||||
const MAX_SESSIONS: usize = 100;
|
||||
|
||||
/// Sessions idle longer than this are reaped.
|
||||
const SESSION_TIMEOUT: Duration = Duration::from_secs(30 * 60);
|
||||
|
||||
/// How often the reaper task runs.
|
||||
const REAPER_INTERVAL: Duration = Duration::from_secs(5 * 60);
|
||||
|
||||
/// Concurrency limit for the tower middleware.
|
||||
const CONCURRENCY_LIMIT: usize = 50;
|
||||
|
||||
/// Maximum request body size (256 KB — JSON-RPC requests should be small).
|
||||
const MAX_BODY_SIZE: usize = 256 * 1024;
|
||||
|
||||
/// A per-client session holding its own McpServer instance.
|
||||
struct Session {
|
||||
server: McpServer,
|
||||
last_active: Instant,
|
||||
}
|
||||
|
||||
/// Shared state cloned into every axum handler.
|
||||
#[derive(Clone)]
|
||||
pub struct HttpTransportState {
|
||||
sessions: Arc<RwLock<HashMap<String, Arc<Mutex<Session>>>>>,
|
||||
storage: Arc<Storage>,
|
||||
cognitive: Arc<Mutex<CognitiveEngine>>,
|
||||
event_tx: broadcast::Sender<VestigeEvent>,
|
||||
auth_token: String,
|
||||
}
|
||||
|
||||
/// Start the HTTP MCP transport on `127.0.0.1:<port>`.
|
||||
///
|
||||
/// This function spawns a background tokio task and returns immediately.
|
||||
pub async fn start_http_transport(
|
||||
storage: Arc<Storage>,
|
||||
cognitive: Arc<Mutex<CognitiveEngine>>,
|
||||
event_tx: broadcast::Sender<VestigeEvent>,
|
||||
auth_token: String,
|
||||
port: u16,
|
||||
) -> Result<(), Box<dyn std::error::Error>> {
|
||||
let state = HttpTransportState {
|
||||
sessions: Arc::new(RwLock::new(HashMap::new())),
|
||||
storage,
|
||||
cognitive,
|
||||
event_tx,
|
||||
auth_token,
|
||||
};
|
||||
|
||||
// Spawn session reaper
|
||||
{
|
||||
let sessions = Arc::clone(&state.sessions);
|
||||
tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::time::sleep(REAPER_INTERVAL).await;
|
||||
let mut map = sessions.write().await;
|
||||
let before = map.len();
|
||||
map.retain(|_id, session| {
|
||||
// Try to check last_active without blocking; skip if locked
|
||||
match session.try_lock() {
|
||||
Ok(s) => s.last_active.elapsed() < SESSION_TIMEOUT,
|
||||
Err(_) => true, // in-use, keep
|
||||
}
|
||||
});
|
||||
let removed = before - map.len();
|
||||
if removed > 0 {
|
||||
info!("Session reaper: removed {} idle sessions ({} active)", removed, map.len());
|
||||
}
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
let app = Router::new()
|
||||
.route("/mcp", post(post_mcp))
|
||||
.route("/mcp", delete(delete_mcp))
|
||||
.layer(
|
||||
ServiceBuilder::new()
|
||||
.layer(DefaultBodyLimit::max(MAX_BODY_SIZE))
|
||||
.layer(ConcurrencyLimitLayer::new(CONCURRENCY_LIMIT))
|
||||
.layer(CorsLayer::permissive()),
|
||||
)
|
||||
.with_state(state);
|
||||
|
||||
// Bind to localhost only — use VESTIGE_HTTP_BIND=0.0.0.0 for remote access
|
||||
let bind_addr: std::net::IpAddr = std::env::var("VESTIGE_HTTP_BIND")
|
||||
.ok()
|
||||
.and_then(|s| s.parse().ok())
|
||||
.unwrap_or_else(|| std::net::IpAddr::V4(std::net::Ipv4Addr::LOCALHOST));
|
||||
|
||||
let addr = std::net::SocketAddr::from((bind_addr, port));
|
||||
let listener = tokio::net::TcpListener::bind(addr).await?;
|
||||
info!("HTTP MCP transport listening on http://{}/mcp", addr);
|
||||
|
||||
tokio::spawn(async move {
|
||||
if let Err(e) = axum::serve(listener, app).await {
|
||||
warn!("HTTP transport error: {}", e);
|
||||
}
|
||||
});
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Helpers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// Validate the `Authorization: Bearer <token>` header using constant-time
|
||||
/// comparison to prevent timing side-channel attacks.
|
||||
fn validate_auth(headers: &HeaderMap, expected: &str) -> Result<(), (StatusCode, &'static str)> {
|
||||
let header = headers
|
||||
.get("authorization")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.ok_or((StatusCode::UNAUTHORIZED, "Missing Authorization header"))?;
|
||||
|
||||
let token = header
|
||||
.strip_prefix("Bearer ")
|
||||
.ok_or((StatusCode::UNAUTHORIZED, "Invalid Authorization scheme (expected Bearer)"))?;
|
||||
|
||||
// Constant-time comparison: prevents timing side-channel attacks.
|
||||
// We first check lengths match (length itself is not secret since UUIDs
|
||||
// have a fixed public format), then compare bytes in constant time.
|
||||
let token_bytes = token.as_bytes();
|
||||
let expected_bytes = expected.as_bytes();
|
||||
|
||||
if token_bytes.len() != expected_bytes.len()
|
||||
|| token_bytes.ct_eq(expected_bytes).unwrap_u8() != 1
|
||||
{
|
||||
return Err((StatusCode::FORBIDDEN, "Invalid auth token"));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Extract and validate the `Mcp-Session-Id` header value.
|
||||
///
|
||||
/// Only accepts valid UUID v4 format (8-4-4-4-12 hex) to prevent header
|
||||
/// injection and ensure session IDs match server-generated format.
|
||||
fn session_id_from_headers(headers: &HeaderMap) -> Option<String> {
|
||||
headers
|
||||
.get("mcp-session-id")
|
||||
.and_then(|v| v.to_str().ok())
|
||||
.filter(|s| uuid::Uuid::parse_str(s).is_ok())
|
||||
.map(|s| s.to_string())
|
||||
}
|
||||
|
||||
// ---------------------------------------------------------------------------
|
||||
// Handlers
|
||||
// ---------------------------------------------------------------------------
|
||||
|
||||
/// `POST /mcp` — main JSON-RPC handler.
|
||||
async fn post_mcp(
|
||||
State(state): State<HttpTransportState>,
|
||||
headers: HeaderMap,
|
||||
Json(request): Json<JsonRpcRequest>,
|
||||
) -> impl IntoResponse {
|
||||
// Auth check
|
||||
if let Err((status, msg)) = validate_auth(&headers, &state.auth_token) {
|
||||
return (status, HeaderMap::new(), msg.to_string()).into_response();
|
||||
}
|
||||
|
||||
let is_initialize = request.method == "initialize";
|
||||
|
||||
if is_initialize {
|
||||
// ── New session ──
|
||||
// Take write lock immediately to avoid TOCTOU race on MAX_SESSIONS check.
|
||||
let mut sessions = state.sessions.write().await;
|
||||
if sessions.len() >= MAX_SESSIONS {
|
||||
return (
|
||||
StatusCode::SERVICE_UNAVAILABLE,
|
||||
"Too many active sessions",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
|
||||
let server = McpServer::new_with_events(
|
||||
Arc::clone(&state.storage),
|
||||
Arc::clone(&state.cognitive),
|
||||
state.event_tx.clone(),
|
||||
);
|
||||
|
||||
let session_id = uuid::Uuid::new_v4().to_string();
|
||||
|
||||
let session = Arc::new(Mutex::new(Session {
|
||||
server,
|
||||
last_active: Instant::now(),
|
||||
}));
|
||||
|
||||
// Handle the initialize request
|
||||
let response = {
|
||||
let mut sess = session.lock().await;
|
||||
sess.server.handle_request(request).await
|
||||
};
|
||||
|
||||
// Insert session while still holding write lock — atomic check-and-insert
|
||||
sessions.insert(session_id.clone(), session);
|
||||
drop(sessions);
|
||||
|
||||
match response {
|
||||
Some(resp) => {
|
||||
let mut resp_headers = HeaderMap::new();
|
||||
resp_headers.insert("mcp-session-id", session_id.parse().unwrap());
|
||||
(StatusCode::OK, resp_headers, Json(resp)).into_response()
|
||||
}
|
||||
None => {
|
||||
// Notifications return 202
|
||||
let mut resp_headers = HeaderMap::new();
|
||||
resp_headers.insert("mcp-session-id", session_id.parse().unwrap());
|
||||
(StatusCode::ACCEPTED, resp_headers).into_response()
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// ── Existing session ──
|
||||
let session_id = match session_id_from_headers(&headers) {
|
||||
Some(id) => id,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::BAD_REQUEST,
|
||||
"Missing or invalid Mcp-Session-Id header",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let session = {
|
||||
let sessions = state.sessions.read().await;
|
||||
sessions.get(&session_id).cloned()
|
||||
};
|
||||
|
||||
let session = match session {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
return (
|
||||
StatusCode::NOT_FOUND,
|
||||
"Session not found or expired",
|
||||
)
|
||||
.into_response();
|
||||
}
|
||||
};
|
||||
|
||||
let response = {
|
||||
let mut sess = session.lock().await;
|
||||
sess.last_active = Instant::now();
|
||||
sess.server.handle_request(request).await
|
||||
};
|
||||
|
||||
let mut resp_headers = HeaderMap::new();
|
||||
resp_headers.insert("mcp-session-id", session_id.parse().unwrap());
|
||||
|
||||
match response {
|
||||
Some(resp) => (StatusCode::OK, resp_headers, Json(resp)).into_response(),
|
||||
None => (StatusCode::ACCEPTED, resp_headers).into_response(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// `DELETE /mcp` — explicit session cleanup.
|
||||
async fn delete_mcp(
|
||||
State(state): State<HttpTransportState>,
|
||||
headers: HeaderMap,
|
||||
) -> impl IntoResponse {
|
||||
if let Err((status, msg)) = validate_auth(&headers, &state.auth_token) {
|
||||
return (status, msg).into_response();
|
||||
}
|
||||
|
||||
let session_id = match session_id_from_headers(&headers) {
|
||||
Some(id) => id,
|
||||
None => return (StatusCode::BAD_REQUEST, "Missing or invalid Mcp-Session-Id header").into_response(),
|
||||
};
|
||||
|
||||
let mut sessions = state.sessions.write().await;
|
||||
if sessions.remove(&session_id).is_some() {
|
||||
info!("Session {} deleted via DELETE /mcp", &session_id[..8]);
|
||||
(StatusCode::OK, "Session deleted").into_response()
|
||||
} else {
|
||||
(StatusCode::NOT_FOUND, "Session not found").into_response()
|
||||
}
|
||||
}
|
||||
|
|
@ -2,6 +2,8 @@
|
|||
//!
|
||||
//! JSON-RPC 2.0 over stdio for the Model Context Protocol.
|
||||
|
||||
pub mod auth;
|
||||
pub mod http;
|
||||
pub mod messages;
|
||||
pub mod stdio;
|
||||
pub mod types;
|
||||
|
|
|
|||
|
|
@ -11,7 +11,7 @@ use tokio::sync::{broadcast, Mutex};
|
|||
use tracing::{debug, info, warn};
|
||||
|
||||
use crate::cognitive::CognitiveEngine;
|
||||
use vestige_mcp::dashboard::events::VestigeEvent;
|
||||
use crate::dashboard::events::VestigeEvent;
|
||||
use crate::protocol::messages::{
|
||||
CallToolRequest, CallToolResult, InitializeRequest, InitializeResult,
|
||||
ListResourcesResult, ListToolsResult, ReadResourceRequest, ReadResourceResult,
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue