mirror of
https://github.com/ModernRelay/omnigraph.git
synced 2026-06-09 01:35:18 +02:00
Initial public Omnigraph repository
This commit is contained in:
commit
338289656a
110 changed files with 60747 additions and 0 deletions
4
.dockerignore
Normal file
4
.dockerignore
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
**
|
||||
!Dockerfile
|
||||
!docker/entrypoint.sh
|
||||
!target/release/omnigraph-server
|
||||
123
.github/workflows/ci.yml
vendored
Normal file
123
.github/workflows/ci.yml
vendored
Normal file
|
|
@ -0,0 +1,123 @@
|
|||
name: CI
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
push:
|
||||
branches:
|
||||
- main
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
concurrency:
|
||||
group: ci-${{ github.workflow }}-${{ github.ref }}
|
||||
cancel-in-progress: true
|
||||
|
||||
jobs:
|
||||
test:
|
||||
name: Test Workspace
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
steps:
|
||||
- name: Checkout source
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libprotobuf-dev
|
||||
|
||||
- name: Install Rust stable
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Cache Rust build data
|
||||
uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: |
|
||||
. -> target
|
||||
|
||||
- name: Run workspace tests
|
||||
run: cargo test --workspace --locked
|
||||
|
||||
rustfs_integration:
|
||||
name: RustFS S3 Integration
|
||||
needs: test
|
||||
runs-on: ubuntu-latest
|
||||
timeout-minutes: 45
|
||||
permissions:
|
||||
contents: read
|
||||
env:
|
||||
AWS_ACCESS_KEY_ID: rustfsadmin
|
||||
AWS_SECRET_ACCESS_KEY: rustfsadmin
|
||||
AWS_REGION: us-east-1
|
||||
AWS_ENDPOINT_URL: http://127.0.0.1:9000
|
||||
AWS_ENDPOINT_URL_S3: http://127.0.0.1:9000
|
||||
AWS_ALLOW_HTTP: "true"
|
||||
AWS_S3_FORCE_PATH_STYLE: "true"
|
||||
OMNIGRAPH_S3_TEST_BUCKET: omnigraph-ci
|
||||
OMNIGRAPH_S3_TEST_PREFIX: github-actions
|
||||
CARGO_TERM_COLOR: always
|
||||
steps:
|
||||
- name: Checkout source
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install system dependencies
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libprotobuf-dev python3-pip
|
||||
|
||||
- name: Install Rust stable
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Cache Rust build data
|
||||
uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: |
|
||||
. -> target
|
||||
|
||||
- name: Start RustFS
|
||||
run: |
|
||||
docker rm -f rustfs >/dev/null 2>&1 || true
|
||||
docker run -d \
|
||||
--name rustfs \
|
||||
-p 9000:9000 \
|
||||
-p 9001:9001 \
|
||||
-e RUSTFS_ACCESS_KEY="${AWS_ACCESS_KEY_ID}" \
|
||||
-e RUSTFS_SECRET_KEY="${AWS_SECRET_ACCESS_KEY}" \
|
||||
rustfs/rustfs:latest \
|
||||
/data
|
||||
|
||||
- name: Install AWS CLI
|
||||
run: |
|
||||
python3 -m pip install --user awscli
|
||||
echo "$HOME/.local/bin" >> "$GITHUB_PATH"
|
||||
|
||||
- name: Create RustFS test bucket
|
||||
run: |
|
||||
for _ in $(seq 1 30); do
|
||||
if aws --endpoint-url "${AWS_ENDPOINT_URL_S3}" s3api list-buckets >/dev/null 2>&1; then
|
||||
break
|
||||
fi
|
||||
sleep 2
|
||||
done
|
||||
aws --endpoint-url "${AWS_ENDPOINT_URL_S3}" \
|
||||
s3api create-bucket \
|
||||
--bucket "${OMNIGRAPH_S3_TEST_BUCKET}" >/dev/null 2>&1 || true
|
||||
|
||||
- name: Run RustFS-backed repo tests
|
||||
run: |
|
||||
cargo test --locked -p omnigraph --test s3_storage -- --nocapture
|
||||
cargo test --locked -p omnigraph-server --test server server_opens_s3_repo_directly_and_serves_snapshot_and_read -- --nocapture
|
||||
cargo test --locked -p omnigraph-cli --test system_local local_cli_s3_end_to_end_init_load_read_flow -- --nocapture
|
||||
|
||||
- name: Dump RustFS logs on failure
|
||||
if: failure()
|
||||
run: docker logs rustfs
|
||||
68
.github/workflows/release.yml
vendored
Normal file
68
.github/workflows/release.yml
vendored
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
name: Release
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- "v*"
|
||||
workflow_dispatch:
|
||||
|
||||
jobs:
|
||||
build_release:
|
||||
name: Build ${{ matrix.asset_name }}
|
||||
runs-on: ${{ matrix.runner }}
|
||||
permissions:
|
||||
contents: write
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
include:
|
||||
- runner: ubuntu-latest
|
||||
asset_name: omnigraph-linux-x86_64
|
||||
- runner: macos-13
|
||||
asset_name: omnigraph-macos-x86_64
|
||||
- runner: macos-14
|
||||
asset_name: omnigraph-macos-arm64
|
||||
env:
|
||||
CARGO_TERM_COLOR: always
|
||||
steps:
|
||||
- name: Checkout source
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install Linux dependencies
|
||||
if: runner.os == 'Linux'
|
||||
run: |
|
||||
sudo apt-get update
|
||||
sudo apt-get install -y protobuf-compiler libprotobuf-dev
|
||||
|
||||
- name: Install macOS dependencies
|
||||
if: runner.os == 'macOS'
|
||||
run: brew install protobuf
|
||||
|
||||
- name: Install Rust stable
|
||||
uses: dtolnay/rust-toolchain@stable
|
||||
with:
|
||||
toolchain: stable
|
||||
|
||||
- name: Cache Rust build data
|
||||
uses: Swatinem/rust-cache@v2
|
||||
with:
|
||||
workspaces: |
|
||||
. -> target
|
||||
|
||||
- name: Build release binaries
|
||||
run: cargo build --release --locked -p omnigraph-cli -p omnigraph-server
|
||||
|
||||
- name: Package release archive
|
||||
run: |
|
||||
mkdir -p release
|
||||
install -m 0755 target/release/omnigraph release/omnigraph
|
||||
install -m 0755 target/release/omnigraph-server release/omnigraph-server
|
||||
tar -C release -czf "${{ matrix.asset_name }}.tar.gz" omnigraph omnigraph-server
|
||||
shasum -a 256 "${{ matrix.asset_name }}.tar.gz" > "${{ matrix.asset_name }}.sha256"
|
||||
|
||||
- name: Publish GitHub release assets
|
||||
uses: softprops/action-gh-release@v2
|
||||
with:
|
||||
files: |
|
||||
${{ matrix.asset_name }}.tar.gz
|
||||
${{ matrix.asset_name }}.sha256
|
||||
18
.gitignore
vendored
Normal file
18
.gitignore
vendored
Normal file
|
|
@ -0,0 +1,18 @@
|
|||
/target
|
||||
**/target
|
||||
*.lance
|
||||
*.nano
|
||||
*.nanograph
|
||||
.DS_Store
|
||||
.env
|
||||
.env.*
|
||||
*.tfstate
|
||||
*.tfstate.*
|
||||
.terraform/
|
||||
.terraform.lock.hcl
|
||||
*.tfvars
|
||||
!*.tfvars.example
|
||||
__pycache__/
|
||||
*.pyc
|
||||
demo/*.omni/
|
||||
.omnigraph-rustfs-demo/
|
||||
13
CODE_OF_CONDUCT.md
Normal file
13
CODE_OF_CONDUCT.md
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
# Code Of Conduct
|
||||
|
||||
This project follows a simple rule: be direct, respectful, and constructive.
|
||||
|
||||
Expected behavior:
|
||||
|
||||
- focus on technical substance
|
||||
- assume good intent
|
||||
- give actionable feedback
|
||||
- avoid harassment, personal attacks, and pile-ons
|
||||
|
||||
Maintainers may remove comments, issues, or pull requests that make the project
|
||||
harder to collaborate in productively.
|
||||
23
CONTRIBUTING.md
Normal file
23
CONTRIBUTING.md
Normal file
|
|
@ -0,0 +1,23 @@
|
|||
# Contributing
|
||||
|
||||
Small bug fixes and documentation improvements are welcome directly through pull
|
||||
requests.
|
||||
|
||||
For larger changes, please open an issue or design discussion first so the
|
||||
proposed direction is clear before implementation starts.
|
||||
|
||||
## Development
|
||||
|
||||
```bash
|
||||
cargo build --workspace
|
||||
cargo test --workspace
|
||||
```
|
||||
|
||||
If you touch S3-backed flows, the CI model uses a local RustFS instance for
|
||||
integration tests.
|
||||
|
||||
## Pull Requests
|
||||
|
||||
- keep changes focused
|
||||
- include tests for behavior changes when practical
|
||||
- update public docs when the user-facing surface changes
|
||||
7646
Cargo.lock
generated
Normal file
7646
Cargo.lock
generated
Normal file
File diff suppressed because it is too large
Load diff
79
Cargo.toml
Normal file
79
Cargo.toml
Normal file
|
|
@ -0,0 +1,79 @@
|
|||
[workspace]
|
||||
resolver = "2"
|
||||
members = [
|
||||
"crates/omnigraph-compiler",
|
||||
"crates/omnigraph",
|
||||
"crates/omnigraph-cli",
|
||||
"crates/omnigraph-server",
|
||||
]
|
||||
default-members = [
|
||||
"crates/omnigraph",
|
||||
"crates/omnigraph-cli",
|
||||
"crates/omnigraph-server",
|
||||
]
|
||||
|
||||
[workspace.dependencies]
|
||||
arrow-array = "57"
|
||||
arrow-ipc = "57"
|
||||
arrow-schema = "57"
|
||||
arrow-select = "57"
|
||||
arrow-cast = { version = "57", features = ["prettyprint"] }
|
||||
arrow-ord = "57"
|
||||
|
||||
datafusion-physical-plan = "52"
|
||||
datafusion-physical-expr = "52"
|
||||
datafusion-execution = "52"
|
||||
datafusion-common = "52"
|
||||
datafusion-expr = "52"
|
||||
datafusion-functions-aggregate = "52"
|
||||
|
||||
lance = { version = "4.0.0", default-features = false, features = ["aws"] }
|
||||
lance-datafusion = "4.0.0"
|
||||
lance-file = "4.0.0"
|
||||
lance-index = "4.0.0"
|
||||
lance-linalg = "4.0.0"
|
||||
lance-namespace = "4.0.0"
|
||||
lance-namespace-impls = "4.0.0"
|
||||
lance-table = "4.0.0"
|
||||
|
||||
ulid = "1"
|
||||
futures = "0.3"
|
||||
async-trait = "0.1"
|
||||
pest = "2"
|
||||
pest_derive = "2"
|
||||
thiserror = "2"
|
||||
tokio = { version = "1", features = ["rt-multi-thread", "macros", "time", "net", "signal", "sync"] }
|
||||
clap = { version = "4", features = ["derive"] }
|
||||
serde = { version = "1", features = ["derive"] }
|
||||
serde_json = "1"
|
||||
serde_yaml = "0.9"
|
||||
tracing = "0.1"
|
||||
tracing-subscriber = { version = "0.3", features = ["env-filter", "fmt"] }
|
||||
tower = "0.5"
|
||||
tower-http = { version = "0.6", features = ["trace"] }
|
||||
color-eyre = "0.6"
|
||||
tempfile = "3"
|
||||
ahash = "0.8"
|
||||
base64 = "0.22"
|
||||
ariadne = "0.4"
|
||||
regex = "1"
|
||||
reqwest = { version = "0.12", default-features = false, features = ["json", "rustls-tls"] }
|
||||
object_store = { version = "0.12.5", default-features = false, features = ["aws"] }
|
||||
fail = "0.5"
|
||||
time = { version = "0.3", features = ["formatting"] }
|
||||
axum = { version = "0.8", features = ["json", "macros"] }
|
||||
url = "2"
|
||||
cedar-policy = "4.9"
|
||||
sha2 = "0.10"
|
||||
|
||||
[profile.dev]
|
||||
debug = 0
|
||||
|
||||
[profile.dev.package."*"]
|
||||
opt-level = 2
|
||||
|
||||
[profile.release]
|
||||
opt-level = 2
|
||||
lto = "thin"
|
||||
codegen-units = 16
|
||||
strip = true
|
||||
25
Dockerfile
Normal file
25
Dockerfile
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
FROM debian:bookworm-slim
|
||||
|
||||
RUN apt-get update \
|
||||
&& apt-get install -y --no-install-recommends ca-certificates curl \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
RUN groupadd --system omnigraph \
|
||||
&& useradd --system --gid omnigraph --create-home --home-dir /var/lib/omnigraph omnigraph
|
||||
|
||||
COPY target/release/omnigraph-server /usr/local/bin/omnigraph-server
|
||||
COPY docker/entrypoint.sh /usr/local/bin/omnigraph-entrypoint
|
||||
|
||||
RUN chmod 0755 /usr/local/bin/omnigraph-server /usr/local/bin/omnigraph-entrypoint
|
||||
|
||||
ENV OMNIGRAPH_BIND=0.0.0.0:8080
|
||||
|
||||
WORKDIR /var/lib/omnigraph
|
||||
USER omnigraph:omnigraph
|
||||
|
||||
EXPOSE 8080
|
||||
|
||||
HEALTHCHECK --interval=30s --timeout=5s --start-period=10s --retries=3 \
|
||||
CMD curl -fsS http://127.0.0.1:8080/healthz >/dev/null || exit 1
|
||||
|
||||
ENTRYPOINT ["/usr/local/bin/omnigraph-entrypoint"]
|
||||
21
LICENSE
Normal file
21
LICENSE
Normal file
|
|
@ -0,0 +1,21 @@
|
|||
MIT License
|
||||
|
||||
Copyright (c) 2026 NanoGraph Contributors
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
||||
154
README.md
Normal file
154
README.md
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
# Omnigraph
|
||||
|
||||
Omnigraph is a typed property graph database built on Lance. It combines
|
||||
schema-first graph modeling, typed queries and mutations, Git-style graph
|
||||
workflows, and storage that runs equally well on a local directory or an
|
||||
`s3://` URI.
|
||||
|
||||
## Quick Install
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/ModernRelay/omnigraph-public/main/scripts/install.sh | bash
|
||||
```
|
||||
|
||||
This installs `omnigraph` and `omnigraph-server` into `~/.local/bin`. If no
|
||||
tagged release exists for your platform yet, the installer falls back to a
|
||||
source build.
|
||||
|
||||
## One-Command Local RustFS Bootstrap
|
||||
|
||||
```bash
|
||||
curl -fsSL https://raw.githubusercontent.com/ModernRelay/omnigraph-public/main/scripts/local-rustfs-bootstrap.sh | bash
|
||||
```
|
||||
|
||||
That bootstrap:
|
||||
|
||||
- starts RustFS on `127.0.0.1:9000`
|
||||
- creates a bucket and S3-backed repo
|
||||
- loads the checked-in context fixture
|
||||
- launches `omnigraph-server` on `127.0.0.1:8080`
|
||||
|
||||
Docker must be installed and running first.
|
||||
|
||||
## Good Fit For
|
||||
|
||||
- Team knowledge graphs and internal context graphs
|
||||
- Research, decisions, and evidence tracking
|
||||
- Collaborative knowledge systems with reviewable changes
|
||||
- Private self-hosted graph backends for local or on-prem AI tooling
|
||||
|
||||
## Why Omnigraph
|
||||
|
||||
- Typed schema, typed queries, and typed mutations
|
||||
- Git-style graph workflows: branches, commits, merges, and transactional runs
|
||||
- Local-first and S3-native storage with snapshot-pinned reads
|
||||
- Graph traversal plus text, fuzzy, BM25, vector, and RRF search in one runtime
|
||||
- Policy as code for server-side access control
|
||||
|
||||
## Quick Start
|
||||
|
||||
From a checkout of this repo:
|
||||
|
||||
```bash
|
||||
cargo build --workspace
|
||||
|
||||
cargo run -p omnigraph-cli -- init \
|
||||
--schema crates/omnigraph/tests/fixtures/test.pg \
|
||||
./repo.omni
|
||||
|
||||
cargo run -p omnigraph-cli -- load \
|
||||
--data crates/omnigraph/tests/fixtures/test.jsonl \
|
||||
./repo.omni
|
||||
|
||||
cargo run -p omnigraph-cli -- read \
|
||||
./repo.omni \
|
||||
--query crates/omnigraph/tests/fixtures/test.gq \
|
||||
--name friends_of \
|
||||
--params '{"name":"Alice"}'
|
||||
```
|
||||
|
||||
`init` also scaffolds an `omnigraph.yaml` next to the repo if one does not
|
||||
already exist.
|
||||
|
||||
## Run A Server
|
||||
|
||||
Serve the same repo over HTTP:
|
||||
|
||||
```bash
|
||||
cargo run -p omnigraph-server -- ./repo.omni --bind 127.0.0.1:8080
|
||||
```
|
||||
|
||||
Then query it remotely:
|
||||
|
||||
```bash
|
||||
cargo run -p omnigraph-cli -- read \
|
||||
--target http://127.0.0.1:8080 \
|
||||
--query crates/omnigraph/tests/fixtures/test.gq \
|
||||
--name get_person \
|
||||
--params '{"name":"Alice"}'
|
||||
```
|
||||
|
||||
Server routes include `/healthz`, `/snapshot`, `/export`, `/read`, `/change`,
|
||||
`/ingest`, `/branches`, `/runs`, and `/commits`.
|
||||
|
||||
To require auth, set `OMNIGRAPH_SERVER_BEARER_TOKEN` on the server and set the
|
||||
matching bearer token env var in your CLI target config.
|
||||
|
||||
## Common Commands
|
||||
|
||||
Core repo flow:
|
||||
|
||||
```bash
|
||||
omnigraph init --schema ./schema.pg ./repo.omni
|
||||
omnigraph load --data ./data.jsonl --mode overwrite ./repo.omni
|
||||
omnigraph snapshot ./repo.omni --branch main --json
|
||||
omnigraph read ./repo.omni --query ./queries.gq --name get_person --params '{"name":"Alice"}'
|
||||
omnigraph change ./repo.omni --query ./queries.gq --name insert_person --params '{"name":"Mina","age":28}'
|
||||
omnigraph branch create --uri ./repo.omni --from main feature-x
|
||||
omnigraph branch merge --uri ./repo.omni feature-x --into main
|
||||
```
|
||||
|
||||
More CLI examples, config patterns, and admin commands live in
|
||||
[docs/cli.md](docs/cli.md).
|
||||
|
||||
## Production Features
|
||||
|
||||
- Branches, commits, merge-base-aware graph merges, and transactional runs
|
||||
- Snapshot-pinned reads across local and S3-backed repos
|
||||
- Traversal plus text, fuzzy, BM25, vector, and RRF search
|
||||
- Axum server for reads, changes, export, branches, commits, and runs
|
||||
- Cedar-based server-side authorization
|
||||
|
||||
## Docs
|
||||
|
||||
- [Install guide](docs/install.md)
|
||||
- [CLI guide](docs/cli.md)
|
||||
- [Deployment guide](docs/deployment.md)
|
||||
|
||||
## Build And Test
|
||||
|
||||
```bash
|
||||
cargo build --workspace
|
||||
cargo check --workspace
|
||||
cargo test --workspace
|
||||
```
|
||||
|
||||
Notes:
|
||||
|
||||
- Rust stable toolchain, edition 2024
|
||||
- CI runs `cargo test --workspace --locked`
|
||||
- Full CI and some local test flows require `protobuf-compiler`
|
||||
- S3 integration tests expect an S3-compatible endpoint such as RustFS
|
||||
|
||||
## Workspace Crates
|
||||
|
||||
- `crates/omnigraph-compiler`: shared schema/query parser, typechecker, catalog, and IR lowering
|
||||
- `crates/omnigraph`: storage/runtime, branching, merge, change detection, and query execution
|
||||
- `crates/omnigraph-cli`: CLI for init/load/ingest/read/change/branch/snapshot/export/policy operations
|
||||
- `crates/omnigraph-server`: Axum HTTP server for remote reads, changes, ingest, export, branches, commits, and runs
|
||||
|
||||
## Contributing
|
||||
|
||||
Please open an issue, spec, or design discussion before sending large code
|
||||
changes. Design feedback and concrete problem statements are the fastest way to
|
||||
collaborate on the roadmap.
|
||||
14
SECURITY.md
Normal file
14
SECURITY.md
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
# Security Policy
|
||||
|
||||
Please do not report security issues through public GitHub issues.
|
||||
|
||||
If GitHub private vulnerability reporting is enabled for this repository, use
|
||||
that channel. Otherwise, contact the maintainers directly through a private
|
||||
channel before publishing details.
|
||||
|
||||
When reporting an issue, include:
|
||||
|
||||
- affected version or commit
|
||||
- impact
|
||||
- reproduction steps
|
||||
- any proposed mitigation
|
||||
28
crates/omnigraph-cli/Cargo.toml
Normal file
28
crates/omnigraph-cli/Cargo.toml
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
[package]
|
||||
name = "omnigraph-cli"
|
||||
version = "0.4.0"
|
||||
edition = "2024"
|
||||
description = "CLI for the Omnigraph graph database."
|
||||
license = "MIT"
|
||||
|
||||
[[bin]]
|
||||
name = "omnigraph"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
omnigraph = { path = "../omnigraph", version = "0.4.0" }
|
||||
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
|
||||
omnigraph-server = { path = "../omnigraph-server", version = "0.4.0" }
|
||||
clap = { workspace = true }
|
||||
color-eyre = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
reqwest = { workspace = true, features = ["blocking"] }
|
||||
|
||||
[dev-dependencies]
|
||||
assert_cmd = "2"
|
||||
predicates = "3"
|
||||
serde_json = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
586
crates/omnigraph-cli/src/embed.rs
Normal file
586
crates/omnigraph-cli/src/embed.rs
Normal file
|
|
@ -0,0 +1,586 @@
|
|||
use std::collections::{BTreeMap, HashSet};
|
||||
use std::fs::{self, File};
|
||||
use std::io::{BufRead, BufReader, BufWriter, Write};
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use clap::Args;
|
||||
use color_eyre::eyre::{Result, bail, eyre};
|
||||
use omnigraph::embedding::EmbeddingClient;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{Map, Value, json};
|
||||
|
||||
const DEFAULT_EMBED_MODEL: &str = "gemini-embedding-2-preview";
|
||||
|
||||
#[derive(Debug, Args, Clone)]
|
||||
pub(crate) struct EmbedArgs {
|
||||
/// Seed manifest path
|
||||
#[arg(long, conflicts_with_all = ["input", "output", "spec"])]
|
||||
pub seed: Option<PathBuf>,
|
||||
/// Raw seed JSONL input path
|
||||
#[arg(long, requires_all = ["output", "spec"], conflicts_with = "seed")]
|
||||
pub input: Option<PathBuf>,
|
||||
/// Embedded JSONL output path
|
||||
#[arg(long)]
|
||||
pub output: Option<PathBuf>,
|
||||
/// Embedding spec JSON path
|
||||
#[arg(long, requires_all = ["input", "output"], conflicts_with = "seed")]
|
||||
pub spec: Option<PathBuf>,
|
||||
/// Remove embedding fields instead of generating embeddings
|
||||
#[arg(long, conflicts_with = "reembed_all")]
|
||||
pub clean: bool,
|
||||
/// Regenerate embeddings for all matching rows
|
||||
#[arg(long, conflicts_with = "clean")]
|
||||
pub reembed_all: bool,
|
||||
/// Restrict processing to these type names
|
||||
#[arg(long = "type")]
|
||||
pub types: Vec<String>,
|
||||
/// Reembed or clean matching rows only. Syntax: Type:field=value or field=value
|
||||
#[arg(long = "select")]
|
||||
pub selectors: Vec<String>,
|
||||
/// Print JSON summary
|
||||
#[arg(long)]
|
||||
pub json: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
pub(crate) struct EmbedOutput {
|
||||
pub input: String,
|
||||
pub output: String,
|
||||
pub rows: usize,
|
||||
pub selected_rows: usize,
|
||||
pub embedded_rows: usize,
|
||||
pub cleaned_rows: usize,
|
||||
pub mode: &'static str,
|
||||
pub dimension: usize,
|
||||
pub model: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct EmbedJob {
|
||||
input: PathBuf,
|
||||
output: PathBuf,
|
||||
spec: EmbedSpec,
|
||||
mode: EmbedMode,
|
||||
type_filter: HashSet<String>,
|
||||
selectors: Vec<RowSelector>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy)]
|
||||
enum EmbedMode {
|
||||
FillMissing,
|
||||
ReembedAll,
|
||||
Clean,
|
||||
}
|
||||
|
||||
impl EmbedMode {
|
||||
fn as_str(self, selectors_present: bool) -> &'static str {
|
||||
match self {
|
||||
Self::FillMissing if selectors_present => "reembed_selected",
|
||||
Self::FillMissing => "fill_missing",
|
||||
Self::ReembedAll => "reembed_all",
|
||||
Self::Clean => "clean",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct EmbedSpec {
|
||||
#[serde(default = "default_embed_model")]
|
||||
model: String,
|
||||
dimension: usize,
|
||||
types: BTreeMap<String, EmbedTypeSpec>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct EmbedTypeSpec {
|
||||
target: String,
|
||||
fields: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SeedManifest {
|
||||
#[serde(default)]
|
||||
sources: Option<SeedSources>,
|
||||
#[serde(default)]
|
||||
artifacts: Option<SeedArtifacts>,
|
||||
#[serde(default)]
|
||||
embeddings: Option<EmbedSpec>,
|
||||
#[serde(default)]
|
||||
seed: Option<LegacySeed>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SeedSources {
|
||||
raw_seed: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct SeedArtifacts {
|
||||
embedded_seed: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
struct LegacySeed {
|
||||
data: PathBuf,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct RowSelector {
|
||||
type_name: Option<String>,
|
||||
field: String,
|
||||
expected: String,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum EmbedRow {
|
||||
Entity {
|
||||
type_name: String,
|
||||
data: Map<String, Value>,
|
||||
root: Map<String, Value>,
|
||||
},
|
||||
Passthrough(Map<String, Value>),
|
||||
}
|
||||
|
||||
pub(crate) fn resolve_embed_job(args: &EmbedArgs) -> Result<EmbedJob> {
|
||||
let mode = if args.clean {
|
||||
EmbedMode::Clean
|
||||
} else if args.reembed_all {
|
||||
EmbedMode::ReembedAll
|
||||
} else {
|
||||
EmbedMode::FillMissing
|
||||
};
|
||||
let selectors = args
|
||||
.selectors
|
||||
.iter()
|
||||
.map(|selector| RowSelector::parse(selector))
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
let type_filter = args.types.iter().cloned().collect::<HashSet<_>>();
|
||||
|
||||
let (input, output, spec) = if let Some(seed_path) = &args.seed {
|
||||
let manifest = load_seed_manifest(seed_path)?;
|
||||
(
|
||||
manifest.raw_seed,
|
||||
args.output.clone().unwrap_or(manifest.embedded_seed),
|
||||
manifest.spec,
|
||||
)
|
||||
} else {
|
||||
let input = args
|
||||
.input
|
||||
.clone()
|
||||
.ok_or_else(|| eyre!("--input is required when --seed is not provided"))?;
|
||||
let output = args
|
||||
.output
|
||||
.clone()
|
||||
.ok_or_else(|| eyre!("--output is required when --seed is not provided"))?;
|
||||
let spec_path = args
|
||||
.spec
|
||||
.clone()
|
||||
.ok_or_else(|| eyre!("--spec is required when --seed is not provided"))?;
|
||||
let spec = load_embed_spec(&spec_path)?;
|
||||
(input, output, spec)
|
||||
};
|
||||
|
||||
if spec.model != DEFAULT_EMBED_MODEL {
|
||||
bail!(
|
||||
"only {} is supported for explicit seed embeddings right now",
|
||||
DEFAULT_EMBED_MODEL
|
||||
);
|
||||
}
|
||||
|
||||
Ok(EmbedJob {
|
||||
input,
|
||||
output,
|
||||
spec,
|
||||
mode,
|
||||
type_filter,
|
||||
selectors,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn execute_embed(args: &EmbedArgs) -> Result<EmbedOutput> {
|
||||
let job = resolve_embed_job(args)?;
|
||||
run_embed_job(&job).await
|
||||
}
|
||||
|
||||
pub(crate) async fn run_embed_job(job: &EmbedJob) -> Result<EmbedOutput> {
|
||||
if !job.input.exists() {
|
||||
bail!("seed input does not exist: {}", job.input.display());
|
||||
}
|
||||
|
||||
if let Some(parent) = job.output.parent() {
|
||||
fs::create_dir_all(parent)?;
|
||||
}
|
||||
|
||||
let temp_output = temp_output_path(&job.output);
|
||||
let mut reader = BufReader::new(File::open(&job.input)?);
|
||||
let mut writer = BufWriter::new(File::create(&temp_output)?);
|
||||
let client = match job.mode {
|
||||
EmbedMode::Clean => None,
|
||||
_ => Some(EmbeddingClient::from_env()?),
|
||||
};
|
||||
|
||||
let mut line = String::new();
|
||||
let mut rows = 0usize;
|
||||
let mut selected_rows = 0usize;
|
||||
let mut embedded_rows = 0usize;
|
||||
let mut cleaned_rows = 0usize;
|
||||
|
||||
loop {
|
||||
line.clear();
|
||||
let bytes = reader.read_line(&mut line)?;
|
||||
if bytes == 0 {
|
||||
break;
|
||||
}
|
||||
let raw = line.trim();
|
||||
if raw.is_empty() {
|
||||
continue;
|
||||
}
|
||||
rows += 1;
|
||||
let mut row = parse_row(raw, rows)?;
|
||||
let selected = row_matches_selection(&row, &job.type_filter, &job.selectors);
|
||||
if selected {
|
||||
selected_rows += 1;
|
||||
}
|
||||
|
||||
if let Some(type_spec) = row
|
||||
.type_name()
|
||||
.and_then(|type_name| job.spec.types.get(type_name))
|
||||
{
|
||||
match job.mode {
|
||||
EmbedMode::Clean => {
|
||||
if selected
|
||||
&& row
|
||||
.data_mut()
|
||||
.is_some_and(|data| data.remove(&type_spec.target).is_some())
|
||||
{
|
||||
cleaned_rows += 1;
|
||||
}
|
||||
}
|
||||
EmbedMode::ReembedAll => {
|
||||
if selected {
|
||||
embed_row(
|
||||
&mut row,
|
||||
type_spec,
|
||||
job.spec.dimension,
|
||||
client.as_ref().unwrap(),
|
||||
)
|
||||
.await?;
|
||||
embedded_rows += 1;
|
||||
}
|
||||
}
|
||||
EmbedMode::FillMissing => {
|
||||
let reembed_selected = !job.selectors.is_empty();
|
||||
if selected
|
||||
&& (reembed_selected
|
||||
|| embedding_missing(
|
||||
row.data().and_then(|data| data.get(&type_spec.target)),
|
||||
))
|
||||
{
|
||||
embed_row(
|
||||
&mut row,
|
||||
type_spec,
|
||||
job.spec.dimension,
|
||||
client.as_ref().unwrap(),
|
||||
)
|
||||
.await?;
|
||||
embedded_rows += 1;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
writer.write_all(serde_json::to_string(&row.into_value())?.as_bytes())?;
|
||||
writer.write_all(b"\n")?;
|
||||
}
|
||||
|
||||
writer.flush()?;
|
||||
fs::rename(&temp_output, &job.output)?;
|
||||
|
||||
Ok(EmbedOutput {
|
||||
input: job.input.display().to_string(),
|
||||
output: job.output.display().to_string(),
|
||||
rows,
|
||||
selected_rows,
|
||||
embedded_rows,
|
||||
cleaned_rows,
|
||||
mode: job.mode.as_str(!job.selectors.is_empty()),
|
||||
dimension: job.spec.dimension,
|
||||
model: job.spec.model.clone(),
|
||||
})
|
||||
}
|
||||
|
||||
fn temp_output_path(output: &Path) -> PathBuf {
|
||||
let mut temp = output.as_os_str().to_os_string();
|
||||
temp.push(".tmp");
|
||||
PathBuf::from(temp)
|
||||
}
|
||||
|
||||
fn default_embed_model() -> String {
|
||||
DEFAULT_EMBED_MODEL.to_string()
|
||||
}
|
||||
|
||||
fn load_embed_spec(path: &Path) -> Result<EmbedSpec> {
|
||||
Ok(serde_json::from_str(&fs::read_to_string(path)?)?)
|
||||
}
|
||||
|
||||
struct ResolvedSeedManifest {
|
||||
raw_seed: PathBuf,
|
||||
embedded_seed: PathBuf,
|
||||
spec: EmbedSpec,
|
||||
}
|
||||
|
||||
fn load_seed_manifest(path: &Path) -> Result<ResolvedSeedManifest> {
|
||||
let base_dir = path
|
||||
.parent()
|
||||
.map(Path::to_path_buf)
|
||||
.unwrap_or(std::env::current_dir()?);
|
||||
let manifest: SeedManifest = serde_yaml::from_str(&fs::read_to_string(path)?)?;
|
||||
let raw_seed = manifest
|
||||
.sources
|
||||
.as_ref()
|
||||
.map(|sources| sources.raw_seed.clone())
|
||||
.or_else(|| manifest.seed.as_ref().map(|seed| seed.data.clone()))
|
||||
.ok_or_else(|| eyre!("seed manifest is missing sources.raw_seed"))?;
|
||||
let embedded_seed = manifest
|
||||
.artifacts
|
||||
.as_ref()
|
||||
.map(|artifacts| artifacts.embedded_seed.clone())
|
||||
.unwrap_or_else(|| PathBuf::from("./build/seed.embedded.jsonl"));
|
||||
let spec = manifest
|
||||
.embeddings
|
||||
.ok_or_else(|| eyre!("seed manifest is missing embeddings"))?;
|
||||
|
||||
Ok(ResolvedSeedManifest {
|
||||
raw_seed: base_dir.join(raw_seed),
|
||||
embedded_seed: base_dir.join(embedded_seed),
|
||||
spec,
|
||||
})
|
||||
}
|
||||
|
||||
impl RowSelector {
|
||||
fn parse(value: &str) -> Result<Self> {
|
||||
let (lhs, expected) = value
|
||||
.split_once('=')
|
||||
.ok_or_else(|| eyre!("selector must be field=value or Type:field=value"))?;
|
||||
let (type_name, field) = if let Some((type_name, field)) = lhs.split_once(':') {
|
||||
(
|
||||
Some(type_name.trim().to_string()).filter(|value| !value.is_empty()),
|
||||
field.trim().to_string(),
|
||||
)
|
||||
} else {
|
||||
(None, lhs.trim().to_string())
|
||||
};
|
||||
|
||||
if field.is_empty() {
|
||||
bail!("selector field cannot be empty");
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
type_name,
|
||||
field,
|
||||
expected: expected.trim().to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn matches(&self, type_name: &str, data: &Map<String, Value>) -> bool {
|
||||
if self
|
||||
.type_name
|
||||
.as_deref()
|
||||
.is_some_and(|expected| expected != type_name)
|
||||
{
|
||||
return false;
|
||||
}
|
||||
|
||||
data.get(&self.field)
|
||||
.map(render_value)
|
||||
.is_some_and(|value| value == self.expected)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_row(raw: &str, line_number: usize) -> Result<EmbedRow> {
|
||||
let mut root = serde_json::from_str::<Map<String, Value>>(raw)
|
||||
.map_err(|err| eyre!("line {} is not valid JSON: {}", line_number, err))?;
|
||||
let Some(type_name) = root.get("type").and_then(Value::as_str).map(str::to_string) else {
|
||||
return Ok(EmbedRow::Passthrough(root));
|
||||
};
|
||||
let data = root
|
||||
.remove("data")
|
||||
.and_then(|value| value.as_object().cloned())
|
||||
.ok_or_else(|| eyre!("line {} is missing object field 'data'", line_number))?;
|
||||
|
||||
Ok(EmbedRow::Entity {
|
||||
type_name,
|
||||
data,
|
||||
root,
|
||||
})
|
||||
}
|
||||
|
||||
impl EmbedRow {
|
||||
fn into_value(self) -> Value {
|
||||
match self {
|
||||
Self::Entity {
|
||||
type_name,
|
||||
data,
|
||||
mut root,
|
||||
} => {
|
||||
root.insert("type".to_string(), Value::String(type_name));
|
||||
root.insert("data".to_string(), Value::Object(data));
|
||||
Value::Object(root)
|
||||
}
|
||||
Self::Passthrough(root) => Value::Object(root),
|
||||
}
|
||||
}
|
||||
|
||||
fn type_name(&self) -> Option<&str> {
|
||||
match self {
|
||||
Self::Entity { type_name, .. } => Some(type_name.as_str()),
|
||||
Self::Passthrough(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn data(&self) -> Option<&Map<String, Value>> {
|
||||
match self {
|
||||
Self::Entity { data, .. } => Some(data),
|
||||
Self::Passthrough(_) => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn data_mut(&mut self) -> Option<&mut Map<String, Value>> {
|
||||
match self {
|
||||
Self::Entity { data, .. } => Some(data),
|
||||
Self::Passthrough(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn row_matches_selection(
|
||||
row: &EmbedRow,
|
||||
type_filter: &HashSet<String>,
|
||||
selectors: &[RowSelector],
|
||||
) -> bool {
|
||||
let Some(type_name) = row.type_name() else {
|
||||
return false;
|
||||
};
|
||||
let Some(data) = row.data() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
let matches_type = type_filter.is_empty() || type_filter.contains(type_name);
|
||||
if !matches_type {
|
||||
return false;
|
||||
}
|
||||
if selectors.is_empty() {
|
||||
return true;
|
||||
}
|
||||
selectors
|
||||
.iter()
|
||||
.any(|selector| selector.matches(type_name, data))
|
||||
}
|
||||
|
||||
fn embedding_missing(value: Option<&Value>) -> bool {
|
||||
match value {
|
||||
None | Some(Value::Null) => true,
|
||||
Some(Value::Array(values)) => values.is_empty(),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn render_value(value: &Value) -> String {
|
||||
match value {
|
||||
Value::Null => String::new(),
|
||||
Value::String(value) => value.trim().to_string(),
|
||||
Value::Bool(value) => {
|
||||
if *value {
|
||||
"true".to_string()
|
||||
} else {
|
||||
"false".to_string()
|
||||
}
|
||||
}
|
||||
Value::Number(value) => value.to_string(),
|
||||
Value::Array(values) => values
|
||||
.iter()
|
||||
.map(render_value)
|
||||
.filter(|value| !value.is_empty())
|
||||
.collect::<Vec<_>>()
|
||||
.join(", "),
|
||||
other => other.to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
fn build_embedding_text(type_name: &str, data: &Map<String, Value>, fields: &[String]) -> String {
|
||||
let mut parts = vec![format!("type: {}", type_name)];
|
||||
for field in fields {
|
||||
if let Some(value) = data.get(field) {
|
||||
let rendered = render_value(value);
|
||||
if !rendered.is_empty() {
|
||||
parts.push(format!("{}: {}", field, rendered));
|
||||
}
|
||||
}
|
||||
}
|
||||
parts.join("\n")
|
||||
}
|
||||
|
||||
async fn embed_row(
|
||||
row: &mut EmbedRow,
|
||||
spec: &EmbedTypeSpec,
|
||||
dimension: usize,
|
||||
client: &EmbeddingClient,
|
||||
) -> Result<()> {
|
||||
let type_name = row
|
||||
.type_name()
|
||||
.ok_or_else(|| eyre!("cannot embed non-entity seed rows"))?
|
||||
.to_string();
|
||||
let data = row
|
||||
.data_mut()
|
||||
.ok_or_else(|| eyre!("cannot embed non-entity seed rows"))?;
|
||||
let text = build_embedding_text(&type_name, data, &spec.fields);
|
||||
if text.trim().is_empty() {
|
||||
return Ok(());
|
||||
}
|
||||
let embedding = client.embed_document_text(&text, dimension).await?;
|
||||
data.insert(spec.target.clone(), json!(embedding));
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{RowSelector, build_embedding_text, render_value};
|
||||
use serde_json::json;
|
||||
|
||||
#[test]
|
||||
fn selector_parses_type_and_field_forms() {
|
||||
let typed = RowSelector::parse("Decision:slug=dec-1").unwrap();
|
||||
assert_eq!(typed.type_name.as_deref(), Some("Decision"));
|
||||
assert_eq!(typed.field, "slug");
|
||||
assert_eq!(typed.expected, "dec-1");
|
||||
|
||||
let plain = RowSelector::parse("slug=dec-2").unwrap();
|
||||
assert_eq!(plain.type_name, None);
|
||||
assert_eq!(plain.field, "slug");
|
||||
assert_eq!(plain.expected, "dec-2");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_value_handles_lists_and_scalars() {
|
||||
assert_eq!(render_value(&json!(["a", "b"])), "a, b");
|
||||
assert_eq!(render_value(&json!(true)), "true");
|
||||
assert_eq!(render_value(&json!(3)), "3");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_embedding_text_prefixes_type_and_fields() {
|
||||
let data = json!({
|
||||
"slug": "dec-1",
|
||||
"intent": "Ship it"
|
||||
});
|
||||
let object = data.as_object().unwrap();
|
||||
let text = build_embedding_text(
|
||||
"Decision",
|
||||
object,
|
||||
&["slug".to_string(), "intent".to_string()],
|
||||
);
|
||||
assert!(text.contains("type: Decision"));
|
||||
assert!(text.contains("slug: dec-1"));
|
||||
assert!(text.contains("intent: Ship it"));
|
||||
}
|
||||
}
|
||||
2410
crates/omnigraph-cli/src/main.rs
Normal file
2410
crates/omnigraph-cli/src/main.rs
Normal file
File diff suppressed because it is too large
Load diff
356
crates/omnigraph-cli/src/read_format.rs
Normal file
356
crates/omnigraph-cli/src/read_format.rs
Normal file
|
|
@ -0,0 +1,356 @@
|
|||
use color_eyre::eyre::Result;
|
||||
use omnigraph_server::ReadOutputFormat;
|
||||
use omnigraph_server::api::ReadOutput;
|
||||
use omnigraph_server::config::TableCellLayout;
|
||||
use serde_json::{Map, Value};
|
||||
|
||||
pub struct ReadRenderOptions {
|
||||
pub max_column_width: usize,
|
||||
pub cell_layout: TableCellLayout,
|
||||
}
|
||||
|
||||
pub fn render_read(
|
||||
output: &ReadOutput,
|
||||
format: ReadOutputFormat,
|
||||
options: &ReadRenderOptions,
|
||||
) -> Result<String> {
|
||||
match format {
|
||||
ReadOutputFormat::Json => Ok(serde_json::to_string_pretty(output)?),
|
||||
ReadOutputFormat::Jsonl => render_jsonl(output),
|
||||
ReadOutputFormat::Csv => render_csv(output),
|
||||
ReadOutputFormat::Kv => Ok(render_kv(output)),
|
||||
ReadOutputFormat::Table => Ok(render_table(output, options)),
|
||||
}
|
||||
}
|
||||
|
||||
fn render_jsonl(output: &ReadOutput) -> Result<String> {
|
||||
let mut lines = Vec::new();
|
||||
lines.push(serde_json::to_string(&serde_json::json!({
|
||||
"kind": "metadata",
|
||||
"query_name": output.query_name,
|
||||
"target": output.target,
|
||||
"row_count": output.row_count,
|
||||
}))?);
|
||||
for row in rows(output) {
|
||||
lines.push(serde_json::to_string(&row)?);
|
||||
}
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
|
||||
fn render_csv(output: &ReadOutput) -> Result<String> {
|
||||
let rows = rows(output);
|
||||
let columns = columns(output, &rows);
|
||||
let mut lines = Vec::new();
|
||||
lines.push(
|
||||
columns
|
||||
.iter()
|
||||
.map(|column| csv_escape(column))
|
||||
.collect::<Vec<_>>()
|
||||
.join(","),
|
||||
);
|
||||
for row in rows {
|
||||
lines.push(
|
||||
columns
|
||||
.iter()
|
||||
.map(|column| csv_escape(&stringify_value(row.get(column).unwrap_or(&Value::Null))))
|
||||
.collect::<Vec<_>>()
|
||||
.join(","),
|
||||
);
|
||||
}
|
||||
Ok(lines.join("\n"))
|
||||
}
|
||||
|
||||
fn render_kv(output: &ReadOutput) -> String {
|
||||
let mut lines = vec![header_line(output)];
|
||||
let rows = rows(output);
|
||||
if rows.is_empty() {
|
||||
lines.push("(no rows)".to_string());
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
for (idx, row) in rows.iter().enumerate() {
|
||||
if idx > 0 {
|
||||
lines.push(String::new());
|
||||
}
|
||||
lines.push(format!("row {}", idx + 1));
|
||||
for column in columns(output, &rows) {
|
||||
lines.push(format!(
|
||||
"{}: {}",
|
||||
column,
|
||||
stringify_value(row.get(&column).unwrap_or(&Value::Null))
|
||||
));
|
||||
}
|
||||
}
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn render_table(output: &ReadOutput, options: &ReadRenderOptions) -> String {
|
||||
let mut lines = vec![header_line(output)];
|
||||
let rows = rows(output);
|
||||
let columns = columns(output, &rows);
|
||||
|
||||
if columns.is_empty() {
|
||||
lines.push("(no rows)".to_string());
|
||||
return lines.join("\n");
|
||||
}
|
||||
|
||||
let widths = columns
|
||||
.iter()
|
||||
.map(|column| {
|
||||
let mut width = column.chars().count();
|
||||
for row in &rows {
|
||||
let rendered =
|
||||
normalize_cell(&stringify_value(row.get(column).unwrap_or(&Value::Null)));
|
||||
let longest = rendered
|
||||
.lines()
|
||||
.map(|line| line.chars().count())
|
||||
.max()
|
||||
.unwrap_or(0);
|
||||
width = width.max(longest.min(options.max_column_width));
|
||||
}
|
||||
width.min(options.max_column_width.max(8))
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
lines.push(render_table_line(&columns, &widths));
|
||||
lines.push(
|
||||
widths
|
||||
.iter()
|
||||
.map(|width| "-".repeat(*width))
|
||||
.collect::<Vec<_>>()
|
||||
.join("-+-"),
|
||||
);
|
||||
|
||||
for row in rows {
|
||||
let cell_lines = columns
|
||||
.iter()
|
||||
.zip(widths.iter())
|
||||
.map(|(column, width)| {
|
||||
split_cell(
|
||||
&normalize_cell(&stringify_value(row.get(column).unwrap_or(&Value::Null))),
|
||||
*width,
|
||||
options.cell_layout,
|
||||
)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
let line_count = cell_lines.iter().map(Vec::len).max().unwrap_or(1);
|
||||
for line_idx in 0..line_count {
|
||||
let rendered = cell_lines
|
||||
.iter()
|
||||
.zip(widths.iter())
|
||||
.map(|(segments, width)| {
|
||||
let segment = segments.get(line_idx).cloned().unwrap_or_default();
|
||||
pad_to_width(&segment, *width)
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
lines.push(rendered.join(" | "));
|
||||
}
|
||||
}
|
||||
|
||||
lines.join("\n")
|
||||
}
|
||||
|
||||
fn render_table_line(columns: &[String], widths: &[usize]) -> String {
|
||||
columns
|
||||
.iter()
|
||||
.zip(widths.iter())
|
||||
.map(|(column, width)| pad_to_width(column, *width))
|
||||
.collect::<Vec<_>>()
|
||||
.join(" | ")
|
||||
}
|
||||
|
||||
fn header_line(output: &ReadOutput) -> String {
|
||||
format!(
|
||||
"{} rows from {} via {}",
|
||||
output.row_count,
|
||||
output
|
||||
.target
|
||||
.snapshot
|
||||
.as_deref()
|
||||
.map(|id| format!("snapshot {}", id))
|
||||
.or_else(|| {
|
||||
output
|
||||
.target
|
||||
.branch
|
||||
.as_deref()
|
||||
.map(|branch| format!("branch {}", branch))
|
||||
})
|
||||
.unwrap_or_else(|| "target".to_string()),
|
||||
output.query_name
|
||||
)
|
||||
}
|
||||
|
||||
fn rows(output: &ReadOutput) -> Vec<Map<String, Value>> {
|
||||
output
|
||||
.rows
|
||||
.as_array()
|
||||
.into_iter()
|
||||
.flatten()
|
||||
.map(|row| match row {
|
||||
Value::Object(map) => map.clone(),
|
||||
other => {
|
||||
let mut map = Map::new();
|
||||
map.insert("value".to_string(), other.clone());
|
||||
map
|
||||
}
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn columns(output: &ReadOutput, rows: &[Map<String, Value>]) -> Vec<String> {
|
||||
if !output.columns.is_empty() {
|
||||
return output.columns.clone();
|
||||
}
|
||||
|
||||
let mut columns = rows
|
||||
.iter()
|
||||
.flat_map(|row| row.keys().cloned())
|
||||
.collect::<Vec<_>>();
|
||||
columns.sort();
|
||||
columns.dedup();
|
||||
columns
|
||||
}
|
||||
|
||||
fn stringify_value(value: &Value) -> String {
|
||||
match value {
|
||||
Value::Null => "null".to_string(),
|
||||
Value::String(text) => text.clone(),
|
||||
Value::Bool(boolean) => boolean.to_string(),
|
||||
Value::Number(number) => number.to_string(),
|
||||
other => serde_json::to_string(other).unwrap_or_else(|_| "<invalid json>".to_string()),
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_cell(value: &str) -> String {
|
||||
value.replace('\n', "\\n")
|
||||
}
|
||||
|
||||
fn split_cell(value: &str, width: usize, layout: TableCellLayout) -> Vec<String> {
|
||||
if value.is_empty() {
|
||||
return vec![String::new()];
|
||||
}
|
||||
if value.chars().count() <= width {
|
||||
return vec![value.to_string()];
|
||||
}
|
||||
match layout {
|
||||
TableCellLayout::Truncate => vec![truncate(value, width)],
|
||||
TableCellLayout::Wrap => wrap(value, width),
|
||||
}
|
||||
}
|
||||
|
||||
fn truncate(value: &str, width: usize) -> String {
|
||||
if width <= 1 {
|
||||
return value.chars().take(width).collect();
|
||||
}
|
||||
let keep = width.saturating_sub(1);
|
||||
let mut out = value.chars().take(keep).collect::<String>();
|
||||
out.push('…');
|
||||
out
|
||||
}
|
||||
|
||||
fn wrap(value: &str, width: usize) -> Vec<String> {
|
||||
let chars = value.chars().collect::<Vec<_>>();
|
||||
chars
|
||||
.chunks(width.max(1))
|
||||
.map(|chunk| chunk.iter().collect::<String>())
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn pad_to_width(value: &str, width: usize) -> String {
|
||||
let value_width = value.chars().count();
|
||||
if value_width >= width {
|
||||
value.to_string()
|
||||
} else {
|
||||
format!("{}{}", value, " ".repeat(width - value_width))
|
||||
}
|
||||
}
|
||||
|
||||
fn csv_escape(value: &str) -> String {
|
||||
if value.contains(',') || value.contains('"') || value.contains('\n') || value.contains('\r') {
|
||||
format!("\"{}\"", value.replace('"', "\"\""))
|
||||
} else {
|
||||
value.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use omnigraph_server::api::{ReadOutput, ReadTargetOutput};
|
||||
|
||||
use super::*;
|
||||
|
||||
fn sample_output() -> ReadOutput {
|
||||
ReadOutput {
|
||||
query_name: "get_person".to_string(),
|
||||
target: ReadTargetOutput {
|
||||
branch: Some("main".to_string()),
|
||||
snapshot: None,
|
||||
},
|
||||
row_count: 1,
|
||||
columns: vec!["name".to_string(), "age".to_string()],
|
||||
rows: serde_json::json!([{ "name": "Alice", "age": 30 }]),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csv_format_outputs_header_and_rows() {
|
||||
let rendered = render_read(
|
||||
&sample_output(),
|
||||
ReadOutputFormat::Csv,
|
||||
&ReadRenderOptions {
|
||||
max_column_width: 80,
|
||||
cell_layout: TableCellLayout::Truncate,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(rendered.lines().next().unwrap().contains("name,age"));
|
||||
assert!(rendered.contains("Alice,30"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn jsonl_format_emits_metadata_first() {
|
||||
let rendered = render_read(
|
||||
&sample_output(),
|
||||
ReadOutputFormat::Jsonl,
|
||||
&ReadRenderOptions {
|
||||
max_column_width: 80,
|
||||
cell_layout: TableCellLayout::Truncate,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let first = rendered.lines().next().unwrap();
|
||||
assert!(first.contains("\"kind\":\"metadata\""));
|
||||
assert!(
|
||||
rendered
|
||||
.lines()
|
||||
.nth(1)
|
||||
.unwrap()
|
||||
.contains("\"name\":\"Alice\"")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_falls_back_to_discovered_columns_for_legacy_payloads() {
|
||||
let mut output = sample_output();
|
||||
output.columns.clear();
|
||||
|
||||
let rendered = render_read(
|
||||
&output,
|
||||
ReadOutputFormat::Csv,
|
||||
&ReadRenderOptions {
|
||||
max_column_width: 80,
|
||||
cell_layout: TableCellLayout::Truncate,
|
||||
},
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
assert!(rendered.lines().next().unwrap().contains("age,name"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csv_quotes_carriage_returns() {
|
||||
assert_eq!(csv_escape("hello\rworld"), "\"hello\rworld\"");
|
||||
}
|
||||
}
|
||||
1408
crates/omnigraph-cli/tests/cli.rs
Normal file
1408
crates/omnigraph-cli/tests/cli.rs
Normal file
File diff suppressed because it is too large
Load diff
292
crates/omnigraph-cli/tests/support/mod.rs
Normal file
292
crates/omnigraph-cli/tests/support/mod.rs
Normal file
|
|
@ -0,0 +1,292 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use std::fs;
|
||||
use std::net::TcpListener;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::process::{Child, Command as StdCommand, Output, Stdio};
|
||||
use std::thread::sleep;
|
||||
use std::time::Duration;
|
||||
|
||||
use assert_cmd::Command;
|
||||
use omnigraph::db::Omnigraph;
|
||||
use omnigraph::loader::LoadMode;
|
||||
use reqwest::blocking::Client;
|
||||
use serde_json::Value;
|
||||
use tempfile::{TempDir, tempdir};
|
||||
|
||||
pub fn cli() -> Command {
|
||||
Command::cargo_bin("omnigraph").unwrap()
|
||||
}
|
||||
|
||||
pub fn cli_process() -> StdCommand {
|
||||
StdCommand::new(assert_cmd::cargo::cargo_bin("omnigraph"))
|
||||
}
|
||||
|
||||
fn server_process() -> StdCommand {
|
||||
if let Some(path) = std::env::var_os("CARGO_BIN_EXE_omnigraph-server") {
|
||||
StdCommand::new(path)
|
||||
} else if let Some(path) = built_server_binary() {
|
||||
StdCommand::new(path)
|
||||
} else {
|
||||
let cargo = std::env::var_os("CARGO").unwrap_or_else(|| "cargo".into());
|
||||
let mut cmd = StdCommand::new(cargo);
|
||||
cmd.arg("run")
|
||||
.arg("--quiet")
|
||||
.arg("-p")
|
||||
.arg("omnigraph-server")
|
||||
.arg("--");
|
||||
cmd
|
||||
}
|
||||
}
|
||||
|
||||
fn built_server_binary() -> Option<PathBuf> {
|
||||
let workspace_root = PathBuf::from(env!("CARGO_MANIFEST_DIR")).join("../..");
|
||||
let candidate = workspace_root
|
||||
.join("target")
|
||||
.join("debug")
|
||||
.join(format!("omnigraph-server{}", std::env::consts::EXE_SUFFIX));
|
||||
candidate.exists().then_some(candidate)
|
||||
}
|
||||
|
||||
pub fn fixture(name: &str) -> PathBuf {
|
||||
PathBuf::from(env!("CARGO_MANIFEST_DIR"))
|
||||
.join("../omnigraph/tests/fixtures")
|
||||
.join(name)
|
||||
}
|
||||
|
||||
pub fn repo_path(root: &Path) -> PathBuf {
|
||||
root.join("demo.omni")
|
||||
}
|
||||
|
||||
pub fn output_success(cmd: &mut Command) -> Output {
|
||||
let output = cmd.output().unwrap();
|
||||
assert!(
|
||||
output.status.success(),
|
||||
"command failed\nstdout:\n{}\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
output
|
||||
}
|
||||
|
||||
pub fn output_failure(cmd: &mut Command) -> Output {
|
||||
let output = cmd.output().unwrap();
|
||||
assert!(
|
||||
!output.status.success(),
|
||||
"command unexpectedly succeeded\nstdout:\n{}\nstderr:\n{}",
|
||||
String::from_utf8_lossy(&output.stdout),
|
||||
String::from_utf8_lossy(&output.stderr)
|
||||
);
|
||||
output
|
||||
}
|
||||
|
||||
pub fn stdout_string(output: &Output) -> String {
|
||||
String::from_utf8(output.stdout.clone()).unwrap()
|
||||
}
|
||||
|
||||
pub fn parse_stdout_json(output: &Output) -> Value {
|
||||
serde_json::from_slice(&output.stdout).unwrap()
|
||||
}
|
||||
|
||||
pub fn init_repo(repo: &Path) {
|
||||
let schema = fixture("test.pg");
|
||||
output_success(cli().arg("init").arg("--schema").arg(&schema).arg(repo));
|
||||
}
|
||||
|
||||
pub fn load_fixture(repo: &Path) {
|
||||
let data = fixture("test.jsonl");
|
||||
output_success(cli().arg("load").arg("--data").arg(&data).arg(repo));
|
||||
}
|
||||
|
||||
pub fn write_jsonl(path: &Path, rows: &str) {
|
||||
fs::write(path, rows).unwrap();
|
||||
}
|
||||
|
||||
pub fn write_query_file(path: &Path, source: &str) {
|
||||
fs::write(path, source).unwrap();
|
||||
}
|
||||
|
||||
pub fn write_config(path: &Path, source: &str) {
|
||||
fs::write(path, source).unwrap();
|
||||
}
|
||||
|
||||
fn yaml_string(value: &str) -> String {
|
||||
format!("'{}'", value.replace('\'', "''"))
|
||||
}
|
||||
|
||||
pub fn local_yaml_config(repo: &Path) -> String {
|
||||
format!(
|
||||
"\
|
||||
targets:
|
||||
local:
|
||||
uri: {}
|
||||
cli:
|
||||
target: local
|
||||
branch: main
|
||||
query:
|
||||
roots:
|
||||
- .
|
||||
policy: {{}}
|
||||
",
|
||||
yaml_string(&repo.to_string_lossy())
|
||||
)
|
||||
}
|
||||
|
||||
pub fn remote_yaml_config(url: &str) -> String {
|
||||
format!(
|
||||
"\
|
||||
targets:
|
||||
dev:
|
||||
uri: {}
|
||||
cli:
|
||||
target: dev
|
||||
branch: main
|
||||
query:
|
||||
roots:
|
||||
- .
|
||||
policy: {{}}
|
||||
",
|
||||
yaml_string(url)
|
||||
)
|
||||
}
|
||||
|
||||
pub struct TestServer {
|
||||
child: Child,
|
||||
pub base_url: String,
|
||||
}
|
||||
|
||||
impl Drop for TestServer {
|
||||
fn drop(&mut self) {
|
||||
let _ = self.child.kill();
|
||||
let _ = self.child.wait();
|
||||
}
|
||||
}
|
||||
|
||||
fn free_port() -> u16 {
|
||||
let listener = TcpListener::bind("127.0.0.1:0").unwrap();
|
||||
let port = listener.local_addr().unwrap().port();
|
||||
drop(listener);
|
||||
port
|
||||
}
|
||||
|
||||
fn spawn_server_process(mut command: StdCommand) -> TestServer {
|
||||
let port = free_port();
|
||||
let bind = format!("127.0.0.1:{}", port);
|
||||
let mut child = command
|
||||
.arg("--bind")
|
||||
.arg(&bind)
|
||||
.stdout(Stdio::null())
|
||||
.stderr(Stdio::null())
|
||||
.spawn()
|
||||
.unwrap();
|
||||
let base_url = format!("http://{}", bind);
|
||||
let client = Client::new();
|
||||
for _ in 0..300 {
|
||||
if client
|
||||
.get(format!("{}/healthz", base_url))
|
||||
.send()
|
||||
.map(|response| response.status().is_success())
|
||||
.unwrap_or(false)
|
||||
{
|
||||
return TestServer { child, base_url };
|
||||
}
|
||||
if let Some(status) = child.try_wait().unwrap() {
|
||||
panic!("server exited before becoming healthy: {status}");
|
||||
}
|
||||
sleep(Duration::from_millis(100));
|
||||
}
|
||||
panic!("server did not become healthy");
|
||||
}
|
||||
|
||||
pub fn spawn_server(repo: &Path) -> TestServer {
|
||||
let mut command = server_process();
|
||||
command.arg(repo);
|
||||
spawn_server_process(command)
|
||||
}
|
||||
|
||||
pub fn spawn_server_with_config(config: &Path) -> TestServer {
|
||||
let mut command = server_process();
|
||||
command.arg("--config").arg(config);
|
||||
spawn_server_process(command)
|
||||
}
|
||||
|
||||
pub fn spawn_server_with_config_env(config: &Path, envs: &[(&str, &str)]) -> TestServer {
|
||||
let mut command = server_process();
|
||||
command.arg("--config").arg(config);
|
||||
for (name, value) in envs {
|
||||
command.env(name, value);
|
||||
}
|
||||
spawn_server_process(command)
|
||||
}
|
||||
|
||||
pub async fn begin_manual_run(repo: &Path, target_branch: &str) -> String {
|
||||
let mut db = Omnigraph::open(repo.to_str().unwrap()).await.unwrap();
|
||||
let run = db
|
||||
.begin_run(target_branch, Some("cli-test-run"))
|
||||
.await
|
||||
.unwrap();
|
||||
db.load(
|
||||
&run.run_branch,
|
||||
r#"{"type":"Person","data":{"name":"Eve","age":29}}"#,
|
||||
LoadMode::Append,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
run.run_id.as_str().to_string()
|
||||
}
|
||||
|
||||
pub struct SystemRepo {
|
||||
_temp: TempDir,
|
||||
repo: PathBuf,
|
||||
}
|
||||
|
||||
impl SystemRepo {
|
||||
pub fn initialized() -> Self {
|
||||
let temp = tempdir().unwrap();
|
||||
let repo = repo_path(temp.path());
|
||||
init_repo(&repo);
|
||||
Self { _temp: temp, repo }
|
||||
}
|
||||
|
||||
pub fn loaded() -> Self {
|
||||
let temp = tempdir().unwrap();
|
||||
let repo = repo_path(temp.path());
|
||||
init_repo(&repo);
|
||||
load_fixture(&repo);
|
||||
Self { _temp: temp, repo }
|
||||
}
|
||||
|
||||
pub fn path(&self) -> &Path {
|
||||
&self.repo
|
||||
}
|
||||
|
||||
pub fn write_query(&self, name: &str, source: &str) -> PathBuf {
|
||||
let path = self.repo.parent().unwrap().join(name);
|
||||
write_query_file(&path, source);
|
||||
path
|
||||
}
|
||||
|
||||
pub fn write_jsonl(&self, name: &str, rows: &str) -> PathBuf {
|
||||
let path = self.repo.parent().unwrap().join(name);
|
||||
write_jsonl(&path, rows);
|
||||
path
|
||||
}
|
||||
|
||||
pub fn write_config(&self, name: &str, source: &str) -> PathBuf {
|
||||
let path = self.repo.parent().unwrap().join(name);
|
||||
write_config(&path, source);
|
||||
path
|
||||
}
|
||||
|
||||
pub fn spawn_server(&self) -> TestServer {
|
||||
spawn_server(&self.repo)
|
||||
}
|
||||
|
||||
pub fn spawn_server_with_config(&self, config: &Path) -> TestServer {
|
||||
spawn_server_with_config(config)
|
||||
}
|
||||
|
||||
pub fn spawn_server_with_config_env(&self, config: &Path, envs: &[(&str, &str)]) -> TestServer {
|
||||
spawn_server_with_config_env(config, envs)
|
||||
}
|
||||
}
|
||||
1162
crates/omnigraph-cli/tests/system_local.rs
Normal file
1162
crates/omnigraph-cli/tests/system_local.rs
Normal file
File diff suppressed because it is too large
Load diff
810
crates/omnigraph-cli/tests/system_remote.rs
Normal file
810
crates/omnigraph-cli/tests/system_remote.rs
Normal file
|
|
@ -0,0 +1,810 @@
|
|||
mod support;
|
||||
|
||||
use std::fs;
|
||||
|
||||
use reqwest::blocking::Client;
|
||||
use serde_json::json;
|
||||
|
||||
use support::*;
|
||||
|
||||
const REMOTE_POLICY_E2E_YAML: &str = r#"
|
||||
version: 1
|
||||
groups:
|
||||
team: [act-bruno]
|
||||
admins: [act-ragnor]
|
||||
protected_branches: [main]
|
||||
rules:
|
||||
- id: team-read
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [read]
|
||||
branch_scope: any
|
||||
- id: team-branch-create
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [branch_create]
|
||||
target_branch_scope: unprotected
|
||||
- id: team-write-unprotected
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [change]
|
||||
branch_scope: unprotected
|
||||
- id: admins-promote
|
||||
allow:
|
||||
actors: { group: admins }
|
||||
actions: [branch_merge, run_publish]
|
||||
target_branch_scope: protected
|
||||
"#;
|
||||
|
||||
fn yaml_string(value: &str) -> String {
|
||||
format!("'{}'", value.replace('\'', "''"))
|
||||
}
|
||||
|
||||
fn remote_policy_server_config(repo: &SystemRepo) -> String {
|
||||
format!(
|
||||
"\
|
||||
project:
|
||||
name: remote-policy-e2e
|
||||
targets:
|
||||
local:
|
||||
uri: {}
|
||||
server:
|
||||
target: local
|
||||
policy:
|
||||
file: ./policy.yaml
|
||||
",
|
||||
yaml_string(&repo.path().to_string_lossy())
|
||||
)
|
||||
}
|
||||
|
||||
fn remote_policy_client_config(url: &str) -> String {
|
||||
format!(
|
||||
"\
|
||||
targets:
|
||||
dev:
|
||||
uri: {}
|
||||
bearer_token_env: POLICY_TEST_TOKEN
|
||||
cli:
|
||||
target: dev
|
||||
branch: main
|
||||
query:
|
||||
roots:
|
||||
- .
|
||||
auth:
|
||||
env_file: ./.env.omni
|
||||
",
|
||||
yaml_string(url)
|
||||
)
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_server_and_cli_end_to_end_flow() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server = repo.spawn_server();
|
||||
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
|
||||
let mutation_file = repo.write_query(
|
||||
"system-remote-change.gq",
|
||||
r#"
|
||||
query insert_person($name: String, $age: I32) {
|
||||
insert Person { name: $name, age: $age }
|
||||
}
|
||||
"#,
|
||||
);
|
||||
let client = Client::new();
|
||||
|
||||
let health = client
|
||||
.get(format!("{}/healthz", server.base_url))
|
||||
.send()
|
||||
.unwrap()
|
||||
.error_for_status()
|
||||
.unwrap()
|
||||
.json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
assert_eq!(health["status"], "ok");
|
||||
|
||||
let local_snapshot = parse_stdout_json(&output_success(
|
||||
cli().arg("snapshot").arg(repo.path()).arg("--json"),
|
||||
));
|
||||
let snapshot = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("snapshot")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(snapshot["branch"], "main");
|
||||
assert_eq!(snapshot["tables"], local_snapshot["tables"]);
|
||||
|
||||
let local_read = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg(repo.path())
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Alice"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
let read_payload = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Alice"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(read_payload, local_read);
|
||||
assert_eq!(read_payload["row_count"], 1);
|
||||
assert_eq!(read_payload["rows"][0]["p.name"], "Alice");
|
||||
|
||||
let change_payload = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("change")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(&mutation_file)
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Mina","age":28}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(change_payload["affected_nodes"], 1);
|
||||
|
||||
let query_source = fs::read_to_string(fixture("test.gq")).unwrap();
|
||||
let http_read = client
|
||||
.post(format!("{}/read", server.base_url))
|
||||
.json(&json!({
|
||||
"branch": "main",
|
||||
"query_source": query_source,
|
||||
"query_name": "get_person",
|
||||
"params": { "name": "Mina" }
|
||||
}))
|
||||
.send()
|
||||
.unwrap()
|
||||
.error_for_status()
|
||||
.unwrap()
|
||||
.json::<serde_json::Value>()
|
||||
.unwrap();
|
||||
assert_eq!(http_read["row_count"], 1);
|
||||
assert_eq!(http_read["rows"][0]["p.name"], "Mina");
|
||||
|
||||
let local_verify = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg(repo.path())
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Mina"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(local_verify["row_count"], 1);
|
||||
assert_eq!(local_verify["rows"][0]["p.name"], "Mina");
|
||||
|
||||
let manual_run = tokio::runtime::Runtime::new()
|
||||
.unwrap()
|
||||
.block_on(begin_manual_run(repo.path(), "main"));
|
||||
let publish_payload = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("run")
|
||||
.arg("publish")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg(&manual_run)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(publish_payload["run_id"], manual_run);
|
||||
assert_eq!(publish_payload["status"], "published");
|
||||
|
||||
let runs_payload = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("run")
|
||||
.arg("list")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert!(runs_payload["runs"].as_array().unwrap().len() >= 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_read_preserves_projection_order_in_json_and_csv() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server = repo.spawn_server();
|
||||
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
|
||||
let ordered_query = repo.write_query(
|
||||
"ordered-remote.gq",
|
||||
r#"
|
||||
query ordered_person($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
}
|
||||
return { $p.age, $p.name }
|
||||
}
|
||||
"#,
|
||||
);
|
||||
|
||||
let json_payload = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(&ordered_query)
|
||||
.arg("--name")
|
||||
.arg("ordered_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Alice"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
let columns = json_payload["columns"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.map(|value| value.as_str().unwrap())
|
||||
.collect::<Vec<_>>();
|
||||
assert_eq!(columns, vec!["p.age", "p.name"]);
|
||||
|
||||
let csv = stdout_string(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(&ordered_query)
|
||||
.arg("--name")
|
||||
.arg("ordered_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Alice"}"#)
|
||||
.arg("--format")
|
||||
.arg("csv"),
|
||||
));
|
||||
let mut lines = csv.lines();
|
||||
assert_eq!(lines.next().unwrap(), "p.age,p.name");
|
||||
assert_eq!(lines.next().unwrap(), "30,Alice");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_branch_create_list_merge_flow() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server = repo.spawn_server();
|
||||
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
|
||||
let mutation_file = repo.write_query(
|
||||
"system-remote-branch-change.gq",
|
||||
r#"
|
||||
query insert_person($name: String, $age: I32) {
|
||||
insert Person { name: $name, age: $age }
|
||||
}
|
||||
"#,
|
||||
);
|
||||
|
||||
let initial = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("list")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(initial["branches"], json!(["main"]));
|
||||
|
||||
let created = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("create")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--from")
|
||||
.arg("main")
|
||||
.arg("feature")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(created["from"], "main");
|
||||
assert_eq!(created["name"], "feature");
|
||||
|
||||
let listed = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("list")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(listed["branches"], json!(["feature", "main"]));
|
||||
|
||||
let changed = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("change")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(&mutation_file)
|
||||
.arg("--branch")
|
||||
.arg("feature")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Zoe","age":33}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(changed["branch"], "feature");
|
||||
assert_eq!(changed["affected_nodes"], 1);
|
||||
|
||||
let merged = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("merge")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("feature")
|
||||
.arg("--into")
|
||||
.arg("main")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(merged["source"], "feature");
|
||||
assert_eq!(merged["target"], "main");
|
||||
assert_eq!(merged["outcome"], "fast_forward");
|
||||
|
||||
let verify = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Zoe"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(verify["row_count"], 1);
|
||||
assert_eq!(verify["rows"][0]["p.name"], "Zoe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_branch_delete_removes_branch() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server = repo.spawn_server();
|
||||
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
|
||||
|
||||
parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("create")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--from")
|
||||
.arg("main")
|
||||
.arg("feature")
|
||||
.arg("--json"),
|
||||
));
|
||||
|
||||
let deleted = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("delete")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("feature")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(deleted["name"], "feature");
|
||||
|
||||
let listed = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("list")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(listed["branches"], json!(["main"]));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_export_round_trips_full_branch_graph() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server = repo.spawn_server();
|
||||
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
|
||||
let mutation_file = repo.write_query(
|
||||
"system-remote-export-change.gq",
|
||||
r#"
|
||||
query insert_person($name: String, $age: I32) {
|
||||
insert Person { name: $name, age: $age }
|
||||
}
|
||||
|
||||
query add_friend($from: String, $to: String) {
|
||||
insert Knows { from: $from, to: $to }
|
||||
}
|
||||
"#,
|
||||
);
|
||||
|
||||
output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("create")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--from")
|
||||
.arg("main")
|
||||
.arg("feature"),
|
||||
);
|
||||
|
||||
output_success(
|
||||
cli()
|
||||
.arg("change")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(&mutation_file)
|
||||
.arg("--name")
|
||||
.arg("insert_person")
|
||||
.arg("--branch")
|
||||
.arg("feature")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Eve","age":29}"#)
|
||||
.arg("--json"),
|
||||
);
|
||||
output_success(
|
||||
cli()
|
||||
.arg("change")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(&mutation_file)
|
||||
.arg("--name")
|
||||
.arg("add_friend")
|
||||
.arg("--branch")
|
||||
.arg("feature")
|
||||
.arg("--params")
|
||||
.arg(r#"{"from":"Alice","to":"Eve"}"#)
|
||||
.arg("--json"),
|
||||
);
|
||||
|
||||
let exported = stdout_string(&output_success(
|
||||
cli()
|
||||
.arg("export")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--branch")
|
||||
.arg("feature")
|
||||
.arg("--jsonl"),
|
||||
));
|
||||
let export_path = repo.write_jsonl("system-remote-exported.jsonl", &exported);
|
||||
let imported_repo = repo
|
||||
.path()
|
||||
.parent()
|
||||
.unwrap()
|
||||
.join("imported-remote-export.omni");
|
||||
|
||||
output_success(
|
||||
cli()
|
||||
.arg("init")
|
||||
.arg("--schema")
|
||||
.arg(fixture("test.pg"))
|
||||
.arg(&imported_repo),
|
||||
);
|
||||
output_success(
|
||||
cli()
|
||||
.arg("load")
|
||||
.arg("--data")
|
||||
.arg(&export_path)
|
||||
.arg(&imported_repo),
|
||||
);
|
||||
|
||||
let snapshot = parse_stdout_json(&output_success(
|
||||
cli().arg("snapshot").arg(&imported_repo).arg("--json"),
|
||||
));
|
||||
assert_eq!(
|
||||
snapshot["tables"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.find(|table| table["table_key"] == "node:Person")
|
||||
.unwrap()["row_count"],
|
||||
5
|
||||
);
|
||||
assert_eq!(
|
||||
snapshot["tables"]
|
||||
.as_array()
|
||||
.unwrap()
|
||||
.iter()
|
||||
.find(|table| table["table_key"] == "edge:Knows")
|
||||
.unwrap()["row_count"],
|
||||
4
|
||||
);
|
||||
|
||||
let eve = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg(&imported_repo)
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Eve"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(eve["row_count"], 1);
|
||||
assert_eq!(eve["rows"][0]["p.name"], "Eve");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_ingest_creates_review_branch_and_keeps_it_readable() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server = repo.spawn_server();
|
||||
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
|
||||
let ingest_data = repo.write_jsonl(
|
||||
"system-remote-ingest.jsonl",
|
||||
r#"{"type":"Person","data":{"name":"Zoe","age":33}}
|
||||
{"type":"Person","data":{"name":"Bob","age":26}}"#,
|
||||
);
|
||||
|
||||
let ingest_payload = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("ingest")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--data")
|
||||
.arg(&ingest_data)
|
||||
.arg("--branch")
|
||||
.arg("feature-ingest")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(ingest_payload["branch"], "feature-ingest");
|
||||
assert_eq!(ingest_payload["base_branch"], "main");
|
||||
assert_eq!(ingest_payload["branch_created"], true);
|
||||
assert_eq!(ingest_payload["mode"], "merge");
|
||||
assert_eq!(ingest_payload["tables"][0]["table_key"], "node:Person");
|
||||
assert_eq!(ingest_payload["tables"][0]["rows_loaded"], 2);
|
||||
|
||||
let feature_snapshot = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("snapshot")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--branch")
|
||||
.arg("feature-ingest")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(feature_snapshot["branch"], "feature-ingest");
|
||||
|
||||
let zoe = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--branch")
|
||||
.arg("feature-ingest")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Zoe"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(zoe["row_count"], 1);
|
||||
assert_eq!(zoe["rows"][0]["p.name"], "Zoe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_ingest_reuses_existing_branch_and_merges_updates() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server = repo.spawn_server();
|
||||
let config = repo.write_config("omnigraph.yaml", &remote_yaml_config(&server.base_url));
|
||||
|
||||
output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("create")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--from")
|
||||
.arg("main")
|
||||
.arg("feature-ingest"),
|
||||
);
|
||||
|
||||
let ingest_data = repo.write_jsonl(
|
||||
"system-remote-ingest-merge.jsonl",
|
||||
r#"{"type":"Person","data":{"name":"Bob","age":26}}
|
||||
{"type":"Person","data":{"name":"Zoe","age":33}}"#,
|
||||
);
|
||||
|
||||
let ingest_payload = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("ingest")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--data")
|
||||
.arg(&ingest_data)
|
||||
.arg("--branch")
|
||||
.arg("feature-ingest")
|
||||
.arg("--from")
|
||||
.arg("missing-base")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(ingest_payload["branch"], "feature-ingest");
|
||||
assert_eq!(ingest_payload["base_branch"], "missing-base");
|
||||
assert_eq!(ingest_payload["branch_created"], false);
|
||||
assert_eq!(ingest_payload["mode"], "merge");
|
||||
assert_eq!(ingest_payload["tables"][0]["table_key"], "node:Person");
|
||||
assert_eq!(ingest_payload["tables"][0]["rows_loaded"], 2);
|
||||
|
||||
let bob = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--branch")
|
||||
.arg("feature-ingest")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Bob"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(bob["row_count"], 1);
|
||||
assert_eq!(bob["rows"][0]["p.age"], 26);
|
||||
|
||||
let zoe = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&config)
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--branch")
|
||||
.arg("feature-ingest")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"Zoe"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(zoe["row_count"], 1);
|
||||
assert_eq!(zoe["rows"][0]["p.name"], "Zoe");
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[ignore = "requires loopback socket permissions in sandboxed runners"]
|
||||
fn remote_policy_enforces_branch_first_cli_workflow() {
|
||||
let repo = SystemRepo::loaded();
|
||||
let server_config =
|
||||
repo.write_config("server-policy.yaml", &remote_policy_server_config(&repo));
|
||||
repo.write_config("policy.yaml", REMOTE_POLICY_E2E_YAML);
|
||||
let server = repo.spawn_server_with_config_env(
|
||||
&server_config,
|
||||
&[(
|
||||
"OMNIGRAPH_SERVER_BEARER_TOKENS_JSON",
|
||||
r#"{"act-bruno":"team-token","act-ragnor":"admin-token"}"#,
|
||||
)],
|
||||
);
|
||||
let client_config = repo.write_config(
|
||||
"omnigraph-policy.yaml",
|
||||
&remote_policy_client_config(&server.base_url),
|
||||
);
|
||||
repo.write_config(".env.omni", "POLICY_TEST_TOKEN=team-token\n");
|
||||
let mutation_file = repo.write_query(
|
||||
"system-remote-policy-change.gq",
|
||||
r#"
|
||||
query insert_person($name: String, $age: I32) {
|
||||
insert Person { name: $name, age: $age }
|
||||
}
|
||||
"#,
|
||||
);
|
||||
|
||||
let snapshot = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("snapshot")
|
||||
.arg("--config")
|
||||
.arg(&client_config)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(snapshot["branch"], "main");
|
||||
|
||||
let denied_main_change = output_failure(
|
||||
cli()
|
||||
.arg("change")
|
||||
.arg("--config")
|
||||
.arg(&client_config)
|
||||
.arg("--query")
|
||||
.arg(&mutation_file)
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"PolicyRemote","age":41}"#)
|
||||
.arg("--json"),
|
||||
);
|
||||
let denied_main_stderr = String::from_utf8(denied_main_change.stderr).unwrap();
|
||||
assert!(denied_main_stderr.contains("policy denied action 'change' on branch 'main'"));
|
||||
|
||||
let created = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("create")
|
||||
.arg("--config")
|
||||
.arg(&client_config)
|
||||
.arg("--from")
|
||||
.arg("main")
|
||||
.arg("feature")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(created["name"], "feature");
|
||||
|
||||
let changed = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("change")
|
||||
.arg("--config")
|
||||
.arg(&client_config)
|
||||
.arg("--query")
|
||||
.arg(&mutation_file)
|
||||
.arg("--branch")
|
||||
.arg("feature")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"PolicyRemote","age":41}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(changed["branch"], "feature");
|
||||
assert_eq!(changed["affected_nodes"], 1);
|
||||
|
||||
let denied_merge = output_failure(
|
||||
cli()
|
||||
.arg("branch")
|
||||
.arg("merge")
|
||||
.arg("--config")
|
||||
.arg(&client_config)
|
||||
.arg("feature")
|
||||
.arg("--into")
|
||||
.arg("main")
|
||||
.arg("--json"),
|
||||
);
|
||||
let denied_merge_stderr = String::from_utf8(denied_merge.stderr).unwrap();
|
||||
assert!(denied_merge_stderr.contains("policy denied action 'branch_merge'"));
|
||||
|
||||
let merged = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.env("POLICY_TEST_TOKEN", "admin-token")
|
||||
.arg("branch")
|
||||
.arg("merge")
|
||||
.arg("--config")
|
||||
.arg(&client_config)
|
||||
.arg("feature")
|
||||
.arg("--into")
|
||||
.arg("main")
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(merged["target"], "main");
|
||||
|
||||
let verify = parse_stdout_json(&output_success(
|
||||
cli()
|
||||
.arg("read")
|
||||
.arg("--config")
|
||||
.arg(&client_config)
|
||||
.arg("--query")
|
||||
.arg(fixture("test.gq"))
|
||||
.arg("--name")
|
||||
.arg("get_person")
|
||||
.arg("--params")
|
||||
.arg(r#"{"name":"PolicyRemote"}"#)
|
||||
.arg("--json"),
|
||||
));
|
||||
assert_eq!(verify["row_count"], 1);
|
||||
assert_eq!(verify["rows"][0]["p.name"], "PolicyRemote");
|
||||
}
|
||||
26
crates/omnigraph-compiler/Cargo.toml
Normal file
26
crates/omnigraph-compiler/Cargo.toml
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
[package]
|
||||
name = "omnigraph-compiler"
|
||||
version = "0.4.0"
|
||||
edition = "2024"
|
||||
description = "Schema/query compiler for Omnigraph. Zero Lance dependency."
|
||||
license = "MIT"
|
||||
|
||||
[dependencies]
|
||||
arrow-array = { workspace = true }
|
||||
arrow-ipc = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
arrow-select = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
arrow-ord = { workspace = true }
|
||||
pest = { workspace = true }
|
||||
pest_derive = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
ahash = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
sha2 = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tokio = { workspace = true }
|
||||
594
crates/omnigraph-compiler/src/catalog/mod.rs
Normal file
594
crates/omnigraph-compiler/src/catalog/mod.rs
Normal file
|
|
@ -0,0 +1,594 @@
|
|||
pub mod schema_ir;
|
||||
pub mod schema_plan;
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
|
||||
use crate::error::{NanoError, Result};
|
||||
use crate::schema::ast::{Cardinality, Constraint, ConstraintBound, SchemaDecl, SchemaFile};
|
||||
use crate::types::{PropType, ScalarType};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Catalog {
|
||||
pub node_types: HashMap<String, NodeType>,
|
||||
pub edge_types: HashMap<String, EdgeType>,
|
||||
/// Maps normalized lowercase edge name -> EdgeType key (e.g. "knows" -> "Knows")
|
||||
pub edge_name_index: HashMap<String, String>,
|
||||
/// Interface declarations (for Phase 2 polymorphic queries)
|
||||
pub interfaces: HashMap<String, InterfaceType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InterfaceType {
|
||||
pub name: String,
|
||||
pub properties: HashMap<String, PropType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct NodeType {
|
||||
pub name: String,
|
||||
/// Interface names this type implements
|
||||
pub implements: Vec<String>,
|
||||
pub properties: HashMap<String, PropType>,
|
||||
/// Key property names (from `@key` or `@key(name)`). Usually 0 or 1 element.
|
||||
pub key: Option<Vec<String>>,
|
||||
/// Uniqueness constraints (each entry is a list of column names)
|
||||
pub unique_constraints: Vec<Vec<String>>,
|
||||
/// Index declarations (each entry is a list of column names)
|
||||
pub indices: Vec<Vec<String>>,
|
||||
/// Value range constraints
|
||||
pub range_constraints: Vec<RangeConstraint>,
|
||||
/// Regex check constraints
|
||||
pub check_constraints: Vec<CheckConstraint>,
|
||||
/// Maps @embed target property -> source text property
|
||||
pub embed_sources: HashMap<String, String>,
|
||||
pub blob_properties: HashSet<String>,
|
||||
pub arrow_schema: SchemaRef,
|
||||
}
|
||||
|
||||
impl NodeType {
|
||||
/// Backward-compatible accessor: returns the first (and typically only) key property name.
|
||||
pub fn key_property(&self) -> Option<&str> {
|
||||
self.key
|
||||
.as_ref()
|
||||
.and_then(|v| v.first())
|
||||
.map(|s| s.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct RangeConstraint {
|
||||
pub property: String,
|
||||
pub min: Option<LiteralValue>,
|
||||
pub max: Option<LiteralValue>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum LiteralValue {
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CheckConstraint {
|
||||
pub property: String,
|
||||
pub pattern: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EdgeType {
|
||||
pub name: String,
|
||||
pub from_type: String,
|
||||
pub to_type: String,
|
||||
pub cardinality: Cardinality,
|
||||
pub properties: HashMap<String, PropType>,
|
||||
/// Uniqueness constraints on edge columns (e.g. `@unique(src, dst)`)
|
||||
pub unique_constraints: Vec<Vec<String>>,
|
||||
/// Index declarations on edge properties
|
||||
pub indices: Vec<Vec<String>>,
|
||||
pub blob_properties: HashSet<String>,
|
||||
pub arrow_schema: SchemaRef,
|
||||
}
|
||||
|
||||
impl Catalog {
|
||||
pub fn lookup_edge_by_name(&self, name: &str) -> Option<&EdgeType> {
|
||||
if let Some(et) = self.edge_types.get(name) {
|
||||
return Some(et);
|
||||
}
|
||||
if let Some(key) = self.edge_name_index.get(&normalize_edge_name(name)) {
|
||||
return self.edge_types.get(key);
|
||||
}
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_edge_name(name: &str) -> String {
|
||||
name.to_lowercase()
|
||||
}
|
||||
|
||||
fn bound_to_literal(b: &ConstraintBound) -> LiteralValue {
|
||||
match b {
|
||||
ConstraintBound::Integer(n) => LiteralValue::Integer(*n),
|
||||
ConstraintBound::Float(f) => LiteralValue::Float(*f),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn build_catalog(schema: &SchemaFile) -> Result<Catalog> {
|
||||
let mut node_types = HashMap::new();
|
||||
let mut edge_types = HashMap::new();
|
||||
let mut edge_name_index = HashMap::new();
|
||||
let mut interfaces = HashMap::new();
|
||||
|
||||
// Pass 0: collect interfaces
|
||||
for decl in &schema.declarations {
|
||||
if let SchemaDecl::Interface(iface) = decl {
|
||||
let mut properties = HashMap::new();
|
||||
for prop in &iface.properties {
|
||||
properties.insert(prop.name.clone(), prop.prop_type.clone());
|
||||
}
|
||||
interfaces.insert(
|
||||
iface.name.clone(),
|
||||
InterfaceType {
|
||||
name: iface.name.clone(),
|
||||
properties,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 1: collect node types
|
||||
for decl in &schema.declarations {
|
||||
if let SchemaDecl::Node(node) = decl {
|
||||
if node_types.contains_key(&node.name) {
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"duplicate node type: {}",
|
||||
node.name
|
||||
)));
|
||||
}
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
let mut embed_sources = HashMap::new();
|
||||
let mut blob_properties = HashSet::new();
|
||||
for prop in &node.properties {
|
||||
properties.insert(prop.name.clone(), prop.prop_type.clone());
|
||||
if matches!(prop.prop_type.scalar, ScalarType::Blob) {
|
||||
blob_properties.insert(prop.name.clone());
|
||||
}
|
||||
// Extract @embed from property annotations (stays as annotation)
|
||||
if let Some(source_prop) = prop
|
||||
.annotations
|
||||
.iter()
|
||||
.find(|ann| ann.name == "embed")
|
||||
.and_then(|ann| ann.value.clone())
|
||||
{
|
||||
embed_sources.insert(prop.name.clone(), source_prop);
|
||||
}
|
||||
}
|
||||
|
||||
// Extract constraints from the typed Constraint enum
|
||||
let mut key: Option<Vec<String>> = None;
|
||||
let mut unique_constraints = Vec::new();
|
||||
let mut indices = Vec::new();
|
||||
let mut range_constraints = Vec::new();
|
||||
let mut check_constraints = Vec::new();
|
||||
|
||||
for constraint in &node.constraints {
|
||||
match constraint {
|
||||
Constraint::Key(cols) => {
|
||||
key = Some(cols.clone());
|
||||
// @key implies index on key columns
|
||||
indices.push(cols.clone());
|
||||
}
|
||||
Constraint::Unique(cols) => {
|
||||
unique_constraints.push(cols.clone());
|
||||
}
|
||||
Constraint::Index(cols) => {
|
||||
indices.push(cols.clone());
|
||||
}
|
||||
Constraint::Range { property, min, max } => {
|
||||
range_constraints.push(RangeConstraint {
|
||||
property: property.clone(),
|
||||
min: min.as_ref().map(bound_to_literal),
|
||||
max: max.as_ref().map(bound_to_literal),
|
||||
});
|
||||
}
|
||||
Constraint::Check { property, pattern } => {
|
||||
check_constraints.push(CheckConstraint {
|
||||
property: property.clone(),
|
||||
pattern: pattern.clone(),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Build Arrow schema: id: Utf8 + all properties
|
||||
let mut fields = vec![Field::new("id", DataType::Utf8, false)];
|
||||
for prop in &node.properties {
|
||||
fields.push(Field::new(
|
||||
&prop.name,
|
||||
prop.prop_type.to_arrow(),
|
||||
prop.prop_type.nullable,
|
||||
));
|
||||
}
|
||||
let arrow_schema = Arc::new(Schema::new(fields));
|
||||
|
||||
node_types.insert(
|
||||
node.name.clone(),
|
||||
NodeType {
|
||||
name: node.name.clone(),
|
||||
implements: node.implements.clone(),
|
||||
properties,
|
||||
key,
|
||||
unique_constraints,
|
||||
indices,
|
||||
range_constraints,
|
||||
check_constraints,
|
||||
embed_sources,
|
||||
blob_properties,
|
||||
arrow_schema,
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// Pass 2: collect edge types, validate endpoints
|
||||
for decl in &schema.declarations {
|
||||
if let SchemaDecl::Edge(edge) = decl {
|
||||
if edge_types.contains_key(&edge.name) {
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"duplicate edge type: {}",
|
||||
edge.name
|
||||
)));
|
||||
}
|
||||
if !node_types.contains_key(&edge.from_type) {
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"edge {} references unknown source type: {}",
|
||||
edge.name, edge.from_type
|
||||
)));
|
||||
}
|
||||
if !node_types.contains_key(&edge.to_type) {
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"edge {} references unknown target type: {}",
|
||||
edge.name, edge.to_type
|
||||
)));
|
||||
}
|
||||
|
||||
let mut properties = HashMap::new();
|
||||
let mut blob_properties = HashSet::new();
|
||||
let mut fields = vec![
|
||||
Field::new("id", DataType::Utf8, false),
|
||||
Field::new("src", DataType::Utf8, false),
|
||||
Field::new("dst", DataType::Utf8, false),
|
||||
];
|
||||
for prop in &edge.properties {
|
||||
properties.insert(prop.name.clone(), prop.prop_type.clone());
|
||||
if matches!(prop.prop_type.scalar, ScalarType::Blob) {
|
||||
blob_properties.insert(prop.name.clone());
|
||||
}
|
||||
fields.push(Field::new(
|
||||
&prop.name,
|
||||
prop.prop_type.to_arrow(),
|
||||
prop.prop_type.nullable,
|
||||
));
|
||||
}
|
||||
|
||||
// Extract edge constraints
|
||||
let mut unique_constraints = Vec::new();
|
||||
let mut edge_indices = Vec::new();
|
||||
for constraint in &edge.constraints {
|
||||
match constraint {
|
||||
Constraint::Unique(cols) => unique_constraints.push(cols.clone()),
|
||||
Constraint::Index(cols) => edge_indices.push(cols.clone()),
|
||||
_ => {} // Key/Range/Check validated at parse time to not appear on edges
|
||||
}
|
||||
}
|
||||
|
||||
let normalized_name = normalize_edge_name(&edge.name);
|
||||
if let Some(existing) = edge_name_index.get(&normalized_name)
|
||||
&& existing != &edge.name
|
||||
{
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"edge name collision after case folding: '{}' conflicts with '{}'",
|
||||
edge.name, existing
|
||||
)));
|
||||
}
|
||||
edge_name_index.insert(normalized_name, edge.name.clone());
|
||||
|
||||
edge_types.insert(
|
||||
edge.name.clone(),
|
||||
EdgeType {
|
||||
name: edge.name.clone(),
|
||||
from_type: edge.from_type.clone(),
|
||||
to_type: edge.to_type.clone(),
|
||||
cardinality: edge.cardinality.clone(),
|
||||
properties,
|
||||
unique_constraints,
|
||||
indices: edge_indices,
|
||||
blob_properties,
|
||||
arrow_schema: Arc::new(Schema::new(fields)),
|
||||
},
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(Catalog {
|
||||
node_types,
|
||||
edge_types,
|
||||
edge_name_index,
|
||||
interfaces,
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::schema::ast::{EdgeDecl, NodeDecl};
|
||||
use crate::schema::parser::parse_schema;
|
||||
use crate::types::PropType;
|
||||
|
||||
fn test_schema() -> &'static str {
|
||||
r#"
|
||||
node Person {
|
||||
name: String
|
||||
age: I32?
|
||||
}
|
||||
node Company {
|
||||
name: String
|
||||
}
|
||||
edge Knows: Person -> Person {
|
||||
since: Date?
|
||||
}
|
||||
edge WorksAt: Person -> Company {
|
||||
title: String?
|
||||
}
|
||||
"#
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_build_catalog() {
|
||||
let schema = parse_schema(test_schema()).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
assert_eq!(catalog.node_types.len(), 2);
|
||||
assert_eq!(catalog.edge_types.len(), 2);
|
||||
assert!(catalog.node_types.contains_key("Person"));
|
||||
assert!(catalog.node_types.contains_key("Company"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_lookup() {
|
||||
let schema = parse_schema(test_schema()).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
let edge = catalog.lookup_edge_by_name("knows").unwrap();
|
||||
assert_eq!(edge.from_type, "Person");
|
||||
assert_eq!(edge.to_type, "Person");
|
||||
let upper = catalog.lookup_edge_by_name("KNOWS").unwrap();
|
||||
assert_eq!(upper.name, "Knows");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_node_arrow_schema() {
|
||||
let schema = parse_schema(test_schema()).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
let person = &catalog.node_types["Person"];
|
||||
assert_eq!(person.arrow_schema.fields().len(), 3); // id, name, age
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_duplicate_node_error() {
|
||||
let input = r#"
|
||||
node Person { name: String }
|
||||
node Person { age: I32 }
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
assert!(build_catalog(&schema).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_bad_edge_endpoint() {
|
||||
let input = r#"
|
||||
node Person { name: String }
|
||||
edge Knows: Person -> Alien
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
assert!(build_catalog(&schema).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_id_fields_are_utf8() {
|
||||
let schema = parse_schema(test_schema()).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
let person = &catalog.node_types["Person"];
|
||||
assert_eq!(
|
||||
person
|
||||
.arrow_schema
|
||||
.field_with_name("id")
|
||||
.unwrap()
|
||||
.data_type(),
|
||||
&DataType::Utf8
|
||||
);
|
||||
let knows = &catalog.edge_types["Knows"];
|
||||
assert_eq!(
|
||||
knows
|
||||
.arrow_schema
|
||||
.field_with_name("id")
|
||||
.unwrap()
|
||||
.data_type(),
|
||||
&DataType::Utf8
|
||||
);
|
||||
assert_eq!(
|
||||
knows
|
||||
.arrow_schema
|
||||
.field_with_name("src")
|
||||
.unwrap()
|
||||
.data_type(),
|
||||
&DataType::Utf8
|
||||
);
|
||||
assert_eq!(
|
||||
knows
|
||||
.arrow_schema
|
||||
.field_with_name("dst")
|
||||
.unwrap()
|
||||
.data_type(),
|
||||
&DataType::Utf8
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_property_tracking() {
|
||||
let input = r#"
|
||||
node Signal {
|
||||
slug: String @key
|
||||
title: String
|
||||
}
|
||||
node Person {
|
||||
name: String
|
||||
}
|
||||
edge Emits: Person -> Signal
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
assert_eq!(catalog.node_types["Signal"].key_property(), Some("slug"));
|
||||
assert_eq!(catalog.node_types["Person"].key_property(), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_lookup_handles_non_ascii_leading_character() {
|
||||
let schema = SchemaFile {
|
||||
declarations: vec![
|
||||
SchemaDecl::Node(NodeDecl {
|
||||
name: "Person".to_string(),
|
||||
annotations: vec![],
|
||||
implements: vec![],
|
||||
properties: vec![crate::schema::ast::PropDecl {
|
||||
name: "name".to_string(),
|
||||
prop_type: PropType::scalar(ScalarType::String, false),
|
||||
annotations: vec![],
|
||||
}],
|
||||
constraints: vec![],
|
||||
}),
|
||||
SchemaDecl::Edge(EdgeDecl {
|
||||
name: "Édges".to_string(),
|
||||
from_type: "Person".to_string(),
|
||||
to_type: "Person".to_string(),
|
||||
cardinality: Default::default(),
|
||||
annotations: vec![],
|
||||
properties: vec![],
|
||||
constraints: vec![],
|
||||
}),
|
||||
],
|
||||
};
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
assert!(catalog.lookup_edge_by_name("édges").is_some());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_edge_lookup_rejects_case_fold_collisions() {
|
||||
let input = r#"
|
||||
node Person { name: String }
|
||||
edge Knows: Person -> Person
|
||||
edge KNOWS: Person -> Person
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let err = build_catalog(&schema).unwrap_err();
|
||||
assert!(err.to_string().contains("case folding"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_catalog_composite_unique() {
|
||||
let input = r#"
|
||||
node Person {
|
||||
first: String
|
||||
last: String
|
||||
@unique(first, last)
|
||||
}
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
let person = &catalog.node_types["Person"];
|
||||
assert!(
|
||||
person
|
||||
.unique_constraints
|
||||
.contains(&vec!["first".to_string(), "last".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_catalog_composite_index() {
|
||||
let input = r#"
|
||||
node Event {
|
||||
category: String
|
||||
date: Date
|
||||
@index(category, date)
|
||||
}
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
let event = &catalog.node_types["Event"];
|
||||
assert!(
|
||||
event
|
||||
.indices
|
||||
.contains(&vec!["category".to_string(), "date".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_catalog_edge_cardinality() {
|
||||
let input = r#"
|
||||
node Person { name: String }
|
||||
node Company { name: String }
|
||||
edge WorksAt: Person -> Company @card(0..1)
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
let edge = &catalog.edge_types["WorksAt"];
|
||||
assert_eq!(edge.cardinality.min, 0);
|
||||
assert_eq!(edge.cardinality.max, Some(1));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_catalog_interfaces_stored() {
|
||||
let input = r#"
|
||||
interface Named {
|
||||
name: String
|
||||
}
|
||||
node Person implements Named {
|
||||
age: I32?
|
||||
}
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
assert!(catalog.interfaces.contains_key("Named"));
|
||||
assert!(catalog.interfaces["Named"].properties.contains_key("name"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_catalog_node_implements() {
|
||||
let input = r#"
|
||||
interface Named {
|
||||
name: String
|
||||
}
|
||||
node Person implements Named {
|
||||
age: I32?
|
||||
}
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
assert_eq!(catalog.node_types["Person"].implements, vec!["Named"]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_key_implies_index() {
|
||||
let input = r#"
|
||||
node Signal {
|
||||
slug: String @key
|
||||
title: String
|
||||
}
|
||||
"#;
|
||||
let schema = parse_schema(input).unwrap();
|
||||
let catalog = build_catalog(&schema).unwrap();
|
||||
let signal = &catalog.node_types["Signal"];
|
||||
assert!(signal.indices.contains(&vec!["slug".to_string()]));
|
||||
}
|
||||
}
|
||||
393
crates/omnigraph-compiler/src/catalog/schema_ir.rs
Normal file
393
crates/omnigraph-compiler/src/catalog/schema_ir.rs
Normal file
|
|
@ -0,0 +1,393 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sha2::{Digest, Sha256};
|
||||
|
||||
use crate::catalog::{Catalog, build_catalog};
|
||||
use crate::error::{NanoError, Result};
|
||||
use crate::schema::ast::{Annotation, Cardinality, Constraint, PropDecl, SchemaDecl, SchemaFile};
|
||||
use crate::types::PropType;
|
||||
|
||||
const SCHEMA_IR_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SchemaIR {
|
||||
pub ir_version: u32,
|
||||
pub interfaces: Vec<InterfaceIR>,
|
||||
pub nodes: Vec<NodeIR>,
|
||||
pub edges: Vec<EdgeIR>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct InterfaceIR {
|
||||
pub name: String,
|
||||
pub type_id: u32,
|
||||
pub properties: Vec<PropertyIR>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodeIR {
|
||||
pub name: String,
|
||||
pub type_id: u32,
|
||||
pub annotations: Vec<Annotation>,
|
||||
pub implements: Vec<String>,
|
||||
pub properties: Vec<PropertyIR>,
|
||||
pub constraints: Vec<Constraint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EdgeIR {
|
||||
pub name: String,
|
||||
pub type_id: u32,
|
||||
pub from_type: String,
|
||||
pub to_type: String,
|
||||
pub cardinality: Cardinality,
|
||||
pub annotations: Vec<Annotation>,
|
||||
pub properties: Vec<PropertyIR>,
|
||||
pub constraints: Vec<Constraint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct PropertyIR {
|
||||
pub name: String,
|
||||
pub prop_id: u32,
|
||||
pub prop_type: PropType,
|
||||
pub annotations: Vec<Annotation>,
|
||||
}
|
||||
|
||||
pub fn build_schema_ir(schema: &SchemaFile) -> Result<SchemaIR> {
|
||||
let mut seen_type_ids = HashMap::<u32, String>::new();
|
||||
let mut interfaces = Vec::new();
|
||||
let mut nodes = Vec::new();
|
||||
let mut edges = Vec::new();
|
||||
|
||||
for decl in &schema.declarations {
|
||||
match decl {
|
||||
SchemaDecl::Interface(interface) => {
|
||||
let type_id = stable_type_id("interface", &interface.name);
|
||||
check_type_id_collision(&mut seen_type_ids, type_id, &interface.name)?;
|
||||
interfaces.push(InterfaceIR {
|
||||
name: interface.name.clone(),
|
||||
type_id,
|
||||
properties: canonical_properties(
|
||||
"interface",
|
||||
&interface.name,
|
||||
&interface.properties,
|
||||
)?,
|
||||
});
|
||||
}
|
||||
SchemaDecl::Node(node) => {
|
||||
let type_id = stable_type_id("node", &node.name);
|
||||
check_type_id_collision(&mut seen_type_ids, type_id, &node.name)?;
|
||||
nodes.push(NodeIR {
|
||||
name: node.name.clone(),
|
||||
type_id,
|
||||
annotations: canonical_annotations(&node.annotations),
|
||||
implements: canonical_strings(&node.implements),
|
||||
properties: canonical_properties("node", &node.name, &node.properties)?,
|
||||
constraints: canonical_constraints(&node.constraints),
|
||||
});
|
||||
}
|
||||
SchemaDecl::Edge(edge) => {
|
||||
let type_id = stable_type_id("edge", &edge.name);
|
||||
check_type_id_collision(&mut seen_type_ids, type_id, &edge.name)?;
|
||||
edges.push(EdgeIR {
|
||||
name: edge.name.clone(),
|
||||
type_id,
|
||||
from_type: edge.from_type.clone(),
|
||||
to_type: edge.to_type.clone(),
|
||||
cardinality: edge.cardinality.clone(),
|
||||
annotations: canonical_annotations(&edge.annotations),
|
||||
properties: canonical_properties("edge", &edge.name, &edge.properties)?,
|
||||
constraints: canonical_constraints(&edge.constraints),
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
interfaces.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
nodes.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
edges.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
|
||||
Ok(SchemaIR {
|
||||
ir_version: SCHEMA_IR_VERSION,
|
||||
interfaces,
|
||||
nodes,
|
||||
edges,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn build_catalog_from_ir(ir: &SchemaIR) -> Result<Catalog> {
|
||||
if ir.ir_version != SCHEMA_IR_VERSION {
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"unsupported schema ir_version {} (expected {})",
|
||||
ir.ir_version, SCHEMA_IR_VERSION
|
||||
)));
|
||||
}
|
||||
|
||||
let schema = SchemaFile {
|
||||
declarations: ir
|
||||
.interfaces
|
||||
.iter()
|
||||
.map(|interface| {
|
||||
SchemaDecl::Interface(crate::schema::ast::InterfaceDecl {
|
||||
name: interface.name.clone(),
|
||||
properties: interface
|
||||
.properties
|
||||
.iter()
|
||||
.map(property_decl_from_ir)
|
||||
.collect(),
|
||||
})
|
||||
})
|
||||
.chain(ir.nodes.iter().map(|node| {
|
||||
SchemaDecl::Node(crate::schema::ast::NodeDecl {
|
||||
name: node.name.clone(),
|
||||
annotations: node.annotations.clone(),
|
||||
implements: node.implements.clone(),
|
||||
properties: node.properties.iter().map(property_decl_from_ir).collect(),
|
||||
constraints: node.constraints.clone(),
|
||||
})
|
||||
}))
|
||||
.chain(ir.edges.iter().map(|edge| {
|
||||
SchemaDecl::Edge(crate::schema::ast::EdgeDecl {
|
||||
name: edge.name.clone(),
|
||||
from_type: edge.from_type.clone(),
|
||||
to_type: edge.to_type.clone(),
|
||||
cardinality: edge.cardinality.clone(),
|
||||
annotations: edge.annotations.clone(),
|
||||
properties: edge.properties.iter().map(property_decl_from_ir).collect(),
|
||||
constraints: edge.constraints.clone(),
|
||||
})
|
||||
}))
|
||||
.collect(),
|
||||
};
|
||||
|
||||
build_catalog(&schema)
|
||||
}
|
||||
|
||||
pub fn schema_ir_json(ir: &SchemaIR) -> Result<String> {
|
||||
serde_json::to_string(ir)
|
||||
.map_err(|err| NanoError::Catalog(format!("serialize schema ir error: {}", err)))
|
||||
}
|
||||
|
||||
pub fn schema_ir_pretty_json(ir: &SchemaIR) -> Result<String> {
|
||||
serde_json::to_string_pretty(ir)
|
||||
.map_err(|err| NanoError::Catalog(format!("serialize schema ir error: {}", err)))
|
||||
}
|
||||
|
||||
pub fn schema_ir_hash(ir: &SchemaIR) -> Result<String> {
|
||||
let json = schema_ir_json(ir)?;
|
||||
let mut hasher = Sha256::new();
|
||||
hasher.update(json.as_bytes());
|
||||
Ok(format!("sha256:{:x}", hasher.finalize()))
|
||||
}
|
||||
|
||||
fn property_decl_from_ir(property: &PropertyIR) -> PropDecl {
|
||||
PropDecl {
|
||||
name: property.name.clone(),
|
||||
prop_type: property.prop_type.clone(),
|
||||
annotations: property.annotations.clone(),
|
||||
}
|
||||
}
|
||||
|
||||
fn canonical_strings(values: &[String]) -> Vec<String> {
|
||||
let mut values = values.to_vec();
|
||||
values.sort();
|
||||
values.dedup();
|
||||
values
|
||||
}
|
||||
|
||||
fn canonical_annotations(annotations: &[Annotation]) -> Vec<Annotation> {
|
||||
let mut annotations = annotations.to_vec();
|
||||
annotations.sort_by(|left, right| {
|
||||
left.name
|
||||
.cmp(&right.name)
|
||||
.then_with(|| left.value.cmp(&right.value))
|
||||
});
|
||||
annotations
|
||||
}
|
||||
|
||||
fn canonical_prop_type(prop_type: &PropType) -> PropType {
|
||||
let mut normalized = prop_type.clone();
|
||||
if let Some(values) = &mut normalized.enum_values {
|
||||
values.sort();
|
||||
values.dedup();
|
||||
}
|
||||
normalized
|
||||
}
|
||||
|
||||
fn canonical_properties(
|
||||
kind: &str,
|
||||
owner_name: &str,
|
||||
properties: &[PropDecl],
|
||||
) -> Result<Vec<PropertyIR>> {
|
||||
let mut seen_prop_ids = HashMap::<u32, String>::new();
|
||||
let owner_key = format!("{}:{}", kind, owner_name);
|
||||
let mut canonical = properties
|
||||
.iter()
|
||||
.map(|property| {
|
||||
let prop_id = stable_prop_id(&owner_key, &property.name);
|
||||
if let Some(previous) = seen_prop_ids.insert(prop_id, property.name.clone()) {
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"property id collision on {}: '{}' and '{}' both hash to {}",
|
||||
owner_name, previous, property.name, prop_id
|
||||
)));
|
||||
}
|
||||
Ok(PropertyIR {
|
||||
name: property.name.clone(),
|
||||
prop_id,
|
||||
prop_type: canonical_prop_type(&property.prop_type),
|
||||
annotations: canonical_annotations(&property.annotations),
|
||||
})
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
canonical.sort_by(|a, b| a.name.cmp(&b.name));
|
||||
Ok(canonical)
|
||||
}
|
||||
|
||||
fn canonical_constraints(constraints: &[Constraint]) -> Vec<Constraint> {
|
||||
let mut constraints = constraints
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(normalize_constraint)
|
||||
.collect::<Vec<_>>();
|
||||
constraints.sort_by_key(constraint_sort_key);
|
||||
constraints
|
||||
}
|
||||
|
||||
fn normalize_constraint(constraint: Constraint) -> Constraint {
|
||||
match constraint {
|
||||
Constraint::Key(mut columns) => {
|
||||
columns.sort();
|
||||
Constraint::Key(columns)
|
||||
}
|
||||
Constraint::Unique(mut columns) => {
|
||||
columns.sort();
|
||||
Constraint::Unique(columns)
|
||||
}
|
||||
Constraint::Index(mut columns) => {
|
||||
columns.sort();
|
||||
Constraint::Index(columns)
|
||||
}
|
||||
other => other,
|
||||
}
|
||||
}
|
||||
|
||||
fn constraint_sort_key(constraint: &Constraint) -> String {
|
||||
match constraint {
|
||||
Constraint::Key(columns) => format!("key:{}", columns.join(",")),
|
||||
Constraint::Unique(columns) => format!("unique:{}", columns.join(",")),
|
||||
Constraint::Index(columns) => format!("index:{}", columns.join(",")),
|
||||
Constraint::Range { property, min, max } => {
|
||||
format!("range:{}:{:?}:{:?}", property, min, max)
|
||||
}
|
||||
Constraint::Check { property, pattern } => format!("check:{}:{}", property, pattern),
|
||||
}
|
||||
}
|
||||
|
||||
fn stable_type_id(kind: &str, name: &str) -> u32 {
|
||||
fnv1a_u32(&format!("{}:{}", kind, name))
|
||||
}
|
||||
|
||||
fn stable_prop_id(owner: &str, name: &str) -> u32 {
|
||||
fnv1a_u32(&format!("{}:{}", owner, name))
|
||||
}
|
||||
|
||||
fn fnv1a_u32(value: &str) -> u32 {
|
||||
let mut hash: u32 = 2_166_136_261;
|
||||
for byte in value.bytes() {
|
||||
hash ^= u32::from(byte);
|
||||
hash = hash.wrapping_mul(16_777_619);
|
||||
}
|
||||
if hash == 0 { 1 } else { hash }
|
||||
}
|
||||
|
||||
fn check_type_id_collision(
|
||||
seen_type_ids: &mut HashMap<u32, String>,
|
||||
type_id: u32,
|
||||
name: &str,
|
||||
) -> Result<()> {
|
||||
if let Some(previous) = seen_type_ids.insert(type_id, name.to_string()) {
|
||||
return Err(NanoError::Catalog(format!(
|
||||
"type id collision: '{}' and '{}' both hash to {}",
|
||||
previous, name, type_id
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::catalog::build_catalog;
|
||||
use crate::schema::parser::parse_schema;
|
||||
|
||||
#[test]
|
||||
fn schema_ir_hash_is_stable_across_source_ordering_noise() {
|
||||
let schema_a = parse_schema(
|
||||
r#"
|
||||
node Person {
|
||||
age: I32?
|
||||
name: String @key
|
||||
}
|
||||
|
||||
edge Knows: Person -> Person {
|
||||
since: Date?
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let schema_b = parse_schema(
|
||||
r#"
|
||||
edge Knows: Person -> Person {
|
||||
since: Date?
|
||||
}
|
||||
|
||||
node Person {
|
||||
name: String @key
|
||||
age: I32?
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let ir_a = build_schema_ir(&schema_a).unwrap();
|
||||
let ir_b = build_schema_ir(&schema_b).unwrap();
|
||||
assert_eq!(ir_a, ir_b);
|
||||
assert_eq!(
|
||||
schema_ir_hash(&ir_a).unwrap(),
|
||||
schema_ir_hash(&ir_b).unwrap()
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_catalog_from_ir_round_trips_core_catalog_fields() {
|
||||
let schema = parse_schema(
|
||||
r#"
|
||||
node Person @description("person") {
|
||||
name: String @key
|
||||
age: I32? @description("age")
|
||||
}
|
||||
|
||||
edge Knows: Person -> Person @instruction("friendship") {
|
||||
since: Date?
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let direct = build_catalog(&schema).unwrap();
|
||||
let ir = build_schema_ir(&schema).unwrap();
|
||||
let rebuilt = build_catalog_from_ir(&ir).unwrap();
|
||||
|
||||
assert_eq!(direct.node_types.len(), rebuilt.node_types.len());
|
||||
assert_eq!(direct.edge_types.len(), rebuilt.edge_types.len());
|
||||
assert_eq!(
|
||||
direct.node_types["Person"].key_property(),
|
||||
rebuilt.node_types["Person"].key_property()
|
||||
);
|
||||
assert_eq!(
|
||||
direct.edge_types["Knows"].cardinality,
|
||||
rebuilt.edge_types["Knows"].cardinality
|
||||
);
|
||||
}
|
||||
}
|
||||
895
crates/omnigraph-compiler/src/catalog/schema_plan.rs
Normal file
895
crates/omnigraph-compiler/src/catalog/schema_plan.rs
Normal file
|
|
@ -0,0 +1,895 @@
|
|||
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::Result;
|
||||
use crate::schema::ast::{Annotation, Constraint};
|
||||
use crate::types::PropType;
|
||||
|
||||
use super::schema_ir::{EdgeIR, InterfaceIR, NodeIR, PropertyIR, SchemaIR};
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum SchemaTypeKind {
|
||||
Interface,
|
||||
Node,
|
||||
Edge,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SchemaMigrationPlan {
|
||||
pub supported: bool,
|
||||
pub steps: Vec<SchemaMigrationStep>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(tag = "kind", rename_all = "snake_case")]
|
||||
pub enum SchemaMigrationStep {
|
||||
AddType {
|
||||
type_kind: SchemaTypeKind,
|
||||
name: String,
|
||||
},
|
||||
RenameType {
|
||||
type_kind: SchemaTypeKind,
|
||||
from: String,
|
||||
to: String,
|
||||
},
|
||||
AddProperty {
|
||||
type_kind: SchemaTypeKind,
|
||||
type_name: String,
|
||||
property_name: String,
|
||||
property_type: PropType,
|
||||
},
|
||||
RenameProperty {
|
||||
type_kind: SchemaTypeKind,
|
||||
type_name: String,
|
||||
from: String,
|
||||
to: String,
|
||||
},
|
||||
AddConstraint {
|
||||
type_kind: SchemaTypeKind,
|
||||
type_name: String,
|
||||
constraint: Constraint,
|
||||
},
|
||||
UpdateTypeMetadata {
|
||||
type_kind: SchemaTypeKind,
|
||||
name: String,
|
||||
annotations: Vec<Annotation>,
|
||||
},
|
||||
UpdatePropertyMetadata {
|
||||
type_kind: SchemaTypeKind,
|
||||
type_name: String,
|
||||
property_name: String,
|
||||
annotations: Vec<Annotation>,
|
||||
},
|
||||
UnsupportedChange {
|
||||
entity: String,
|
||||
reason: String,
|
||||
},
|
||||
}
|
||||
|
||||
pub fn plan_schema_migration(
|
||||
accepted: &SchemaIR,
|
||||
desired: &SchemaIR,
|
||||
) -> Result<SchemaMigrationPlan> {
|
||||
let mut steps = Vec::new();
|
||||
let interface_renames = plan_interfaces(&accepted.interfaces, &desired.interfaces, &mut steps);
|
||||
let node_renames = plan_nodes(
|
||||
&accepted.nodes,
|
||||
&desired.nodes,
|
||||
&interface_renames,
|
||||
&mut steps,
|
||||
);
|
||||
plan_edges(&accepted.edges, &desired.edges, &node_renames, &mut steps);
|
||||
|
||||
Ok(SchemaMigrationPlan {
|
||||
supported: !steps
|
||||
.iter()
|
||||
.any(|step| matches!(step, SchemaMigrationStep::UnsupportedChange { .. })),
|
||||
steps,
|
||||
})
|
||||
}
|
||||
|
||||
fn plan_interfaces(
|
||||
accepted: &[InterfaceIR],
|
||||
desired: &[InterfaceIR],
|
||||
steps: &mut Vec<SchemaMigrationStep>,
|
||||
) -> HashMap<String, String> {
|
||||
let accepted_by_name = accepted
|
||||
.iter()
|
||||
.map(|interface| (interface.name.as_str(), interface))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut consumed = HashSet::new();
|
||||
|
||||
for interface in desired {
|
||||
if let Some(existing) = accepted_by_name.get(interface.name.as_str()) {
|
||||
consumed.insert(existing.name.clone());
|
||||
let _property_renames = plan_properties(
|
||||
SchemaTypeKind::Interface,
|
||||
&interface.name,
|
||||
&existing.properties,
|
||||
&interface.properties,
|
||||
steps,
|
||||
);
|
||||
continue;
|
||||
}
|
||||
|
||||
steps.push(SchemaMigrationStep::AddType {
|
||||
type_kind: SchemaTypeKind::Interface,
|
||||
name: interface.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
for leftover in accepted
|
||||
.iter()
|
||||
.filter(|interface| !consumed.contains(&interface.name))
|
||||
{
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("interface:{}", leftover.name),
|
||||
reason: format!(
|
||||
"removing interface '{}' is not supported in schema migration v1",
|
||||
leftover.name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
HashMap::new()
|
||||
}
|
||||
|
||||
fn plan_nodes(
|
||||
accepted: &[NodeIR],
|
||||
desired: &[NodeIR],
|
||||
interface_renames: &HashMap<String, String>,
|
||||
steps: &mut Vec<SchemaMigrationStep>,
|
||||
) -> HashMap<String, String> {
|
||||
let accepted_by_name = accepted
|
||||
.iter()
|
||||
.map(|node| (node.name.as_str(), node))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut consumed = HashSet::new();
|
||||
let mut renames = HashMap::new();
|
||||
|
||||
for node in desired {
|
||||
let rename_from = rename_from_value(&node.annotations);
|
||||
let matched = accepted_by_name
|
||||
.get(node.name.as_str())
|
||||
.copied()
|
||||
.or_else(|| {
|
||||
rename_from.and_then(|from| {
|
||||
accepted_by_name
|
||||
.get(from)
|
||||
.copied()
|
||||
.filter(|candidate| candidate.name != node.name)
|
||||
})
|
||||
});
|
||||
|
||||
let Some(existing) = matched else {
|
||||
if let Some(from) = rename_from {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("node:{}", node.name),
|
||||
reason: format!(
|
||||
"node '{}' declares @rename_from(\"{}\") but no accepted node with that name exists",
|
||||
node.name, from
|
||||
),
|
||||
});
|
||||
} else {
|
||||
steps.push(SchemaMigrationStep::AddType {
|
||||
type_kind: SchemaTypeKind::Node,
|
||||
name: node.name.clone(),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
};
|
||||
|
||||
consumed.insert(existing.name.clone());
|
||||
if existing.name != node.name {
|
||||
renames.insert(existing.name.clone(), node.name.clone());
|
||||
steps.push(SchemaMigrationStep::RenameType {
|
||||
type_kind: SchemaTypeKind::Node,
|
||||
from: existing.name.clone(),
|
||||
to: node.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
if normalize_strings(&existing.implements, interface_renames)
|
||||
!= normalize_strings(&node.implements, &HashMap::new())
|
||||
{
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("node:{}", node.name),
|
||||
reason: format!(
|
||||
"changing implemented interfaces on node '{}' is not supported in schema migration v1",
|
||||
node.name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
plan_type_metadata(
|
||||
SchemaTypeKind::Node,
|
||||
&node.name,
|
||||
&existing.annotations,
|
||||
&node.annotations,
|
||||
steps,
|
||||
);
|
||||
let property_renames = plan_properties(
|
||||
SchemaTypeKind::Node,
|
||||
&node.name,
|
||||
&existing.properties,
|
||||
&node.properties,
|
||||
steps,
|
||||
);
|
||||
plan_constraints(
|
||||
SchemaTypeKind::Node,
|
||||
&node.name,
|
||||
&existing.constraints,
|
||||
&node.constraints,
|
||||
&property_renames,
|
||||
steps,
|
||||
);
|
||||
}
|
||||
|
||||
for leftover in accepted
|
||||
.iter()
|
||||
.filter(|node| !consumed.contains(&node.name))
|
||||
{
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("node:{}", leftover.name),
|
||||
reason: format!(
|
||||
"removing node type '{}' is not supported in schema migration v1",
|
||||
leftover.name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
renames
|
||||
}
|
||||
|
||||
fn plan_edges(
|
||||
accepted: &[EdgeIR],
|
||||
desired: &[EdgeIR],
|
||||
node_renames: &HashMap<String, String>,
|
||||
steps: &mut Vec<SchemaMigrationStep>,
|
||||
) {
|
||||
let accepted_by_name = accepted
|
||||
.iter()
|
||||
.map(|edge| (edge.name.as_str(), edge))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut consumed = HashSet::new();
|
||||
|
||||
for edge in desired {
|
||||
let rename_from = rename_from_value(&edge.annotations);
|
||||
let matched = accepted_by_name
|
||||
.get(edge.name.as_str())
|
||||
.copied()
|
||||
.or_else(|| {
|
||||
rename_from.and_then(|from| {
|
||||
accepted_by_name
|
||||
.get(from)
|
||||
.copied()
|
||||
.filter(|candidate| candidate.name != edge.name)
|
||||
})
|
||||
});
|
||||
|
||||
let Some(existing) = matched else {
|
||||
if let Some(from) = rename_from {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("edge:{}", edge.name),
|
||||
reason: format!(
|
||||
"edge '{}' declares @rename_from(\"{}\") but no accepted edge with that name exists",
|
||||
edge.name, from
|
||||
),
|
||||
});
|
||||
} else {
|
||||
steps.push(SchemaMigrationStep::AddType {
|
||||
type_kind: SchemaTypeKind::Edge,
|
||||
name: edge.name.clone(),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
};
|
||||
|
||||
consumed.insert(existing.name.clone());
|
||||
if existing.name != edge.name {
|
||||
steps.push(SchemaMigrationStep::RenameType {
|
||||
type_kind: SchemaTypeKind::Edge,
|
||||
from: existing.name.clone(),
|
||||
to: edge.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
let normalized_from = normalize_type_ref(&existing.from_type, node_renames);
|
||||
let normalized_to = normalize_type_ref(&existing.to_type, node_renames);
|
||||
if normalized_from != edge.from_type || normalized_to != edge.to_type {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("edge:{}", edge.name),
|
||||
reason: format!(
|
||||
"changing edge endpoints on '{}' is not supported in schema migration v1",
|
||||
edge.name
|
||||
),
|
||||
});
|
||||
}
|
||||
if existing.cardinality != edge.cardinality {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("edge:{}", edge.name),
|
||||
reason: format!(
|
||||
"changing cardinality on edge '{}' is not supported in schema migration v1",
|
||||
edge.name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
plan_type_metadata(
|
||||
SchemaTypeKind::Edge,
|
||||
&edge.name,
|
||||
&existing.annotations,
|
||||
&edge.annotations,
|
||||
steps,
|
||||
);
|
||||
let property_renames = plan_properties(
|
||||
SchemaTypeKind::Edge,
|
||||
&edge.name,
|
||||
&existing.properties,
|
||||
&edge.properties,
|
||||
steps,
|
||||
);
|
||||
plan_constraints(
|
||||
SchemaTypeKind::Edge,
|
||||
&edge.name,
|
||||
&existing.constraints,
|
||||
&edge.constraints,
|
||||
&property_renames,
|
||||
steps,
|
||||
);
|
||||
}
|
||||
|
||||
for leftover in accepted
|
||||
.iter()
|
||||
.filter(|edge| !consumed.contains(&edge.name))
|
||||
{
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("edge:{}", leftover.name),
|
||||
reason: format!(
|
||||
"removing edge type '{}' is not supported in schema migration v1",
|
||||
leftover.name
|
||||
),
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn plan_properties(
|
||||
type_kind: SchemaTypeKind,
|
||||
type_name: &str,
|
||||
accepted: &[PropertyIR],
|
||||
desired: &[PropertyIR],
|
||||
steps: &mut Vec<SchemaMigrationStep>,
|
||||
) -> HashMap<String, String> {
|
||||
let accepted_by_name = accepted
|
||||
.iter()
|
||||
.map(|property| (property.name.as_str(), property))
|
||||
.collect::<HashMap<_, _>>();
|
||||
let mut consumed = HashSet::new();
|
||||
let mut renames = HashMap::new();
|
||||
|
||||
for property in desired {
|
||||
let rename_from = rename_from_value(&property.annotations);
|
||||
let matched = accepted_by_name
|
||||
.get(property.name.as_str())
|
||||
.copied()
|
||||
.or_else(|| {
|
||||
rename_from.and_then(|from| {
|
||||
accepted_by_name
|
||||
.get(from)
|
||||
.copied()
|
||||
.filter(|candidate| candidate.name != property.name)
|
||||
})
|
||||
});
|
||||
|
||||
let Some(existing) = matched else {
|
||||
if let Some(from) = rename_from {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!(
|
||||
"{}:{}.{}",
|
||||
schema_type_kind_key(type_kind),
|
||||
type_name,
|
||||
property.name
|
||||
),
|
||||
reason: format!(
|
||||
"property '{}.{}' declares @rename_from(\"{}\") but no accepted property with that name exists",
|
||||
type_name, property.name, from
|
||||
),
|
||||
});
|
||||
} else if property.prop_type.nullable {
|
||||
steps.push(SchemaMigrationStep::AddProperty {
|
||||
type_kind,
|
||||
type_name: type_name.to_string(),
|
||||
property_name: property.name.clone(),
|
||||
property_type: property.prop_type.clone(),
|
||||
});
|
||||
} else {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!(
|
||||
"{}:{}.{}",
|
||||
schema_type_kind_key(type_kind),
|
||||
type_name,
|
||||
property.name
|
||||
),
|
||||
reason: format!(
|
||||
"adding required property '{}.{}' requires a backfill and is not supported in schema migration v1",
|
||||
type_name, property.name
|
||||
),
|
||||
});
|
||||
}
|
||||
continue;
|
||||
};
|
||||
|
||||
consumed.insert(existing.name.clone());
|
||||
if existing.name != property.name {
|
||||
renames.insert(existing.name.clone(), property.name.clone());
|
||||
steps.push(SchemaMigrationStep::RenameProperty {
|
||||
type_kind,
|
||||
type_name: type_name.to_string(),
|
||||
from: existing.name.clone(),
|
||||
to: property.name.clone(),
|
||||
});
|
||||
}
|
||||
|
||||
if existing.prop_type != property.prop_type {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!(
|
||||
"{}:{}.{}",
|
||||
schema_type_kind_key(type_kind),
|
||||
type_name,
|
||||
property.name
|
||||
),
|
||||
reason: format!(
|
||||
"changing property type for '{}.{}' is not supported in schema migration v1",
|
||||
type_name, property.name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
plan_property_metadata(
|
||||
type_kind,
|
||||
type_name,
|
||||
&property.name,
|
||||
&existing.annotations,
|
||||
&property.annotations,
|
||||
steps,
|
||||
);
|
||||
}
|
||||
|
||||
for leftover in accepted
|
||||
.iter()
|
||||
.filter(|property| !consumed.contains(&property.name))
|
||||
{
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!(
|
||||
"{}:{}.{}",
|
||||
schema_type_kind_key(type_kind),
|
||||
type_name,
|
||||
leftover.name
|
||||
),
|
||||
reason: format!(
|
||||
"removing property '{}.{}' is not supported in schema migration v1",
|
||||
type_name, leftover.name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
renames
|
||||
}
|
||||
|
||||
fn plan_constraints(
|
||||
type_kind: SchemaTypeKind,
|
||||
type_name: &str,
|
||||
accepted: &[Constraint],
|
||||
desired: &[Constraint],
|
||||
property_renames: &HashMap<String, String>,
|
||||
steps: &mut Vec<SchemaMigrationStep>,
|
||||
) {
|
||||
let accepted = accepted
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|constraint| rename_constraint_properties(constraint, property_renames))
|
||||
.collect::<Vec<_>>();
|
||||
let desired_map = desired
|
||||
.iter()
|
||||
.cloned()
|
||||
.map(|constraint| (constraint_key(&constraint), constraint))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
let accepted_map = accepted
|
||||
.into_iter()
|
||||
.map(|constraint| (constraint_key(&constraint), constraint))
|
||||
.collect::<BTreeMap<_, _>>();
|
||||
|
||||
let removed = accepted_map
|
||||
.keys()
|
||||
.filter(|key| !desired_map.contains_key(*key))
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
if !removed.is_empty() {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("{}:{}", schema_type_kind_key(type_kind), type_name),
|
||||
reason: format!(
|
||||
"removing constraints from '{}' is not supported in schema migration v1",
|
||||
type_name
|
||||
),
|
||||
});
|
||||
}
|
||||
|
||||
for (key, constraint) in desired_map {
|
||||
if accepted_map.contains_key(&key) {
|
||||
continue;
|
||||
}
|
||||
match constraint {
|
||||
Constraint::Index(_) => steps.push(SchemaMigrationStep::AddConstraint {
|
||||
type_kind,
|
||||
type_name: type_name.to_string(),
|
||||
constraint,
|
||||
}),
|
||||
_ => steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("{}:{}", schema_type_kind_key(type_kind), type_name),
|
||||
reason: format!(
|
||||
"adding constraint '{}' to '{}' is not supported in schema migration v1",
|
||||
key, type_name
|
||||
),
|
||||
}),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn plan_type_metadata(
|
||||
type_kind: SchemaTypeKind,
|
||||
name: &str,
|
||||
accepted: &[Annotation],
|
||||
desired: &[Annotation],
|
||||
steps: &mut Vec<SchemaMigrationStep>,
|
||||
) {
|
||||
match annotation_change_kind(accepted, desired) {
|
||||
AnnotationChangeKind::None => {}
|
||||
AnnotationChangeKind::MetadataOnly(metadata) => {
|
||||
steps.push(SchemaMigrationStep::UpdateTypeMetadata {
|
||||
type_kind,
|
||||
name: name.to_string(),
|
||||
annotations: metadata,
|
||||
});
|
||||
}
|
||||
AnnotationChangeKind::Unsupported(reason) => {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!("{}:{}", schema_type_kind_key(type_kind), name),
|
||||
reason,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn plan_property_metadata(
|
||||
type_kind: SchemaTypeKind,
|
||||
type_name: &str,
|
||||
property_name: &str,
|
||||
accepted: &[Annotation],
|
||||
desired: &[Annotation],
|
||||
steps: &mut Vec<SchemaMigrationStep>,
|
||||
) {
|
||||
match annotation_change_kind(accepted, desired) {
|
||||
AnnotationChangeKind::None => {}
|
||||
AnnotationChangeKind::MetadataOnly(metadata) => {
|
||||
steps.push(SchemaMigrationStep::UpdatePropertyMetadata {
|
||||
type_kind,
|
||||
type_name: type_name.to_string(),
|
||||
property_name: property_name.to_string(),
|
||||
annotations: metadata,
|
||||
});
|
||||
}
|
||||
AnnotationChangeKind::Unsupported(reason) => {
|
||||
steps.push(SchemaMigrationStep::UnsupportedChange {
|
||||
entity: format!(
|
||||
"{}:{}.{}",
|
||||
schema_type_kind_key(type_kind),
|
||||
type_name,
|
||||
property_name
|
||||
),
|
||||
reason,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
enum AnnotationChangeKind {
|
||||
None,
|
||||
MetadataOnly(Vec<Annotation>),
|
||||
Unsupported(String),
|
||||
}
|
||||
|
||||
fn annotation_change_kind(accepted: &[Annotation], desired: &[Annotation]) -> AnnotationChangeKind {
|
||||
let accepted_non_metadata = strip_metadata_annotations(accepted);
|
||||
let desired_non_metadata = strip_metadata_annotations(desired);
|
||||
if accepted_non_metadata != desired_non_metadata {
|
||||
return AnnotationChangeKind::Unsupported(
|
||||
"changing annotations beyond @description/@instruction is not supported in schema migration v1"
|
||||
.to_string(),
|
||||
);
|
||||
}
|
||||
|
||||
let accepted_metadata = metadata_annotations(accepted);
|
||||
let desired_metadata = metadata_annotations(desired);
|
||||
if accepted_metadata == desired_metadata {
|
||||
AnnotationChangeKind::None
|
||||
} else {
|
||||
AnnotationChangeKind::MetadataOnly(desired_metadata)
|
||||
}
|
||||
}
|
||||
|
||||
fn strip_metadata_annotations(annotations: &[Annotation]) -> Vec<Annotation> {
|
||||
annotations
|
||||
.iter()
|
||||
.filter(|annotation| {
|
||||
!matches!(
|
||||
annotation.name.as_str(),
|
||||
"description" | "instruction" | "rename_from" | "key" | "unique" | "index"
|
||||
)
|
||||
})
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn metadata_annotations(annotations: &[Annotation]) -> Vec<Annotation> {
|
||||
annotations
|
||||
.iter()
|
||||
.filter(|annotation| matches!(annotation.name.as_str(), "description" | "instruction"))
|
||||
.cloned()
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn normalize_strings(values: &[String], renames: &HashMap<String, String>) -> BTreeSet<String> {
|
||||
values
|
||||
.iter()
|
||||
.map(|value| normalize_type_ref(value, renames))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn normalize_type_ref(value: &str, renames: &HashMap<String, String>) -> String {
|
||||
renames
|
||||
.get(value)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| value.to_string())
|
||||
}
|
||||
|
||||
fn rename_constraint_properties(
|
||||
constraint: Constraint,
|
||||
property_renames: &HashMap<String, String>,
|
||||
) -> Constraint {
|
||||
match constraint {
|
||||
Constraint::Key(columns) => {
|
||||
Constraint::Key(rename_constraint_columns(columns, property_renames))
|
||||
}
|
||||
Constraint::Unique(columns) => {
|
||||
Constraint::Unique(rename_constraint_columns(columns, property_renames))
|
||||
}
|
||||
Constraint::Index(columns) => {
|
||||
Constraint::Index(rename_constraint_columns(columns, property_renames))
|
||||
}
|
||||
Constraint::Range { property, min, max } => Constraint::Range {
|
||||
property: normalize_property_ref(&property, property_renames),
|
||||
min,
|
||||
max,
|
||||
},
|
||||
Constraint::Check { property, pattern } => Constraint::Check {
|
||||
property: normalize_property_ref(&property, property_renames),
|
||||
pattern,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn rename_constraint_columns(
|
||||
columns: Vec<String>,
|
||||
property_renames: &HashMap<String, String>,
|
||||
) -> Vec<String> {
|
||||
let mut columns = columns
|
||||
.into_iter()
|
||||
.map(|column| normalize_property_ref(&column, property_renames))
|
||||
.collect::<Vec<_>>();
|
||||
columns.sort();
|
||||
columns
|
||||
}
|
||||
|
||||
fn normalize_property_ref(value: &str, renames: &HashMap<String, String>) -> String {
|
||||
renames
|
||||
.get(value)
|
||||
.cloned()
|
||||
.unwrap_or_else(|| value.to_string())
|
||||
}
|
||||
|
||||
fn constraint_key(constraint: &Constraint) -> String {
|
||||
match constraint {
|
||||
Constraint::Key(columns) => format!("key:{}", columns.join(",")),
|
||||
Constraint::Unique(columns) => format!("unique:{}", columns.join(",")),
|
||||
Constraint::Index(columns) => format!("index:{}", columns.join(",")),
|
||||
Constraint::Range { property, min, max } => {
|
||||
format!("range:{}:{:?}:{:?}", property, min, max)
|
||||
}
|
||||
Constraint::Check { property, pattern } => format!("check:{}:{}", property, pattern),
|
||||
}
|
||||
}
|
||||
|
||||
fn rename_from_value(annotations: &[Annotation]) -> Option<&str> {
|
||||
annotations
|
||||
.iter()
|
||||
.find(|annotation| annotation.name == "rename_from")
|
||||
.and_then(|annotation| annotation.value.as_deref())
|
||||
}
|
||||
|
||||
fn schema_type_kind_key(kind: SchemaTypeKind) -> &'static str {
|
||||
match kind {
|
||||
SchemaTypeKind::Interface => "interface",
|
||||
SchemaTypeKind::Node => "node",
|
||||
SchemaTypeKind::Edge => "edge",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use crate::catalog::schema_ir::build_schema_ir;
|
||||
use crate::schema::parser::parse_schema;
|
||||
|
||||
use super::SchemaMigrationStep::{
|
||||
AddConstraint, AddProperty, RenameProperty, RenameType, UnsupportedChange,
|
||||
UpdateTypeMetadata,
|
||||
};
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn plan_supports_additive_nullable_property_and_index() {
|
||||
let accepted = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Person {
|
||||
name: String @key
|
||||
age: I32?
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let desired = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Person {
|
||||
name: String @key
|
||||
age: I32? @index
|
||||
nickname: String?
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let plan = plan_schema_migration(&accepted, &desired).unwrap();
|
||||
assert!(plan.supported);
|
||||
assert!(plan.steps.contains(&AddProperty {
|
||||
type_kind: SchemaTypeKind::Node,
|
||||
type_name: "Person".to_string(),
|
||||
property_name: "nickname".to_string(),
|
||||
property_type: PropType::scalar(crate::types::ScalarType::String, true),
|
||||
}));
|
||||
assert!(plan.steps.contains(&AddConstraint {
|
||||
type_kind: SchemaTypeKind::Node,
|
||||
type_name: "Person".to_string(),
|
||||
constraint: Constraint::Index(vec!["age".to_string()]),
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_supports_explicit_type_and_property_rename() {
|
||||
let accepted = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node User {
|
||||
name: String @key
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let desired = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Account @rename_from("User") {
|
||||
full_name: String @key @rename_from("name")
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let plan = plan_schema_migration(&accepted, &desired).unwrap();
|
||||
assert!(plan.supported);
|
||||
assert!(plan.steps.contains(&RenameType {
|
||||
type_kind: SchemaTypeKind::Node,
|
||||
from: "User".to_string(),
|
||||
to: "Account".to_string(),
|
||||
}));
|
||||
assert!(plan.steps.contains(&RenameProperty {
|
||||
type_kind: SchemaTypeKind::Node,
|
||||
type_name: "Account".to_string(),
|
||||
from: "name".to_string(),
|
||||
to: "full_name".to_string(),
|
||||
}));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_rejects_required_property_addition() {
|
||||
let accepted = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Person {
|
||||
name: String @key
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let desired = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Person {
|
||||
name: String @key
|
||||
age: I32
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let plan = plan_schema_migration(&accepted, &desired).unwrap();
|
||||
assert!(!plan.supported);
|
||||
assert!(plan.steps.iter().any(|step| matches!(
|
||||
step,
|
||||
UnsupportedChange { entity, reason }
|
||||
if entity.contains("Person.age")
|
||||
&& reason.contains("adding required property")
|
||||
)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn plan_supports_metadata_only_annotation_changes() {
|
||||
let accepted = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Person @description("old") {
|
||||
name: String @key
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let desired = build_schema_ir(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Person @description("new") {
|
||||
name: String @key
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let plan = plan_schema_migration(&accepted, &desired).unwrap();
|
||||
assert!(plan.supported);
|
||||
assert!(plan.steps.contains(&UpdateTypeMetadata {
|
||||
type_kind: SchemaTypeKind::Node,
|
||||
name: "Person".to_string(),
|
||||
annotations: vec![Annotation {
|
||||
name: "description".to_string(),
|
||||
value: Some("new".to_string()),
|
||||
}],
|
||||
}));
|
||||
}
|
||||
}
|
||||
379
crates/omnigraph-compiler/src/embedding.rs
Normal file
379
crates/omnigraph-compiler/src/embedding.rs
Normal file
|
|
@ -0,0 +1,379 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::error::{NanoError, Result};
|
||||
|
||||
const DEFAULT_EMBED_MODEL: &str = "text-embedding-3-small";
|
||||
const DEFAULT_OPENAI_BASE_URL: &str = "https://api.openai.com/v1";
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
|
||||
const DEFAULT_RETRY_ATTEMPTS: usize = 4;
|
||||
const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
|
||||
|
||||
#[derive(Clone)]
|
||||
enum EmbeddingTransport {
|
||||
Mock,
|
||||
OpenAi {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
http: Client,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct EmbeddingClient {
|
||||
model: String,
|
||||
retry_attempts: usize,
|
||||
retry_backoff_ms: u64,
|
||||
transport: EmbeddingTransport,
|
||||
}
|
||||
|
||||
struct EmbedCallError {
|
||||
message: String,
|
||||
retryable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbeddingResponse {
|
||||
data: Vec<OpenAiEmbeddingDatum>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiEmbeddingDatum {
|
||||
index: usize,
|
||||
embedding: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiErrorEnvelope {
|
||||
error: OpenAiErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct OpenAiErrorBody {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl EmbeddingClient {
|
||||
pub(crate) fn from_env() -> Result<Self> {
|
||||
let model = std::env::var("NANOGRAPH_EMBED_MODEL")
|
||||
.ok()
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.unwrap_or_else(|| DEFAULT_EMBED_MODEL.to_string());
|
||||
let retry_attempts =
|
||||
parse_env_usize("NANOGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
|
||||
let retry_backoff_ms =
|
||||
parse_env_u64("NANOGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
|
||||
|
||||
if env_flag("NANOGRAPH_EMBEDDINGS_MOCK") {
|
||||
return Ok(Self {
|
||||
model,
|
||||
retry_attempts,
|
||||
retry_backoff_ms,
|
||||
transport: EmbeddingTransport::Mock,
|
||||
});
|
||||
}
|
||||
|
||||
let api_key = std::env::var("OPENAI_API_KEY")
|
||||
.ok()
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| {
|
||||
NanoError::Execution(
|
||||
"OPENAI_API_KEY is required when an embedding call is needed".to_string(),
|
||||
)
|
||||
})?;
|
||||
let base_url = std::env::var("OPENAI_BASE_URL")
|
||||
.ok()
|
||||
.map(|v| v.trim_end_matches('/').to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.unwrap_or_else(|| DEFAULT_OPENAI_BASE_URL.to_string());
|
||||
let timeout_ms = parse_env_u64("NANOGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
|
||||
let http = Client::builder()
|
||||
.timeout(Duration::from_millis(timeout_ms))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
NanoError::Execution(format!("failed to initialize HTTP client: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
model,
|
||||
retry_attempts,
|
||||
retry_backoff_ms,
|
||||
transport: EmbeddingTransport::OpenAi {
|
||||
api_key,
|
||||
base_url,
|
||||
http,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn mock_for_tests() -> Self {
|
||||
Self {
|
||||
model: DEFAULT_EMBED_MODEL.to_string(),
|
||||
retry_attempts: DEFAULT_RETRY_ATTEMPTS,
|
||||
retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
|
||||
transport: EmbeddingTransport::Mock,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn model(&self) -> &str {
|
||||
&self.model
|
||||
}
|
||||
|
||||
pub(crate) async fn embed_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
|
||||
let mut vectors = self.embed_texts(&[input.to_string()], expected_dim).await?;
|
||||
vectors.pop().ok_or_else(|| {
|
||||
NanoError::Execution("embedding provider returned no vector".to_string())
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn embed_texts(
|
||||
&self,
|
||||
inputs: &[String],
|
||||
expected_dim: usize,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
if expected_dim == 0 {
|
||||
return Err(NanoError::Execution(
|
||||
"embedding dimension must be greater than zero".to_string(),
|
||||
));
|
||||
}
|
||||
if inputs.is_empty() {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
|
||||
match &self.transport {
|
||||
EmbeddingTransport::Mock => Ok(inputs
|
||||
.iter()
|
||||
.map(|input| mock_embedding(input, expected_dim))
|
||||
.collect()),
|
||||
EmbeddingTransport::OpenAi { .. } => {
|
||||
self.embed_texts_openai_with_retry(inputs, expected_dim)
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn embed_texts_openai_with_retry(
|
||||
&self,
|
||||
inputs: &[String],
|
||||
expected_dim: usize,
|
||||
) -> Result<Vec<Vec<f32>>> {
|
||||
let max_attempt = self.retry_attempts.max(1);
|
||||
let mut attempt = 0usize;
|
||||
loop {
|
||||
attempt += 1;
|
||||
match self.embed_texts_openai_once(inputs, expected_dim).await {
|
||||
Ok(vectors) => return Ok(vectors),
|
||||
Err(err) => {
|
||||
if !err.retryable || attempt >= max_attempt {
|
||||
return Err(NanoError::Execution(err.message));
|
||||
}
|
||||
let shift = (attempt - 1).min(10) as u32;
|
||||
let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn embed_texts_openai_once(
|
||||
&self,
|
||||
inputs: &[String],
|
||||
expected_dim: usize,
|
||||
) -> std::result::Result<Vec<Vec<f32>>, EmbedCallError> {
|
||||
let (api_key, base_url, http) = match &self.transport {
|
||||
EmbeddingTransport::OpenAi {
|
||||
api_key,
|
||||
base_url,
|
||||
http,
|
||||
} => (api_key, base_url, http),
|
||||
EmbeddingTransport::Mock => unreachable!("mock transport should not call OpenAI"),
|
||||
};
|
||||
|
||||
let request = serde_json::json!({
|
||||
"model": self.model,
|
||||
"input": inputs,
|
||||
"dimensions": expected_dim,
|
||||
});
|
||||
let url = format!("{}/embeddings", base_url);
|
||||
let response = http
|
||||
.post(&url)
|
||||
.bearer_auth(api_key)
|
||||
.json(&request)
|
||||
.send()
|
||||
.await;
|
||||
|
||||
let response = match response {
|
||||
Ok(resp) => resp,
|
||||
Err(err) => {
|
||||
let retryable = err.is_timeout() || err.is_connect() || err.is_request();
|
||||
return Err(EmbedCallError {
|
||||
message: format!("embedding request failed: {}", err),
|
||||
retryable,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
let body = match response.text().await {
|
||||
Ok(body) => body,
|
||||
Err(err) => {
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding response read failed (status {}): {}",
|
||||
status, err
|
||||
),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if !status.is_success() {
|
||||
let message = parse_openai_error_message(&body).unwrap_or_else(|| body.clone());
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding request failed with status {}: {}",
|
||||
status, message
|
||||
),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
|
||||
let mut parsed: OpenAiEmbeddingResponse =
|
||||
serde_json::from_str(&body).map_err(|err| EmbedCallError {
|
||||
message: format!("embedding response decode failed: {}", err),
|
||||
retryable: false,
|
||||
})?;
|
||||
|
||||
if parsed.data.len() != inputs.len() {
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding response size mismatch: expected {}, got {}",
|
||||
inputs.len(),
|
||||
parsed.data.len()
|
||||
),
|
||||
retryable: false,
|
||||
});
|
||||
}
|
||||
|
||||
parsed.data.sort_by_key(|item| item.index);
|
||||
let mut vectors = Vec::with_capacity(parsed.data.len());
|
||||
for (idx, item) in parsed.data.into_iter().enumerate() {
|
||||
if item.index != idx {
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding response index mismatch at position {}: got {}",
|
||||
idx, item.index
|
||||
),
|
||||
retryable: false,
|
||||
});
|
||||
}
|
||||
if item.embedding.len() != expected_dim {
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding dimension mismatch: expected {}, got {}",
|
||||
expected_dim,
|
||||
item.embedding.len()
|
||||
),
|
||||
retryable: false,
|
||||
});
|
||||
}
|
||||
vectors.push(item.embedding);
|
||||
}
|
||||
Ok(vectors)
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_openai_error_message(body: &str) -> Option<String> {
|
||||
serde_json::from_str::<OpenAiErrorEnvelope>(body)
|
||||
.ok()
|
||||
.map(|e| e.error.message)
|
||||
.filter(|msg| !msg.trim().is_empty())
|
||||
}
|
||||
|
||||
fn parse_env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<usize>().ok())
|
||||
.filter(|v| *v > 0)
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn parse_env_u64(name: &str, default: u64) -> u64 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.filter(|v| *v > 0)
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_flag(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.map(|v| {
|
||||
let s = v.trim().to_ascii_lowercase();
|
||||
s == "1" || s == "true" || s == "yes" || s == "on"
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
|
||||
let mut seed = fnv1a64(input.as_bytes());
|
||||
let mut out = Vec::with_capacity(dim);
|
||||
for _ in 0..dim {
|
||||
seed = xorshift64(seed);
|
||||
let ratio = (seed as f64 / u64::MAX as f64) as f32;
|
||||
out.push((ratio * 2.0) - 1.0);
|
||||
}
|
||||
|
||||
let norm = out
|
||||
.iter()
|
||||
.map(|v| (*v as f64) * (*v as f64))
|
||||
.sum::<f64>()
|
||||
.sqrt() as f32;
|
||||
if norm > f32::EPSILON {
|
||||
for value in &mut out {
|
||||
*value /= norm;
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
fn fnv1a64(bytes: &[u8]) -> u64 {
|
||||
let mut hash = 14695981039346656037u64;
|
||||
for byte in bytes {
|
||||
hash ^= *byte as u64;
|
||||
hash = hash.wrapping_mul(1099511628211u64);
|
||||
}
|
||||
hash
|
||||
}
|
||||
|
||||
fn xorshift64(mut x: u64) -> u64 {
|
||||
x ^= x << 13;
|
||||
x ^= x >> 7;
|
||||
x ^= x << 17;
|
||||
x
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_embeddings_are_deterministic() {
|
||||
let client = EmbeddingClient::mock_for_tests();
|
||||
let a = client.embed_text("alpha", 8).await.unwrap();
|
||||
let b = client.embed_text("alpha", 8).await.unwrap();
|
||||
let c = client.embed_text("beta", 8).await.unwrap();
|
||||
assert_eq!(a, b);
|
||||
assert_ne!(a, c);
|
||||
assert_eq!(a.len(), 8);
|
||||
}
|
||||
}
|
||||
146
crates/omnigraph-compiler/src/error.rs
Normal file
146
crates/omnigraph-compiler/src/error.rs
Normal file
|
|
@ -0,0 +1,146 @@
|
|||
use thiserror::Error;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub struct SourceSpan {
|
||||
pub start: usize,
|
||||
pub end: usize,
|
||||
}
|
||||
|
||||
impl SourceSpan {
|
||||
pub fn new(start: usize, end: usize) -> Self {
|
||||
Self { start, end }
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ParseDiagnostic {
|
||||
pub message: String,
|
||||
pub span: Option<SourceSpan>,
|
||||
}
|
||||
|
||||
impl ParseDiagnostic {
|
||||
pub fn new(message: String, span: Option<SourceSpan>) -> Self {
|
||||
Self { message, span }
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ParseDiagnostic {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "{}", self.message)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::error::Error for ParseDiagnostic {}
|
||||
|
||||
pub fn render_span(span: SourceSpan) -> SourceSpan {
|
||||
SourceSpan {
|
||||
start: span.start,
|
||||
end: span.end.max(span.start.saturating_add(1)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn decode_string_literal(raw: &str) -> Result<String> {
|
||||
let inner = raw
|
||||
.strip_prefix('"')
|
||||
.and_then(|inner| inner.strip_suffix('"'))
|
||||
.unwrap_or(raw);
|
||||
|
||||
let mut decoded = String::with_capacity(inner.len());
|
||||
let mut chars = inner.chars();
|
||||
while let Some(ch) = chars.next() {
|
||||
if ch != '\\' {
|
||||
decoded.push(ch);
|
||||
continue;
|
||||
}
|
||||
|
||||
let escaped = chars
|
||||
.next()
|
||||
.ok_or_else(|| NanoError::Parse("unterminated escape sequence".to_string()))?;
|
||||
match escaped {
|
||||
'"' => decoded.push('"'),
|
||||
'\\' => decoded.push('\\'),
|
||||
'n' => decoded.push('\n'),
|
||||
'r' => decoded.push('\r'),
|
||||
't' => decoded.push('\t'),
|
||||
other => {
|
||||
return Err(NanoError::Parse(format!(
|
||||
"unsupported escape sequence: \\{}",
|
||||
other
|
||||
)));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(decoded)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum NanoError {
|
||||
#[error("parse error: {0}")]
|
||||
Parse(String),
|
||||
|
||||
#[error("catalog error: {0}")]
|
||||
Catalog(String),
|
||||
|
||||
#[error("type error: {0}")]
|
||||
Type(String),
|
||||
|
||||
#[error("storage error: {0}")]
|
||||
Storage(String),
|
||||
|
||||
#[error(
|
||||
"@unique constraint violation on {type_name}.{property}: duplicate value '{value}' at rows {first_row} and {second_row}"
|
||||
)]
|
||||
UniqueConstraint {
|
||||
type_name: String,
|
||||
property: String,
|
||||
value: String,
|
||||
first_row: usize,
|
||||
second_row: usize,
|
||||
},
|
||||
|
||||
#[error("plan error: {0}")]
|
||||
Plan(String),
|
||||
|
||||
#[error("execution error: {0}")]
|
||||
Execution(String),
|
||||
|
||||
#[error(transparent)]
|
||||
Arrow(#[from] arrow_schema::ArrowError),
|
||||
|
||||
#[error("io error: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
|
||||
#[error("lance error: {0}")]
|
||||
Lance(String),
|
||||
|
||||
#[error("manifest error: {0}")]
|
||||
Manifest(String),
|
||||
}
|
||||
|
||||
pub type Result<T> = std::result::Result<T, NanoError>;
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{SourceSpan, decode_string_literal, render_span};
|
||||
|
||||
#[test]
|
||||
fn source_span_preserves_zero_width() {
|
||||
let span = SourceSpan::new(7, 7);
|
||||
assert_eq!(span.start, 7);
|
||||
assert_eq!(span.end, 7);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn render_span_widens_zero_width_for_diagnostics() {
|
||||
let rendered = render_span(SourceSpan::new(7, 7));
|
||||
assert_eq!(rendered.start, 7);
|
||||
assert_eq!(rendered.end, 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn decode_string_literal_supports_common_escapes() {
|
||||
let decoded = decode_string_literal("\"a\\n\\r\\t\\\\\\\"b\"").unwrap();
|
||||
assert_eq!(decoded, "a\n\r\t\\\"b");
|
||||
}
|
||||
}
|
||||
657
crates/omnigraph-compiler/src/ir/lower.rs
Normal file
657
crates/omnigraph-compiler/src/ir/lower.rs
Normal file
|
|
@ -0,0 +1,657 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use crate::catalog::Catalog;
|
||||
use crate::error::Result;
|
||||
use crate::query::ast::*;
|
||||
use crate::query::typecheck::TypeContext;
|
||||
use crate::types::Direction;
|
||||
|
||||
use super::*;
|
||||
|
||||
pub fn lower_query(
|
||||
catalog: &Catalog,
|
||||
query: &QueryDecl,
|
||||
type_ctx: &TypeContext,
|
||||
) -> Result<QueryIR> {
|
||||
if query.mutation.is_some() {
|
||||
return Err(crate::error::NanoError::Plan(
|
||||
"cannot lower mutation query with read-query lowerer".to_string(),
|
||||
));
|
||||
}
|
||||
let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
|
||||
|
||||
let mut pipeline = Vec::new();
|
||||
let mut bound_vars = HashSet::new();
|
||||
|
||||
lower_clauses(
|
||||
catalog,
|
||||
&query.match_clause,
|
||||
type_ctx,
|
||||
&mut pipeline,
|
||||
&mut bound_vars,
|
||||
¶m_names,
|
||||
)?;
|
||||
|
||||
let return_exprs: Vec<IRProjection> = query
|
||||
.return_clause
|
||||
.iter()
|
||||
.map(|p| IRProjection {
|
||||
expr: lower_expr(&p.expr, ¶m_names),
|
||||
alias: p.alias.clone(),
|
||||
})
|
||||
.collect();
|
||||
|
||||
let order_by: Vec<IROrdering> = query
|
||||
.order_clause
|
||||
.iter()
|
||||
.map(|o| IROrdering {
|
||||
expr: lower_expr(&o.expr, ¶m_names),
|
||||
descending: o.descending,
|
||||
})
|
||||
.collect();
|
||||
|
||||
Ok(QueryIR {
|
||||
name: query.name.clone(),
|
||||
params: query.params.clone(),
|
||||
pipeline,
|
||||
return_exprs,
|
||||
order_by,
|
||||
limit: query.limit,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn lower_mutation_query(query: &QueryDecl) -> Result<MutationIR> {
|
||||
let mutation = query.mutation.as_ref().ok_or_else(|| {
|
||||
crate::error::NanoError::Plan("query does not contain a mutation body".to_string())
|
||||
})?;
|
||||
let param_names: HashSet<String> = query.params.iter().map(|p| p.name.clone()).collect();
|
||||
|
||||
let op = match mutation {
|
||||
Mutation::Insert(insert) => MutationOpIR::Insert {
|
||||
type_name: insert.type_name.clone(),
|
||||
assignments: insert
|
||||
.assignments
|
||||
.iter()
|
||||
.map(|a| IRAssignment {
|
||||
property: a.property.clone(),
|
||||
value: lower_match_value(&a.value, ¶m_names),
|
||||
})
|
||||
.collect(),
|
||||
},
|
||||
Mutation::Update(update) => MutationOpIR::Update {
|
||||
type_name: update.type_name.clone(),
|
||||
assignments: update
|
||||
.assignments
|
||||
.iter()
|
||||
.map(|a| IRAssignment {
|
||||
property: a.property.clone(),
|
||||
value: lower_match_value(&a.value, ¶m_names),
|
||||
})
|
||||
.collect(),
|
||||
predicate: IRMutationPredicate {
|
||||
property: update.predicate.property.clone(),
|
||||
op: update.predicate.op,
|
||||
value: lower_match_value(&update.predicate.value, ¶m_names),
|
||||
},
|
||||
},
|
||||
Mutation::Delete(delete) => MutationOpIR::Delete {
|
||||
type_name: delete.type_name.clone(),
|
||||
predicate: IRMutationPredicate {
|
||||
property: delete.predicate.property.clone(),
|
||||
op: delete.predicate.op,
|
||||
value: lower_match_value(&delete.predicate.value, ¶m_names),
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
Ok(MutationIR {
|
||||
name: query.name.clone(),
|
||||
params: query.params.clone(),
|
||||
op,
|
||||
})
|
||||
}
|
||||
|
||||
fn lower_clauses(
|
||||
catalog: &Catalog,
|
||||
clauses: &[Clause],
|
||||
type_ctx: &TypeContext,
|
||||
pipeline: &mut Vec<IROp>,
|
||||
bound_vars: &mut HashSet<String>,
|
||||
param_names: &HashSet<String>,
|
||||
) -> Result<()> {
|
||||
// Separate clause types for ordering: bindings first, then traversals, then filters
|
||||
let mut bindings = Vec::new();
|
||||
let mut traversals = Vec::new();
|
||||
let mut filters = Vec::new();
|
||||
let mut negations = Vec::new();
|
||||
|
||||
for clause in clauses {
|
||||
match clause {
|
||||
Clause::Binding(b) => bindings.push(b),
|
||||
Clause::Traversal(t) => traversals.push(t),
|
||||
Clause::Filter(f) => filters.push(f),
|
||||
Clause::Negation(inner) => negations.push(inner),
|
||||
}
|
||||
}
|
||||
|
||||
// Lower bindings into NodeScan ops
|
||||
for binding in &bindings {
|
||||
let node_type = catalog
|
||||
.node_types
|
||||
.get(&binding.type_name)
|
||||
.expect("binding type was validated during typecheck");
|
||||
// Collect inline filters from prop matches
|
||||
let mut scan_filters = Vec::new();
|
||||
for pm in &binding.prop_matches {
|
||||
let prop = node_type
|
||||
.properties
|
||||
.get(&pm.prop_name)
|
||||
.expect("binding property was validated during typecheck");
|
||||
let op = if prop.list {
|
||||
CompOp::Contains
|
||||
} else {
|
||||
CompOp::Eq
|
||||
};
|
||||
match &pm.value {
|
||||
MatchValue::Literal(lit) => {
|
||||
scan_filters.push(IRFilter {
|
||||
left: IRExpr::PropAccess {
|
||||
variable: binding.variable.clone(),
|
||||
property: pm.prop_name.clone(),
|
||||
},
|
||||
op,
|
||||
right: IRExpr::Literal(lit.clone()),
|
||||
});
|
||||
}
|
||||
MatchValue::Now => {
|
||||
scan_filters.push(IRFilter {
|
||||
left: IRExpr::PropAccess {
|
||||
variable: binding.variable.clone(),
|
||||
property: pm.prop_name.clone(),
|
||||
},
|
||||
op,
|
||||
right: IRExpr::Param(NOW_PARAM_NAME.to_string()),
|
||||
});
|
||||
}
|
||||
MatchValue::Variable(v) => {
|
||||
let right = if param_names.contains(v) {
|
||||
IRExpr::Param(v.clone())
|
||||
} else {
|
||||
IRExpr::Variable(v.clone())
|
||||
};
|
||||
scan_filters.push(IRFilter {
|
||||
left: IRExpr::PropAccess {
|
||||
variable: binding.variable.clone(),
|
||||
property: pm.prop_name.clone(),
|
||||
},
|
||||
op,
|
||||
right,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pipeline.push(IROp::NodeScan {
|
||||
variable: binding.variable.clone(),
|
||||
type_name: binding.type_name.clone(),
|
||||
filters: scan_filters,
|
||||
});
|
||||
bound_vars.insert(binding.variable.clone());
|
||||
}
|
||||
|
||||
// Lower traversals into Expand ops
|
||||
// Handle "cycle closing" — if both src and dst are already bound, use a filter
|
||||
for traversal in &traversals {
|
||||
let edge = catalog
|
||||
.lookup_edge_by_name(&traversal.edge_name)
|
||||
.ok_or_else(|| {
|
||||
crate::error::NanoError::Plan(format!(
|
||||
"lowering traversal referenced missing edge '{}' after typecheck",
|
||||
traversal.edge_name
|
||||
))
|
||||
})?;
|
||||
|
||||
// Determine direction from type context
|
||||
let direction = type_ctx
|
||||
.traversals
|
||||
.iter()
|
||||
.find(|rt| {
|
||||
rt.src == traversal.src && rt.dst == traversal.dst && rt.edge_type == edge.name
|
||||
})
|
||||
.map(|rt| rt.direction)
|
||||
.unwrap_or(Direction::Out);
|
||||
|
||||
let dst_type = match direction {
|
||||
Direction::Out => edge.to_type.clone(),
|
||||
Direction::In => edge.from_type.clone(),
|
||||
};
|
||||
|
||||
if bound_vars.contains(&traversal.src) && bound_vars.contains(&traversal.dst) {
|
||||
// Cycle closing: emit expand to a temp var, then filter temp.id = dst.id
|
||||
let temp_var = format!("__temp_{}", traversal.dst);
|
||||
pipeline.push(IROp::Expand {
|
||||
src_var: traversal.src.clone(),
|
||||
dst_var: temp_var.clone(),
|
||||
edge_type: edge.name.clone(),
|
||||
direction,
|
||||
dst_type,
|
||||
min_hops: traversal.min_hops,
|
||||
max_hops: traversal.max_hops,
|
||||
});
|
||||
pipeline.push(IROp::Filter(IRFilter {
|
||||
left: IRExpr::PropAccess {
|
||||
variable: temp_var,
|
||||
property: "id".to_string(),
|
||||
},
|
||||
op: CompOp::Eq,
|
||||
right: IRExpr::PropAccess {
|
||||
variable: traversal.dst.clone(),
|
||||
property: "id".to_string(),
|
||||
},
|
||||
}));
|
||||
} else if !bound_vars.contains(&traversal.src) && bound_vars.contains(&traversal.dst) {
|
||||
// Reverse expand: dst is bound, src is not.
|
||||
// Swap direction and expand from dst to discover src.
|
||||
let reverse_dir = match direction {
|
||||
Direction::Out => Direction::In,
|
||||
Direction::In => Direction::Out,
|
||||
};
|
||||
let src_type = match direction {
|
||||
Direction::Out => edge.from_type.clone(),
|
||||
Direction::In => edge.to_type.clone(),
|
||||
};
|
||||
pipeline.push(IROp::Expand {
|
||||
src_var: traversal.dst.clone(),
|
||||
dst_var: traversal.src.clone(),
|
||||
edge_type: edge.name.clone(),
|
||||
direction: reverse_dir,
|
||||
dst_type: src_type,
|
||||
min_hops: traversal.min_hops,
|
||||
max_hops: traversal.max_hops,
|
||||
});
|
||||
if traversal.src != "_" {
|
||||
bound_vars.insert(traversal.src.clone());
|
||||
}
|
||||
} else {
|
||||
pipeline.push(IROp::Expand {
|
||||
src_var: traversal.src.clone(),
|
||||
dst_var: traversal.dst.clone(),
|
||||
edge_type: edge.name.clone(),
|
||||
direction,
|
||||
dst_type,
|
||||
min_hops: traversal.min_hops,
|
||||
max_hops: traversal.max_hops,
|
||||
});
|
||||
if traversal.dst != "_" {
|
||||
bound_vars.insert(traversal.dst.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Lower explicit filters
|
||||
for filter in &filters {
|
||||
pipeline.push(IROp::Filter(IRFilter {
|
||||
left: lower_expr(&filter.left, param_names),
|
||||
op: filter.op,
|
||||
right: lower_expr(&filter.right, param_names),
|
||||
}));
|
||||
}
|
||||
|
||||
// Lower negations into AntiJoin ops
|
||||
for neg_clauses in &negations {
|
||||
// Find outer-bound variable referenced in the negation
|
||||
let outer_var = find_outer_var(neg_clauses, bound_vars);
|
||||
|
||||
let mut inner_pipeline = Vec::new();
|
||||
let mut inner_bound = bound_vars.clone();
|
||||
lower_clauses(
|
||||
catalog,
|
||||
neg_clauses,
|
||||
type_ctx,
|
||||
&mut inner_pipeline,
|
||||
&mut inner_bound,
|
||||
param_names,
|
||||
)?;
|
||||
|
||||
pipeline.push(IROp::AntiJoin {
|
||||
outer_var: outer_var.unwrap_or_default(),
|
||||
inner: inner_pipeline,
|
||||
});
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn find_outer_var(clauses: &[Clause], outer_bound: &HashSet<String>) -> Option<String> {
|
||||
for clause in clauses {
|
||||
match clause {
|
||||
Clause::Traversal(t) => {
|
||||
if outer_bound.contains(&t.src) {
|
||||
return Some(t.src.clone());
|
||||
}
|
||||
if outer_bound.contains(&t.dst) {
|
||||
return Some(t.dst.clone());
|
||||
}
|
||||
}
|
||||
Clause::Filter(f) => {
|
||||
if let Some(v) = expr_var(&f.left)
|
||||
&& outer_bound.contains(&v)
|
||||
{
|
||||
return Some(v);
|
||||
}
|
||||
if let Some(v) = expr_var(&f.right)
|
||||
&& outer_bound.contains(&v)
|
||||
{
|
||||
return Some(v);
|
||||
}
|
||||
}
|
||||
Clause::Binding(b) => {
|
||||
if outer_bound.contains(&b.variable) {
|
||||
return Some(b.variable.clone());
|
||||
}
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
None
|
||||
}
|
||||
|
||||
fn expr_var(expr: &Expr) -> Option<String> {
|
||||
match expr {
|
||||
Expr::Now => None,
|
||||
Expr::PropAccess { variable, .. } => Some(variable.clone()),
|
||||
Expr::Variable(v) => Some(v.clone()),
|
||||
Expr::Nearest { variable, .. } => Some(variable.clone()),
|
||||
Expr::Search { field, query } => expr_var(field).or_else(|| expr_var(query)),
|
||||
Expr::Fuzzy {
|
||||
field,
|
||||
query,
|
||||
max_edits,
|
||||
} => expr_var(field)
|
||||
.or_else(|| expr_var(query))
|
||||
.or_else(|| max_edits.as_deref().and_then(expr_var)),
|
||||
Expr::MatchText { field, query } => expr_var(field).or_else(|| expr_var(query)),
|
||||
Expr::Bm25 { field, query } => expr_var(field).or_else(|| expr_var(query)),
|
||||
Expr::Rrf {
|
||||
primary,
|
||||
secondary,
|
||||
k,
|
||||
} => expr_var(primary)
|
||||
.or_else(|| expr_var(secondary))
|
||||
.or_else(|| k.as_deref().and_then(expr_var)),
|
||||
Expr::Aggregate { arg, .. } => expr_var(arg),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn lower_expr(expr: &Expr, param_names: &HashSet<String>) -> IRExpr {
|
||||
match expr {
|
||||
Expr::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
|
||||
Expr::PropAccess { variable, property } => IRExpr::PropAccess {
|
||||
variable: variable.clone(),
|
||||
property: property.clone(),
|
||||
},
|
||||
Expr::Nearest {
|
||||
variable,
|
||||
property,
|
||||
query,
|
||||
} => IRExpr::Nearest {
|
||||
variable: variable.clone(),
|
||||
property: property.clone(),
|
||||
query: Box::new(lower_expr(query, param_names)),
|
||||
},
|
||||
Expr::Search { field, query } => IRExpr::Search {
|
||||
field: Box::new(lower_expr(field, param_names)),
|
||||
query: Box::new(lower_expr(query, param_names)),
|
||||
},
|
||||
Expr::Fuzzy {
|
||||
field,
|
||||
query,
|
||||
max_edits,
|
||||
} => IRExpr::Fuzzy {
|
||||
field: Box::new(lower_expr(field, param_names)),
|
||||
query: Box::new(lower_expr(query, param_names)),
|
||||
max_edits: max_edits
|
||||
.as_ref()
|
||||
.map(|expr| Box::new(lower_expr(expr, param_names))),
|
||||
},
|
||||
Expr::MatchText { field, query } => IRExpr::MatchText {
|
||||
field: Box::new(lower_expr(field, param_names)),
|
||||
query: Box::new(lower_expr(query, param_names)),
|
||||
},
|
||||
Expr::Bm25 { field, query } => IRExpr::Bm25 {
|
||||
field: Box::new(lower_expr(field, param_names)),
|
||||
query: Box::new(lower_expr(query, param_names)),
|
||||
},
|
||||
Expr::Rrf {
|
||||
primary,
|
||||
secondary,
|
||||
k,
|
||||
} => IRExpr::Rrf {
|
||||
primary: Box::new(lower_expr(primary, param_names)),
|
||||
secondary: Box::new(lower_expr(secondary, param_names)),
|
||||
k: k.as_ref()
|
||||
.map(|expr| Box::new(lower_expr(expr, param_names))),
|
||||
},
|
||||
Expr::Variable(v) => {
|
||||
if param_names.contains(v) {
|
||||
IRExpr::Param(v.clone())
|
||||
} else {
|
||||
IRExpr::Variable(v.clone())
|
||||
}
|
||||
}
|
||||
Expr::Literal(l) => IRExpr::Literal(l.clone()),
|
||||
Expr::Aggregate { func, arg } => IRExpr::Aggregate {
|
||||
func: *func,
|
||||
arg: Box::new(lower_expr(arg, param_names)),
|
||||
},
|
||||
Expr::AliasRef(name) => IRExpr::AliasRef(name.clone()),
|
||||
}
|
||||
}
|
||||
|
||||
fn lower_match_value(value: &MatchValue, param_names: &HashSet<String>) -> IRExpr {
|
||||
match value {
|
||||
MatchValue::Now => IRExpr::Param(NOW_PARAM_NAME.to_string()),
|
||||
MatchValue::Literal(l) => IRExpr::Literal(l.clone()),
|
||||
MatchValue::Variable(v) => {
|
||||
if param_names.contains(v) {
|
||||
IRExpr::Param(v.clone())
|
||||
} else {
|
||||
IRExpr::Variable(v.clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use crate::catalog::build_catalog;
|
||||
use crate::query::parser::parse_query;
|
||||
use crate::query::typecheck::{CheckedQuery, typecheck_query, typecheck_query_decl};
|
||||
use crate::schema::parser::parse_schema;
|
||||
|
||||
fn setup() -> Catalog {
|
||||
let schema = parse_schema(
|
||||
r#"
|
||||
node Person { name: String age: I32? }
|
||||
node Company { name: String }
|
||||
edge Knows: Person -> Person { since: Date? }
|
||||
edge WorksAt: Person -> Company
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
build_catalog(&schema).unwrap()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_basic() {
|
||||
let catalog = setup();
|
||||
let qf = parse_query(
|
||||
r#"
|
||||
query q($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
$p knows $f
|
||||
}
|
||||
return { $f.name, $f.age }
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
|
||||
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
|
||||
|
||||
assert_eq!(ir.pipeline.len(), 2); // NodeScan + Expand
|
||||
assert_eq!(ir.return_exprs.len(), 2);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_negation() {
|
||||
let catalog = setup();
|
||||
let qf = parse_query(
|
||||
r#"
|
||||
query q() {
|
||||
match {
|
||||
$p: Person
|
||||
not { $p worksAt $_ }
|
||||
}
|
||||
return { $p.name }
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
|
||||
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
|
||||
|
||||
assert_eq!(ir.pipeline.len(), 2); // NodeScan + AntiJoin
|
||||
assert!(matches!(&ir.pipeline[1], IROp::AntiJoin { .. }));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_mutation_update() {
|
||||
let catalog = setup();
|
||||
let qf = parse_query(
|
||||
r#"
|
||||
query q($name: String, $age: I32) {
|
||||
update Person set { age: $age } where name = $name
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let checked = typecheck_query_decl(&catalog, &qf.queries[0]).unwrap();
|
||||
assert!(matches!(checked, CheckedQuery::Mutation(_)));
|
||||
|
||||
let ir = lower_mutation_query(&qf.queries[0]).unwrap();
|
||||
match ir.op {
|
||||
MutationOpIR::Update {
|
||||
type_name,
|
||||
assignments,
|
||||
predicate,
|
||||
} => {
|
||||
assert_eq!(type_name, "Person");
|
||||
assert_eq!(assignments.len(), 1);
|
||||
assert_eq!(assignments[0].property, "age");
|
||||
assert_eq!(predicate.property, "name");
|
||||
}
|
||||
_ => panic!("expected update mutation op"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_bounded_traversal() {
|
||||
let catalog = setup();
|
||||
let qf = parse_query(
|
||||
r#"
|
||||
query q() {
|
||||
match {
|
||||
$p: Person
|
||||
$p knows{1,3} $f
|
||||
}
|
||||
return { $f.name }
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
|
||||
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
|
||||
let expand = ir
|
||||
.pipeline
|
||||
.iter()
|
||||
.find_map(|op| match op {
|
||||
IROp::Expand {
|
||||
min_hops, max_hops, ..
|
||||
} => Some((*min_hops, *max_hops)),
|
||||
_ => None,
|
||||
})
|
||||
.expect("expected expand op");
|
||||
assert_eq!(expand.0, 1);
|
||||
assert_eq!(expand.1, Some(3));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_now_uses_reserved_runtime_param() {
|
||||
let catalog = setup();
|
||||
let qf = parse_query(
|
||||
r#"
|
||||
query stamp() {
|
||||
match { $p: Person }
|
||||
return { now() as ts }
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let tc = typecheck_query(&catalog, &qf.queries[0]).unwrap();
|
||||
let ir = lower_query(&catalog, &qf.queries[0], &tc).unwrap();
|
||||
|
||||
assert!(matches!(
|
||||
ir.return_exprs[0].expr,
|
||||
IRExpr::Param(ref name) if name == NOW_PARAM_NAME
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn test_lower_mutation_now_uses_reserved_runtime_param() {
|
||||
let catalog = build_catalog(
|
||||
&parse_schema(
|
||||
r#"
|
||||
node Event {
|
||||
slug: String @key
|
||||
updated_at: DateTime?
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap(),
|
||||
)
|
||||
.unwrap();
|
||||
let qf = parse_query(
|
||||
r#"
|
||||
query stamp() {
|
||||
update Event set { updated_at: now() } where updated_at = now()
|
||||
}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let checked = typecheck_query_decl(&catalog, &qf.queries[0]).unwrap();
|
||||
assert!(matches!(checked, CheckedQuery::Mutation(_)));
|
||||
|
||||
let ir = lower_mutation_query(&qf.queries[0]).unwrap();
|
||||
match ir.op {
|
||||
MutationOpIR::Update {
|
||||
assignments,
|
||||
predicate,
|
||||
..
|
||||
} => {
|
||||
assert!(matches!(
|
||||
assignments[0].value,
|
||||
IRExpr::Param(ref name) if name == NOW_PARAM_NAME
|
||||
));
|
||||
assert!(matches!(
|
||||
predicate.value,
|
||||
IRExpr::Param(ref name) if name == NOW_PARAM_NAME
|
||||
));
|
||||
}
|
||||
_ => panic!("expected update mutation op"),
|
||||
}
|
||||
}
|
||||
}
|
||||
143
crates/omnigraph-compiler/src/ir/mod.rs
Normal file
143
crates/omnigraph-compiler/src/ir/mod.rs
Normal file
|
|
@ -0,0 +1,143 @@
|
|||
pub(crate) mod lower;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::query::ast::{AggFunc, CompOp, Literal, Param};
|
||||
use crate::types::Direction;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryIR {
|
||||
pub name: String,
|
||||
pub params: Vec<Param>,
|
||||
pub pipeline: Vec<IROp>,
|
||||
pub return_exprs: Vec<IRProjection>,
|
||||
pub order_by: Vec<IROrdering>,
|
||||
pub limit: Option<u64>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MutationIR {
|
||||
pub name: String,
|
||||
pub params: Vec<Param>,
|
||||
pub op: MutationOpIR,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MutationOpIR {
|
||||
Insert {
|
||||
type_name: String,
|
||||
assignments: Vec<IRAssignment>,
|
||||
},
|
||||
Update {
|
||||
type_name: String,
|
||||
assignments: Vec<IRAssignment>,
|
||||
predicate: IRMutationPredicate,
|
||||
},
|
||||
Delete {
|
||||
type_name: String,
|
||||
predicate: IRMutationPredicate,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IRAssignment {
|
||||
pub property: String,
|
||||
pub value: IRExpr,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IRMutationPredicate {
|
||||
pub property: String,
|
||||
pub op: CompOp,
|
||||
pub value: IRExpr,
|
||||
}
|
||||
|
||||
/// Resolved runtime parameters: param name → literal value.
|
||||
pub type ParamMap = HashMap<String, Literal>;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum IROp {
|
||||
NodeScan {
|
||||
variable: String,
|
||||
type_name: String,
|
||||
filters: Vec<IRFilter>,
|
||||
},
|
||||
Expand {
|
||||
src_var: String,
|
||||
dst_var: String,
|
||||
edge_type: String,
|
||||
direction: Direction,
|
||||
dst_type: String,
|
||||
min_hops: u32,
|
||||
max_hops: Option<u32>,
|
||||
},
|
||||
Filter(IRFilter),
|
||||
AntiJoin {
|
||||
/// The outer variable whose id is used for the join key
|
||||
outer_var: String,
|
||||
/// The inner pipeline that produces rows to anti-join against
|
||||
inner: Vec<IROp>,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IRFilter {
|
||||
pub left: IRExpr,
|
||||
pub op: CompOp,
|
||||
pub right: IRExpr,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum IRExpr {
|
||||
PropAccess {
|
||||
variable: String,
|
||||
property: String,
|
||||
},
|
||||
Nearest {
|
||||
variable: String,
|
||||
property: String,
|
||||
query: Box<IRExpr>,
|
||||
},
|
||||
Search {
|
||||
field: Box<IRExpr>,
|
||||
query: Box<IRExpr>,
|
||||
},
|
||||
Fuzzy {
|
||||
field: Box<IRExpr>,
|
||||
query: Box<IRExpr>,
|
||||
max_edits: Option<Box<IRExpr>>,
|
||||
},
|
||||
MatchText {
|
||||
field: Box<IRExpr>,
|
||||
query: Box<IRExpr>,
|
||||
},
|
||||
Bm25 {
|
||||
field: Box<IRExpr>,
|
||||
query: Box<IRExpr>,
|
||||
},
|
||||
Rrf {
|
||||
primary: Box<IRExpr>,
|
||||
secondary: Box<IRExpr>,
|
||||
k: Option<Box<IRExpr>>,
|
||||
},
|
||||
Variable(String),
|
||||
Param(String),
|
||||
Literal(Literal),
|
||||
Aggregate {
|
||||
func: AggFunc,
|
||||
arg: Box<IRExpr>,
|
||||
},
|
||||
AliasRef(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IRProjection {
|
||||
pub expr: IRExpr,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct IROrdering {
|
||||
pub expr: IRExpr,
|
||||
pub descending: bool,
|
||||
}
|
||||
352
crates/omnigraph-compiler/src/json_output.rs
Normal file
352
crates/omnigraph-compiler/src/json_output.rs
Normal file
|
|
@ -0,0 +1,352 @@
|
|||
use arrow_array::{
|
||||
Array, ArrayRef, BooleanArray, Date32Array, Date64Array, FixedSizeListArray, Float32Array,
|
||||
Float64Array, Int32Array, Int64Array, ListArray, RecordBatch, StringArray, StructArray,
|
||||
UInt32Array, UInt64Array,
|
||||
};
|
||||
use arrow_schema::DataType;
|
||||
|
||||
pub const JS_MAX_SAFE_INTEGER_I64: i64 = 9_007_199_254_740_991;
|
||||
pub const JS_MAX_SAFE_INTEGER_U64: u64 = 9_007_199_254_740_991;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
enum JsonIntegerMode {
|
||||
JavaScript,
|
||||
Native,
|
||||
}
|
||||
|
||||
pub fn is_js_safe_integer_i64(value: i64) -> bool {
|
||||
(-JS_MAX_SAFE_INTEGER_I64..=JS_MAX_SAFE_INTEGER_I64).contains(&value)
|
||||
}
|
||||
|
||||
/// Convert Arrow RecordBatches into a Vec of JSON objects (one per row).
|
||||
pub fn record_batches_to_json_rows(results: &[RecordBatch]) -> Vec<serde_json::Value> {
|
||||
record_batches_to_json_rows_with_mode(results, JsonIntegerMode::JavaScript)
|
||||
}
|
||||
|
||||
/// Convert Arrow RecordBatches into JSON rows without JS-safe integer coercion.
|
||||
pub fn record_batches_to_rust_json_rows(results: &[RecordBatch]) -> Vec<serde_json::Value> {
|
||||
record_batches_to_json_rows_with_mode(results, JsonIntegerMode::Native)
|
||||
}
|
||||
|
||||
fn record_batches_to_json_rows_with_mode(
|
||||
results: &[RecordBatch],
|
||||
integer_mode: JsonIntegerMode,
|
||||
) -> Vec<serde_json::Value> {
|
||||
let total_rows = results.iter().map(RecordBatch::num_rows).sum();
|
||||
let mut out = Vec::with_capacity(total_rows);
|
||||
for batch in results {
|
||||
let schema = batch.schema();
|
||||
for row in 0..batch.num_rows() {
|
||||
let mut map = serde_json::Map::new();
|
||||
for (col_idx, field) in schema.fields().iter().enumerate() {
|
||||
let col_arr = batch.column(col_idx);
|
||||
map.insert(
|
||||
field.name().clone(),
|
||||
array_value_to_json_with_mode(col_arr, row, integer_mode),
|
||||
);
|
||||
}
|
||||
out.push(serde_json::Value::Object(map));
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
/// Convert a single cell from an Arrow array to a serde_json::Value.
|
||||
pub fn array_value_to_json(array: &ArrayRef, row: usize) -> serde_json::Value {
|
||||
array_value_to_json_with_mode(array, row, JsonIntegerMode::JavaScript)
|
||||
}
|
||||
|
||||
fn array_value_to_json_with_mode(
|
||||
array: &ArrayRef,
|
||||
row: usize,
|
||||
integer_mode: JsonIntegerMode,
|
||||
) -> serde_json::Value {
|
||||
if array.is_null(row) {
|
||||
return serde_json::Value::Null;
|
||||
}
|
||||
|
||||
match array.data_type() {
|
||||
DataType::Utf8 => array
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.map(|a| serde_json::Value::String(a.value(row).to_string()))
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Boolean => array
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.map(|a| serde_json::Value::Bool(a.value(row)))
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Int32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.map(|a| serde_json::Value::Number((a.value(row) as i64).into()))
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Int64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.map(|a| {
|
||||
let value = a.value(row);
|
||||
match integer_mode {
|
||||
JsonIntegerMode::JavaScript if !is_js_safe_integer_i64(value) => {
|
||||
serde_json::Value::String(value.to_string())
|
||||
}
|
||||
JsonIntegerMode::JavaScript | JsonIntegerMode::Native => {
|
||||
serde_json::Value::Number(value.into())
|
||||
}
|
||||
}
|
||||
})
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::UInt32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<UInt32Array>()
|
||||
.map(|a| serde_json::Value::Number((a.value(row) as u64).into()))
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::UInt64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.map(|a| {
|
||||
let value = a.value(row);
|
||||
match integer_mode {
|
||||
JsonIntegerMode::JavaScript if value > JS_MAX_SAFE_INTEGER_U64 => {
|
||||
serde_json::Value::String(value.to_string())
|
||||
}
|
||||
JsonIntegerMode::JavaScript | JsonIntegerMode::Native => {
|
||||
serde_json::Value::Number(value.into())
|
||||
}
|
||||
}
|
||||
})
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Float32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Float32Array>()
|
||||
.map(|a| json_float_value(a.value(row) as f64))
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Float64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.map(|a| json_float_value(a.value(row)))
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Date32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Date32Array>()
|
||||
.map(|a| {
|
||||
let days = a.value(row);
|
||||
arrow_array::temporal_conversions::date32_to_datetime(days)
|
||||
.map(|dt| serde_json::Value::String(dt.format("%Y-%m-%d").to_string()))
|
||||
.unwrap_or_else(|| serde_json::Value::Number((days as i64).into()))
|
||||
})
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Date64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Date64Array>()
|
||||
.map(|a| {
|
||||
let ms = a.value(row);
|
||||
arrow_array::temporal_conversions::date64_to_datetime(ms)
|
||||
.map(|dt| {
|
||||
serde_json::Value::String(dt.format("%Y-%m-%dT%H:%M:%S%.3fZ").to_string())
|
||||
})
|
||||
.unwrap_or_else(|| serde_json::Value::Number(ms.into()))
|
||||
})
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::List(_) => array
|
||||
.as_any()
|
||||
.downcast_ref::<ListArray>()
|
||||
.map(|a| {
|
||||
let values = a.value(row);
|
||||
serde_json::Value::Array(
|
||||
(0..values.len())
|
||||
.map(|idx| array_value_to_json_with_mode(&values, idx, integer_mode))
|
||||
.collect(),
|
||||
)
|
||||
})
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::FixedSizeList(_, _) => array
|
||||
.as_any()
|
||||
.downcast_ref::<FixedSizeListArray>()
|
||||
.map(|a| fixed_size_list_value_to_json(a, row, integer_mode))
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
DataType::Struct(_) => array
|
||||
.as_any()
|
||||
.downcast_ref::<StructArray>()
|
||||
.map(|struct_arr| {
|
||||
let mut obj = serde_json::Map::new();
|
||||
for (i, field) in struct_arr.fields().iter().enumerate() {
|
||||
let col = struct_arr.column(i);
|
||||
obj.insert(
|
||||
field.name().clone(),
|
||||
array_value_to_json_with_mode(col, row, integer_mode),
|
||||
);
|
||||
}
|
||||
serde_json::Value::Object(obj)
|
||||
})
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
_ => {
|
||||
let display =
|
||||
arrow_cast::display::array_value_to_string(array, row).unwrap_or_default();
|
||||
serde_json::Value::String(display)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn json_float_value(value: f64) -> serde_json::Value {
|
||||
if value.is_nan() {
|
||||
return serde_json::Value::String("NaN".to_string());
|
||||
}
|
||||
if value == f64::INFINITY {
|
||||
return serde_json::Value::String("Infinity".to_string());
|
||||
}
|
||||
if value == f64::NEG_INFINITY {
|
||||
return serde_json::Value::String("-Infinity".to_string());
|
||||
}
|
||||
|
||||
serde_json::Number::from_f64(value)
|
||||
.map(serde_json::Value::Number)
|
||||
.unwrap_or(serde_json::Value::Null)
|
||||
}
|
||||
|
||||
fn fixed_size_list_value_to_json(
|
||||
array: &FixedSizeListArray,
|
||||
row: usize,
|
||||
integer_mode: JsonIntegerMode,
|
||||
) -> serde_json::Value {
|
||||
let value_len = array.value_length() as usize;
|
||||
let values = array.values();
|
||||
if let Some(float_values) = values.as_any().downcast_ref::<Float32Array>() {
|
||||
let start = row.saturating_mul(value_len);
|
||||
return float32_json_array(float_values, start, value_len);
|
||||
}
|
||||
|
||||
let values = array.value(row);
|
||||
serde_json::Value::Array(
|
||||
(0..values.len())
|
||||
.map(|idx| array_value_to_json_with_mode(&values, idx, integer_mode))
|
||||
.collect(),
|
||||
)
|
||||
}
|
||||
|
||||
fn float32_json_array(values: &Float32Array, start: usize, len: usize) -> serde_json::Value {
|
||||
let mut out = Vec::with_capacity(len);
|
||||
let end = start.saturating_add(len).min(values.len());
|
||||
for idx in start..end {
|
||||
if values.is_null(idx) {
|
||||
out.push(serde_json::Value::Null);
|
||||
continue;
|
||||
}
|
||||
let value = values.value(idx) as f64;
|
||||
out.push(
|
||||
serde_json::Number::from_f64(value)
|
||||
.map(serde_json::Value::Number)
|
||||
.unwrap_or(serde_json::Value::Null),
|
||||
);
|
||||
}
|
||||
serde_json::Value::Array(out)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{array_value_to_json, record_batches_to_rust_json_rows};
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::builder::{FixedSizeListBuilder, Float32Builder};
|
||||
use arrow_array::{ArrayRef, Float64Array, Int64Array, RecordBatch, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
#[test]
|
||||
fn int64_outside_js_safe_range_is_stringified() {
|
||||
let values: ArrayRef = Arc::new(Int64Array::from(vec![Some(9_007_199_254_740_992)]));
|
||||
assert_eq!(
|
||||
array_value_to_json(&values, 0),
|
||||
serde_json::Value::String("9007199254740992".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint64_outside_js_safe_range_is_stringified() {
|
||||
let values: ArrayRef = Arc::new(UInt64Array::from(vec![Some(9_007_199_254_740_992)]));
|
||||
assert_eq!(
|
||||
array_value_to_json(&values, 0),
|
||||
serde_json::Value::String("9007199254740992".to_string())
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn uint64_within_js_safe_range_stays_numeric() {
|
||||
let values: ArrayRef = Arc::new(UInt64Array::from(vec![Some(9_007_199_254_740_991)]));
|
||||
assert_eq!(
|
||||
array_value_to_json(&values, 0),
|
||||
serde_json::json!(9_007_199_254_740_991u64)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rust_json_rows_preserve_full_width_integers() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("signed", DataType::Int64, false),
|
||||
Field::new("unsigned", DataType::UInt64, false),
|
||||
]));
|
||||
let batch = RecordBatch::try_new(
|
||||
schema,
|
||||
vec![
|
||||
Arc::new(Int64Array::from(vec![i64::MIN])),
|
||||
Arc::new(UInt64Array::from(vec![u64::MAX])),
|
||||
],
|
||||
)
|
||||
.expect("batch");
|
||||
|
||||
assert_eq!(
|
||||
record_batches_to_rust_json_rows(&[batch]),
|
||||
vec![serde_json::json!({
|
||||
"signed": i64::MIN,
|
||||
"unsigned": u64::MAX,
|
||||
})]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn fixed_size_float32_vectors_serialize_without_recursive_dispatch() {
|
||||
let mut builder = FixedSizeListBuilder::new(Float32Builder::new(), 3);
|
||||
builder.values().append_value(0.25);
|
||||
builder.values().append_value(0.5);
|
||||
builder.values().append_value(0.75);
|
||||
builder.append(true);
|
||||
|
||||
for _ in 0..3 {
|
||||
builder.values().append_null();
|
||||
}
|
||||
builder.append(false);
|
||||
|
||||
builder.values().append_value(1.0);
|
||||
builder.values().append_value(2.0);
|
||||
builder.values().append_value(3.0);
|
||||
builder.append(true);
|
||||
|
||||
let values: ArrayRef = Arc::new(builder.finish());
|
||||
assert_eq!(
|
||||
array_value_to_json(&values, 0),
|
||||
serde_json::json!([0.25, 0.5, 0.75])
|
||||
);
|
||||
assert_eq!(array_value_to_json(&values, 1), serde_json::Value::Null);
|
||||
assert_eq!(
|
||||
array_value_to_json(&values, 2),
|
||||
serde_json::json!([1.0, 2.0, 3.0])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn non_finite_floats_are_stringified() {
|
||||
let values: ArrayRef = Arc::new(Float64Array::from(vec![
|
||||
Some(f64::NAN),
|
||||
Some(f64::INFINITY),
|
||||
Some(f64::NEG_INFINITY),
|
||||
]));
|
||||
assert_eq!(array_value_to_json(&values, 0), serde_json::json!("NaN"));
|
||||
assert_eq!(
|
||||
array_value_to_json(&values, 1),
|
||||
serde_json::json!("Infinity")
|
||||
);
|
||||
assert_eq!(
|
||||
array_value_to_json(&values, 2),
|
||||
serde_json::json!("-Infinity")
|
||||
);
|
||||
}
|
||||
}
|
||||
28
crates/omnigraph-compiler/src/lib.rs
Normal file
28
crates/omnigraph-compiler/src/lib.rs
Normal file
|
|
@ -0,0 +1,28 @@
|
|||
pub mod catalog;
|
||||
pub mod embedding;
|
||||
pub mod error;
|
||||
pub mod ir;
|
||||
pub mod json_output;
|
||||
pub mod query;
|
||||
pub mod query_input;
|
||||
pub mod result;
|
||||
pub mod schema;
|
||||
pub mod types;
|
||||
|
||||
pub use catalog::build_catalog;
|
||||
pub use catalog::schema_ir::{
|
||||
SchemaIR, build_catalog_from_ir, build_schema_ir, schema_ir_hash, schema_ir_json,
|
||||
schema_ir_pretty_json,
|
||||
};
|
||||
pub use catalog::schema_plan::{
|
||||
SchemaMigrationPlan, SchemaMigrationStep, SchemaTypeKind, plan_schema_migration,
|
||||
};
|
||||
pub use ir::ParamMap;
|
||||
pub use ir::lower::{lower_mutation_query, lower_query};
|
||||
pub use query::ast::Literal;
|
||||
pub use query_input::{
|
||||
JsonParamMode, RunInputError, RunInputResult, ToParam, find_named_query,
|
||||
json_params_to_param_map,
|
||||
};
|
||||
pub use result::{MutationExecResult, MutationResult, QueryResult, RunResult};
|
||||
pub use types::{Direction, PropType, ScalarType};
|
||||
221
crates/omnigraph-compiler/src/query/ast.rs
Normal file
221
crates/omnigraph-compiler/src/query/ast.rs
Normal file
|
|
@ -0,0 +1,221 @@
|
|||
pub const NOW_PARAM_NAME: &str = "__nanograph_now";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryFile {
|
||||
pub queries: Vec<QueryDecl>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryDecl {
|
||||
pub name: String,
|
||||
pub description: Option<String>,
|
||||
pub instruction: Option<String>,
|
||||
pub params: Vec<Param>,
|
||||
pub match_clause: Vec<Clause>,
|
||||
pub return_clause: Vec<Projection>,
|
||||
pub order_clause: Vec<Ordering>,
|
||||
pub limit: Option<u64>,
|
||||
pub mutation: Option<Mutation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Param {
|
||||
pub name: String,
|
||||
pub type_name: String,
|
||||
pub nullable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Clause {
|
||||
Binding(Binding),
|
||||
Traversal(Traversal),
|
||||
Filter(Filter),
|
||||
Negation(Vec<Clause>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Binding {
|
||||
pub variable: String,
|
||||
pub type_name: String,
|
||||
pub prop_matches: Vec<PropMatch>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PropMatch {
|
||||
pub prop_name: String,
|
||||
pub value: MatchValue,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum MatchValue {
|
||||
Literal(Literal),
|
||||
Variable(String),
|
||||
Now,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Traversal {
|
||||
pub src: String,
|
||||
pub edge_name: String,
|
||||
pub dst: String,
|
||||
pub min_hops: u32,
|
||||
pub max_hops: Option<u32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Filter {
|
||||
pub left: Expr,
|
||||
pub op: CompOp,
|
||||
pub right: Expr,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum CompOp {
|
||||
Eq,
|
||||
Ne,
|
||||
Gt,
|
||||
Lt,
|
||||
Ge,
|
||||
Le,
|
||||
Contains,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for CompOp {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Eq => write!(f, "="),
|
||||
Self::Ne => write!(f, "!="),
|
||||
Self::Gt => write!(f, ">"),
|
||||
Self::Lt => write!(f, "<"),
|
||||
Self::Ge => write!(f, ">="),
|
||||
Self::Le => write!(f, "<="),
|
||||
Self::Contains => write!(f, "contains"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Expr {
|
||||
Now,
|
||||
PropAccess {
|
||||
variable: String,
|
||||
property: String,
|
||||
},
|
||||
Nearest {
|
||||
variable: String,
|
||||
property: String,
|
||||
query: Box<Expr>,
|
||||
},
|
||||
Search {
|
||||
field: Box<Expr>,
|
||||
query: Box<Expr>,
|
||||
},
|
||||
Fuzzy {
|
||||
field: Box<Expr>,
|
||||
query: Box<Expr>,
|
||||
max_edits: Option<Box<Expr>>,
|
||||
},
|
||||
MatchText {
|
||||
field: Box<Expr>,
|
||||
query: Box<Expr>,
|
||||
},
|
||||
Bm25 {
|
||||
field: Box<Expr>,
|
||||
query: Box<Expr>,
|
||||
},
|
||||
Rrf {
|
||||
primary: Box<Expr>,
|
||||
secondary: Box<Expr>,
|
||||
k: Option<Box<Expr>>,
|
||||
},
|
||||
Variable(String),
|
||||
Literal(Literal),
|
||||
Aggregate {
|
||||
func: AggFunc,
|
||||
arg: Box<Expr>,
|
||||
},
|
||||
AliasRef(String),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum AggFunc {
|
||||
Count,
|
||||
Sum,
|
||||
Avg,
|
||||
Min,
|
||||
Max,
|
||||
}
|
||||
|
||||
impl std::fmt::Display for AggFunc {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Count => write!(f, "count"),
|
||||
Self::Sum => write!(f, "sum"),
|
||||
Self::Avg => write!(f, "avg"),
|
||||
Self::Min => write!(f, "min"),
|
||||
Self::Max => write!(f, "max"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Literal {
|
||||
String(String),
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
Bool(bool),
|
||||
Date(String),
|
||||
DateTime(String),
|
||||
List(Vec<Literal>),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Projection {
|
||||
pub expr: Expr,
|
||||
pub alias: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Ordering {
|
||||
pub expr: Expr,
|
||||
pub descending: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum Mutation {
|
||||
Insert(InsertMutation),
|
||||
Update(UpdateMutation),
|
||||
Delete(DeleteMutation),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct InsertMutation {
|
||||
pub type_name: String,
|
||||
pub assignments: Vec<MutationAssignment>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct UpdateMutation {
|
||||
pub type_name: String,
|
||||
pub assignments: Vec<MutationAssignment>,
|
||||
pub predicate: MutationPredicate,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct DeleteMutation {
|
||||
pub type_name: String,
|
||||
pub predicate: MutationPredicate,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MutationAssignment {
|
||||
pub property: String,
|
||||
pub value: MatchValue,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MutationPredicate {
|
||||
pub property: String,
|
||||
pub op: CompOp,
|
||||
pub value: MatchValue,
|
||||
}
|
||||
3
crates/omnigraph-compiler/src/query/mod.rs
Normal file
3
crates/omnigraph-compiler/src/query/mod.rs
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
pub mod ast;
|
||||
pub mod parser;
|
||||
pub mod typecheck;
|
||||
1689
crates/omnigraph-compiler/src/query/parser.rs
Normal file
1689
crates/omnigraph-compiler/src/query/parser.rs
Normal file
File diff suppressed because it is too large
Load diff
114
crates/omnigraph-compiler/src/query/query.pest
Normal file
114
crates/omnigraph-compiler/src/query/query.pest
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
// NanoGraph Query Grammar (.gq files)
|
||||
|
||||
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }
|
||||
COMMENT = _{ LINE_COMMENT | BLOCK_COMMENT }
|
||||
LINE_COMMENT = _{ "//" ~ (!"\n" ~ ANY)* }
|
||||
BLOCK_COMMENT = _{ "/*" ~ (!"*/" ~ ANY)* ~ "*/" }
|
||||
|
||||
query_file = { SOI ~ query_decl* ~ EOI }
|
||||
|
||||
query_decl = {
|
||||
"query" ~ ident ~ "(" ~ param_list? ~ ")" ~ query_annotation* ~ "{"
|
||||
~ query_body
|
||||
~ "}"
|
||||
}
|
||||
query_annotation = { description_annotation | instruction_annotation }
|
||||
description_annotation = { "@description" ~ "(" ~ string_lit ~ ")" }
|
||||
instruction_annotation = { "@instruction" ~ "(" ~ string_lit ~ ")" }
|
||||
|
||||
query_body = { read_query_body | mutation_stmt }
|
||||
read_query_body = {
|
||||
match_clause
|
||||
~ return_clause
|
||||
~ order_clause?
|
||||
~ limit_clause?
|
||||
}
|
||||
|
||||
mutation_stmt = { insert_stmt | update_stmt | delete_stmt }
|
||||
insert_stmt = { "insert" ~ type_name ~ "{" ~ mutation_assignment+ ~ "}" }
|
||||
update_stmt = { "update" ~ type_name ~ "set" ~ "{" ~ mutation_assignment+ ~ "}" ~ "where" ~ mutation_predicate }
|
||||
delete_stmt = { "delete" ~ type_name ~ "where" ~ mutation_predicate }
|
||||
mutation_assignment = { ident ~ ":" ~ match_value ~ ","? }
|
||||
mutation_predicate = { ident ~ comp_op ~ match_value }
|
||||
|
||||
param_list = { param ~ ("," ~ param)* }
|
||||
param = { variable ~ ":" ~ type_ref }
|
||||
|
||||
type_ref = { (list_type | base_type | vector_type) ~ "?"? }
|
||||
list_type = { "[" ~ base_type ~ "]" }
|
||||
vector_type = { "Vector" ~ "(" ~ integer ~ ")" }
|
||||
base_type = { "String" | "Blob" | "Bool" | "I32" | "I64" | "U32" | "U64" | "F32" | "F64" | "DateTime" | "Date" }
|
||||
|
||||
match_clause = { "match" ~ "{" ~ clause+ ~ "}" }
|
||||
|
||||
clause = { negation | binding | traversal | filter | text_search_clause }
|
||||
text_search_clause = { search_call | fuzzy_call | match_text_call }
|
||||
|
||||
// Binding: $p: Person { name: "Alice" }
|
||||
binding = { variable ~ ":" ~ type_name ~ ("{" ~ prop_match_list ~ "}")? }
|
||||
|
||||
prop_match_list = { prop_match ~ ("," ~ prop_match)* ~ ","? }
|
||||
prop_match = { ident ~ ":" ~ match_value }
|
||||
match_value = { literal | variable | now_call }
|
||||
|
||||
// Traversal: $p knows $f
|
||||
traversal = { variable ~ edge_ident ~ traversal_bounds? ~ variable }
|
||||
traversal_bounds = { "{" ~ integer ~ "," ~ integer? ~ "}" }
|
||||
|
||||
// Filter: $f.age > 25
|
||||
filter = { expr ~ filter_op ~ expr }
|
||||
|
||||
// Negation: not { ... }
|
||||
negation = { "not" ~ "{" ~ clause+ ~ "}" }
|
||||
|
||||
// Return clause — projections separated by commas or newlines
|
||||
return_clause = { "return" ~ "{" ~ projection+ ~ "}" }
|
||||
projection = { expr ~ ("as" ~ ident)? ~ ","? }
|
||||
|
||||
// Order clause
|
||||
order_clause = { "order" ~ "{" ~ ordering ~ ("," ~ ordering)* ~ "}" }
|
||||
ordering = { nearest_ordering | (expr ~ order_dir?) }
|
||||
nearest_ordering = { "nearest" ~ "(" ~ prop_access ~ "," ~ expr ~ ")" }
|
||||
order_dir = { "asc" | "desc" }
|
||||
|
||||
// Limit clause
|
||||
limit_clause = { "limit" ~ integer }
|
||||
|
||||
// Expressions
|
||||
expr = { now_call | nearest_ordering | search_call | fuzzy_call | match_text_call | bm25_call | rrf_call | agg_call | prop_access | variable | literal | ident }
|
||||
now_call = { "now" ~ "(" ~ ")" }
|
||||
search_call = { "search" ~ "(" ~ expr ~ "," ~ expr ~ ")" }
|
||||
fuzzy_call = { "fuzzy" ~ "(" ~ expr ~ "," ~ expr ~ ("," ~ expr)? ~ ")" }
|
||||
match_text_call = { "match_text" ~ "(" ~ expr ~ "," ~ expr ~ ")" }
|
||||
bm25_call = { "bm25" ~ "(" ~ expr ~ "," ~ expr ~ ")" }
|
||||
rank_expr = { nearest_ordering | bm25_call }
|
||||
rrf_call = { "rrf" ~ "(" ~ rank_expr ~ "," ~ rank_expr ~ ("," ~ expr)? ~ ")" }
|
||||
|
||||
prop_access = { variable ~ "." ~ ident }
|
||||
|
||||
agg_call = { agg_func ~ "(" ~ expr ~ ")" }
|
||||
agg_func = { "count" | "sum" | "avg" | "min" | "max" }
|
||||
|
||||
comp_op = { ">=" | "<=" | "!=" | ">" | "<" | "=" }
|
||||
filter_op = { "contains" | comp_op }
|
||||
|
||||
// Terminals
|
||||
variable = @{ "$" ~ (ident_chars | "_") }
|
||||
ident_chars = @{ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
|
||||
|
||||
// Edge identifier — lowercase start, same as ident but used in traversal context
|
||||
// Must not match keywords
|
||||
edge_ident = @{ !("not" ~ !ASCII_ALPHANUMERIC) ~ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
|
||||
|
||||
type_name = @{ ASCII_ALPHA_UPPER ~ (ASCII_ALPHANUMERIC | "_")* }
|
||||
ident = @{ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
|
||||
|
||||
literal = { list_lit | datetime_lit | date_lit | string_lit | float_lit | integer | bool_lit }
|
||||
date_lit = { "date" ~ "(" ~ string_lit ~ ")" }
|
||||
datetime_lit = { "datetime" ~ "(" ~ string_lit ~ ")" }
|
||||
list_lit = { "[" ~ (literal ~ ("," ~ literal)*)? ~ "]" }
|
||||
string_lit = @{ "\"" ~ string_char* ~ "\"" }
|
||||
string_char = @{ !("\"" | "\\") ~ ANY | "\\" ~ ANY }
|
||||
float_lit = @{ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ }
|
||||
integer = @{ ASCII_DIGIT+ }
|
||||
bool_lit = { "true" | "false" }
|
||||
2776
crates/omnigraph-compiler/src/query/typecheck.rs
Normal file
2776
crates/omnigraph-compiler/src/query/typecheck.rs
Normal file
File diff suppressed because it is too large
Load diff
892
crates/omnigraph-compiler/src/query_input.rs
Normal file
892
crates/omnigraph-compiler/src/query_input.rs
Normal file
|
|
@ -0,0 +1,892 @@
|
|||
use std::error::Error;
|
||||
use std::fmt;
|
||||
|
||||
use serde_json::Value;
|
||||
|
||||
use crate::error::NanoError;
|
||||
use crate::ir::ParamMap;
|
||||
use crate::json_output::{JS_MAX_SAFE_INTEGER_U64, is_js_safe_integer_i64};
|
||||
use crate::query::ast::{Literal, Param, QueryDecl};
|
||||
use crate::query::parser::parse_query;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum JsonParamMode {
|
||||
Standard,
|
||||
JavaScript,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub enum RunInputError {
|
||||
Core(NanoError),
|
||||
Message(String),
|
||||
}
|
||||
|
||||
impl RunInputError {
|
||||
fn message(message: impl Into<String>) -> Self {
|
||||
Self::Message(message.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RunInputError {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
match self {
|
||||
Self::Core(err) => err.fmt(f),
|
||||
Self::Message(message) => f.write_str(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Error for RunInputError {
|
||||
fn source(&self) -> Option<&(dyn Error + 'static)> {
|
||||
match self {
|
||||
Self::Core(err) => Some(err),
|
||||
Self::Message(_) => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<NanoError> for RunInputError {
|
||||
fn from(value: NanoError) -> Self {
|
||||
Self::Core(value)
|
||||
}
|
||||
}
|
||||
|
||||
pub type RunInputResult<T> = std::result::Result<T, RunInputError>;
|
||||
|
||||
pub trait ToParam {
|
||||
fn to_param(self) -> crate::error::Result<Literal>;
|
||||
}
|
||||
|
||||
impl ToParam for Literal {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(self)
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for &Literal {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(self.clone())
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for String {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::String(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for &String {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::String(self.clone()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for &str {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::String(self.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for bool {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Bool(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for i8 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Integer(i64::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for i16 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Integer(i64::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for i32 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Integer(i64::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for i64 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Integer(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for isize {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
let value = i64::try_from(self).map_err(|_| {
|
||||
NanoError::Execution(format!(
|
||||
"param value {} exceeds current engine range for numeric literals (max {})",
|
||||
self,
|
||||
i64::MAX
|
||||
))
|
||||
})?;
|
||||
Ok(Literal::Integer(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for u8 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Integer(i64::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for u16 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Integer(i64::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for u32 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
Ok(Literal::Integer(i64::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for u64 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
let value = i64::try_from(self).map_err(|_| {
|
||||
NanoError::Execution(format!(
|
||||
"param value {} exceeds current engine range for numeric literals (max {})",
|
||||
self,
|
||||
i64::MAX
|
||||
))
|
||||
})?;
|
||||
Ok(Literal::Integer(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for usize {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
let value = i64::try_from(self).map_err(|_| {
|
||||
NanoError::Execution(format!(
|
||||
"param value {} exceeds current engine range for numeric literals (max {})",
|
||||
self,
|
||||
i64::MAX
|
||||
))
|
||||
})?;
|
||||
Ok(Literal::Integer(value))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for f32 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
if !self.is_finite() {
|
||||
return Err(NanoError::Execution(format!(
|
||||
"invalid float parameter {}",
|
||||
self
|
||||
)));
|
||||
}
|
||||
Ok(Literal::Float(f64::from(self)))
|
||||
}
|
||||
}
|
||||
|
||||
impl ToParam for f64 {
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
if !self.is_finite() {
|
||||
return Err(NanoError::Execution(format!(
|
||||
"invalid float parameter {}",
|
||||
self
|
||||
)));
|
||||
}
|
||||
Ok(Literal::Float(self))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ToParam for Vec<T>
|
||||
where
|
||||
T: ToParam,
|
||||
{
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
let mut out = Vec::with_capacity(self.len());
|
||||
for value in self {
|
||||
out.push(value.to_param()?);
|
||||
}
|
||||
Ok(Literal::List(out))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> ToParam for &[T]
|
||||
where
|
||||
T: Clone + ToParam,
|
||||
{
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
let mut out = Vec::with_capacity(self.len());
|
||||
for value in self {
|
||||
out.push(value.clone().to_param()?);
|
||||
}
|
||||
Ok(Literal::List(out))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T, const N: usize> ToParam for [T; N]
|
||||
where
|
||||
T: ToParam,
|
||||
{
|
||||
fn to_param(self) -> crate::error::Result<Literal> {
|
||||
let mut out = Vec::with_capacity(N);
|
||||
for value in self {
|
||||
out.push(value.to_param()?);
|
||||
}
|
||||
Ok(Literal::List(out))
|
||||
}
|
||||
}
|
||||
|
||||
#[macro_export]
|
||||
macro_rules! params {
|
||||
() => {
|
||||
::std::result::Result::Ok($crate::ParamMap::new())
|
||||
};
|
||||
($($key:expr => $value:expr),+ $(,)?) => {{
|
||||
(|| -> $crate::error::Result<$crate::ParamMap> {
|
||||
let mut map = $crate::ParamMap::new();
|
||||
$(
|
||||
map.insert(::std::convert::Into::<String>::into($key), $crate::ToParam::to_param($value)?);
|
||||
)+
|
||||
Ok(map)
|
||||
})()
|
||||
}};
|
||||
}
|
||||
|
||||
pub fn find_named_query(query_source: &str, query_name: &str) -> RunInputResult<QueryDecl> {
|
||||
let queries = parse_query(query_source)?;
|
||||
queries
|
||||
.queries
|
||||
.into_iter()
|
||||
.find(|query| query.name == query_name)
|
||||
.ok_or_else(|| RunInputError::message(format!("query '{}' not found", query_name)))
|
||||
}
|
||||
|
||||
pub fn json_params_to_param_map(
|
||||
params: Option<&Value>,
|
||||
query_params: &[Param],
|
||||
mode: JsonParamMode,
|
||||
) -> RunInputResult<ParamMap> {
|
||||
let mut map = ParamMap::new();
|
||||
let object = match params {
|
||||
Some(Value::Object(object)) => object,
|
||||
Some(Value::Null) | None => return Ok(map),
|
||||
Some(other) => {
|
||||
let message = match mode {
|
||||
JsonParamMode::Standard => "params must be a JSON object".to_string(),
|
||||
JsonParamMode::JavaScript => {
|
||||
format!("params must be an object, got {}", json_type_name(other))
|
||||
}
|
||||
};
|
||||
return Err(RunInputError::message(message));
|
||||
}
|
||||
};
|
||||
|
||||
for (key, value) in object {
|
||||
let decl = query_params.iter().find(|param| param.name == *key);
|
||||
let literal = if let Some(decl) = decl {
|
||||
json_value_to_literal_typed(key, value, &decl.type_name, mode)?
|
||||
} else {
|
||||
json_value_to_literal_inferred(key, value, mode)?
|
||||
};
|
||||
map.insert(key.clone(), literal);
|
||||
}
|
||||
|
||||
Ok(map)
|
||||
}
|
||||
|
||||
fn json_value_to_literal_typed(
|
||||
key: &str,
|
||||
value: &Value,
|
||||
type_name: &str,
|
||||
mode: JsonParamMode,
|
||||
) -> RunInputResult<Literal> {
|
||||
match type_name {
|
||||
"String" => match value {
|
||||
Value::String(value) => Ok(Literal::String(value.clone())),
|
||||
other => Err(RunInputError::message(format!(
|
||||
"param '{}': expected string, got {}",
|
||||
key,
|
||||
json_type_name(other)
|
||||
))),
|
||||
},
|
||||
"I32" => match mode {
|
||||
JsonParamMode::Standard => {
|
||||
let value = parse_i64_param(key, value, mode)?;
|
||||
let value = i32::try_from(value).map_err(|_| {
|
||||
RunInputError::message(format!("param '{}': value {} exceeds I32", key, value))
|
||||
})?;
|
||||
Ok(Literal::Integer(i64::from(value)))
|
||||
}
|
||||
JsonParamMode::JavaScript => {
|
||||
let value = parse_i64_param(key, value, mode)?;
|
||||
let value = i32::try_from(value).map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': value {} exceeds I32 range",
|
||||
key, value
|
||||
))
|
||||
})?;
|
||||
Ok(Literal::Integer(i64::from(value)))
|
||||
}
|
||||
},
|
||||
"I64" => Ok(Literal::Integer(parse_i64_param(key, value, mode)?)),
|
||||
"U32" => {
|
||||
let value = parse_u64_param(key, value, mode)?;
|
||||
let value = match mode {
|
||||
JsonParamMode::Standard => u32::try_from(value).map_err(|_| {
|
||||
RunInputError::message(format!("param '{}': value {} exceeds U32", key, value))
|
||||
})?,
|
||||
JsonParamMode::JavaScript => u32::try_from(value).map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': value {} exceeds U32 range",
|
||||
key, value
|
||||
))
|
||||
})?,
|
||||
};
|
||||
Ok(Literal::Integer(i64::from(value)))
|
||||
}
|
||||
"U64" => {
|
||||
let value = parse_u64_param(key, value, mode)?;
|
||||
let value = match mode {
|
||||
JsonParamMode::Standard => i64::try_from(value).map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': value {} exceeds current engine range for U64 (max {})",
|
||||
key,
|
||||
value,
|
||||
i64::MAX
|
||||
))
|
||||
})?,
|
||||
JsonParamMode::JavaScript => i64::try_from(value).map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': value {} exceeds current engine range for U64 parameters (max {})",
|
||||
key,
|
||||
value,
|
||||
i64::MAX
|
||||
))
|
||||
})?,
|
||||
};
|
||||
Ok(Literal::Integer(value))
|
||||
}
|
||||
"F32" | "F64" => {
|
||||
let value = value.as_f64().ok_or_else(|| match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': expected float", key))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': expected float, got {}",
|
||||
key,
|
||||
json_type_name(value)
|
||||
)),
|
||||
})?;
|
||||
Ok(Literal::Float(value))
|
||||
}
|
||||
"Bool" => {
|
||||
let value = value.as_bool().ok_or_else(|| match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': expected boolean", key))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': expected boolean, got {}",
|
||||
key,
|
||||
json_type_name(value)
|
||||
)),
|
||||
})?;
|
||||
Ok(Literal::Bool(value))
|
||||
}
|
||||
"Date" => match value {
|
||||
Value::String(value) => Ok(Literal::Date(value.clone())),
|
||||
other => Err(match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': expected date string", key))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': expected date string, got {}",
|
||||
key,
|
||||
json_type_name(other)
|
||||
)),
|
||||
}),
|
||||
},
|
||||
"DateTime" => match value {
|
||||
Value::String(value) => Ok(Literal::DateTime(value.clone())),
|
||||
other => Err(match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': expected datetime string", key))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': expected datetime string, got {}",
|
||||
key,
|
||||
json_type_name(other)
|
||||
)),
|
||||
}),
|
||||
},
|
||||
"Blob" => match value {
|
||||
Value::String(value) => Ok(Literal::String(value.clone())),
|
||||
other => Err(RunInputError::message(format!(
|
||||
"param '{}': expected blob URI string, got {}",
|
||||
key,
|
||||
json_type_name(other)
|
||||
))),
|
||||
},
|
||||
other if parse_list_item_type(other).is_some() => {
|
||||
let item_type = parse_list_item_type(other).unwrap();
|
||||
let items = value.as_array().ok_or_else(|| match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': expected array for {}", key, other))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': expected array for {}, got {}",
|
||||
key,
|
||||
other,
|
||||
json_type_name(value)
|
||||
)),
|
||||
})?;
|
||||
let mut out = Vec::with_capacity(items.len());
|
||||
for item in items {
|
||||
out.push(json_value_to_literal_typed(key, item, item_type, mode)?);
|
||||
}
|
||||
Ok(Literal::List(out))
|
||||
}
|
||||
other if other.starts_with("Vector(") => {
|
||||
let expected_dim = parse_vector_dim(other).ok_or_else(|| match mode {
|
||||
JsonParamMode::Standard => RunInputError::message(format!(
|
||||
"param '{}': invalid vector type '{}'",
|
||||
key, other
|
||||
)),
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': invalid vector type '{}' (expected Vector(N))",
|
||||
key, other
|
||||
)),
|
||||
})?;
|
||||
let items = value.as_array().ok_or_else(|| match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': expected array for {}", key, other))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': expected array for {}, got {}",
|
||||
key,
|
||||
other,
|
||||
json_type_name(value)
|
||||
)),
|
||||
})?;
|
||||
if items.len() != expected_dim {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': expected {} values for {}, got {}",
|
||||
key,
|
||||
expected_dim,
|
||||
other,
|
||||
items.len()
|
||||
)));
|
||||
}
|
||||
let mut out = Vec::with_capacity(items.len());
|
||||
for item in items {
|
||||
let value = item.as_f64().ok_or_else(|| match mode {
|
||||
JsonParamMode::Standard => RunInputError::message(format!(
|
||||
"param '{}': vector element is not numeric",
|
||||
key
|
||||
)),
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': vector element '{}' is not numeric",
|
||||
key, item
|
||||
)),
|
||||
})?;
|
||||
out.push(Literal::Float(value));
|
||||
}
|
||||
Ok(Literal::List(out))
|
||||
}
|
||||
_ => match value {
|
||||
Value::String(value) => Ok(Literal::String(value.clone())),
|
||||
other => Err(RunInputError::message(format!(
|
||||
"param '{}': expected string for type '{}', got {}",
|
||||
key,
|
||||
type_name,
|
||||
json_type_name(other)
|
||||
))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn json_value_to_literal_inferred(
|
||||
key: &str,
|
||||
value: &Value,
|
||||
mode: JsonParamMode,
|
||||
) -> RunInputResult<Literal> {
|
||||
match value {
|
||||
Value::String(value) => Ok(Literal::String(value.clone())),
|
||||
Value::Bool(value) => Ok(Literal::Bool(*value)),
|
||||
Value::Number(number) => match mode {
|
||||
JsonParamMode::Standard => {
|
||||
if let Some(value) = number.as_i64() {
|
||||
Ok(Literal::Integer(value))
|
||||
} else if let Some(value) = number.as_f64() {
|
||||
Ok(Literal::Float(value))
|
||||
} else {
|
||||
Err(RunInputError::message(format!(
|
||||
"param '{}': unsupported numeric value",
|
||||
key
|
||||
)))
|
||||
}
|
||||
}
|
||||
JsonParamMode::JavaScript => {
|
||||
if let Some(value) = number.as_i64() {
|
||||
if !is_js_safe_integer_i64(value) {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': integer {} exceeds JS safe integer range; use a decimal string and a typed query parameter for exact values",
|
||||
key, value
|
||||
)));
|
||||
}
|
||||
Ok(Literal::Integer(value))
|
||||
} else if let Some(value) = number.as_u64() {
|
||||
if value > JS_MAX_SAFE_INTEGER_U64 {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': integer {} exceeds JS safe integer range; use a decimal string and a typed query parameter for exact values",
|
||||
key, value
|
||||
)));
|
||||
}
|
||||
let value = i64::try_from(value).map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': integer {} exceeds supported range (max {})",
|
||||
key,
|
||||
value,
|
||||
i64::MAX
|
||||
))
|
||||
})?;
|
||||
Ok(Literal::Integer(value))
|
||||
} else if let Some(value) = number.as_f64() {
|
||||
Ok(Literal::Float(value))
|
||||
} else {
|
||||
Err(RunInputError::message(format!(
|
||||
"param '{}': unsupported number value",
|
||||
key
|
||||
)))
|
||||
}
|
||||
}
|
||||
},
|
||||
Value::Array(values) => {
|
||||
let mut out = Vec::with_capacity(values.len());
|
||||
for value in values {
|
||||
out.push(json_value_to_literal_inferred(key, value, mode)?);
|
||||
}
|
||||
Ok(Literal::List(out))
|
||||
}
|
||||
Value::Null => Err(match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': null is not supported", key))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': null values are not supported as query parameters",
|
||||
key
|
||||
)),
|
||||
}),
|
||||
Value::Object(_) => Err(match mode {
|
||||
JsonParamMode::Standard => {
|
||||
RunInputError::message(format!("param '{}': object is not supported", key))
|
||||
}
|
||||
JsonParamMode::JavaScript => RunInputError::message(format!(
|
||||
"param '{}': object values are not supported as query parameters",
|
||||
key
|
||||
)),
|
||||
}),
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_i64_param(key: &str, value: &Value, mode: JsonParamMode) -> RunInputResult<i64> {
|
||||
match mode {
|
||||
JsonParamMode::Standard => match value {
|
||||
Value::Number(number) => number.as_i64().ok_or_else(|| {
|
||||
RunInputError::message(format!("param '{}': expected integer number", key))
|
||||
}),
|
||||
Value::String(value) => value.parse::<i64>().map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': expected integer string, got '{}'",
|
||||
key, value
|
||||
))
|
||||
}),
|
||||
_ => Err(RunInputError::message(format!(
|
||||
"param '{}': expected integer",
|
||||
key
|
||||
))),
|
||||
},
|
||||
JsonParamMode::JavaScript => match value {
|
||||
Value::Number(number) => {
|
||||
let parsed = if let Some(parsed) = number.as_i64() {
|
||||
parsed
|
||||
} else if let Some(parsed) = number.as_f64() {
|
||||
if !parsed.is_finite() || parsed.fract() != 0.0 {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': expected integer, got number",
|
||||
key
|
||||
)));
|
||||
}
|
||||
if parsed < i64::MIN as f64 || parsed > i64::MAX as f64 {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': integer {} is outside i64 range",
|
||||
key, parsed
|
||||
)));
|
||||
}
|
||||
parsed as i64
|
||||
} else {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': expected integer, got number",
|
||||
key
|
||||
)));
|
||||
};
|
||||
if !is_js_safe_integer_i64(parsed) {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': integer {} exceeds JS safe integer range; pass a decimal string for exact values",
|
||||
key, parsed
|
||||
)));
|
||||
}
|
||||
Ok(parsed)
|
||||
}
|
||||
Value::String(value) => value.parse::<i64>().map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': expected integer string, got '{}'",
|
||||
key, value
|
||||
))
|
||||
}),
|
||||
other => Err(RunInputError::message(format!(
|
||||
"param '{}': expected integer, got {}",
|
||||
key,
|
||||
json_type_name(other)
|
||||
))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_u64_param(key: &str, value: &Value, mode: JsonParamMode) -> RunInputResult<u64> {
|
||||
match mode {
|
||||
JsonParamMode::Standard => match value {
|
||||
Value::Number(number) => number.as_u64().ok_or_else(|| {
|
||||
RunInputError::message(format!("param '{}': expected unsigned integer number", key))
|
||||
}),
|
||||
Value::String(value) => value.parse::<u64>().map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': expected unsigned integer string, got '{}'",
|
||||
key, value
|
||||
))
|
||||
}),
|
||||
_ => Err(RunInputError::message(format!(
|
||||
"param '{}': expected unsigned integer",
|
||||
key
|
||||
))),
|
||||
},
|
||||
JsonParamMode::JavaScript => match value {
|
||||
Value::Number(number) => {
|
||||
let parsed = if let Some(parsed) = number.as_u64() {
|
||||
parsed
|
||||
} else if let Some(parsed) = number.as_f64() {
|
||||
if !parsed.is_finite() || parsed.fract() != 0.0 || parsed < 0.0 {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': expected unsigned integer, got number",
|
||||
key
|
||||
)));
|
||||
}
|
||||
if parsed > u64::MAX as f64 {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': integer {} is outside u64 range",
|
||||
key, parsed
|
||||
)));
|
||||
}
|
||||
parsed as u64
|
||||
} else {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': expected unsigned integer, got number",
|
||||
key
|
||||
)));
|
||||
};
|
||||
if parsed > JS_MAX_SAFE_INTEGER_U64 {
|
||||
return Err(RunInputError::message(format!(
|
||||
"param '{}': integer {} exceeds JS safe integer range; pass a decimal string for exact values",
|
||||
key, parsed
|
||||
)));
|
||||
}
|
||||
Ok(parsed)
|
||||
}
|
||||
Value::String(value) => value.parse::<u64>().map_err(|_| {
|
||||
RunInputError::message(format!(
|
||||
"param '{}': expected unsigned integer string, got '{}'",
|
||||
key, value
|
||||
))
|
||||
}),
|
||||
other => Err(RunInputError::message(format!(
|
||||
"param '{}': expected unsigned integer, got {}",
|
||||
key,
|
||||
json_type_name(other)
|
||||
))),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn parse_vector_dim(type_name: &str) -> Option<usize> {
|
||||
let dim = type_name
|
||||
.strip_prefix("Vector(")?
|
||||
.strip_suffix(')')?
|
||||
.parse::<usize>()
|
||||
.ok()?;
|
||||
if dim == 0 { None } else { Some(dim) }
|
||||
}
|
||||
|
||||
fn parse_list_item_type(type_name: &str) -> Option<&str> {
|
||||
Some(type_name.strip_prefix('[')?.strip_suffix(']')?.trim())
|
||||
}
|
||||
|
||||
fn json_type_name(value: &Value) -> &'static str {
|
||||
match value {
|
||||
Value::Null => "null",
|
||||
Value::Bool(_) => "boolean",
|
||||
Value::Number(_) => "number",
|
||||
Value::String(_) => "string",
|
||||
Value::Array(_) => "array",
|
||||
Value::Object(_) => "object",
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use serde_json::json;
|
||||
|
||||
use super::{JsonParamMode, ToParam, find_named_query, json_params_to_param_map};
|
||||
use crate::query::ast::Literal;
|
||||
|
||||
#[test]
|
||||
fn js_mode_rejects_unsafe_integer_numbers() {
|
||||
let query = find_named_query(
|
||||
"query find($id: U64) { match { $u: User } return { $u } }",
|
||||
"find",
|
||||
)
|
||||
.expect("query should parse");
|
||||
|
||||
let error = json_params_to_param_map(
|
||||
Some(&json!({ "id": 9_007_199_254_740_992u64 })),
|
||||
&query.params,
|
||||
JsonParamMode::JavaScript,
|
||||
)
|
||||
.expect_err("unsafe integer should fail");
|
||||
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
"param 'id': integer 9007199254740992 exceeds JS safe integer range; pass a decimal string for exact values"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn standard_mode_preserves_ffi_param_object_error() {
|
||||
let error = json_params_to_param_map(Some(&json!(["nope"])), &[], JsonParamMode::Standard)
|
||||
.expect_err("non-object params should fail");
|
||||
|
||||
assert_eq!(error.to_string(), "params must be a JSON object");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_param_supports_lists_and_explicit_date_literals() {
|
||||
let vector = vec![1_i32, 2_i32, 3_i32].to_param().expect("vector param");
|
||||
match vector {
|
||||
Literal::List(values) => {
|
||||
assert!(matches!(values.first(), Some(Literal::Integer(1))));
|
||||
assert!(matches!(values.get(1), Some(Literal::Integer(2))));
|
||||
assert!(matches!(values.get(2), Some(Literal::Integer(3))));
|
||||
}
|
||||
other => panic!("expected list param, got {:?}", other),
|
||||
}
|
||||
|
||||
let date = Literal::Date("2026-03-06".to_string())
|
||||
.to_param()
|
||||
.expect("date param");
|
||||
assert!(matches!(date, Literal::Date(ref value) if value == "2026-03-06"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn to_param_rejects_unsigned_values_outside_engine_range() {
|
||||
let error = u64::MAX.to_param().expect_err("oversized u64 should fail");
|
||||
|
||||
assert_eq!(
|
||||
error.to_string(),
|
||||
format!(
|
||||
"execution error: param value {} exceeds current engine range for numeric literals (max {})",
|
||||
u64::MAX,
|
||||
i64::MAX
|
||||
)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn params_macro_builds_param_map() {
|
||||
let params = params! {
|
||||
"name" => "Alice",
|
||||
"age" => 41_i32,
|
||||
"scores" => [1_u8, 2_u8, 3_u8],
|
||||
"published_at" => Literal::DateTime("2026-03-06T12:00:00Z".to_string()),
|
||||
}
|
||||
.expect("params");
|
||||
|
||||
assert!(matches!(
|
||||
params.get("name"),
|
||||
Some(Literal::String(value)) if value == "Alice"
|
||||
));
|
||||
assert!(matches!(params.get("age"), Some(Literal::Integer(41))));
|
||||
match params.get("scores") {
|
||||
Some(Literal::List(values)) => {
|
||||
assert!(matches!(values.first(), Some(Literal::Integer(1))));
|
||||
assert!(matches!(values.get(1), Some(Literal::Integer(2))));
|
||||
assert!(matches!(values.get(2), Some(Literal::Integer(3))));
|
||||
}
|
||||
other => panic!("expected list param, got {:?}", other),
|
||||
}
|
||||
assert!(matches!(
|
||||
params.get("published_at"),
|
||||
Some(Literal::DateTime(value)) if value == "2026-03-06T12:00:00Z"
|
||||
));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn typed_json_params_support_list_and_datetime_types() {
|
||||
let query = find_named_query(
|
||||
r#"
|
||||
query q($tags: [String], $days: [Date]?, $due_at: DateTime) {
|
||||
match { $t: Task }
|
||||
return { $t.slug }
|
||||
}
|
||||
"#,
|
||||
"q",
|
||||
)
|
||||
.expect("query");
|
||||
|
||||
let params = json_params_to_param_map(
|
||||
Some(&json!({
|
||||
"tags": ["launch", "priority"],
|
||||
"days": ["2026-04-01", "2026-04-02"],
|
||||
"due_at": "2026-04-03T10:15:00Z"
|
||||
})),
|
||||
&query.params,
|
||||
JsonParamMode::Standard,
|
||||
)
|
||||
.expect("typed params");
|
||||
|
||||
assert!(matches!(
|
||||
params.get("due_at"),
|
||||
Some(Literal::DateTime(value)) if value == "2026-04-03T10:15:00Z"
|
||||
));
|
||||
match params.get("tags") {
|
||||
Some(Literal::List(values)) => {
|
||||
assert!(
|
||||
matches!(values.first(), Some(Literal::String(value)) if value == "launch")
|
||||
);
|
||||
assert!(
|
||||
matches!(values.get(1), Some(Literal::String(value)) if value == "priority")
|
||||
);
|
||||
}
|
||||
other => panic!("expected string list param, got {:?}", other),
|
||||
}
|
||||
match params.get("days") {
|
||||
Some(Literal::List(values)) => {
|
||||
assert!(
|
||||
matches!(values.first(), Some(Literal::Date(value)) if value == "2026-04-01")
|
||||
);
|
||||
assert!(
|
||||
matches!(values.get(1), Some(Literal::Date(value)) if value == "2026-04-02")
|
||||
);
|
||||
}
|
||||
other => panic!("expected date list param, got {:?}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
286
crates/omnigraph-compiler/src/result.rs
Normal file
286
crates/omnigraph-compiler/src/result.rs
Normal file
|
|
@ -0,0 +1,286 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use arrow_ipc::writer::StreamWriter;
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use serde::de::DeserializeOwned;
|
||||
|
||||
use crate::error::{NanoError, Result};
|
||||
use crate::json_output::{record_batches_to_json_rows, record_batches_to_rust_json_rows};
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct MutationExecResult {
|
||||
pub affected_nodes: usize,
|
||||
pub affected_edges: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct QueryResult {
|
||||
schema: SchemaRef,
|
||||
batches: Vec<RecordBatch>,
|
||||
}
|
||||
|
||||
impl QueryResult {
|
||||
pub fn new(schema: SchemaRef, batches: Vec<RecordBatch>) -> Self {
|
||||
Self { schema, batches }
|
||||
}
|
||||
|
||||
pub fn schema(&self) -> &SchemaRef {
|
||||
&self.schema
|
||||
}
|
||||
|
||||
pub fn batches(&self) -> &[RecordBatch] {
|
||||
&self.batches
|
||||
}
|
||||
|
||||
pub fn into_batches(self) -> Vec<RecordBatch> {
|
||||
self.batches
|
||||
}
|
||||
|
||||
pub fn num_rows(&self) -> usize {
|
||||
self.batches.iter().map(RecordBatch::num_rows).sum()
|
||||
}
|
||||
|
||||
pub fn concat_batches(&self) -> Result<RecordBatch> {
|
||||
if self.batches.is_empty() {
|
||||
return Ok(RecordBatch::new_empty(self.schema.clone()));
|
||||
}
|
||||
|
||||
arrow_select::concat::concat_batches(&self.schema, &self.batches)
|
||||
.map_err(|err| NanoError::Execution(err.to_string()))
|
||||
}
|
||||
|
||||
pub fn to_sdk_json(&self) -> serde_json::Value {
|
||||
serde_json::Value::Array(record_batches_to_json_rows(&self.batches))
|
||||
}
|
||||
|
||||
pub fn to_rust_json(&self) -> serde_json::Value {
|
||||
serde_json::Value::Array(record_batches_to_rust_json_rows(&self.batches))
|
||||
}
|
||||
|
||||
pub fn deserialize<T: DeserializeOwned>(&self) -> Result<T> {
|
||||
serde_json::from_value(self.to_rust_json()).map_err(|err| {
|
||||
NanoError::Execution(format!("failed to deserialize query result: {}", err))
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_arrow_ipc(&self) -> Result<Vec<u8>> {
|
||||
let mut buffer = Vec::new();
|
||||
let mut writer = StreamWriter::try_new(&mut buffer, &self.schema)?;
|
||||
for batch in &self.batches {
|
||||
writer.write(batch)?;
|
||||
}
|
||||
writer.finish()?;
|
||||
drop(writer);
|
||||
Ok(buffer)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default)]
|
||||
pub struct MutationResult {
|
||||
pub affected_nodes: usize,
|
||||
pub affected_edges: usize,
|
||||
}
|
||||
|
||||
impl MutationResult {
|
||||
pub fn to_sdk_json(&self) -> serde_json::Value {
|
||||
serde_json::json!({
|
||||
"affectedNodes": self.affected_nodes,
|
||||
"affectedEdges": self.affected_edges,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn to_record_batch(&self) -> Result<RecordBatch> {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("affected_nodes", DataType::UInt64, false),
|
||||
Field::new("affected_edges", DataType::UInt64, false),
|
||||
]));
|
||||
Ok(RecordBatch::try_new(
|
||||
schema,
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(vec![self.affected_nodes as u64])),
|
||||
Arc::new(UInt64Array::from(vec![self.affected_edges as u64])),
|
||||
],
|
||||
)?)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MutationExecResult> for MutationResult {
|
||||
fn from(value: MutationExecResult) -> Self {
|
||||
Self {
|
||||
affected_nodes: value.affected_nodes,
|
||||
affected_edges: value.affected_edges,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub enum RunResult {
|
||||
Query(QueryResult),
|
||||
Mutation(MutationResult),
|
||||
}
|
||||
|
||||
impl RunResult {
|
||||
pub fn to_sdk_json(&self) -> serde_json::Value {
|
||||
match self {
|
||||
Self::Query(result) => result.to_sdk_json(),
|
||||
Self::Mutation(result) => result.to_sdk_json(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_record_batches(self) -> Result<Vec<RecordBatch>> {
|
||||
match self {
|
||||
Self::Query(result) => Ok(result.into_batches()),
|
||||
Self::Mutation(result) => Ok(vec![result.to_record_batch()?]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::io::Cursor;
|
||||
|
||||
use arrow_array::Int64Array;
|
||||
use arrow_ipc::reader::StreamReader;
|
||||
use serde::Deserialize;
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn query_result_arrow_ipc_round_trips_empty_schema() {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("name", DataType::Utf8, false)]));
|
||||
let result = QueryResult::new(schema.clone(), vec![]);
|
||||
|
||||
let encoded = result.to_arrow_ipc().expect("encode empty result");
|
||||
let reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
|
||||
|
||||
assert_eq!(reader.schema().as_ref(), schema.as_ref());
|
||||
assert_eq!(reader.count(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_result_arrow_ipc_round_trips_batches() {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
|
||||
)
|
||||
.expect("batch");
|
||||
let result = QueryResult::new(schema.clone(), vec![batch]);
|
||||
|
||||
let encoded = result.to_arrow_ipc().expect("encode result");
|
||||
let mut reader = StreamReader::try_new(Cursor::new(encoded), None).expect("open stream");
|
||||
let decoded = reader.next().expect("first batch").expect("decode batch");
|
||||
|
||||
assert_eq!(reader.schema().as_ref(), schema.as_ref());
|
||||
assert_eq!(decoded.num_rows(), 2);
|
||||
assert_eq!(decoded.schema().as_ref(), schema.as_ref());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_result_num_rows_and_concat_cover_multiple_batches() {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
|
||||
let batch1 = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(UInt64Array::from(vec![1_u64, 2_u64]))],
|
||||
)
|
||||
.expect("batch1");
|
||||
let batch2 = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![Arc::new(UInt64Array::from(vec![3_u64]))],
|
||||
)
|
||||
.expect("batch2");
|
||||
let result = QueryResult::new(schema.clone(), vec![batch1, batch2]);
|
||||
|
||||
assert_eq!(result.num_rows(), 3);
|
||||
|
||||
let concatenated = result.concat_batches().expect("concat batches");
|
||||
let ids = concatenated
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.expect("u64 ids");
|
||||
assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
|
||||
assert_eq!(ids.values(), &[1, 2, 3]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_result_concat_empty_batches_returns_empty_batch() {
|
||||
let schema = Arc::new(Schema::new(vec![Field::new("id", DataType::UInt64, false)]));
|
||||
let result = QueryResult::new(schema.clone(), vec![]);
|
||||
|
||||
let concatenated = result.concat_batches().expect("concat empty");
|
||||
|
||||
assert_eq!(concatenated.schema().as_ref(), schema.as_ref());
|
||||
assert_eq!(concatenated.num_rows(), 0);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_result_to_rust_json_preserves_wide_integers() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("signed", DataType::Int64, false),
|
||||
Field::new("unsigned", DataType::UInt64, false),
|
||||
]));
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(Int64Array::from(vec![i64::MIN])),
|
||||
Arc::new(UInt64Array::from(vec![u64::MAX])),
|
||||
],
|
||||
)
|
||||
.expect("batch");
|
||||
let result = QueryResult::new(schema, vec![batch]);
|
||||
|
||||
assert_eq!(
|
||||
result.to_rust_json(),
|
||||
serde_json::json!([{
|
||||
"signed": i64::MIN,
|
||||
"unsigned": u64::MAX,
|
||||
}])
|
||||
);
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, PartialEq)]
|
||||
struct PersonRow {
|
||||
id: u64,
|
||||
age: i64,
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn query_result_deserialize_decodes_rust_rows() {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::UInt64, false),
|
||||
Field::new("age", DataType::Int64, false),
|
||||
]));
|
||||
let batch1 = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(vec![1_u64])),
|
||||
Arc::new(Int64Array::from(vec![40_i64])),
|
||||
],
|
||||
)
|
||||
.expect("batch1");
|
||||
let batch2 = RecordBatch::try_new(
|
||||
schema,
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(vec![u64::MAX])),
|
||||
Arc::new(Int64Array::from(vec![-5_i64])),
|
||||
],
|
||||
)
|
||||
.expect("batch2");
|
||||
let result = QueryResult::new(batch1.schema(), vec![batch1, batch2]);
|
||||
|
||||
let rows: Vec<PersonRow> = result.deserialize().expect("deserialize rows");
|
||||
|
||||
assert_eq!(
|
||||
rows,
|
||||
vec![
|
||||
PersonRow { id: 1, age: 40 },
|
||||
PersonRow {
|
||||
id: u64::MAX,
|
||||
age: -5,
|
||||
},
|
||||
]
|
||||
);
|
||||
}
|
||||
}
|
||||
111
crates/omnigraph-compiler/src/schema/ast.rs
Normal file
111
crates/omnigraph-compiler/src/schema/ast.rs
Normal file
|
|
@ -0,0 +1,111 @@
|
|||
use crate::types::PropType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct SchemaFile {
|
||||
pub declarations: Vec<SchemaDecl>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum SchemaDecl {
|
||||
Interface(InterfaceDecl),
|
||||
Node(NodeDecl),
|
||||
Edge(EdgeDecl),
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct InterfaceDecl {
|
||||
pub name: String,
|
||||
pub properties: Vec<PropDecl>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct NodeDecl {
|
||||
pub name: String,
|
||||
pub annotations: Vec<Annotation>,
|
||||
pub implements: Vec<String>,
|
||||
pub properties: Vec<PropDecl>,
|
||||
pub constraints: Vec<Constraint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct EdgeDecl {
|
||||
pub name: String,
|
||||
pub from_type: String,
|
||||
pub to_type: String,
|
||||
pub cardinality: Cardinality,
|
||||
pub annotations: Vec<Annotation>,
|
||||
pub properties: Vec<PropDecl>,
|
||||
pub constraints: Vec<Constraint>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct PropDecl {
|
||||
pub name: String,
|
||||
pub prop_type: PropType,
|
||||
pub annotations: Vec<Annotation>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub struct Annotation {
|
||||
pub name: String,
|
||||
pub value: Option<String>,
|
||||
}
|
||||
|
||||
/// A typed constraint declared in a node or edge body.
|
||||
///
|
||||
/// Property-level annotations (`@key`, `@unique`, `@index`) are desugared
|
||||
/// into these during parsing, so both syntactic positions produce the same
|
||||
/// representation.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum Constraint {
|
||||
Key(Vec<String>),
|
||||
Unique(Vec<String>),
|
||||
Index(Vec<String>),
|
||||
Range {
|
||||
property: String,
|
||||
min: Option<ConstraintBound>,
|
||||
max: Option<ConstraintBound>,
|
||||
},
|
||||
Check {
|
||||
property: String,
|
||||
pattern: String,
|
||||
},
|
||||
}
|
||||
|
||||
/// A numeric bound used in `@range` constraints.
|
||||
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
|
||||
pub enum ConstraintBound {
|
||||
Integer(i64),
|
||||
Float(f64),
|
||||
}
|
||||
|
||||
/// Edge cardinality: `@card(min..max)`. Default is `0..*`.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub struct Cardinality {
|
||||
pub min: u32,
|
||||
pub max: Option<u32>,
|
||||
}
|
||||
|
||||
impl Default for Cardinality {
|
||||
fn default() -> Self {
|
||||
Self { min: 0, max: None }
|
||||
}
|
||||
}
|
||||
|
||||
impl Cardinality {
|
||||
pub fn is_default(&self) -> bool {
|
||||
self.min == 0 && self.max.is_none()
|
||||
}
|
||||
}
|
||||
|
||||
pub fn has_annotation(annotations: &[Annotation], name: &str) -> bool {
|
||||
annotations.iter().any(|ann| ann.name == name)
|
||||
}
|
||||
|
||||
pub fn annotation_value<'a>(annotations: &'a [Annotation], name: &str) -> Option<&'a str> {
|
||||
annotations
|
||||
.iter()
|
||||
.find(|ann| ann.name == name)
|
||||
.and_then(|ann| ann.value.as_deref())
|
||||
}
|
||||
2
crates/omnigraph-compiler/src/schema/mod.rs
Normal file
2
crates/omnigraph-compiler/src/schema/mod.rs
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
pub mod ast;
|
||||
pub mod parser;
|
||||
1950
crates/omnigraph-compiler/src/schema/parser.rs
Normal file
1950
crates/omnigraph-compiler/src/schema/parser.rs
Normal file
File diff suppressed because it is too large
Load diff
60
crates/omnigraph-compiler/src/schema/schema.pest
Normal file
60
crates/omnigraph-compiler/src/schema/schema.pest
Normal file
|
|
@ -0,0 +1,60 @@
|
|||
// Omnigraph Schema Grammar (.pg files)
|
||||
|
||||
WHITESPACE = _{ " " | "\t" | "\r" | "\n" }
|
||||
COMMENT = _{ LINE_COMMENT | BLOCK_COMMENT }
|
||||
LINE_COMMENT = _{ "//" ~ (!"\n" ~ ANY)* }
|
||||
BLOCK_COMMENT = _{ "/*" ~ (!"*/" ~ ANY)* ~ "*/" }
|
||||
|
||||
schema_file = { SOI ~ schema_decl* ~ EOI }
|
||||
|
||||
schema_decl = { interface_decl | node_decl | edge_decl }
|
||||
|
||||
// interface Named { name: String @key }
|
||||
interface_decl = { "interface" ~ type_name ~ "{" ~ prop_decl* ~ "}" }
|
||||
|
||||
// node Person implements Named, Described { ... }
|
||||
node_decl = { "node" ~ type_name ~ annotation* ~ implements_clause? ~ "{" ~ (prop_decl | body_constraint)* ~ "}" }
|
||||
implements_clause = { "implements" ~ type_name ~ ("," ~ type_name)* }
|
||||
|
||||
// edge Knows: Person -> Person @card(0..1) { ... }
|
||||
// edge Knows: Person -> Person
|
||||
edge_decl = { "edge" ~ type_name ~ ":" ~ type_name ~ "->" ~ type_name ~ cardinality? ~ annotation* ~ ("{" ~ (prop_decl | body_constraint)* ~ "}")? }
|
||||
|
||||
// @card(0..1), @card(1..), @card(0..)
|
||||
cardinality = { "@card" ~ "(" ~ integer ~ ".." ~ integer? ~ ")" }
|
||||
|
||||
prop_decl = { ident ~ ":" ~ type_ref ~ annotation* }
|
||||
|
||||
// Body-level constraints: @key(name), @unique(a, b), @index(a, b), @range(age, 0..200), @check(code, "regex")
|
||||
body_constraint = { "@" ~ constraint_name ~ "(" ~ constraint_args ~ ")" }
|
||||
constraint_name = { "key" | "unique" | "index" | "range" | "check" }
|
||||
constraint_args = { constraint_arg ~ ("," ~ constraint_arg)* }
|
||||
constraint_arg = { range_bound | literal | ident }
|
||||
range_bound = { (signed_float | signed_integer) ~ ".." ~ (signed_float | signed_integer)? | ".." ~ (signed_float | signed_integer) }
|
||||
|
||||
type_ref = { core_type ~ "?"? }
|
||||
core_type = { list_type | enum_type | vector_type | base_type }
|
||||
list_type = { "[" ~ base_type ~ "]" }
|
||||
enum_type = { "enum" ~ "(" ~ enum_value ~ ("," ~ enum_value)* ~ ")" }
|
||||
vector_type = { "Vector" ~ "(" ~ integer ~ ")" }
|
||||
enum_value = @{ (ASCII_ALPHANUMERIC | "_" | "-")+ }
|
||||
|
||||
base_type = { "String" | "Blob" | "Bool" | "I32" | "I64" | "U32" | "U64" | "F32" | "F64" | "DateTime" | "Date" }
|
||||
|
||||
// Annotation rule excludes constraint keywords followed by "(" — those are body_constraints
|
||||
annotation = { "@" ~ !(constraint_name ~ "(") ~ ident ~ ("(" ~ annotation_arg ~ ")")? }
|
||||
annotation_arg = { literal | ident }
|
||||
|
||||
literal = { string_lit | float_lit | integer | bool_lit }
|
||||
|
||||
string_lit = @{ "\"" ~ string_char* ~ "\"" }
|
||||
string_char = @{ !("\"" | "\\") ~ ANY | "\\" ~ ANY }
|
||||
float_lit = @{ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ }
|
||||
integer = @{ ASCII_DIGIT+ }
|
||||
|
||||
signed_float = @{ "-"? ~ ASCII_DIGIT+ ~ "." ~ ASCII_DIGIT+ }
|
||||
signed_integer = @{ "-"? ~ ASCII_DIGIT+ }
|
||||
bool_lit = { "true" | "false" }
|
||||
|
||||
type_name = @{ ASCII_ALPHA_UPPER ~ (ASCII_ALPHANUMERIC | "_")* }
|
||||
ident = @{ (ASCII_ALPHA_LOWER | "_") ~ (ASCII_ALPHANUMERIC | "_")* }
|
||||
227
crates/omnigraph-compiler/src/types.rs
Normal file
227
crates/omnigraph-compiler/src/types.rs
Normal file
|
|
@ -0,0 +1,227 @@
|
|||
use arrow_schema::DataType;
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
const MAX_VECTOR_DIM: u32 = i32::MAX as u32;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub enum ScalarType {
|
||||
String,
|
||||
Bool,
|
||||
I32,
|
||||
I64,
|
||||
U32,
|
||||
U64,
|
||||
F32,
|
||||
F64,
|
||||
Date,
|
||||
DateTime,
|
||||
Vector(u32),
|
||||
Blob,
|
||||
}
|
||||
|
||||
impl ScalarType {
|
||||
pub fn from_str_name(s: &str) -> Option<Self> {
|
||||
if let Some(inner) = s.strip_prefix("Vector(").and_then(|t| t.strip_suffix(')')) {
|
||||
let dim = inner.parse::<u32>().ok()?;
|
||||
if dim == 0 || dim > MAX_VECTOR_DIM {
|
||||
return None;
|
||||
}
|
||||
return Some(Self::Vector(dim));
|
||||
}
|
||||
|
||||
match s {
|
||||
"String" => Some(Self::String),
|
||||
"Bool" => Some(Self::Bool),
|
||||
"I32" => Some(Self::I32),
|
||||
"I64" => Some(Self::I64),
|
||||
"U32" => Some(Self::U32),
|
||||
"U64" => Some(Self::U64),
|
||||
"F32" => Some(Self::F32),
|
||||
"F64" => Some(Self::F64),
|
||||
"Date" => Some(Self::Date),
|
||||
"DateTime" => Some(Self::DateTime),
|
||||
"Blob" => Some(Self::Blob),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn to_arrow(&self) -> DataType {
|
||||
match self {
|
||||
Self::String => DataType::Utf8,
|
||||
Self::Bool => DataType::Boolean,
|
||||
Self::I32 => DataType::Int32,
|
||||
Self::I64 => DataType::Int64,
|
||||
Self::U32 => DataType::UInt32,
|
||||
Self::U64 => DataType::UInt64,
|
||||
Self::F32 => DataType::Float32,
|
||||
Self::F64 => DataType::Float64,
|
||||
Self::Date => DataType::Date32,
|
||||
Self::DateTime => DataType::Date64,
|
||||
Self::Blob => DataType::LargeBinary,
|
||||
Self::Vector(dim) => {
|
||||
let dim = i32::try_from(*dim)
|
||||
.expect("vector dimension exceeds Arrow FixedSizeList i32 bound");
|
||||
DataType::FixedSizeList(
|
||||
std::sync::Arc::new(arrow_schema::Field::new("item", DataType::Float32, true)),
|
||||
dim,
|
||||
)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_numeric(&self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::I32 | Self::I64 | Self::U32 | Self::U64 | Self::F32 | Self::F64
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl std::fmt::Display for ScalarType {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
let s = match self {
|
||||
Self::String => "String",
|
||||
Self::Bool => "Bool",
|
||||
Self::I32 => "I32",
|
||||
Self::I64 => "I64",
|
||||
Self::U32 => "U32",
|
||||
Self::U64 => "U64",
|
||||
Self::F32 => "F32",
|
||||
Self::F64 => "F64",
|
||||
Self::Date => "Date",
|
||||
Self::DateTime => "DateTime",
|
||||
Self::Blob => "Blob",
|
||||
Self::Vector(dim) => return write!(f, "Vector({})", dim),
|
||||
};
|
||||
write!(f, "{}", s)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
|
||||
pub struct PropType {
|
||||
pub scalar: ScalarType,
|
||||
pub nullable: bool,
|
||||
pub list: bool,
|
||||
pub enum_values: Option<Vec<String>>,
|
||||
}
|
||||
|
||||
impl PropType {
|
||||
pub fn from_param_type_name(s: &str, nullable: bool) -> Option<Self> {
|
||||
if let Some(inner) = s
|
||||
.strip_prefix('[')
|
||||
.and_then(|value| value.strip_suffix(']'))
|
||||
{
|
||||
let scalar = ScalarType::from_str_name(inner)?;
|
||||
return Some(Self::list_of(scalar, nullable));
|
||||
}
|
||||
|
||||
let scalar = ScalarType::from_str_name(s)?;
|
||||
Some(Self::scalar(scalar, nullable))
|
||||
}
|
||||
|
||||
pub fn scalar(scalar: ScalarType, nullable: bool) -> Self {
|
||||
Self {
|
||||
scalar,
|
||||
nullable,
|
||||
list: false,
|
||||
enum_values: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn list_of(scalar: ScalarType, nullable: bool) -> Self {
|
||||
Self {
|
||||
scalar,
|
||||
nullable,
|
||||
list: true,
|
||||
enum_values: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn enum_type(mut values: Vec<String>, nullable: bool) -> Self {
|
||||
values.sort();
|
||||
values.dedup();
|
||||
Self {
|
||||
scalar: ScalarType::String,
|
||||
nullable,
|
||||
list: false,
|
||||
enum_values: Some(values),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn is_enum(&self) -> bool {
|
||||
self.enum_values.is_some()
|
||||
}
|
||||
|
||||
pub fn to_arrow(&self) -> DataType {
|
||||
let scalar_dt = self.scalar.to_arrow();
|
||||
if self.list {
|
||||
DataType::List(std::sync::Arc::new(arrow_schema::Field::new(
|
||||
"item", scalar_dt, true,
|
||||
)))
|
||||
} else {
|
||||
scalar_dt
|
||||
}
|
||||
}
|
||||
|
||||
pub fn display_name(&self) -> String {
|
||||
let base = if let Some(values) = &self.enum_values {
|
||||
format!("enum({})", values.join(", "))
|
||||
} else {
|
||||
self.scalar.to_string()
|
||||
};
|
||||
let wrapped = if self.list {
|
||||
format!("[{}]", base)
|
||||
} else {
|
||||
base
|
||||
};
|
||||
if self.nullable {
|
||||
format!("{}?", wrapped)
|
||||
} else {
|
||||
wrapped
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum Direction {
|
||||
Out,
|
||||
In,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
use arrow_schema::{DataType, Field};
|
||||
use std::sync::Arc;
|
||||
|
||||
#[test]
|
||||
fn vector_to_arrow_uses_nullable_float32_child() {
|
||||
let dt = ScalarType::Vector(4).to_arrow();
|
||||
assert_eq!(
|
||||
dt,
|
||||
DataType::FixedSizeList(Arc::new(Field::new("item", DataType::Float32, true)), 4)
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scalar_type_from_str_name_rejects_vector_dimensions_outside_arrow_bounds() {
|
||||
let too_large = format!("Vector({})", (i32::MAX as u64) + 1);
|
||||
assert!(ScalarType::from_str_name(&too_large).is_none());
|
||||
assert_eq!(
|
||||
ScalarType::from_str_name("Vector(2147483647)"),
|
||||
Some(ScalarType::Vector(2147483647))
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn prop_type_from_param_type_name_supports_lists_and_nullable_scalars() {
|
||||
assert_eq!(
|
||||
PropType::from_param_type_name("[DateTime]", false),
|
||||
Some(PropType::list_of(ScalarType::DateTime, false))
|
||||
);
|
||||
assert_eq!(
|
||||
PropType::from_param_type_name("DateTime", true),
|
||||
Some(PropType::scalar(ScalarType::DateTime, true))
|
||||
);
|
||||
}
|
||||
}
|
||||
30
crates/omnigraph-server/Cargo.toml
Normal file
30
crates/omnigraph-server/Cargo.toml
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
[package]
|
||||
name = "omnigraph-server"
|
||||
version = "0.4.0"
|
||||
edition = "2024"
|
||||
description = "HTTP server for the Omnigraph graph database."
|
||||
license = "MIT"
|
||||
|
||||
[[bin]]
|
||||
name = "omnigraph-server"
|
||||
path = "src/main.rs"
|
||||
|
||||
[dependencies]
|
||||
omnigraph = { path = "../omnigraph", version = "0.4.0" }
|
||||
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
|
||||
axum = { workspace = true }
|
||||
clap = { workspace = true }
|
||||
color-eyre = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
serde_yaml = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
tracing-subscriber = { workspace = true }
|
||||
tower-http = { workspace = true }
|
||||
cedar-policy = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
tempfile = { workspace = true }
|
||||
tower = { workspace = true }
|
||||
serial_test = "3"
|
||||
395
crates/omnigraph-server/src/api.rs
Normal file
395
crates/omnigraph-server/src/api.rs
Normal file
|
|
@ -0,0 +1,395 @@
|
|||
use omnigraph::db::{GraphCommit, MergeOutcome, ReadTarget, RunRecord, Snapshot};
|
||||
use omnigraph::error::{MergeConflict, MergeConflictKind};
|
||||
use omnigraph::loader::{IngestResult, LoadMode};
|
||||
use omnigraph_compiler::result::QueryResult;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SnapshotTableOutput {
|
||||
pub table_key: String,
|
||||
pub table_path: String,
|
||||
pub table_version: u64,
|
||||
pub table_branch: Option<String>,
|
||||
pub row_count: u64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct SnapshotOutput {
|
||||
pub branch: String,
|
||||
pub manifest_version: u64,
|
||||
pub tables: Vec<SnapshotTableOutput>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RunOutput {
|
||||
pub run_id: String,
|
||||
pub target_branch: String,
|
||||
pub run_branch: String,
|
||||
pub base_snapshot_id: String,
|
||||
pub base_manifest_version: u64,
|
||||
pub operation_hash: Option<String>,
|
||||
pub actor_id: Option<String>,
|
||||
pub status: String,
|
||||
pub published_snapshot_id: Option<String>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct RunListOutput {
|
||||
pub runs: Vec<RunOutput>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BranchCreateRequest {
|
||||
pub from: Option<String>,
|
||||
pub name: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BranchCreateOutput {
|
||||
pub uri: String,
|
||||
pub from: String,
|
||||
pub name: String,
|
||||
pub actor_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BranchListOutput {
|
||||
pub branches: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BranchDeleteOutput {
|
||||
pub uri: String,
|
||||
pub name: String,
|
||||
pub actor_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BranchMergeRequest {
|
||||
pub source: String,
|
||||
pub target: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum BranchMergeOutcome {
|
||||
AlreadyUpToDate,
|
||||
FastForward,
|
||||
Merged,
|
||||
}
|
||||
|
||||
impl From<MergeOutcome> for BranchMergeOutcome {
|
||||
fn from(value: MergeOutcome) -> Self {
|
||||
match value {
|
||||
MergeOutcome::AlreadyUpToDate => Self::AlreadyUpToDate,
|
||||
MergeOutcome::FastForward => Self::FastForward,
|
||||
MergeOutcome::Merged => Self::Merged,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl BranchMergeOutcome {
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::AlreadyUpToDate => "already_up_to_date",
|
||||
Self::FastForward => "fast_forward",
|
||||
Self::Merged => "merged",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct BranchMergeOutput {
|
||||
pub source: String,
|
||||
pub target: String,
|
||||
pub outcome: BranchMergeOutcome,
|
||||
pub actor_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum MergeConflictKindOutput {
|
||||
DivergentInsert,
|
||||
DivergentUpdate,
|
||||
DeleteVsUpdate,
|
||||
OrphanEdge,
|
||||
UniqueViolation,
|
||||
CardinalityViolation,
|
||||
ValueConstraintViolation,
|
||||
}
|
||||
|
||||
impl MergeConflictKindOutput {
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::DivergentInsert => "divergent_insert",
|
||||
Self::DivergentUpdate => "divergent_update",
|
||||
Self::DeleteVsUpdate => "delete_vs_update",
|
||||
Self::OrphanEdge => "orphan_edge",
|
||||
Self::UniqueViolation => "unique_violation",
|
||||
Self::CardinalityViolation => "cardinality_violation",
|
||||
Self::ValueConstraintViolation => "value_constraint_violation",
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<MergeConflictKind> for MergeConflictKindOutput {
|
||||
fn from(value: MergeConflictKind) -> Self {
|
||||
match value {
|
||||
MergeConflictKind::DivergentInsert => Self::DivergentInsert,
|
||||
MergeConflictKind::DivergentUpdate => Self::DivergentUpdate,
|
||||
MergeConflictKind::DeleteVsUpdate => Self::DeleteVsUpdate,
|
||||
MergeConflictKind::OrphanEdge => Self::OrphanEdge,
|
||||
MergeConflictKind::UniqueViolation => Self::UniqueViolation,
|
||||
MergeConflictKind::CardinalityViolation => Self::CardinalityViolation,
|
||||
MergeConflictKind::ValueConstraintViolation => Self::ValueConstraintViolation,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct MergeConflictOutput {
|
||||
pub table_key: String,
|
||||
pub row_id: Option<String>,
|
||||
pub kind: MergeConflictKindOutput,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl From<&MergeConflict> for MergeConflictOutput {
|
||||
fn from(value: &MergeConflict) -> Self {
|
||||
Self {
|
||||
table_key: value.table_key.clone(),
|
||||
row_id: value.row_id.clone(),
|
||||
kind: value.kind.into(),
|
||||
message: value.message.clone(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadTargetOutput {
|
||||
pub branch: Option<String>,
|
||||
pub snapshot: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadOutput {
|
||||
pub query_name: String,
|
||||
pub target: ReadTargetOutput,
|
||||
pub row_count: usize,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub columns: Vec<String>,
|
||||
pub rows: Value,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChangeOutput {
|
||||
pub branch: String,
|
||||
pub query_name: String,
|
||||
pub affected_nodes: usize,
|
||||
pub affected_edges: usize,
|
||||
pub actor_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IngestTableOutput {
|
||||
pub table_key: String,
|
||||
pub rows_loaded: usize,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IngestOutput {
|
||||
pub uri: String,
|
||||
pub branch: String,
|
||||
pub base_branch: String,
|
||||
pub branch_created: bool,
|
||||
pub mode: LoadMode,
|
||||
pub tables: Vec<IngestTableOutput>,
|
||||
pub actor_id: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommitOutput {
|
||||
pub graph_commit_id: String,
|
||||
pub manifest_branch: Option<String>,
|
||||
pub manifest_version: u64,
|
||||
pub parent_commit_id: Option<String>,
|
||||
pub merged_parent_commit_id: Option<String>,
|
||||
pub actor_id: Option<String>,
|
||||
pub created_at: i64,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct CommitListOutput {
|
||||
pub commits: Vec<CommitOutput>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ReadRequest {
|
||||
pub query_source: String,
|
||||
pub query_name: Option<String>,
|
||||
pub params: Option<Value>,
|
||||
pub branch: Option<String>,
|
||||
pub snapshot: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ChangeRequest {
|
||||
pub query_source: String,
|
||||
pub query_name: Option<String>,
|
||||
pub params: Option<Value>,
|
||||
pub branch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct IngestRequest {
|
||||
pub branch: Option<String>,
|
||||
pub from: Option<String>,
|
||||
pub mode: Option<LoadMode>,
|
||||
pub data: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ExportRequest {
|
||||
pub branch: Option<String>,
|
||||
#[serde(default)]
|
||||
pub type_names: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub table_keys: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct SnapshotQuery {
|
||||
pub branch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Deserialize)]
|
||||
pub struct CommitListQuery {
|
||||
pub branch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct HealthOutput {
|
||||
pub status: String,
|
||||
pub version: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub source_version: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ErrorCode {
|
||||
Unauthorized,
|
||||
Forbidden,
|
||||
BadRequest,
|
||||
NotFound,
|
||||
Conflict,
|
||||
Internal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct ErrorOutput {
|
||||
pub error: String,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub code: Option<ErrorCode>,
|
||||
#[serde(default, skip_serializing_if = "Vec::is_empty")]
|
||||
pub merge_conflicts: Vec<MergeConflictOutput>,
|
||||
}
|
||||
|
||||
pub fn snapshot_payload(branch: &str, snapshot: &Snapshot) -> SnapshotOutput {
|
||||
let mut entries: Vec<_> = snapshot.entries().cloned().collect();
|
||||
entries.sort_by(|a, b| a.table_key.cmp(&b.table_key));
|
||||
let tables = entries
|
||||
.iter()
|
||||
.map(|entry| SnapshotTableOutput {
|
||||
table_key: entry.table_key.clone(),
|
||||
table_path: entry.table_path.clone(),
|
||||
table_version: entry.table_version,
|
||||
table_branch: entry.table_branch.clone(),
|
||||
row_count: entry.row_count,
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
SnapshotOutput {
|
||||
branch: branch.to_string(),
|
||||
manifest_version: snapshot.version(),
|
||||
tables,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn run_output(run: &RunRecord) -> RunOutput {
|
||||
RunOutput {
|
||||
run_id: run.run_id.as_str().to_string(),
|
||||
target_branch: run.target_branch.clone(),
|
||||
run_branch: run.run_branch.clone(),
|
||||
base_snapshot_id: run.base_snapshot_id.as_str().to_string(),
|
||||
base_manifest_version: run.base_manifest_version,
|
||||
operation_hash: run.operation_hash.clone(),
|
||||
actor_id: run.actor_id.clone(),
|
||||
status: run.status.as_str().to_string(),
|
||||
published_snapshot_id: run.published_snapshot_id.clone(),
|
||||
created_at: run.created_at,
|
||||
updated_at: run.updated_at,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn commit_output(commit: &GraphCommit) -> CommitOutput {
|
||||
CommitOutput {
|
||||
graph_commit_id: commit.graph_commit_id.clone(),
|
||||
manifest_branch: commit.manifest_branch.clone(),
|
||||
manifest_version: commit.manifest_version,
|
||||
parent_commit_id: commit.parent_commit_id.clone(),
|
||||
merged_parent_commit_id: commit.merged_parent_commit_id.clone(),
|
||||
actor_id: commit.actor_id.clone(),
|
||||
created_at: commit.created_at,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_output(query_name: String, target: &ReadTarget, result: QueryResult) -> ReadOutput {
|
||||
let columns = result
|
||||
.schema()
|
||||
.fields()
|
||||
.iter()
|
||||
.map(|field| field.name().clone())
|
||||
.collect();
|
||||
ReadOutput {
|
||||
query_name,
|
||||
target: read_target_output(target),
|
||||
row_count: result.num_rows(),
|
||||
columns,
|
||||
rows: result.to_rust_json(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn ingest_output(uri: &str, result: &IngestResult, actor_id: Option<String>) -> IngestOutput {
|
||||
IngestOutput {
|
||||
uri: uri.to_string(),
|
||||
branch: result.branch.clone(),
|
||||
base_branch: result.base_branch.clone(),
|
||||
branch_created: result.branch_created,
|
||||
mode: result.mode,
|
||||
tables: result
|
||||
.tables
|
||||
.iter()
|
||||
.map(|table| IngestTableOutput {
|
||||
table_key: table.table_key.clone(),
|
||||
rows_loaded: table.rows_loaded,
|
||||
})
|
||||
.collect(),
|
||||
actor_id,
|
||||
}
|
||||
}
|
||||
|
||||
pub fn read_target_output(target: &ReadTarget) -> ReadTargetOutput {
|
||||
match target {
|
||||
ReadTarget::Branch(branch) => ReadTargetOutput {
|
||||
branch: Some(branch.clone()),
|
||||
snapshot: None,
|
||||
},
|
||||
ReadTarget::Snapshot(snapshot) => ReadTargetOutput {
|
||||
branch: None,
|
||||
snapshot: Some(snapshot.as_str().to_string()),
|
||||
},
|
||||
}
|
||||
}
|
||||
479
crates/omnigraph-server/src/config.rs
Normal file
479
crates/omnigraph-server/src/config.rs
Normal file
|
|
@ -0,0 +1,479 @@
|
|||
use std::collections::BTreeMap;
|
||||
use std::env;
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use clap::ValueEnum;
|
||||
use color_eyre::eyre::{Result, bail};
|
||||
use serde::{Deserialize, Serialize};
|
||||
pub const DEFAULT_CONFIG_FILE: &str = "omnigraph.yaml";
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ProjectConfig {
|
||||
pub name: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct TargetConfig {
|
||||
pub uri: String,
|
||||
pub bearer_token_env: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Serialize, Deserialize, ValueEnum)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ReadOutputFormat {
|
||||
#[default]
|
||||
Table,
|
||||
Kv,
|
||||
Csv,
|
||||
Jsonl,
|
||||
Json,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Default, Eq, PartialEq, Serialize, Deserialize, ValueEnum)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum TableCellLayout {
|
||||
#[default]
|
||||
Truncate,
|
||||
Wrap,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct CliDefaults {
|
||||
pub target: Option<String>,
|
||||
pub branch: Option<String>,
|
||||
pub output_format: Option<ReadOutputFormat>,
|
||||
pub table_max_column_width: Option<usize>,
|
||||
pub table_cell_layout: Option<TableCellLayout>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct ServerDefaults {
|
||||
pub target: Option<String>,
|
||||
pub bind: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct AuthDefaults {
|
||||
pub env_file: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct QueryDefaults {
|
||||
#[serde(default)]
|
||||
pub roots: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
|
||||
pub struct PolicySettings {
|
||||
pub file: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum AliasCommand {
|
||||
Read,
|
||||
Change,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct AliasConfig {
|
||||
pub command: AliasCommand,
|
||||
pub query: String,
|
||||
pub name: Option<String>,
|
||||
#[serde(default)]
|
||||
pub args: Vec<String>,
|
||||
pub target: Option<String>,
|
||||
pub branch: Option<String>,
|
||||
pub format: Option<ReadOutputFormat>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct OmnigraphConfig {
|
||||
#[serde(default)]
|
||||
pub project: ProjectConfig,
|
||||
#[serde(default)]
|
||||
pub targets: BTreeMap<String, TargetConfig>,
|
||||
#[serde(default)]
|
||||
pub server: ServerDefaults,
|
||||
#[serde(default)]
|
||||
pub auth: AuthDefaults,
|
||||
#[serde(default)]
|
||||
pub cli: CliDefaults,
|
||||
#[serde(default)]
|
||||
pub query: QueryDefaults,
|
||||
#[serde(default)]
|
||||
pub aliases: BTreeMap<String, AliasConfig>,
|
||||
#[serde(default)]
|
||||
pub policy: PolicySettings,
|
||||
#[serde(skip)]
|
||||
base_dir: PathBuf,
|
||||
}
|
||||
|
||||
impl Default for OmnigraphConfig {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
project: ProjectConfig::default(),
|
||||
targets: BTreeMap::new(),
|
||||
server: ServerDefaults::default(),
|
||||
auth: AuthDefaults::default(),
|
||||
cli: CliDefaults::default(),
|
||||
query: QueryDefaults::default(),
|
||||
aliases: BTreeMap::new(),
|
||||
policy: PolicySettings::default(),
|
||||
base_dir: PathBuf::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl OmnigraphConfig {
|
||||
pub fn base_dir(&self) -> &Path {
|
||||
&self.base_dir
|
||||
}
|
||||
|
||||
pub fn cli_branch(&self) -> &str {
|
||||
self.cli.branch.as_deref().unwrap_or("main")
|
||||
}
|
||||
|
||||
pub fn cli_output_format(&self) -> ReadOutputFormat {
|
||||
self.cli.output_format.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn table_max_column_width(&self) -> usize {
|
||||
self.cli.table_max_column_width.unwrap_or(80)
|
||||
}
|
||||
|
||||
pub fn table_cell_layout(&self) -> TableCellLayout {
|
||||
self.cli.table_cell_layout.unwrap_or_default()
|
||||
}
|
||||
|
||||
pub fn cli_target_name(&self) -> Option<&str> {
|
||||
self.cli.target.as_deref()
|
||||
}
|
||||
|
||||
pub fn server_target_name(&self) -> Option<&str> {
|
||||
self.server.target.as_deref()
|
||||
}
|
||||
|
||||
pub fn server_bind(&self) -> &str {
|
||||
self.server.bind.as_deref().unwrap_or("127.0.0.1:8080")
|
||||
}
|
||||
|
||||
pub fn resolve_target_name<'a>(
|
||||
&self,
|
||||
explicit_uri: Option<&str>,
|
||||
explicit_target: Option<&'a str>,
|
||||
default_target: Option<&'a str>,
|
||||
) -> Option<&'a str> {
|
||||
explicit_target.or_else(|| {
|
||||
if explicit_uri.is_some() {
|
||||
None
|
||||
} else {
|
||||
default_target
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn target_bearer_token_env(
|
||||
&self,
|
||||
explicit_uri: Option<&str>,
|
||||
explicit_target: Option<&str>,
|
||||
default_target: Option<&str>,
|
||||
) -> Option<&str> {
|
||||
let target_name =
|
||||
self.resolve_target_name(explicit_uri, explicit_target, default_target)?;
|
||||
self.targets
|
||||
.get(target_name)
|
||||
.and_then(|target| target.bearer_token_env.as_deref())
|
||||
}
|
||||
|
||||
pub fn resolve_auth_env_file(&self) -> Option<PathBuf> {
|
||||
let path = self.auth.env_file.as_deref()?;
|
||||
let path = Path::new(path);
|
||||
Some(if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
self.base_dir.join(path)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resolve_policy_file(&self) -> Option<PathBuf> {
|
||||
let path = self.policy.file.as_deref()?;
|
||||
let path = Path::new(path);
|
||||
Some(if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
self.base_dir.join(path)
|
||||
})
|
||||
}
|
||||
|
||||
pub fn resolve_policy_tests_file(&self) -> Option<PathBuf> {
|
||||
let policy_file = self.resolve_policy_file()?;
|
||||
Some(policy_file.with_file_name("policy.tests.yaml"))
|
||||
}
|
||||
|
||||
pub fn alias(&self, name: &str) -> Result<&AliasConfig> {
|
||||
self.aliases
|
||||
.get(name)
|
||||
.ok_or_else(|| color_eyre::eyre::eyre!("alias '{}' not found", name))
|
||||
}
|
||||
|
||||
pub fn resolve_target_uri(
|
||||
&self,
|
||||
explicit_uri: Option<String>,
|
||||
explicit_target: Option<&str>,
|
||||
default_target: Option<&str>,
|
||||
) -> Result<String> {
|
||||
if let Some(uri) = explicit_uri {
|
||||
return Ok(uri);
|
||||
}
|
||||
|
||||
let target_name = explicit_target.or(default_target).ok_or_else(|| {
|
||||
color_eyre::eyre::eyre!("URI must be provided via <URI>, --target, or config")
|
||||
})?;
|
||||
let target = self.targets.get(target_name).ok_or_else(|| {
|
||||
color_eyre::eyre::eyre!(
|
||||
"target '{}' not found in {}",
|
||||
target_name,
|
||||
DEFAULT_CONFIG_FILE
|
||||
)
|
||||
})?;
|
||||
Ok(self.resolve_config_uri(&target.uri))
|
||||
}
|
||||
|
||||
pub fn resolve_query_path(&self, query: &Path) -> Result<PathBuf> {
|
||||
if query.is_absolute() {
|
||||
return Ok(query.to_path_buf());
|
||||
}
|
||||
|
||||
let direct = self.base_dir.join(query);
|
||||
if direct.exists() {
|
||||
return Ok(direct);
|
||||
}
|
||||
|
||||
for root in &self.query.roots {
|
||||
let candidate = self.base_dir.join(root).join(query);
|
||||
if candidate.exists() {
|
||||
return Ok(candidate);
|
||||
}
|
||||
}
|
||||
|
||||
bail!("query file '{}' not found", query.display());
|
||||
}
|
||||
|
||||
fn resolve_config_uri(&self, value: &str) -> String {
|
||||
if value.contains("://") {
|
||||
return value.to_string();
|
||||
}
|
||||
|
||||
let path = Path::new(value);
|
||||
if path.is_absolute() {
|
||||
value.to_string()
|
||||
} else {
|
||||
self.base_dir.join(path).to_string_lossy().to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn default_config_path() -> PathBuf {
|
||||
PathBuf::from(DEFAULT_CONFIG_FILE)
|
||||
}
|
||||
|
||||
pub fn load_config(config_path: Option<&PathBuf>) -> Result<OmnigraphConfig> {
|
||||
load_config_in(&env::current_dir()?, config_path)
|
||||
}
|
||||
|
||||
fn load_config_in(cwd: &Path, config_path: Option<&PathBuf>) -> Result<OmnigraphConfig> {
|
||||
let explicit_path = config_path.cloned();
|
||||
let config_path = explicit_path.or_else(|| {
|
||||
let default_path = cwd.join(DEFAULT_CONFIG_FILE);
|
||||
default_path.exists().then_some(default_path)
|
||||
});
|
||||
|
||||
let mut config = if let Some(path) = &config_path {
|
||||
serde_yaml::from_str::<OmnigraphConfig>(&fs::read_to_string(path)?)?
|
||||
} else {
|
||||
OmnigraphConfig::default()
|
||||
};
|
||||
|
||||
config.base_dir = if let Some(path) = config_path {
|
||||
absolute_base_dir(cwd, &path)?
|
||||
} else {
|
||||
cwd.to_path_buf()
|
||||
};
|
||||
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
fn absolute_base_dir(cwd: &Path, path: &Path) -> Result<PathBuf> {
|
||||
let path = if path.is_absolute() {
|
||||
path.to_path_buf()
|
||||
} else {
|
||||
cwd.join(path)
|
||||
};
|
||||
Ok(path
|
||||
.parent()
|
||||
.map(Path::to_path_buf)
|
||||
.unwrap_or_else(|| cwd.to_path_buf()))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::fs;
|
||||
use std::path::{Path, PathBuf};
|
||||
|
||||
use tempfile::tempdir;
|
||||
|
||||
use super::{ReadOutputFormat, TableCellLayout, load_config_in};
|
||||
|
||||
#[test]
|
||||
fn load_config_reads_yaml_defaults_from_current_dir() {
|
||||
let temp = tempdir().unwrap();
|
||||
fs::write(
|
||||
temp.path().join("omnigraph.yaml"),
|
||||
r#"
|
||||
targets:
|
||||
local:
|
||||
uri: ./demo.omni
|
||||
bearer_token_env: DEMO_TOKEN
|
||||
auth:
|
||||
env_file: .env.omni
|
||||
cli:
|
||||
target: local
|
||||
branch: main
|
||||
output_format: kv
|
||||
table_max_column_width: 40
|
||||
table_cell_layout: wrap
|
||||
policy: {}
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = load_config_in(temp.path(), None).unwrap();
|
||||
assert_eq!(config.cli_target_name(), Some("local"));
|
||||
assert_eq!(config.cli_branch(), "main");
|
||||
assert_eq!(config.cli_output_format(), ReadOutputFormat::Kv);
|
||||
assert_eq!(config.table_max_column_width(), 40);
|
||||
assert_eq!(config.table_cell_layout(), TableCellLayout::Wrap);
|
||||
assert_eq!(
|
||||
config.target_bearer_token_env(None, None, config.cli_target_name()),
|
||||
Some("DEMO_TOKEN")
|
||||
);
|
||||
assert_eq!(
|
||||
config.resolve_auth_env_file().unwrap(),
|
||||
temp.path().join(".env.omni")
|
||||
);
|
||||
assert_eq!(
|
||||
PathBuf::from(
|
||||
config
|
||||
.resolve_target_uri(None, None, config.cli_target_name())
|
||||
.unwrap()
|
||||
),
|
||||
temp.path().join("./demo.omni")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_config_does_not_walk_parent_directories() {
|
||||
let temp = tempdir().unwrap();
|
||||
let child = temp.path().join("child");
|
||||
fs::create_dir_all(&child).unwrap();
|
||||
fs::write(
|
||||
temp.path().join("omnigraph.yaml"),
|
||||
"targets:\n local:\n uri: ./demo.omni\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = load_config_in(&child, None).unwrap();
|
||||
assert!(config.targets.is_empty());
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_query_path_searches_config_roots() {
|
||||
let temp = tempdir().unwrap();
|
||||
fs::create_dir_all(temp.path().join("queries")).unwrap();
|
||||
fs::write(
|
||||
temp.path().join("omnigraph.yaml"),
|
||||
"query:\n roots:\n - queries\npolicy: {}\n",
|
||||
)
|
||||
.unwrap();
|
||||
fs::write(
|
||||
temp.path().join("queries").join("test.gq"),
|
||||
"query q { return {} }",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = load_config_in(temp.path(), None).unwrap();
|
||||
let resolved = config.resolve_query_path(Path::new("test.gq")).unwrap();
|
||||
assert_eq!(resolved, temp.path().join("queries").join("test.gq"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn resolve_query_path_prefers_config_base_dir_over_ambient_cwd() {
|
||||
let workspace = tempdir().unwrap();
|
||||
let config_dir = workspace.path().join("config");
|
||||
let ambient_dir = workspace.path().join("ambient");
|
||||
fs::create_dir_all(&config_dir).unwrap();
|
||||
fs::create_dir_all(&ambient_dir).unwrap();
|
||||
fs::write(config_dir.join("omnigraph.yaml"), "policy: {}\n").unwrap();
|
||||
fs::write(config_dir.join("local.gq"), "query local { return {} }").unwrap();
|
||||
fs::write(ambient_dir.join("local.gq"), "query ambient { return {} }").unwrap();
|
||||
|
||||
let config =
|
||||
load_config_in(&ambient_dir, Some(&config_dir.join("omnigraph.yaml"))).unwrap();
|
||||
let resolved = config.resolve_query_path(Path::new("local.gq")).unwrap();
|
||||
|
||||
assert_eq!(resolved, config_dir.join("local.gq"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_block_accepts_non_empty_mapping() {
|
||||
let temp = tempdir().unwrap();
|
||||
fs::write(
|
||||
temp.path().join("omnigraph.yaml"),
|
||||
"policy:\n file: ./policy.yaml\n",
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = load_config_in(temp.path(), None).unwrap();
|
||||
assert_eq!(
|
||||
config.resolve_policy_file().unwrap(),
|
||||
temp.path().join("policy.yaml")
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn scoped_auth_env_ignores_default_target_when_uri_is_explicit() {
|
||||
let temp = tempdir().unwrap();
|
||||
fs::write(
|
||||
temp.path().join("omnigraph.yaml"),
|
||||
r#"
|
||||
targets:
|
||||
demo:
|
||||
uri: https://example.com
|
||||
bearer_token_env: DEMO_TOKEN
|
||||
cli:
|
||||
target: demo
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let config = load_config_in(temp.path(), None).unwrap();
|
||||
assert_eq!(
|
||||
config.target_bearer_token_env(
|
||||
Some("https://override.example.com"),
|
||||
None,
|
||||
config.cli_target_name()
|
||||
),
|
||||
None
|
||||
);
|
||||
assert_eq!(
|
||||
config.target_bearer_token_env(
|
||||
Some("https://override.example.com"),
|
||||
Some("demo"),
|
||||
config.cli_target_name()
|
||||
),
|
||||
Some("DEMO_TOKEN")
|
||||
);
|
||||
}
|
||||
}
|
||||
1257
crates/omnigraph-server/src/lib.rs
Normal file
1257
crates/omnigraph-server/src/lib.rs
Normal file
File diff suppressed because it is too large
Load diff
30
crates/omnigraph-server/src/main.rs
Normal file
30
crates/omnigraph-server/src/main.rs
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
use std::path::PathBuf;
|
||||
|
||||
use clap::Parser;
|
||||
use color_eyre::eyre::Result;
|
||||
use omnigraph_server::{ServerConfig, init_tracing, load_server_settings, serve};
|
||||
|
||||
#[derive(Debug, Parser)]
|
||||
#[command(name = "omnigraph-server")]
|
||||
#[command(about = "HTTP server for the Omnigraph graph database")]
|
||||
struct Cli {
|
||||
/// Repo URI
|
||||
uri: Option<String>,
|
||||
#[arg(long)]
|
||||
target: Option<String>,
|
||||
#[arg(long)]
|
||||
config: Option<PathBuf>,
|
||||
#[arg(long)]
|
||||
bind: Option<String>,
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
async fn main() -> Result<()> {
|
||||
color_eyre::install()?;
|
||||
init_tracing();
|
||||
|
||||
let cli = Cli::parse();
|
||||
let settings: ServerConfig =
|
||||
load_server_settings(cli.config.as_ref(), cli.uri, cli.target, cli.bind)?;
|
||||
serve(settings).await
|
||||
}
|
||||
812
crates/omnigraph-server/src/policy.rs
Normal file
812
crates/omnigraph-server/src/policy.rs
Normal file
|
|
@ -0,0 +1,812 @@
|
|||
use std::collections::{BTreeMap, BTreeSet, HashMap, HashSet};
|
||||
use std::fmt;
|
||||
use std::fs;
|
||||
use std::path::Path;
|
||||
use std::str::FromStr;
|
||||
|
||||
use cedar_policy::{
|
||||
Authorizer, Context, Decision, Entities, Entity, EntityId, EntityTypeName, EntityUid, Policy,
|
||||
PolicyId, PolicySet, Request, Schema, ValidationMode, Validator,
|
||||
};
|
||||
use clap::ValueEnum;
|
||||
use color_eyre::eyre::{Result, bail, eyre};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize, ValueEnum)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PolicyAction {
|
||||
Read,
|
||||
Export,
|
||||
Change,
|
||||
BranchCreate,
|
||||
BranchDelete,
|
||||
BranchMerge,
|
||||
RunPublish,
|
||||
RunAbort,
|
||||
Admin,
|
||||
}
|
||||
|
||||
impl PolicyAction {
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
Self::Read => "read",
|
||||
Self::Export => "export",
|
||||
Self::Change => "change",
|
||||
Self::BranchCreate => "branch_create",
|
||||
Self::BranchDelete => "branch_delete",
|
||||
Self::BranchMerge => "branch_merge",
|
||||
Self::RunPublish => "run_publish",
|
||||
Self::RunAbort => "run_abort",
|
||||
Self::Admin => "admin",
|
||||
}
|
||||
}
|
||||
|
||||
fn uses_branch_scope(self) -> bool {
|
||||
matches!(self, Self::Read | Self::Export | Self::Change)
|
||||
}
|
||||
|
||||
fn uses_target_branch_scope(self) -> bool {
|
||||
matches!(
|
||||
self,
|
||||
Self::BranchCreate
|
||||
| Self::BranchDelete
|
||||
| Self::BranchMerge
|
||||
| Self::RunPublish
|
||||
| Self::RunAbort
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for PolicyAction {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
f.write_str(self.as_str())
|
||||
}
|
||||
}
|
||||
|
||||
impl FromStr for PolicyAction {
|
||||
type Err = color_eyre::eyre::Error;
|
||||
|
||||
fn from_str(value: &str) -> Result<Self> {
|
||||
match value.trim() {
|
||||
"read" => Ok(Self::Read),
|
||||
"export" => Ok(Self::Export),
|
||||
"change" => Ok(Self::Change),
|
||||
"branch_create" => Ok(Self::BranchCreate),
|
||||
"branch_delete" => Ok(Self::BranchDelete),
|
||||
"branch_merge" => Ok(Self::BranchMerge),
|
||||
"run_publish" => Ok(Self::RunPublish),
|
||||
"run_abort" => Ok(Self::RunAbort),
|
||||
"admin" => Ok(Self::Admin),
|
||||
other => bail!("unknown policy action '{other}'"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PolicyBranchScope {
|
||||
Any,
|
||||
Protected,
|
||||
Unprotected,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyActorSelector {
|
||||
pub group: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyAllowRule {
|
||||
pub actors: PolicyActorSelector,
|
||||
pub actions: Vec<PolicyAction>,
|
||||
pub branch_scope: Option<PolicyBranchScope>,
|
||||
pub target_branch_scope: Option<PolicyBranchScope>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyRule {
|
||||
pub id: String,
|
||||
pub allow: PolicyAllowRule,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyConfig {
|
||||
pub version: u32,
|
||||
#[serde(default)]
|
||||
pub groups: BTreeMap<String, Vec<String>>,
|
||||
#[serde(default)]
|
||||
pub protected_branches: Vec<String>,
|
||||
#[serde(default)]
|
||||
pub rules: Vec<PolicyRule>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyTestConfig {
|
||||
pub version: u32,
|
||||
#[serde(default)]
|
||||
pub cases: Vec<PolicyTestCase>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||
pub struct PolicyTestCase {
|
||||
pub id: String,
|
||||
pub actor: String,
|
||||
pub action: PolicyAction,
|
||||
pub branch: Option<String>,
|
||||
pub target_branch: Option<String>,
|
||||
pub expect: PolicyExpectation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum PolicyExpectation {
|
||||
Allow,
|
||||
Deny,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PolicyRequest {
|
||||
pub actor_id: String,
|
||||
pub action: PolicyAction,
|
||||
pub branch: Option<String>,
|
||||
pub target_branch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct PolicyDecision {
|
||||
pub allowed: bool,
|
||||
pub matched_rule_id: Option<String>,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
pub struct PolicyCompiler;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct PolicyEngine {
|
||||
repo_id: String,
|
||||
protected_branches: BTreeSet<String>,
|
||||
known_actors: BTreeSet<String>,
|
||||
schema: Schema,
|
||||
entities: Entities,
|
||||
policies: PolicySet,
|
||||
policy_to_rule: HashMap<String, String>,
|
||||
}
|
||||
|
||||
impl PolicyConfig {
|
||||
pub fn load(path: &Path) -> Result<Self> {
|
||||
let config: Self = serde_yaml::from_str(&fs::read_to_string(path)?)?;
|
||||
config.validate()?;
|
||||
Ok(config)
|
||||
}
|
||||
|
||||
pub fn validate(&self) -> Result<()> {
|
||||
if self.version != 1 {
|
||||
bail!("policy version must be 1");
|
||||
}
|
||||
|
||||
for (group, members) in &self.groups {
|
||||
if group.trim().is_empty() {
|
||||
bail!("policy group names must not be blank");
|
||||
}
|
||||
if members.is_empty() {
|
||||
bail!("policy group '{group}' must not be empty");
|
||||
}
|
||||
for actor in members {
|
||||
if actor.trim().is_empty() {
|
||||
bail!("policy group '{group}' contains a blank actor id");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for branch in &self.protected_branches {
|
||||
if branch.trim().is_empty() {
|
||||
bail!("protected branch names must not be blank");
|
||||
}
|
||||
}
|
||||
|
||||
let mut seen_rule_ids = HashSet::new();
|
||||
for rule in &self.rules {
|
||||
if rule.id.trim().is_empty() {
|
||||
bail!("policy rule ids must not be blank");
|
||||
}
|
||||
if !seen_rule_ids.insert(rule.id.clone()) {
|
||||
bail!("duplicate policy rule id '{}'", rule.id);
|
||||
}
|
||||
if rule.allow.actors.group.trim().is_empty() {
|
||||
bail!("policy rule '{}' must reference a non-blank group", rule.id);
|
||||
}
|
||||
if !self.groups.contains_key(rule.allow.actors.group.as_str()) {
|
||||
bail!(
|
||||
"policy rule '{}' references unknown group '{}'",
|
||||
rule.id,
|
||||
rule.allow.actors.group
|
||||
);
|
||||
}
|
||||
if rule.allow.actions.is_empty() {
|
||||
bail!("policy rule '{}' must include at least one action", rule.id);
|
||||
}
|
||||
if rule.allow.branch_scope.is_some() && rule.allow.target_branch_scope.is_some() {
|
||||
bail!(
|
||||
"policy rule '{}' may specify branch_scope or target_branch_scope, not both",
|
||||
rule.id
|
||||
);
|
||||
}
|
||||
if let Some(_) = rule.allow.branch_scope {
|
||||
for action in &rule.allow.actions {
|
||||
if !action.uses_branch_scope() {
|
||||
bail!(
|
||||
"policy rule '{}' uses branch_scope with unsupported action '{}'",
|
||||
rule.id,
|
||||
action
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
if let Some(_) = rule.allow.target_branch_scope {
|
||||
for action in &rule.allow.actions {
|
||||
if !action.uses_target_branch_scope() {
|
||||
bail!(
|
||||
"policy rule '{}' uses target_branch_scope with unsupported action '{}'",
|
||||
rule.id,
|
||||
action
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyTestConfig {
|
||||
pub fn load(path: &Path) -> Result<Self> {
|
||||
let config: Self = serde_yaml::from_str(&fs::read_to_string(path)?)?;
|
||||
if config.version != 1 {
|
||||
bail!("policy test version must be 1");
|
||||
}
|
||||
let mut seen = HashSet::new();
|
||||
for case in &config.cases {
|
||||
if case.id.trim().is_empty() {
|
||||
bail!("policy test case ids must not be blank");
|
||||
}
|
||||
if !seen.insert(case.id.clone()) {
|
||||
bail!("duplicate policy test case id '{}'", case.id);
|
||||
}
|
||||
if case.actor.trim().is_empty() {
|
||||
bail!("policy test case '{}' must not use a blank actor", case.id);
|
||||
}
|
||||
}
|
||||
Ok(config)
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyCompiler {
|
||||
pub fn compile(config: &PolicyConfig, repo_id: &str) -> Result<PolicyEngine> {
|
||||
config.validate()?;
|
||||
let (schema, schema_warnings) = Schema::from_cedarschema_str(policy_schema_source())?;
|
||||
let schema_warnings = schema_warnings
|
||||
.map(|warning| warning.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
if !schema_warnings.is_empty() {
|
||||
bail!("policy schema warnings:\n{}", schema_warnings.join("\n"));
|
||||
}
|
||||
let entities = compile_entities(config, repo_id, &schema)?;
|
||||
let (policies, policy_to_rule) = compile_policies(config, repo_id)?;
|
||||
let validator = Validator::new(schema.clone());
|
||||
let validation = validator.validate(&policies, ValidationMode::Strict);
|
||||
let errors = validation
|
||||
.validation_errors()
|
||||
.map(|err| err.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
if !errors.is_empty() {
|
||||
bail!("policy validation failed:\n{}", errors.join("\n"));
|
||||
}
|
||||
|
||||
let known_actors = config
|
||||
.groups
|
||||
.values()
|
||||
.flat_map(|members| members.iter().cloned())
|
||||
.collect();
|
||||
Ok(PolicyEngine {
|
||||
repo_id: repo_id.to_string(),
|
||||
protected_branches: config.protected_branches.iter().cloned().collect(),
|
||||
known_actors,
|
||||
schema,
|
||||
entities,
|
||||
policies,
|
||||
policy_to_rule,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl PolicyEngine {
|
||||
pub fn load(path: &Path, repo_id: &str) -> Result<Self> {
|
||||
let config = PolicyConfig::load(path)?;
|
||||
PolicyCompiler::compile(&config, repo_id)
|
||||
}
|
||||
|
||||
pub fn authorize(&self, request: &PolicyRequest) -> Result<PolicyDecision> {
|
||||
if !self.known_actors.contains(request.actor_id.as_str()) {
|
||||
return Ok(self.deny(
|
||||
request,
|
||||
None,
|
||||
format!(
|
||||
"policy denied action '{}' for unknown actor '{}'",
|
||||
request.action, request.actor_id
|
||||
),
|
||||
));
|
||||
}
|
||||
|
||||
let principal = entity_uid("Actor", &request.actor_id)?;
|
||||
let action = entity_uid("Action", request.action.as_str())?;
|
||||
let resource = entity_uid("Repo", &self.repo_id)?;
|
||||
let context_value = json!({
|
||||
"has_branch": request.branch.is_some(),
|
||||
"branch": request.branch.clone().unwrap_or_default(),
|
||||
"has_target_branch": request.target_branch.is_some(),
|
||||
"target_branch": request.target_branch.clone().unwrap_or_default(),
|
||||
"branch_is_protected": request.branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)),
|
||||
"target_branch_is_protected": request.target_branch.as_ref().is_some_and(|branch| self.protected_branches.contains(branch)),
|
||||
});
|
||||
let context = Context::from_json_value(context_value, Some((&self.schema, &action)))?;
|
||||
let cedar_request = Request::new(principal, action, resource, context, Some(&self.schema))?;
|
||||
let response =
|
||||
Authorizer::new().is_authorized(&cedar_request, &self.policies, &self.entities);
|
||||
let errors = response
|
||||
.diagnostics()
|
||||
.errors()
|
||||
.map(|err| err.to_string())
|
||||
.collect::<Vec<_>>();
|
||||
if !errors.is_empty() {
|
||||
bail!("policy evaluation failed:\n{}", errors.join("\n"));
|
||||
}
|
||||
|
||||
let matched_rule_id = response
|
||||
.diagnostics()
|
||||
.reason()
|
||||
.filter_map(|policy_id| {
|
||||
let key: &str = policy_id.as_ref();
|
||||
self.policy_to_rule.get(key).cloned()
|
||||
})
|
||||
.min();
|
||||
|
||||
Ok(match response.decision() {
|
||||
Decision::Allow => PolicyDecision {
|
||||
allowed: true,
|
||||
matched_rule_id: matched_rule_id.clone(),
|
||||
message: format!(
|
||||
"policy allowed action '{}' for actor '{}'",
|
||||
request.action, request.actor_id
|
||||
),
|
||||
},
|
||||
Decision::Deny => {
|
||||
let message = format!(
|
||||
"policy denied action '{}'{}{} for actor '{}'",
|
||||
request.action,
|
||||
request
|
||||
.branch
|
||||
.as_deref()
|
||||
.map(|branch| format!(" on branch '{}'", branch))
|
||||
.unwrap_or_default(),
|
||||
request
|
||||
.target_branch
|
||||
.as_deref()
|
||||
.map(|branch| format!(" targeting branch '{}'", branch))
|
||||
.unwrap_or_default(),
|
||||
request.actor_id
|
||||
);
|
||||
self.deny(request, matched_rule_id, message)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
pub fn validate_request(&self, request: &PolicyRequest) -> Result<()> {
|
||||
let _ = self.authorize(request)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn run_tests(&self, tests: &PolicyTestConfig) -> Result<()> {
|
||||
if tests.version != 1 {
|
||||
bail!("policy test version must be 1");
|
||||
}
|
||||
let mut failures = Vec::new();
|
||||
for case in &tests.cases {
|
||||
let decision = self.authorize(&PolicyRequest {
|
||||
actor_id: case.actor.clone(),
|
||||
action: case.action,
|
||||
branch: case.branch.clone(),
|
||||
target_branch: case.target_branch.clone(),
|
||||
})?;
|
||||
let expected_allowed = matches!(case.expect, PolicyExpectation::Allow);
|
||||
if decision.allowed != expected_allowed {
|
||||
failures.push(format!(
|
||||
"{}: expected {:?} but got {}",
|
||||
case.id,
|
||||
case.expect,
|
||||
if decision.allowed { "allow" } else { "deny" }
|
||||
));
|
||||
}
|
||||
}
|
||||
if failures.is_empty() {
|
||||
Ok(())
|
||||
} else {
|
||||
bail!("policy tests failed:\n{}", failures.join("\n"))
|
||||
}
|
||||
}
|
||||
|
||||
pub fn known_actor_count(&self) -> usize {
|
||||
self.known_actors.len()
|
||||
}
|
||||
|
||||
fn deny(
|
||||
&self,
|
||||
_request: &PolicyRequest,
|
||||
matched_rule_id: Option<String>,
|
||||
message: String,
|
||||
) -> PolicyDecision {
|
||||
PolicyDecision {
|
||||
allowed: false,
|
||||
matched_rule_id,
|
||||
message,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn compile_entities(config: &PolicyConfig, repo_id: &str, schema: &Schema) -> Result<Entities> {
|
||||
let mut group_entities = Vec::new();
|
||||
for group in config.groups.keys() {
|
||||
group_entities.push(Entity::new(
|
||||
entity_uid("Group", group)?,
|
||||
HashMap::new(),
|
||||
HashSet::<EntityUid>::new(),
|
||||
)?);
|
||||
}
|
||||
|
||||
let mut actor_groups: BTreeMap<String, BTreeSet<String>> = BTreeMap::new();
|
||||
for (group, members) in &config.groups {
|
||||
for actor in members {
|
||||
actor_groups
|
||||
.entry(actor.clone())
|
||||
.or_default()
|
||||
.insert(group.clone());
|
||||
}
|
||||
}
|
||||
|
||||
let mut actor_entities = Vec::new();
|
||||
for (actor, groups) in actor_groups {
|
||||
let parents = groups
|
||||
.iter()
|
||||
.map(|group| entity_uid("Group", group))
|
||||
.collect::<Result<HashSet<_>>>()?;
|
||||
actor_entities.push(Entity::new(
|
||||
entity_uid("Actor", &actor)?,
|
||||
HashMap::new(),
|
||||
parents,
|
||||
)?);
|
||||
}
|
||||
|
||||
let repo_entity = Entity::new(
|
||||
entity_uid("Repo", repo_id)?,
|
||||
HashMap::new(),
|
||||
HashSet::<EntityUid>::new(),
|
||||
)?;
|
||||
|
||||
let mut entities = Vec::new();
|
||||
entities.extend(group_entities);
|
||||
entities.extend(actor_entities);
|
||||
entities.push(repo_entity);
|
||||
Ok(Entities::from_entities(entities, Some(schema))?)
|
||||
}
|
||||
|
||||
fn compile_policies(
|
||||
config: &PolicyConfig,
|
||||
repo_id: &str,
|
||||
) -> Result<(PolicySet, HashMap<String, String>)> {
|
||||
let mut policies = Vec::new();
|
||||
let mut policy_to_rule = HashMap::new();
|
||||
|
||||
for rule in &config.rules {
|
||||
for action in &rule.allow.actions {
|
||||
let policy_id = PolicyId::new(format!("{}:{}", rule.id, action.as_str()));
|
||||
let source = compile_policy_source(rule, action, repo_id);
|
||||
let policy = Policy::parse(Some(policy_id.clone()), source.as_str())?;
|
||||
policy_to_rule.insert(policy_id.to_string(), rule.id.clone());
|
||||
policies.push(policy);
|
||||
}
|
||||
}
|
||||
|
||||
Ok((PolicySet::from_policies(policies)?, policy_to_rule))
|
||||
}
|
||||
|
||||
fn compile_policy_source(rule: &PolicyRule, action: &PolicyAction, repo_id: &str) -> String {
|
||||
let mut conditions = Vec::new();
|
||||
if let Some(scope) = rule.allow.branch_scope {
|
||||
conditions.push(branch_scope_condition(scope));
|
||||
}
|
||||
if let Some(scope) = rule.allow.target_branch_scope {
|
||||
conditions.push(target_branch_scope_condition(scope));
|
||||
}
|
||||
|
||||
let when = if conditions.is_empty() {
|
||||
String::new()
|
||||
} else {
|
||||
format!("\nwhen {{ {} }}", conditions.join(" && "))
|
||||
};
|
||||
|
||||
format!(
|
||||
r#"permit (
|
||||
principal in Omnigraph::Group::{group},
|
||||
action == Omnigraph::Action::{action},
|
||||
resource == Omnigraph::Repo::{repo}
|
||||
){when};"#,
|
||||
group = cedar_literal(&rule.allow.actors.group),
|
||||
action = cedar_literal(action.as_str()),
|
||||
repo = cedar_literal(repo_id),
|
||||
when = when,
|
||||
)
|
||||
}
|
||||
|
||||
fn branch_scope_condition(scope: PolicyBranchScope) -> String {
|
||||
match scope {
|
||||
PolicyBranchScope::Any => "true".to_string(),
|
||||
PolicyBranchScope::Protected => {
|
||||
"context.has_branch && context.branch_is_protected".to_string()
|
||||
}
|
||||
PolicyBranchScope::Unprotected => {
|
||||
"context.has_branch && context.branch_is_protected == false".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn target_branch_scope_condition(scope: PolicyBranchScope) -> String {
|
||||
match scope {
|
||||
PolicyBranchScope::Any => "true".to_string(),
|
||||
PolicyBranchScope::Protected => {
|
||||
"context.has_target_branch && context.target_branch_is_protected".to_string()
|
||||
}
|
||||
PolicyBranchScope::Unprotected => {
|
||||
"context.has_target_branch && context.target_branch_is_protected == false".to_string()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn policy_schema_source() -> &'static str {
|
||||
r#"
|
||||
namespace Omnigraph {
|
||||
type RequestContext = {
|
||||
has_branch: Bool,
|
||||
branch: String,
|
||||
has_target_branch: Bool,
|
||||
target_branch: String,
|
||||
branch_is_protected: Bool,
|
||||
target_branch_is_protected: Bool,
|
||||
};
|
||||
|
||||
entity Actor in [Group];
|
||||
entity Group;
|
||||
entity Repo;
|
||||
|
||||
action "read" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "export" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "change" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "branch_create" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "branch_delete" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "branch_merge" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "run_publish" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "run_abort" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
action "admin" appliesTo { principal: Actor, resource: Repo, context: RequestContext };
|
||||
}
|
||||
"#
|
||||
}
|
||||
|
||||
fn entity_uid(entity_type: &str, id: &str) -> Result<EntityUid> {
|
||||
let typename = EntityTypeName::from_str(&format!("Omnigraph::{entity_type}"))?;
|
||||
let entity_id = EntityId::from_str(id).map_err(|err| eyre!(err.to_string()))?;
|
||||
Ok(EntityUid::from_type_name_and_id(typename, entity_id))
|
||||
}
|
||||
|
||||
fn cedar_literal(value: &str) -> String {
|
||||
serde_json::to_string(value).expect("string literal should serialize")
|
||||
}
|
||||
|
||||
impl PolicyRequest {
|
||||
pub fn actor_id(&self) -> &str {
|
||||
&self.actor_id
|
||||
}
|
||||
|
||||
pub fn action(&self) -> PolicyAction {
|
||||
self.action
|
||||
}
|
||||
|
||||
pub fn branch(&self) -> Option<&str> {
|
||||
self.branch.as_deref()
|
||||
}
|
||||
|
||||
pub fn target_branch(&self) -> Option<&str> {
|
||||
self.target_branch.as_deref()
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::{
|
||||
PolicyAction, PolicyCompiler, PolicyConfig, PolicyExpectation, PolicyRequest,
|
||||
PolicyTestCase, PolicyTestConfig,
|
||||
};
|
||||
|
||||
#[test]
|
||||
fn rejects_duplicate_rule_ids() {
|
||||
let policy: PolicyConfig = serde_yaml::from_str(
|
||||
r#"
|
||||
version: 1
|
||||
groups:
|
||||
team: [act-andrew]
|
||||
rules:
|
||||
- id: same
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [read]
|
||||
branch_scope: any
|
||||
- id: same
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [export]
|
||||
branch_scope: any
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = policy.validate().unwrap_err();
|
||||
assert!(err.to_string().contains("duplicate policy rule id"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_unknown_group_references() {
|
||||
let policy: PolicyConfig = serde_yaml::from_str(
|
||||
r#"
|
||||
version: 1
|
||||
groups:
|
||||
team: [act-andrew]
|
||||
rules:
|
||||
- id: bad
|
||||
allow:
|
||||
actors: { group: admins }
|
||||
actions: [read]
|
||||
branch_scope: any
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = policy.validate().unwrap_err();
|
||||
assert!(err.to_string().contains("references unknown group"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn rejects_invalid_scope_action_combinations() {
|
||||
let policy: PolicyConfig = serde_yaml::from_str(
|
||||
r#"
|
||||
version: 1
|
||||
groups:
|
||||
team: [act-andrew]
|
||||
rules:
|
||||
- id: bad
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [branch_merge]
|
||||
branch_scope: protected
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = policy.validate().unwrap_err();
|
||||
assert!(err.to_string().contains("unsupported action"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn compiles_and_authorizes_branch_and_target_rules() {
|
||||
let policy: PolicyConfig = serde_yaml::from_str(
|
||||
r#"
|
||||
version: 1
|
||||
groups:
|
||||
team: [act-andrew, act-bruno]
|
||||
admins: [act-andrew]
|
||||
protected_branches: [main]
|
||||
rules:
|
||||
- id: team-read
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [read, export]
|
||||
branch_scope: any
|
||||
- id: team-write
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [change]
|
||||
branch_scope: unprotected
|
||||
- id: admins-promote
|
||||
allow:
|
||||
actors: { group: admins }
|
||||
actions: [branch_delete, branch_merge, run_publish]
|
||||
target_branch_scope: protected
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let engine = PolicyCompiler::compile(&policy, "repo").unwrap();
|
||||
let allow = engine
|
||||
.authorize(&PolicyRequest {
|
||||
actor_id: "act-bruno".to_string(),
|
||||
action: PolicyAction::Change,
|
||||
branch: Some("feature".to_string()),
|
||||
target_branch: None,
|
||||
})
|
||||
.unwrap();
|
||||
assert!(allow.allowed);
|
||||
assert_eq!(allow.matched_rule_id.as_deref(), Some("team-write"));
|
||||
|
||||
let deny = engine
|
||||
.authorize(&PolicyRequest {
|
||||
actor_id: "act-bruno".to_string(),
|
||||
action: PolicyAction::BranchDelete,
|
||||
branch: None,
|
||||
target_branch: Some("main".to_string()),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(!deny.allowed);
|
||||
|
||||
let admin = engine
|
||||
.authorize(&PolicyRequest {
|
||||
actor_id: "act-andrew".to_string(),
|
||||
action: PolicyAction::BranchDelete,
|
||||
branch: None,
|
||||
target_branch: Some("main".to_string()),
|
||||
})
|
||||
.unwrap();
|
||||
assert!(admin.allowed);
|
||||
assert_eq!(admin.matched_rule_id.as_deref(), Some("admins-promote"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn policy_tests_enforce_expected_outcomes() {
|
||||
let policy: PolicyConfig = serde_yaml::from_str(
|
||||
r#"
|
||||
version: 1
|
||||
groups:
|
||||
team: [act-andrew]
|
||||
protected_branches: [main]
|
||||
rules:
|
||||
- id: team-read
|
||||
allow:
|
||||
actors: { group: team }
|
||||
actions: [read]
|
||||
branch_scope: any
|
||||
"#,
|
||||
)
|
||||
.unwrap();
|
||||
let engine = PolicyCompiler::compile(&policy, "repo").unwrap();
|
||||
let tests = PolicyTestConfig {
|
||||
version: 1,
|
||||
cases: vec![
|
||||
PolicyTestCase {
|
||||
id: "allow-read".to_string(),
|
||||
actor: "act-andrew".to_string(),
|
||||
action: PolicyAction::Read,
|
||||
branch: Some("main".to_string()),
|
||||
target_branch: None,
|
||||
expect: PolicyExpectation::Allow,
|
||||
},
|
||||
PolicyTestCase {
|
||||
id: "deny-change".to_string(),
|
||||
actor: "act-andrew".to_string(),
|
||||
action: PolicyAction::Change,
|
||||
branch: Some("main".to_string()),
|
||||
target_branch: None,
|
||||
expect: PolicyExpectation::Deny,
|
||||
},
|
||||
],
|
||||
};
|
||||
|
||||
engine.run_tests(&tests).unwrap();
|
||||
}
|
||||
}
|
||||
1773
crates/omnigraph-server/tests/server.rs
Normal file
1773
crates/omnigraph-server/tests/server.rs
Normal file
File diff suppressed because it is too large
Load diff
47
crates/omnigraph/Cargo.toml
Normal file
47
crates/omnigraph/Cargo.toml
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
[package]
|
||||
name = "omnigraph"
|
||||
version = "0.4.0"
|
||||
edition = "2024"
|
||||
description = "Lance-native graph database with git-style branching."
|
||||
license = "MIT"
|
||||
|
||||
[features]
|
||||
default = []
|
||||
failpoints = ["dep:fail", "fail/failpoints"]
|
||||
|
||||
[dependencies]
|
||||
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
|
||||
lance = { workspace = true }
|
||||
lance-datafusion = { workspace = true }
|
||||
lance-file = { workspace = true }
|
||||
lance-index = { workspace = true }
|
||||
lance-linalg = { workspace = true }
|
||||
lance-namespace = { workspace = true }
|
||||
lance-table = { workspace = true }
|
||||
arrow-array = { workspace = true }
|
||||
arrow-schema = { workspace = true }
|
||||
arrow-ord = { workspace = true }
|
||||
arrow-select = { workspace = true }
|
||||
arrow-cast = { workspace = true }
|
||||
tokio = { workspace = true }
|
||||
serde = { workspace = true }
|
||||
serde_json = { workspace = true }
|
||||
reqwest = { workspace = true }
|
||||
object_store = { workspace = true }
|
||||
ulid = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
futures = { workspace = true }
|
||||
tracing = { workspace = true }
|
||||
thiserror = { workspace = true }
|
||||
regex = { workspace = true }
|
||||
tempfile = { workspace = true }
|
||||
fail = { workspace = true, optional = true }
|
||||
time = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
url = { workspace = true }
|
||||
|
||||
[dev-dependencies]
|
||||
omnigraph-compiler = { path = "../omnigraph-compiler", version = "0.4.0" }
|
||||
tokio = { workspace = true }
|
||||
lance-namespace-impls = { workspace = true }
|
||||
serial_test = "3"
|
||||
598
crates/omnigraph/src/changes/mod.rs
Normal file
598
crates/omnigraph/src/changes/mod.rs
Normal file
|
|
@ -0,0 +1,598 @@
|
|||
use std::collections::HashSet;
|
||||
|
||||
use arrow_array::{Array, RecordBatch, StringArray, UInt64Array};
|
||||
use arrow_cast::display::array_value_to_string;
|
||||
use lance::dataset::scanner::ColumnOrdering;
|
||||
|
||||
use crate::db::SubTableEntry;
|
||||
use crate::db::manifest::Snapshot;
|
||||
use crate::error::Result;
|
||||
use crate::table_store::TableStore;
|
||||
|
||||
// ─── Types ──────────────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum EntityKind {
|
||||
Node,
|
||||
Edge,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ChangeOp {
|
||||
Insert,
|
||||
Update,
|
||||
Delete,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct Endpoints {
|
||||
pub src: String,
|
||||
pub dst: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct EntityChange {
|
||||
pub table_key: String,
|
||||
pub kind: EntityKind,
|
||||
pub type_name: String,
|
||||
pub id: String,
|
||||
pub op: ChangeOp,
|
||||
pub manifest_version: u64,
|
||||
pub endpoints: Option<Endpoints>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChangeFilter {
|
||||
pub kinds: Option<Vec<EntityKind>>,
|
||||
pub type_names: Option<Vec<String>>,
|
||||
pub ops: Option<Vec<ChangeOp>>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct ChangeStats {
|
||||
pub inserts: usize,
|
||||
pub updates: usize,
|
||||
pub deletes: usize,
|
||||
pub types_affected: Vec<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ChangeSet {
|
||||
pub from_version: u64,
|
||||
pub to_version: u64,
|
||||
pub branch: Option<String>,
|
||||
pub changes: Vec<EntityChange>,
|
||||
pub stats: ChangeStats,
|
||||
}
|
||||
|
||||
// ─── Filter helpers ─────────────────────────────────────────────────────────
|
||||
|
||||
fn parse_table_key(table_key: &str) -> (EntityKind, &str) {
|
||||
if let Some(name) = table_key.strip_prefix("node:") {
|
||||
(EntityKind::Node, name)
|
||||
} else if let Some(name) = table_key.strip_prefix("edge:") {
|
||||
(EntityKind::Edge, name)
|
||||
} else {
|
||||
(EntityKind::Node, table_key)
|
||||
}
|
||||
}
|
||||
|
||||
impl ChangeFilter {
|
||||
fn matches_table(&self, table_key: &str) -> bool {
|
||||
let (kind, type_name) = parse_table_key(table_key);
|
||||
if let Some(ref kinds) = self.kinds {
|
||||
if !kinds.contains(&kind) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if let Some(ref names) = self.type_names {
|
||||
if !names.iter().any(|n| n == type_name) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
|
||||
fn wants_op(&self, op: ChangeOp) -> bool {
|
||||
match &self.ops {
|
||||
Some(ops) => ops.contains(&op),
|
||||
None => true,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Core diff ──────────────────────────────────────────────────────────────
|
||||
|
||||
/// Net-current diff between two snapshots.
|
||||
///
|
||||
/// Uses a three-level algorithm:
|
||||
/// 1. Manifest diff — skip unchanged sub-tables
|
||||
/// 2. Lineage check — same branch → version-column diff; different → ID-based diff
|
||||
/// 3. Row-level diff
|
||||
pub async fn diff_snapshots(
|
||||
root_uri: &str,
|
||||
from: &Snapshot,
|
||||
to: &Snapshot,
|
||||
filter: &ChangeFilter,
|
||||
branch: Option<String>,
|
||||
) -> Result<ChangeSet> {
|
||||
let table_store = TableStore::new(root_uri);
|
||||
let mut all_keys: HashSet<String> = HashSet::new();
|
||||
for entry in from.entries() {
|
||||
all_keys.insert(entry.table_key.clone());
|
||||
}
|
||||
for entry in to.entries() {
|
||||
all_keys.insert(entry.table_key.clone());
|
||||
}
|
||||
|
||||
let mut changes = Vec::new();
|
||||
|
||||
for table_key in &all_keys {
|
||||
if !filter.matches_table(table_key) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let from_entry = from.entry(table_key);
|
||||
let to_entry = to.entry(table_key);
|
||||
|
||||
// Skip if both snapshots have identical state for this table
|
||||
if same_state(from_entry, to_entry) {
|
||||
continue;
|
||||
}
|
||||
|
||||
let (kind, type_name) = parse_table_key(table_key);
|
||||
let is_edge = kind == EntityKind::Edge;
|
||||
|
||||
let table_changes = if from_entry.is_none() {
|
||||
// Table added — all rows are inserts
|
||||
diff_table_added(&table_store, to, table_key, is_edge, filter).await?
|
||||
} else if to_entry.is_none() {
|
||||
// Table removed — all rows are deletes
|
||||
diff_table_removed(&table_store, from, table_key, is_edge, filter).await?
|
||||
} else if same_lineage(from_entry, to_entry) {
|
||||
// Fast path: version-column diff
|
||||
diff_table_same_lineage(
|
||||
&table_store,
|
||||
from_entry.unwrap(),
|
||||
to_entry.unwrap(),
|
||||
is_edge,
|
||||
filter,
|
||||
)
|
||||
.await?
|
||||
} else {
|
||||
// Cross-branch path: streaming ID-based diff
|
||||
diff_table_cross_branch(&table_store, from, to, table_key, is_edge, filter).await?
|
||||
};
|
||||
|
||||
for mut c in table_changes {
|
||||
c.table_key = table_key.clone();
|
||||
c.kind = kind;
|
||||
c.type_name = type_name.to_string();
|
||||
if c.manifest_version == 0 {
|
||||
c.manifest_version = to.version();
|
||||
}
|
||||
changes.push(c);
|
||||
}
|
||||
}
|
||||
|
||||
let stats = compute_stats(&changes);
|
||||
Ok(ChangeSet {
|
||||
from_version: from.version(),
|
||||
to_version: to.version(),
|
||||
branch,
|
||||
changes,
|
||||
stats,
|
||||
})
|
||||
}
|
||||
|
||||
fn same_state(a: Option<&SubTableEntry>, b: Option<&SubTableEntry>) -> bool {
|
||||
match (a, b) {
|
||||
(None, None) => true,
|
||||
(Some(a), Some(b)) => {
|
||||
a.table_version == b.table_version && a.table_branch == b.table_branch
|
||||
}
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn same_lineage(from: Option<&SubTableEntry>, to: Option<&SubTableEntry>) -> bool {
|
||||
match (from, to) {
|
||||
(Some(f), Some(t)) => f.table_branch == t.table_branch,
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn compute_stats(changes: &[EntityChange]) -> ChangeStats {
|
||||
let mut stats = ChangeStats::default();
|
||||
let mut types = HashSet::new();
|
||||
for c in changes {
|
||||
match c.op {
|
||||
ChangeOp::Insert => stats.inserts += 1,
|
||||
ChangeOp::Update => stats.updates += 1,
|
||||
ChangeOp::Delete => stats.deletes += 1,
|
||||
}
|
||||
types.insert(c.type_name.clone());
|
||||
}
|
||||
stats.types_affected = types.into_iter().collect();
|
||||
stats.types_affected.sort();
|
||||
stats
|
||||
}
|
||||
|
||||
// ─── Fast path: version-column diff ─────────────────────────────────────────
|
||||
|
||||
async fn diff_table_same_lineage(
|
||||
table_store: &TableStore,
|
||||
from_entry: &SubTableEntry,
|
||||
to_entry: &SubTableEntry,
|
||||
is_edge: bool,
|
||||
filter: &ChangeFilter,
|
||||
) -> Result<Vec<EntityChange>> {
|
||||
let vf = from_entry.table_version;
|
||||
let vt = to_entry.table_version;
|
||||
let to_ds = table_store.open_at_entry(to_entry).await?;
|
||||
|
||||
let cols: Vec<&str> = if is_edge {
|
||||
vec!["id", "src", "dst", "_row_last_updated_at_version"]
|
||||
} else {
|
||||
vec!["id", "_row_last_updated_at_version"]
|
||||
};
|
||||
|
||||
let wants_inserts = filter.wants_op(ChangeOp::Insert);
|
||||
let wants_updates = filter.wants_op(ChangeOp::Update);
|
||||
let wants_deletes = filter.wants_op(ChangeOp::Delete);
|
||||
|
||||
let mut changes = Vec::new();
|
||||
|
||||
// Inserts + Updates: use _row_last_updated_at_version to find all rows
|
||||
// touched since Vf, then classify by checking whether the ID existed at Vf.
|
||||
//
|
||||
// Why not _row_created_at_version for inserts: Lance's merge_insert stamps
|
||||
// new rows with _row_created_at_version = dataset_creation_version (v1),
|
||||
// not the merge_insert commit version. This makes _row_created_at_version
|
||||
// unreliable for detecting inserts from merge_insert writes. Using
|
||||
// _row_last_updated_at_version catches all touched rows regardless of
|
||||
// write mode, and ID-set membership distinguishes inserts from updates.
|
||||
if wants_inserts || wants_updates {
|
||||
let filter_sql = format!(
|
||||
"_row_last_updated_at_version > {} AND _row_last_updated_at_version <= {}",
|
||||
vf, vt
|
||||
);
|
||||
let changed_rows = scan_with_filter(table_store, &to_ds, &cols, &filter_sql).await?;
|
||||
|
||||
if !changed_rows.is_empty() {
|
||||
// Build the set of IDs that existed at the from version
|
||||
let from_ds = table_store.open_at_entry(from_entry).await?;
|
||||
let from_ids: HashSet<String> = scan_id_set(table_store, &from_ds, &["id"])
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|r| r.id)
|
||||
.collect();
|
||||
|
||||
for row in changed_rows {
|
||||
if from_ids.contains(&row.id) {
|
||||
if wants_updates {
|
||||
changes.push(entity_change_from_row(&row, ChangeOp::Update, is_edge));
|
||||
}
|
||||
} else if wants_inserts {
|
||||
changes.push(entity_change_from_row(&row, ChangeOp::Insert, is_edge));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Deletes: ID set-difference
|
||||
if wants_deletes {
|
||||
let from_ds = table_store.open_at_entry(from_entry).await?;
|
||||
let deleted = deleted_ids_by_set_diff(table_store, &from_ds, &to_ds, is_edge).await?;
|
||||
changes.extend(deleted);
|
||||
}
|
||||
|
||||
Ok(changes)
|
||||
}
|
||||
|
||||
// ─── Cross-branch path: streaming ID-based diff ────────────────────────────
|
||||
|
||||
async fn diff_table_cross_branch(
|
||||
table_store: &TableStore,
|
||||
from_snap: &Snapshot,
|
||||
to_snap: &Snapshot,
|
||||
table_key: &str,
|
||||
is_edge: bool,
|
||||
filter: &ChangeFilter,
|
||||
) -> Result<Vec<EntityChange>> {
|
||||
let from_ds = table_store
|
||||
.open_snapshot_table(from_snap, table_key)
|
||||
.await?;
|
||||
let to_ds = table_store.open_snapshot_table(to_snap, table_key).await?;
|
||||
|
||||
let from_rows = scan_all_rows_ordered(table_store, &from_ds, is_edge).await?;
|
||||
let to_rows = scan_all_rows_ordered(table_store, &to_ds, is_edge).await?;
|
||||
|
||||
let mut changes = Vec::new();
|
||||
let mut fi = 0;
|
||||
let mut ti = 0;
|
||||
|
||||
while fi < from_rows.len() || ti < to_rows.len() {
|
||||
let from_id = from_rows.get(fi).map(|r| r.id.as_str());
|
||||
let to_id = to_rows.get(ti).map(|r| r.id.as_str());
|
||||
|
||||
match (from_id, to_id) {
|
||||
(Some(fid), Some(tid)) if fid < tid => {
|
||||
// ID only in from → Delete
|
||||
if filter.wants_op(ChangeOp::Delete) {
|
||||
changes.push(entity_change_from_row(
|
||||
&from_rows[fi],
|
||||
ChangeOp::Delete,
|
||||
is_edge,
|
||||
));
|
||||
}
|
||||
fi += 1;
|
||||
}
|
||||
(Some(fid), Some(tid)) if fid > tid => {
|
||||
// ID only in to → Insert
|
||||
if filter.wants_op(ChangeOp::Insert) {
|
||||
changes.push(entity_change_from_row(
|
||||
&to_rows[ti],
|
||||
ChangeOp::Insert,
|
||||
is_edge,
|
||||
));
|
||||
}
|
||||
ti += 1;
|
||||
}
|
||||
(Some(_), Some(_)) => {
|
||||
// Same ID — check signature
|
||||
if from_rows[fi].signature != to_rows[ti].signature
|
||||
&& filter.wants_op(ChangeOp::Update)
|
||||
{
|
||||
changes.push(entity_change_from_row(
|
||||
&to_rows[ti],
|
||||
ChangeOp::Update,
|
||||
is_edge,
|
||||
));
|
||||
}
|
||||
fi += 1;
|
||||
ti += 1;
|
||||
}
|
||||
(Some(_), None) => {
|
||||
if filter.wants_op(ChangeOp::Delete) {
|
||||
changes.push(entity_change_from_row(
|
||||
&from_rows[fi],
|
||||
ChangeOp::Delete,
|
||||
is_edge,
|
||||
));
|
||||
}
|
||||
fi += 1;
|
||||
}
|
||||
(None, Some(_)) => {
|
||||
if filter.wants_op(ChangeOp::Insert) {
|
||||
changes.push(entity_change_from_row(
|
||||
&to_rows[ti],
|
||||
ChangeOp::Insert,
|
||||
is_edge,
|
||||
));
|
||||
}
|
||||
ti += 1;
|
||||
}
|
||||
(None, None) => break,
|
||||
}
|
||||
}
|
||||
|
||||
Ok(changes)
|
||||
}
|
||||
|
||||
// ─── Table added/removed ────────────────────────────────────────────────────
|
||||
|
||||
async fn diff_table_added(
|
||||
table_store: &TableStore,
|
||||
to_snap: &Snapshot,
|
||||
table_key: &str,
|
||||
is_edge: bool,
|
||||
filter: &ChangeFilter,
|
||||
) -> Result<Vec<EntityChange>> {
|
||||
if !filter.wants_op(ChangeOp::Insert) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let ds = table_store.open_snapshot_table(to_snap, table_key).await?;
|
||||
let rows = scan_all_rows_ordered(table_store, &ds, is_edge).await?;
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| entity_change_from_row(&r, ChangeOp::Insert, is_edge))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn diff_table_removed(
|
||||
table_store: &TableStore,
|
||||
from_snap: &Snapshot,
|
||||
table_key: &str,
|
||||
is_edge: bool,
|
||||
filter: &ChangeFilter,
|
||||
) -> Result<Vec<EntityChange>> {
|
||||
if !filter.wants_op(ChangeOp::Delete) {
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let ds = table_store
|
||||
.open_snapshot_table(from_snap, table_key)
|
||||
.await?;
|
||||
let rows = scan_all_rows_ordered(table_store, &ds, is_edge).await?;
|
||||
Ok(rows
|
||||
.into_iter()
|
||||
.map(|r| entity_change_from_row(&r, ChangeOp::Delete, is_edge))
|
||||
.collect())
|
||||
}
|
||||
|
||||
// ─── Helpers ────────────────────────────────────────────────────────────────
|
||||
|
||||
/// Scan with a SQL filter, projecting specific columns.
|
||||
async fn scan_with_filter(
|
||||
table_store: &TableStore,
|
||||
ds: &lance::Dataset,
|
||||
cols: &[&str],
|
||||
filter_sql: &str,
|
||||
) -> Result<Vec<ScannedRow>> {
|
||||
let batches = table_store
|
||||
.scan(ds, Some(cols), Some(filter_sql), None)
|
||||
.await?;
|
||||
Ok(extract_rows(&batches))
|
||||
}
|
||||
|
||||
/// Scan all rows ordered by id, projecting id (+ src/dst for edges) + all columns for signature.
|
||||
async fn scan_all_rows_ordered(
|
||||
table_store: &TableStore,
|
||||
ds: &lance::Dataset,
|
||||
is_edge: bool,
|
||||
) -> Result<Vec<ScannedRow>> {
|
||||
let batches = table_store
|
||||
.scan(
|
||||
ds,
|
||||
None,
|
||||
None,
|
||||
Some(vec![ColumnOrdering::asc_nulls_last("id".to_string())]),
|
||||
)
|
||||
.await?;
|
||||
Ok(extract_rows_with_signature(&batches, is_edge))
|
||||
}
|
||||
|
||||
/// Compute deleted IDs: scan id at from and to, set-difference.
|
||||
async fn deleted_ids_by_set_diff(
|
||||
table_store: &TableStore,
|
||||
from_ds: &lance::Dataset,
|
||||
to_ds: &lance::Dataset,
|
||||
is_edge: bool,
|
||||
) -> Result<Vec<EntityChange>> {
|
||||
let cols: Vec<&str> = if is_edge {
|
||||
vec!["id", "src", "dst"]
|
||||
} else {
|
||||
vec!["id"]
|
||||
};
|
||||
|
||||
let from_rows = scan_id_set(table_store, from_ds, &cols).await?;
|
||||
let to_ids: HashSet<String> = scan_id_set(table_store, to_ds, &["id"])
|
||||
.await?
|
||||
.into_iter()
|
||||
.map(|r| r.id)
|
||||
.collect();
|
||||
|
||||
Ok(from_rows
|
||||
.into_iter()
|
||||
.filter(|r| !to_ids.contains(&r.id))
|
||||
.map(|r| entity_change_from_row(&r, ChangeOp::Delete, is_edge))
|
||||
.collect())
|
||||
}
|
||||
|
||||
async fn scan_id_set(
|
||||
table_store: &TableStore,
|
||||
ds: &lance::Dataset,
|
||||
cols: &[&str],
|
||||
) -> Result<Vec<ScannedRow>> {
|
||||
let batches = table_store.scan(ds, Some(cols), None, None).await?;
|
||||
Ok(extract_rows(&batches))
|
||||
}
|
||||
|
||||
// ─── Row extraction ─────────────────────────────────────────────────────────
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct ScannedRow {
|
||||
id: String,
|
||||
src: Option<String>,
|
||||
dst: Option<String>,
|
||||
signature: String,
|
||||
change_version: Option<u64>,
|
||||
}
|
||||
|
||||
fn extract_rows(batches: &[RecordBatch]) -> Vec<ScannedRow> {
|
||||
let mut rows = Vec::new();
|
||||
for batch in batches {
|
||||
let ids = batch
|
||||
.column_by_name("id")
|
||||
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
|
||||
let Some(ids) = ids else { continue };
|
||||
let srcs = batch
|
||||
.column_by_name("src")
|
||||
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
|
||||
let dsts = batch
|
||||
.column_by_name("dst")
|
||||
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
|
||||
for i in 0..ids.len() {
|
||||
rows.push(ScannedRow {
|
||||
id: ids.value(i).to_string(),
|
||||
src: srcs.map(|a| a.value(i).to_string()),
|
||||
dst: dsts.map(|a| a.value(i).to_string()),
|
||||
signature: String::new(),
|
||||
change_version: batch
|
||||
.column_by_name("_row_last_updated_at_version")
|
||||
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
|
||||
.map(|versions| versions.value(i)),
|
||||
});
|
||||
}
|
||||
}
|
||||
rows
|
||||
}
|
||||
|
||||
fn extract_rows_with_signature(batches: &[RecordBatch], is_edge: bool) -> Vec<ScannedRow> {
|
||||
let mut rows = Vec::new();
|
||||
for batch in batches {
|
||||
let ids = batch
|
||||
.column_by_name("id")
|
||||
.and_then(|c| c.as_any().downcast_ref::<StringArray>());
|
||||
let Some(ids) = ids else { continue };
|
||||
let srcs = if is_edge {
|
||||
batch
|
||||
.column_by_name("src")
|
||||
.and_then(|c| c.as_any().downcast_ref::<StringArray>())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let dsts = if is_edge {
|
||||
batch
|
||||
.column_by_name("dst")
|
||||
.and_then(|c| c.as_any().downcast_ref::<StringArray>())
|
||||
} else {
|
||||
None
|
||||
};
|
||||
for i in 0..ids.len() {
|
||||
let mut values = Vec::with_capacity(batch.num_columns());
|
||||
for (field, col) in batch.schema().fields().iter().zip(batch.columns()) {
|
||||
if field.name().starts_with("_row_") {
|
||||
continue;
|
||||
}
|
||||
if let Ok(v) = array_value_to_string(col.as_ref(), i) {
|
||||
values.push(v);
|
||||
}
|
||||
}
|
||||
rows.push(ScannedRow {
|
||||
id: ids.value(i).to_string(),
|
||||
src: srcs.map(|a| a.value(i).to_string()),
|
||||
dst: dsts.map(|a| a.value(i).to_string()),
|
||||
signature: values.join("\x1f"),
|
||||
change_version: batch
|
||||
.column_by_name("_row_last_updated_at_version")
|
||||
.and_then(|c| c.as_any().downcast_ref::<UInt64Array>())
|
||||
.map(|versions| versions.value(i)),
|
||||
});
|
||||
}
|
||||
}
|
||||
rows
|
||||
}
|
||||
|
||||
fn entity_change_from_row(row: &ScannedRow, op: ChangeOp, is_edge: bool) -> EntityChange {
|
||||
EntityChange {
|
||||
table_key: String::new(),
|
||||
kind: if is_edge {
|
||||
EntityKind::Edge
|
||||
} else {
|
||||
EntityKind::Node
|
||||
},
|
||||
type_name: String::new(),
|
||||
id: row.id.clone(),
|
||||
op,
|
||||
manifest_version: row.change_version.unwrap_or(0),
|
||||
endpoints: if is_edge {
|
||||
Some(Endpoints {
|
||||
src: row.src.clone().unwrap_or_default(),
|
||||
dst: row.dst.clone().unwrap_or_default(),
|
||||
})
|
||||
} else {
|
||||
None
|
||||
},
|
||||
}
|
||||
}
|
||||
692
crates/omnigraph/src/db/commit_graph.rs
Normal file
692
crates/omnigraph/src/db/commit_graph.rs
Normal file
|
|
@ -0,0 +1,692 @@
|
|||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use arrow_array::{
|
||||
Array, RecordBatch, RecordBatchIterator, StringArray, TimestampMicrosecondArray, UInt64Array,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
|
||||
use futures::TryStreamExt;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
use lance_file::version::LanceFileVersion;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
const GRAPH_COMMITS_DIR: &str = "_graph_commits.lance";
|
||||
const GRAPH_COMMIT_ACTORS_DIR: &str = "_graph_commit_actors.lance";
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GraphCommit {
|
||||
pub graph_commit_id: String,
|
||||
pub manifest_branch: Option<String>,
|
||||
pub manifest_version: u64,
|
||||
pub parent_commit_id: Option<String>,
|
||||
pub merged_parent_commit_id: Option<String>,
|
||||
pub actor_id: Option<String>,
|
||||
pub created_at: i64,
|
||||
}
|
||||
|
||||
pub struct CommitGraph {
|
||||
root_uri: String,
|
||||
dataset: Dataset,
|
||||
actor_dataset: Option<Dataset>,
|
||||
active_branch: Option<String>,
|
||||
actor_by_commit_id: HashMap<String, String>,
|
||||
commit_by_id: HashMap<String, GraphCommit>,
|
||||
head_commit: Option<GraphCommit>,
|
||||
}
|
||||
|
||||
impl CommitGraph {
|
||||
pub async fn init(root_uri: &str, manifest_version: u64) -> Result<Self> {
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
let uri = graph_commits_uri(root);
|
||||
let genesis = GraphCommit {
|
||||
graph_commit_id: ulid::Ulid::new().to_string(),
|
||||
manifest_branch: None,
|
||||
manifest_version,
|
||||
parent_commit_id: None,
|
||||
merged_parent_commit_id: None,
|
||||
actor_id: None,
|
||||
created_at: now_micros()?,
|
||||
};
|
||||
|
||||
let batch = commits_to_batch(&[genesis.clone()])?;
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_graph_schema());
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
..Default::default()
|
||||
};
|
||||
let dataset = Dataset::write(reader, &uri as &str, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let actor_dataset = create_commit_actor_dataset(root).await?;
|
||||
|
||||
Ok(Self {
|
||||
root_uri: root.to_string(),
|
||||
dataset,
|
||||
actor_dataset: Some(actor_dataset),
|
||||
active_branch: None,
|
||||
actor_by_commit_id: HashMap::new(),
|
||||
commit_by_id: HashMap::from([(genesis.graph_commit_id.clone(), genesis.clone())]),
|
||||
head_commit: Some(genesis),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn open(root_uri: &str) -> Result<Self> {
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
let dataset = Dataset::open(&graph_commits_uri(root))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let actor_dataset = Dataset::open(&graph_commit_actors_uri(root)).await.ok();
|
||||
let actor_by_commit_id = match &actor_dataset {
|
||||
Some(dataset) => load_commit_actor_cache(dataset).await?,
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let (commit_by_id, head_commit) = load_commit_cache(&dataset, &actor_by_commit_id).await?;
|
||||
Ok(Self {
|
||||
root_uri: root.to_string(),
|
||||
dataset,
|
||||
actor_dataset,
|
||||
active_branch: None,
|
||||
actor_by_commit_id,
|
||||
commit_by_id,
|
||||
head_commit,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn open_at_branch(root_uri: &str, branch: &str) -> Result<Self> {
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
let dataset = Dataset::open(&graph_commits_uri(root))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let dataset = dataset
|
||||
.checkout_branch(branch)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let actor_dataset = Dataset::open(&graph_commit_actors_uri(root)).await.ok();
|
||||
let actor_by_commit_id = match &actor_dataset {
|
||||
Some(dataset) => load_commit_actor_cache(dataset).await?,
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let (commit_by_id, head_commit) = load_commit_cache(&dataset, &actor_by_commit_id).await?;
|
||||
Ok(Self {
|
||||
root_uri: root.to_string(),
|
||||
dataset,
|
||||
actor_dataset,
|
||||
active_branch: Some(branch.to_string()),
|
||||
actor_by_commit_id,
|
||||
commit_by_id,
|
||||
head_commit,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn refresh(&mut self) -> Result<()> {
|
||||
let root = self.root_uri.clone();
|
||||
self.dataset = Dataset::open(&graph_commits_uri(&root))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
if let Some(branch) = &self.active_branch {
|
||||
self.dataset = self
|
||||
.dataset
|
||||
.checkout_branch(branch)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
}
|
||||
self.actor_dataset = Dataset::open(&graph_commit_actors_uri(&root)).await.ok();
|
||||
self.actor_by_commit_id = match &self.actor_dataset {
|
||||
Some(dataset) => load_commit_actor_cache(dataset).await?,
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let (commit_by_id, head_commit) =
|
||||
load_commit_cache(&self.dataset, &self.actor_by_commit_id).await?;
|
||||
self.commit_by_id = commit_by_id;
|
||||
self.head_commit = head_commit;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn version(&self) -> u64 {
|
||||
self.dataset.version().version
|
||||
}
|
||||
|
||||
pub async fn create_branch(&mut self, name: &str) -> Result<()> {
|
||||
let mut ds = self.dataset.clone();
|
||||
ds.create_branch(name, self.version(), None)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_branch(&mut self, name: &str) -> Result<()> {
|
||||
let mut ds = Dataset::open(&graph_commits_uri(&self.root_uri))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
ds.delete_branch(name)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.refresh().await
|
||||
}
|
||||
|
||||
pub async fn append_commit(
|
||||
&mut self,
|
||||
manifest_branch: Option<&str>,
|
||||
manifest_version: u64,
|
||||
actor_id: Option<&str>,
|
||||
) -> Result<String> {
|
||||
let parent_commit_id = self.head_commit_id().await?;
|
||||
self.append_commit_with_parents(
|
||||
manifest_branch,
|
||||
manifest_version,
|
||||
parent_commit_id.as_deref(),
|
||||
None,
|
||||
actor_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn append_merge_commit(
|
||||
&mut self,
|
||||
manifest_branch: Option<&str>,
|
||||
manifest_version: u64,
|
||||
parent_commit_id: &str,
|
||||
merged_parent_commit_id: &str,
|
||||
actor_id: Option<&str>,
|
||||
) -> Result<String> {
|
||||
self.append_commit_with_parents(
|
||||
manifest_branch,
|
||||
manifest_version,
|
||||
Some(parent_commit_id),
|
||||
Some(merged_parent_commit_id),
|
||||
actor_id,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn append_commit_with_parents(
|
||||
&mut self,
|
||||
manifest_branch: Option<&str>,
|
||||
manifest_version: u64,
|
||||
parent_commit_id: Option<&str>,
|
||||
merged_parent_commit_id: Option<&str>,
|
||||
actor_id: Option<&str>,
|
||||
) -> Result<String> {
|
||||
let graph_commit_id = ulid::Ulid::new().to_string();
|
||||
let commit = GraphCommit {
|
||||
graph_commit_id: graph_commit_id.clone(),
|
||||
manifest_branch: manifest_branch.map(|s| s.to_string()),
|
||||
manifest_version,
|
||||
parent_commit_id: parent_commit_id.map(|s| s.to_string()),
|
||||
merged_parent_commit_id: merged_parent_commit_id.map(|s| s.to_string()),
|
||||
actor_id: actor_id.map(str::to_string),
|
||||
created_at: now_micros()?,
|
||||
};
|
||||
|
||||
let batch = commits_to_batch(&[commit.clone()])?;
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_graph_schema());
|
||||
let mut ds = self.dataset.clone();
|
||||
ds.append(reader, None)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.dataset = ds;
|
||||
if let Some(actor_id) = actor_id {
|
||||
self.append_actor(&graph_commit_id, actor_id).await?;
|
||||
}
|
||||
self.commit_by_id
|
||||
.insert(graph_commit_id.clone(), commit.clone());
|
||||
if should_replace_head(self.head_commit.as_ref(), &commit) {
|
||||
self.head_commit = Some(commit);
|
||||
}
|
||||
|
||||
Ok(graph_commit_id)
|
||||
}
|
||||
|
||||
async fn append_actor(&mut self, graph_commit_id: &str, actor_id: &str) -> Result<()> {
|
||||
if self
|
||||
.actor_by_commit_id
|
||||
.get(graph_commit_id)
|
||||
.is_some_and(|existing| existing == actor_id)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let record = CommitActorRecord {
|
||||
graph_commit_id: graph_commit_id.to_string(),
|
||||
actor_id: actor_id.to_string(),
|
||||
created_at: now_micros()?,
|
||||
};
|
||||
let batch = commit_actors_to_batch(&[record])?;
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_actor_schema());
|
||||
let mut dataset = match self.actor_dataset.take() {
|
||||
Some(dataset) => dataset,
|
||||
None => create_commit_actor_dataset(&self.root_uri).await?,
|
||||
};
|
||||
dataset
|
||||
.append(reader, None)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.actor_by_commit_id
|
||||
.insert(graph_commit_id.to_string(), actor_id.to_string());
|
||||
self.actor_dataset = Some(dataset);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn head_commit(&self) -> Result<Option<GraphCommit>> {
|
||||
Ok(self.head_commit.clone())
|
||||
}
|
||||
|
||||
pub async fn head_commit_id(&self) -> Result<Option<String>> {
|
||||
Ok(self.head_commit().await?.map(|c| c.graph_commit_id))
|
||||
}
|
||||
|
||||
pub async fn load_commits(&self) -> Result<Vec<GraphCommit>> {
|
||||
let mut commits = self.commit_by_id.values().cloned().collect::<Vec<_>>();
|
||||
commits.sort_by(|a, b| {
|
||||
a.manifest_version
|
||||
.cmp(&b.manifest_version)
|
||||
.then_with(|| a.created_at.cmp(&b.created_at))
|
||||
.then_with(|| a.graph_commit_id.cmp(&b.graph_commit_id))
|
||||
});
|
||||
Ok(commits)
|
||||
}
|
||||
|
||||
pub fn get_commit(&self, commit_id: &str) -> Option<GraphCommit> {
|
||||
self.commit_by_id.get(commit_id).cloned()
|
||||
}
|
||||
|
||||
pub async fn merge_base(
|
||||
root_uri: &str,
|
||||
source_branch: Option<&str>,
|
||||
target_branch: Option<&str>,
|
||||
) -> Result<Option<GraphCommit>> {
|
||||
let source = open_for_branch(root_uri, source_branch).await?;
|
||||
let target = open_for_branch(root_uri, target_branch).await?;
|
||||
|
||||
let source_head = match source.head_commit().await? {
|
||||
Some(commit) => commit,
|
||||
None => return Ok(None),
|
||||
};
|
||||
let target_head = match target.head_commit().await? {
|
||||
Some(commit) => commit,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
||||
let mut commits = HashMap::new();
|
||||
for commit in source.load_commits().await? {
|
||||
commits.insert(commit.graph_commit_id.clone(), commit);
|
||||
}
|
||||
for commit in target.load_commits().await? {
|
||||
commits.insert(commit.graph_commit_id.clone(), commit);
|
||||
}
|
||||
|
||||
let source_distances = ancestor_distances(&source_head.graph_commit_id, &commits);
|
||||
let target_distances = ancestor_distances(&target_head.graph_commit_id, &commits);
|
||||
|
||||
let best = source_distances
|
||||
.iter()
|
||||
.filter_map(|(id, source_distance)| {
|
||||
target_distances.get(id).and_then(|target_distance| {
|
||||
commits.get(id).map(|commit| {
|
||||
(
|
||||
(
|
||||
*source_distance + *target_distance,
|
||||
u64::MAX - commit.manifest_version,
|
||||
),
|
||||
commit.clone(),
|
||||
)
|
||||
})
|
||||
})
|
||||
})
|
||||
.min_by_key(|(score, _)| *score)
|
||||
.map(|(_, commit)| commit);
|
||||
|
||||
Ok(best)
|
||||
}
|
||||
}
|
||||
|
||||
fn graph_commits_uri(root_uri: &str) -> String {
|
||||
format!("{}/{}", root_uri.trim_end_matches('/'), GRAPH_COMMITS_DIR)
|
||||
}
|
||||
|
||||
fn graph_commit_actors_uri(root_uri: &str) -> String {
|
||||
format!(
|
||||
"{}/{}",
|
||||
root_uri.trim_end_matches('/'),
|
||||
GRAPH_COMMIT_ACTORS_DIR
|
||||
)
|
||||
}
|
||||
|
||||
fn commit_graph_schema() -> SchemaRef {
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("graph_commit_id", DataType::Utf8, false),
|
||||
Field::new("manifest_branch", DataType::Utf8, true),
|
||||
Field::new("manifest_version", DataType::UInt64, false),
|
||||
Field::new("parent_commit_id", DataType::Utf8, true),
|
||||
Field::new("merged_parent_commit_id", DataType::Utf8, true),
|
||||
Field::new(
|
||||
"created_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
]))
|
||||
}
|
||||
|
||||
fn commit_actor_schema() -> SchemaRef {
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("graph_commit_id", DataType::Utf8, false),
|
||||
Field::new("actor_id", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"created_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
]))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct CommitActorRecord {
|
||||
graph_commit_id: String,
|
||||
actor_id: String,
|
||||
created_at: i64,
|
||||
}
|
||||
|
||||
async fn create_commit_actor_dataset(root_uri: &str) -> Result<Dataset> {
|
||||
let uri = graph_commit_actors_uri(root_uri);
|
||||
let batch = RecordBatch::new_empty(commit_actor_schema());
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], commit_actor_schema());
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
..Default::default()
|
||||
};
|
||||
match Dataset::write(reader, &uri as &str, Some(params)).await {
|
||||
Ok(dataset) => Ok(dataset),
|
||||
Err(err) if err.to_string().contains("Dataset already exists") => Dataset::open(&uri)
|
||||
.await
|
||||
.map_err(|open_err| OmniError::Lance(open_err.to_string())),
|
||||
Err(err) => Err(OmniError::Lance(err.to_string())),
|
||||
}
|
||||
}
|
||||
|
||||
fn commits_to_batch(commits: &[GraphCommit]) -> Result<RecordBatch> {
|
||||
let ids: Vec<&str> = commits.iter().map(|c| c.graph_commit_id.as_str()).collect();
|
||||
let branches: Vec<Option<&str>> = commits
|
||||
.iter()
|
||||
.map(|c| c.manifest_branch.as_deref())
|
||||
.collect();
|
||||
let versions: Vec<u64> = commits.iter().map(|c| c.manifest_version).collect();
|
||||
let parents: Vec<Option<&str>> = commits
|
||||
.iter()
|
||||
.map(|c| c.parent_commit_id.as_deref())
|
||||
.collect();
|
||||
let merged_parents: Vec<Option<&str>> = commits
|
||||
.iter()
|
||||
.map(|c| c.merged_parent_commit_id.as_deref())
|
||||
.collect();
|
||||
let created_at: Vec<i64> = commits.iter().map(|c| c.created_at).collect();
|
||||
|
||||
RecordBatch::try_new(
|
||||
commit_graph_schema(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(ids)),
|
||||
Arc::new(StringArray::from(branches)),
|
||||
Arc::new(UInt64Array::from(versions)),
|
||||
Arc::new(StringArray::from(parents)),
|
||||
Arc::new(StringArray::from(merged_parents)),
|
||||
Arc::new(TimestampMicrosecondArray::from(created_at)),
|
||||
],
|
||||
)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
async fn load_commit_cache(
|
||||
dataset: &Dataset,
|
||||
actor_by_commit_id: &HashMap<String, String>,
|
||||
) -> Result<(HashMap<String, GraphCommit>, Option<GraphCommit>)> {
|
||||
let batches: Vec<RecordBatch> = dataset
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
let mut commits = load_commits_from_batches(&batches)?;
|
||||
for commit in &mut commits {
|
||||
commit.actor_id = actor_by_commit_id
|
||||
.get(commit.graph_commit_id.as_str())
|
||||
.cloned();
|
||||
}
|
||||
let mut commit_by_id = HashMap::with_capacity(commits.len());
|
||||
let mut head_commit = None;
|
||||
for commit in commits {
|
||||
if should_replace_head(head_commit.as_ref(), &commit) {
|
||||
head_commit = Some(commit.clone());
|
||||
}
|
||||
commit_by_id.insert(commit.graph_commit_id.clone(), commit);
|
||||
}
|
||||
Ok((commit_by_id, head_commit))
|
||||
}
|
||||
|
||||
async fn load_commit_actor_cache(dataset: &Dataset) -> Result<HashMap<String, String>> {
|
||||
let batches: Vec<RecordBatch> = dataset
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
let mut actors = HashMap::new();
|
||||
for batch in batches {
|
||||
let commit_ids = string_column(&batch, "graph_commit_id", "commit actor registry")?;
|
||||
let actor_ids = string_column(&batch, "actor_id", "commit actor registry")?;
|
||||
for row in 0..batch.num_rows() {
|
||||
actors.insert(
|
||||
commit_ids.value(row).to_string(),
|
||||
actor_ids.value(row).to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(actors)
|
||||
}
|
||||
|
||||
fn load_commits_from_batches(batches: &[RecordBatch]) -> Result<Vec<GraphCommit>> {
|
||||
let mut commits = Vec::new();
|
||||
for batch in batches {
|
||||
let ids = string_column(batch, "graph_commit_id", "commit graph")?;
|
||||
let branches = string_column(batch, "manifest_branch", "commit graph")?;
|
||||
let versions = u64_column(batch, "manifest_version", "commit graph")?;
|
||||
let parents = string_column(batch, "parent_commit_id", "commit graph")?;
|
||||
let merged_parents = string_column(batch, "merged_parent_commit_id", "commit graph")?;
|
||||
let created = timestamp_micros_column(batch, "created_at", "commit graph")?;
|
||||
|
||||
for row in 0..batch.num_rows() {
|
||||
commits.push(GraphCommit {
|
||||
graph_commit_id: ids.value(row).to_string(),
|
||||
manifest_branch: if branches.is_null(row) {
|
||||
None
|
||||
} else {
|
||||
Some(branches.value(row).to_string())
|
||||
},
|
||||
manifest_version: versions.value(row),
|
||||
parent_commit_id: if parents.is_null(row) {
|
||||
None
|
||||
} else {
|
||||
Some(parents.value(row).to_string())
|
||||
},
|
||||
merged_parent_commit_id: if merged_parents.is_null(row) {
|
||||
None
|
||||
} else {
|
||||
Some(merged_parents.value(row).to_string())
|
||||
},
|
||||
actor_id: None,
|
||||
created_at: created.value(row),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(commits)
|
||||
}
|
||||
|
||||
fn commit_actors_to_batch(records: &[CommitActorRecord]) -> Result<RecordBatch> {
|
||||
let commit_ids: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.graph_commit_id.as_str())
|
||||
.collect();
|
||||
let actor_ids: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.actor_id.as_str())
|
||||
.collect();
|
||||
let created_at: Vec<i64> = records.iter().map(|record| record.created_at).collect();
|
||||
|
||||
RecordBatch::try_new(
|
||||
commit_actor_schema(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(commit_ids)),
|
||||
Arc::new(StringArray::from(actor_ids)),
|
||||
Arc::new(TimestampMicrosecondArray::from(created_at)),
|
||||
],
|
||||
)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
fn should_replace_head(current: Option<&GraphCommit>, candidate: &GraphCommit) -> bool {
|
||||
current.is_none_or(|existing| {
|
||||
candidate
|
||||
.manifest_version
|
||||
.cmp(&existing.manifest_version)
|
||||
.then_with(|| candidate.created_at.cmp(&existing.created_at))
|
||||
.then_with(|| candidate.graph_commit_id.cmp(&existing.graph_commit_id))
|
||||
.is_gt()
|
||||
})
|
||||
}
|
||||
|
||||
fn string_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a StringArray> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} column '{name}' is not Utf8"))
|
||||
})
|
||||
}
|
||||
|
||||
fn u64_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a UInt64Array> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} column '{name}' is not UInt64"))
|
||||
})
|
||||
}
|
||||
|
||||
fn timestamp_micros_column<'a>(
|
||||
batch: &'a RecordBatch,
|
||||
name: &str,
|
||||
context: &str,
|
||||
) -> Result<&'a TimestampMicrosecondArray> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampMicrosecondArray>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!(
|
||||
"{context} column '{name}' is not Timestamp(Microsecond)"
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn ancestor_distances(
|
||||
start_id: &str,
|
||||
commits: &HashMap<String, GraphCommit>,
|
||||
) -> HashMap<String, u64> {
|
||||
let mut distances = HashMap::new();
|
||||
let mut queue = VecDeque::from([(start_id.to_string(), 0u64)]);
|
||||
|
||||
while let Some((id, distance)) = queue.pop_front() {
|
||||
if let Some(existing) = distances.get(&id) {
|
||||
if *existing <= distance {
|
||||
continue;
|
||||
}
|
||||
}
|
||||
distances.insert(id.clone(), distance);
|
||||
|
||||
if let Some(commit) = commits.get(&id) {
|
||||
if let Some(parent) = &commit.parent_commit_id {
|
||||
queue.push_back((parent.clone(), distance + 1));
|
||||
}
|
||||
if let Some(parent) = &commit.merged_parent_commit_id {
|
||||
queue.push_back((parent.clone(), distance + 1));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
distances
|
||||
}
|
||||
|
||||
async fn open_for_branch(root_uri: &str, branch: Option<&str>) -> Result<CommitGraph> {
|
||||
match branch {
|
||||
Some(branch) if branch != "main" => CommitGraph::open_at_branch(root_uri, branch).await,
|
||||
_ => CommitGraph::open(root_uri).await,
|
||||
}
|
||||
}
|
||||
|
||||
fn now_micros() -> Result<i64> {
|
||||
let duration = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_err(|e| OmniError::manifest(format!("system clock before UNIX_EPOCH: {}", e)))?;
|
||||
Ok(duration.as_micros() as i64)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_commits_from_batches_returns_error_for_bad_schema() {
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("graph_commit_id", DataType::UInt64, false),
|
||||
Field::new("manifest_branch", DataType::Utf8, true),
|
||||
Field::new("manifest_version", DataType::UInt64, false),
|
||||
Field::new("parent_commit_id", DataType::Utf8, true),
|
||||
Field::new("merged_parent_commit_id", DataType::Utf8, true),
|
||||
Field::new(
|
||||
"created_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(vec![1_u64])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
Arc::new(UInt64Array::from(vec![1_u64])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = load_commits_from_batches(&[batch]).unwrap_err();
|
||||
assert!(err.to_string().contains("graph_commit_id"));
|
||||
}
|
||||
}
|
||||
562
crates/omnigraph/src/db/graph_coordinator.rs
Normal file
562
crates/omnigraph/src/db/graph_coordinator.rs
Normal file
|
|
@ -0,0 +1,562 @@
|
|||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
|
||||
use omnigraph_compiler::catalog::Catalog;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
use crate::failpoints;
|
||||
use crate::storage::{StorageAdapter, join_uri, normalize_root_uri};
|
||||
|
||||
use super::commit_graph::{CommitGraph, GraphCommit};
|
||||
use super::manifest::{ManifestCoordinator, Snapshot, SubTableUpdate};
|
||||
use super::run_registry::{RunId, RunRecord, RunRegistry, graph_runs_uri, is_internal_run_branch};
|
||||
|
||||
const GRAPH_COMMITS_DIR: &str = "_graph_commits.lance";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct SnapshotId(String);
|
||||
|
||||
impl SnapshotId {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
|
||||
pub(crate) fn synthetic(branch: Option<&str>, version: u64) -> Self {
|
||||
match branch {
|
||||
Some(branch) => Self(format!("manifest:{}:v{}", branch, version)),
|
||||
None => Self(format!("manifest:main:v{}", version)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for SnapshotId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub enum ReadTarget {
|
||||
Branch(String),
|
||||
Snapshot(SnapshotId),
|
||||
}
|
||||
|
||||
impl ReadTarget {
|
||||
pub fn branch(name: impl Into<String>) -> Self {
|
||||
Self::Branch(name.into())
|
||||
}
|
||||
|
||||
pub fn snapshot(id: impl Into<SnapshotId>) -> Self {
|
||||
Self::Snapshot(id.into())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&str> for ReadTarget {
|
||||
fn from(value: &str) -> Self {
|
||||
Self::branch(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<String> for ReadTarget {
|
||||
fn from(value: String) -> Self {
|
||||
Self::Branch(value)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<SnapshotId> for ReadTarget {
|
||||
fn from(value: SnapshotId) -> Self {
|
||||
Self::Snapshot(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct ResolvedTarget {
|
||||
pub requested: ReadTarget,
|
||||
pub branch: Option<String>,
|
||||
pub snapshot_id: SnapshotId,
|
||||
pub snapshot: Snapshot,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(crate) struct PublishedSnapshot {
|
||||
pub manifest_version: u64,
|
||||
pub _snapshot_id: SnapshotId,
|
||||
}
|
||||
|
||||
pub struct GraphCoordinator {
|
||||
root_uri: String,
|
||||
storage: Arc<dyn StorageAdapter>,
|
||||
manifest: ManifestCoordinator,
|
||||
commit_graph: Option<CommitGraph>,
|
||||
run_registry: Option<RunRegistry>,
|
||||
bound_branch: Option<String>,
|
||||
}
|
||||
|
||||
impl GraphCoordinator {
|
||||
pub async fn init(
|
||||
root_uri: &str,
|
||||
catalog: &Catalog,
|
||||
storage: Arc<dyn StorageAdapter>,
|
||||
) -> Result<Self> {
|
||||
let root = normalize_root_uri(root_uri)?;
|
||||
let manifest = ManifestCoordinator::init(&root, catalog).await?;
|
||||
let commit_graph = Some(CommitGraph::init(&root, manifest.version()).await?);
|
||||
Ok(Self {
|
||||
root_uri: root,
|
||||
storage,
|
||||
manifest,
|
||||
commit_graph,
|
||||
run_registry: None,
|
||||
bound_branch: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn open(root_uri: &str, storage: Arc<dyn StorageAdapter>) -> Result<Self> {
|
||||
let root = normalize_root_uri(root_uri)?;
|
||||
let manifest = ManifestCoordinator::open(&root).await?;
|
||||
let commit_graph = if storage.exists(&graph_commits_uri(&root)).await? {
|
||||
Some(CommitGraph::open(&root).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let run_registry = if storage.exists(&graph_runs_uri(&root)).await? {
|
||||
Some(RunRegistry::open(&root).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
Ok(Self {
|
||||
root_uri: root,
|
||||
storage,
|
||||
manifest,
|
||||
commit_graph,
|
||||
run_registry,
|
||||
bound_branch: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn open_branch(
|
||||
root_uri: &str,
|
||||
branch: &str,
|
||||
storage: Arc<dyn StorageAdapter>,
|
||||
) -> Result<Self> {
|
||||
let branch = normalize_branch_name(branch)?;
|
||||
let Some(branch_name) = branch else {
|
||||
return Self::open(root_uri, storage).await;
|
||||
};
|
||||
|
||||
let root = normalize_root_uri(root_uri)?;
|
||||
let manifest = ManifestCoordinator::open_at_branch(&root, &branch_name).await?;
|
||||
let commit_graph = if storage.exists(&graph_commits_uri(&root)).await? {
|
||||
Some(CommitGraph::open_at_branch(&root, &branch_name).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
let run_registry = if storage.exists(&graph_runs_uri(&root)).await? {
|
||||
Some(RunRegistry::open(&root).await?)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
Ok(Self {
|
||||
root_uri: root,
|
||||
storage,
|
||||
manifest,
|
||||
commit_graph,
|
||||
run_registry,
|
||||
bound_branch: Some(branch_name),
|
||||
})
|
||||
}
|
||||
|
||||
pub fn root_uri(&self) -> &str {
|
||||
&self.root_uri
|
||||
}
|
||||
|
||||
pub fn version(&self) -> u64 {
|
||||
self.manifest.version()
|
||||
}
|
||||
|
||||
pub fn snapshot(&self) -> Snapshot {
|
||||
self.manifest.snapshot()
|
||||
}
|
||||
|
||||
pub fn current_branch(&self) -> Option<&str> {
|
||||
self.bound_branch.as_deref()
|
||||
}
|
||||
|
||||
pub async fn refresh(&mut self) -> Result<()> {
|
||||
self.manifest.refresh().await?;
|
||||
if let Some(commit_graph) = &mut self.commit_graph {
|
||||
commit_graph.refresh().await?;
|
||||
}
|
||||
if let Some(run_registry) = &mut self.run_registry {
|
||||
let root_uri = self.root_uri.clone();
|
||||
run_registry.refresh(&root_uri).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn branch_list(&self) -> Result<Vec<String>> {
|
||||
self.manifest.list_branches().await.map(|branches| {
|
||||
branches
|
||||
.into_iter()
|
||||
.filter(|branch| !is_internal_run_branch(branch))
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn branch_descendants(&self, name: &str) -> Result<Vec<String>> {
|
||||
self.manifest
|
||||
.descendant_branches(name)
|
||||
.await
|
||||
.map(|branches| {
|
||||
branches
|
||||
.into_iter()
|
||||
.filter(|branch| !is_internal_run_branch(branch))
|
||||
.collect()
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn branch_create(&mut self, name: &str) -> Result<()> {
|
||||
let branch = normalize_branch_name(name)?
|
||||
.ok_or_else(|| OmniError::manifest("cannot create branch 'main'".to_string()))?;
|
||||
self.ensure_commit_graph_initialized().await?;
|
||||
self.manifest.create_branch(&branch).await?;
|
||||
failpoints::maybe_fail("branch_create.after_manifest_branch_create")?;
|
||||
if let Some(commit_graph) = &mut self.commit_graph {
|
||||
commit_graph.create_branch(&branch).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn branch_delete(&mut self, name: &str) -> Result<()> {
|
||||
let branch = normalize_branch_name(name)?
|
||||
.ok_or_else(|| OmniError::manifest("cannot delete branch 'main'".to_string()))?;
|
||||
if self.current_branch() == Some(branch.as_str()) {
|
||||
return Err(OmniError::manifest_conflict(format!(
|
||||
"cannot delete currently active branch '{}'",
|
||||
branch
|
||||
)));
|
||||
}
|
||||
|
||||
self.manifest.delete_branch(&branch).await?;
|
||||
|
||||
if let Some(commit_graph) = &mut self.commit_graph {
|
||||
commit_graph.delete_branch(&branch).await?;
|
||||
} else if self
|
||||
.storage
|
||||
.exists(&graph_commits_uri(self.root_uri()))
|
||||
.await?
|
||||
{
|
||||
let mut commit_graph = CommitGraph::open(self.root_uri()).await?;
|
||||
commit_graph.delete_branch(&branch).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn snapshot_at_version(&self, version: u64) -> Result<Snapshot> {
|
||||
ManifestCoordinator::snapshot_at(self.root_uri(), self.current_branch(), version).await
|
||||
}
|
||||
|
||||
pub async fn resolve_snapshot_id(&self, branch: &str) -> Result<SnapshotId> {
|
||||
let normalized = normalize_branch_name(branch)?;
|
||||
let other = match normalized.as_deref() {
|
||||
Some(branch) => {
|
||||
GraphCoordinator::open_branch(self.root_uri(), branch, Arc::clone(&self.storage))
|
||||
.await?
|
||||
}
|
||||
None => GraphCoordinator::open(self.root_uri(), Arc::clone(&self.storage)).await?,
|
||||
};
|
||||
|
||||
Ok(other
|
||||
.head_commit_id()
|
||||
.await?
|
||||
.unwrap_or_else(|| SnapshotId::synthetic(other.current_branch(), other.version())))
|
||||
}
|
||||
|
||||
pub async fn resolve_target(&self, target: &ReadTarget) -> Result<ResolvedTarget> {
|
||||
match target {
|
||||
ReadTarget::Branch(branch) => {
|
||||
let normalized = normalize_branch_name(branch)?;
|
||||
let other = match normalized.as_deref() {
|
||||
Some(branch) => {
|
||||
GraphCoordinator::open_branch(
|
||||
self.root_uri(),
|
||||
branch,
|
||||
Arc::clone(&self.storage),
|
||||
)
|
||||
.await?
|
||||
}
|
||||
None => {
|
||||
GraphCoordinator::open(self.root_uri(), Arc::clone(&self.storage)).await?
|
||||
}
|
||||
};
|
||||
let snapshot_id = other.head_commit_id().await?.unwrap_or_else(|| {
|
||||
SnapshotId::synthetic(other.current_branch(), other.version())
|
||||
});
|
||||
Ok(ResolvedTarget {
|
||||
requested: target.clone(),
|
||||
branch: other.bound_branch.clone(),
|
||||
snapshot_id,
|
||||
snapshot: other.snapshot(),
|
||||
})
|
||||
}
|
||||
ReadTarget::Snapshot(snapshot_id) => {
|
||||
let commit = self.resolve_commit(snapshot_id).await?;
|
||||
let snapshot = ManifestCoordinator::snapshot_at(
|
||||
self.root_uri(),
|
||||
commit.manifest_branch.as_deref(),
|
||||
commit.manifest_version,
|
||||
)
|
||||
.await?;
|
||||
Ok(ResolvedTarget {
|
||||
requested: target.clone(),
|
||||
branch: commit.manifest_branch.clone(),
|
||||
snapshot_id: snapshot_id.clone(),
|
||||
snapshot,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn resolve_commit(&self, snapshot_id: &SnapshotId) -> Result<GraphCommit> {
|
||||
if let Some(commit_graph) = &self.commit_graph {
|
||||
if let Some(commit) = commit_graph.get_commit(snapshot_id.as_str()) {
|
||||
return Ok(commit);
|
||||
}
|
||||
}
|
||||
|
||||
for branch in self.manifest.list_branches().await? {
|
||||
let normalized = normalize_branch_name(&branch)?;
|
||||
let Some(commit_graph) = self
|
||||
.open_commit_graph_for_branch(normalized.as_deref())
|
||||
.await?
|
||||
else {
|
||||
break;
|
||||
};
|
||||
if let Some(commit) = commit_graph.get_commit(snapshot_id.as_str()) {
|
||||
return Ok(commit);
|
||||
}
|
||||
}
|
||||
|
||||
Err(OmniError::manifest_not_found(format!(
|
||||
"commit '{}' not found",
|
||||
snapshot_id
|
||||
)))
|
||||
}
|
||||
|
||||
pub(crate) async fn head_commit_id(&self) -> Result<Option<SnapshotId>> {
|
||||
match &self.commit_graph {
|
||||
Some(commit_graph) => commit_graph
|
||||
.head_commit_id()
|
||||
.await
|
||||
.map(|id| id.map(SnapshotId::new)),
|
||||
None => Ok(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_commit_graph_initialized(&mut self) -> Result<()> {
|
||||
if self.commit_graph.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
if !self
|
||||
.storage
|
||||
.exists(&graph_commits_uri(self.root_uri()))
|
||||
.await?
|
||||
{
|
||||
let _ = CommitGraph::init(self.root_uri(), self.manifest.version()).await?;
|
||||
}
|
||||
self.commit_graph = match self.current_branch() {
|
||||
Some(branch) => Some(CommitGraph::open_at_branch(self.root_uri(), branch).await?),
|
||||
None => Some(CommitGraph::open(self.root_uri()).await?),
|
||||
};
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn ensure_run_registry_initialized(&mut self) -> Result<()> {
|
||||
if self.run_registry.is_some() {
|
||||
return Ok(());
|
||||
}
|
||||
if !self
|
||||
.storage
|
||||
.exists(&graph_runs_uri(self.root_uri()))
|
||||
.await?
|
||||
{
|
||||
let _ = RunRegistry::init(self.root_uri()).await?;
|
||||
}
|
||||
self.run_registry = Some(RunRegistry::open(self.root_uri()).await?);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn commit_updates_with_actor(
|
||||
&mut self,
|
||||
updates: &[SubTableUpdate],
|
||||
actor_id: Option<&str>,
|
||||
) -> Result<PublishedSnapshot> {
|
||||
let manifest_version = self.commit_manifest_updates(updates).await?;
|
||||
let snapshot_id = self.record_graph_commit(manifest_version, actor_id).await?;
|
||||
Ok(PublishedSnapshot {
|
||||
manifest_version,
|
||||
_snapshot_id: snapshot_id,
|
||||
})
|
||||
}
|
||||
|
||||
pub(crate) async fn commit_manifest_updates(
|
||||
&mut self,
|
||||
updates: &[SubTableUpdate],
|
||||
) -> Result<u64> {
|
||||
let manifest_version = self.manifest.commit(updates).await?;
|
||||
failpoints::maybe_fail("graph_publish.after_manifest_commit")?;
|
||||
Ok(manifest_version)
|
||||
}
|
||||
|
||||
pub(crate) async fn record_graph_commit(
|
||||
&mut self,
|
||||
manifest_version: u64,
|
||||
actor_id: Option<&str>,
|
||||
) -> Result<SnapshotId> {
|
||||
self.ensure_commit_graph_initialized().await?;
|
||||
let current_branch = self.current_branch().map(str::to_string);
|
||||
let Some(commit_graph) = &mut self.commit_graph else {
|
||||
return Ok(SnapshotId::synthetic(
|
||||
current_branch.as_deref(),
|
||||
manifest_version,
|
||||
));
|
||||
};
|
||||
failpoints::maybe_fail("graph_publish.before_commit_append")?;
|
||||
let graph_commit_id = commit_graph
|
||||
.append_commit(current_branch.as_deref(), manifest_version, actor_id)
|
||||
.await?;
|
||||
Ok(SnapshotId::new(graph_commit_id))
|
||||
}
|
||||
|
||||
pub(crate) async fn record_merge_commit(
|
||||
&mut self,
|
||||
manifest_version: u64,
|
||||
parent_commit_id: &str,
|
||||
merged_parent_commit_id: &str,
|
||||
actor_id: Option<&str>,
|
||||
) -> Result<SnapshotId> {
|
||||
self.ensure_commit_graph_initialized().await?;
|
||||
let current_branch = self.current_branch().map(str::to_string);
|
||||
let commit_graph = self.commit_graph.as_mut().ok_or_else(|| {
|
||||
OmniError::manifest("branch merge requires _graph_commits.lance".to_string())
|
||||
})?;
|
||||
failpoints::maybe_fail("graph_publish.before_commit_append")?;
|
||||
let graph_commit_id = commit_graph
|
||||
.append_merge_commit(
|
||||
current_branch.as_deref(),
|
||||
manifest_version,
|
||||
parent_commit_id,
|
||||
merged_parent_commit_id,
|
||||
actor_id,
|
||||
)
|
||||
.await?;
|
||||
Ok(SnapshotId::new(graph_commit_id))
|
||||
}
|
||||
|
||||
async fn open_commit_graph_for_branch(
|
||||
&self,
|
||||
branch: Option<&str>,
|
||||
) -> Result<Option<CommitGraph>> {
|
||||
if !self
|
||||
.storage
|
||||
.exists(&graph_commits_uri(self.root_uri()))
|
||||
.await?
|
||||
{
|
||||
return Ok(None);
|
||||
}
|
||||
let graph = match branch {
|
||||
Some(branch) => CommitGraph::open_at_branch(self.root_uri(), branch).await?,
|
||||
None => CommitGraph::open(self.root_uri()).await?,
|
||||
};
|
||||
Ok(Some(graph))
|
||||
}
|
||||
|
||||
pub(crate) async fn append_run_record(&mut self, record: &RunRecord) -> Result<()> {
|
||||
self.ensure_run_registry_initialized().await?;
|
||||
let Some(run_registry) = &mut self.run_registry else {
|
||||
return Err(OmniError::manifest(
|
||||
"run registry not initialized".to_string(),
|
||||
));
|
||||
};
|
||||
run_registry.append_record(record).await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_run(&self, run_id: &RunId) -> Result<RunRecord> {
|
||||
if let Some(run_registry) = &self.run_registry {
|
||||
if let Some(run) = run_registry.get_run(run_id).await? {
|
||||
return Ok(run);
|
||||
}
|
||||
}
|
||||
if !self
|
||||
.storage
|
||||
.exists(&graph_runs_uri(self.root_uri()))
|
||||
.await?
|
||||
{
|
||||
return Err(OmniError::manifest_not_found(format!(
|
||||
"run '{}' not found",
|
||||
run_id
|
||||
)));
|
||||
}
|
||||
let run_registry = RunRegistry::open(self.root_uri()).await?;
|
||||
run_registry
|
||||
.get_run(run_id)
|
||||
.await?
|
||||
.ok_or_else(|| OmniError::manifest_not_found(format!("run '{}' not found", run_id)))
|
||||
}
|
||||
|
||||
pub(crate) async fn list_runs(&self) -> Result<Vec<RunRecord>> {
|
||||
if let Some(run_registry) = &self.run_registry {
|
||||
return run_registry.list_runs().await;
|
||||
}
|
||||
if !self
|
||||
.storage
|
||||
.exists(&graph_runs_uri(self.root_uri()))
|
||||
.await?
|
||||
{
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let run_registry = RunRegistry::open(self.root_uri()).await?;
|
||||
run_registry.list_runs().await
|
||||
}
|
||||
|
||||
pub(crate) async fn list_commits(&self) -> Result<Vec<GraphCommit>> {
|
||||
if let Some(commit_graph) = &self.commit_graph {
|
||||
return commit_graph.load_commits().await;
|
||||
}
|
||||
if !self
|
||||
.storage
|
||||
.exists(&graph_commits_uri(self.root_uri()))
|
||||
.await?
|
||||
{
|
||||
return Ok(Vec::new());
|
||||
}
|
||||
let commit_graph = match self.current_branch() {
|
||||
Some(branch) => CommitGraph::open_at_branch(self.root_uri(), branch).await?,
|
||||
None => CommitGraph::open(self.root_uri()).await?,
|
||||
};
|
||||
commit_graph.load_commits().await
|
||||
}
|
||||
}
|
||||
|
||||
fn graph_commits_uri(root_uri: &str) -> String {
|
||||
join_uri(root_uri, GRAPH_COMMITS_DIR)
|
||||
}
|
||||
|
||||
fn normalize_branch_name(branch: &str) -> Result<Option<String>> {
|
||||
let branch = branch.trim();
|
||||
if branch.is_empty() {
|
||||
return Err(OmniError::manifest(
|
||||
"branch name cannot be empty".to_string(),
|
||||
));
|
||||
}
|
||||
if branch == "main" {
|
||||
return Ok(None);
|
||||
}
|
||||
Ok(Some(branch.to_string()))
|
||||
}
|
||||
339
crates/omnigraph/src/db/manifest.rs
Normal file
339
crates/omnigraph/src/db/manifest.rs
Normal file
|
|
@ -0,0 +1,339 @@
|
|||
use std::collections::{HashMap, HashSet};
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
use lance::Dataset;
|
||||
use lance_namespace::models::CreateTableVersionRequest;
|
||||
use omnigraph_compiler::catalog::Catalog;
|
||||
|
||||
#[path = "manifest/layout.rs"]
|
||||
mod layout;
|
||||
#[path = "manifest/metadata.rs"]
|
||||
mod metadata;
|
||||
#[path = "manifest/namespace.rs"]
|
||||
mod namespace;
|
||||
#[path = "manifest/publisher.rs"]
|
||||
mod publisher;
|
||||
#[path = "manifest/repo.rs"]
|
||||
mod repo;
|
||||
#[path = "manifest/state.rs"]
|
||||
mod state;
|
||||
|
||||
use layout::{manifest_uri, open_manifest_dataset};
|
||||
pub(crate) use metadata::TableVersionMetadata;
|
||||
#[cfg(test)]
|
||||
use metadata::{OMNIGRAPH_ROW_COUNT_KEY, table_version_metadata_for_state};
|
||||
use namespace::open_table_at_version_from_manifest;
|
||||
pub(crate) use namespace::open_table_head_for_write;
|
||||
#[cfg(test)]
|
||||
use namespace::{branch_manifest_namespace, staged_table_namespace};
|
||||
use publisher::{GraphNamespacePublisher, ManifestBatchPublisher};
|
||||
use repo::{init_manifest_repo, open_manifest_repo, snapshot_state_at};
|
||||
pub use state::SubTableEntry;
|
||||
#[cfg(test)]
|
||||
use state::string_column;
|
||||
use state::{ManifestState, read_manifest_state};
|
||||
|
||||
const OBJECT_TYPE_TABLE: &str = "table";
|
||||
const OBJECT_TYPE_TABLE_VERSION: &str = "table_version";
|
||||
const TABLE_VERSION_MANAGEMENT_KEY: &str = "table_version_management";
|
||||
|
||||
/// Immutable point-in-time view of the database.
|
||||
///
|
||||
/// Cheap to create (no storage I/O). All reads within a query go through one
|
||||
/// Snapshot to guarantee cross-type consistency.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct Snapshot {
|
||||
root_uri: String,
|
||||
version: u64,
|
||||
entries: HashMap<String, SubTableEntry>,
|
||||
}
|
||||
|
||||
impl Snapshot {
|
||||
/// Open a sub-table dataset at its pinned version.
|
||||
pub async fn open(&self, table_key: &str) -> Result<Dataset> {
|
||||
let entry = self
|
||||
.entries
|
||||
.get(table_key)
|
||||
.ok_or_else(|| OmniError::manifest(format!("no manifest entry for {}", table_key)))?;
|
||||
entry.open(&self.root_uri).await
|
||||
}
|
||||
|
||||
/// Manifest version this snapshot was taken from.
|
||||
pub fn version(&self) -> u64 {
|
||||
self.version
|
||||
}
|
||||
|
||||
/// Look up a sub-table entry by key.
|
||||
pub fn entry(&self, table_key: &str) -> Option<&SubTableEntry> {
|
||||
self.entries.get(table_key)
|
||||
}
|
||||
|
||||
pub fn entries(&self) -> impl Iterator<Item = &SubTableEntry> {
|
||||
self.entries.values()
|
||||
}
|
||||
}
|
||||
|
||||
impl SubTableUpdate {
|
||||
pub(crate) fn to_create_table_version_request(&self) -> CreateTableVersionRequest {
|
||||
self.version_metadata.to_create_table_version_request(
|
||||
&self.table_key,
|
||||
self.table_version,
|
||||
self.row_count,
|
||||
self.table_branch.as_deref(),
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
impl SubTableEntry {
|
||||
pub(crate) async fn open(&self, root_uri: &str) -> Result<Dataset> {
|
||||
open_table_at_version_from_manifest(
|
||||
root_uri,
|
||||
&self.table_key,
|
||||
self.table_branch.as_deref(),
|
||||
self.table_version,
|
||||
)
|
||||
.await
|
||||
}
|
||||
}
|
||||
|
||||
/// An update to apply to the manifest via `commit`.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubTableUpdate {
|
||||
pub table_key: String,
|
||||
pub table_version: u64,
|
||||
pub table_branch: Option<String>,
|
||||
pub row_count: u64,
|
||||
pub(crate) version_metadata: TableVersionMetadata,
|
||||
}
|
||||
|
||||
/// Coordinates cross-dataset state through the namespace `__manifest` table.
|
||||
///
|
||||
/// Table rows register stable metadata such as location. Append-only
|
||||
/// `table_version` rows are the graph publish boundary and reconstruct the
|
||||
/// current graph snapshot by selecting the latest visible version row per
|
||||
/// sub-table.
|
||||
pub struct ManifestCoordinator {
|
||||
root_uri: String,
|
||||
dataset: Dataset,
|
||||
known_state: ManifestState,
|
||||
active_branch: Option<String>,
|
||||
publisher: Arc<dyn ManifestBatchPublisher>,
|
||||
}
|
||||
|
||||
impl ManifestCoordinator {
|
||||
fn default_batch_publisher(
|
||||
root_uri: &str,
|
||||
active_branch: Option<&str>,
|
||||
) -> Arc<dyn ManifestBatchPublisher> {
|
||||
Arc::new(GraphNamespacePublisher::new(root_uri, active_branch))
|
||||
}
|
||||
|
||||
fn from_parts(
|
||||
root_uri: &str,
|
||||
dataset: Dataset,
|
||||
known_state: ManifestState,
|
||||
active_branch: Option<String>,
|
||||
publisher: Arc<dyn ManifestBatchPublisher>,
|
||||
) -> Self {
|
||||
Self {
|
||||
root_uri: root_uri.trim_end_matches('/').to_string(),
|
||||
dataset,
|
||||
known_state,
|
||||
active_branch,
|
||||
publisher,
|
||||
}
|
||||
}
|
||||
|
||||
fn from_parts_with_default_publisher(
|
||||
root_uri: &str,
|
||||
dataset: Dataset,
|
||||
known_state: ManifestState,
|
||||
active_branch: Option<String>,
|
||||
) -> Self {
|
||||
let publisher = Self::default_batch_publisher(root_uri, active_branch.as_deref());
|
||||
Self::from_parts(root_uri, dataset, known_state, active_branch, publisher)
|
||||
}
|
||||
|
||||
fn snapshot_from_state(root_uri: &str, state: ManifestState) -> Snapshot {
|
||||
Snapshot {
|
||||
root_uri: root_uri.trim_end_matches('/').to_string(),
|
||||
version: state.version,
|
||||
entries: state
|
||||
.entries
|
||||
.into_iter()
|
||||
.map(|entry| (entry.table_key.clone(), entry))
|
||||
.collect(),
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn with_batch_publisher(mut self, publisher: Arc<dyn ManifestBatchPublisher>) -> Self {
|
||||
self.publisher = publisher;
|
||||
self
|
||||
}
|
||||
|
||||
/// Create a new repo at `root_uri` from a catalog.
|
||||
///
|
||||
/// Creates per-type Lance datasets and the namespace `__manifest` table.
|
||||
pub async fn init(root_uri: &str, catalog: &Catalog) -> Result<Self> {
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
let (dataset, known_state) = init_manifest_repo(root, catalog).await?;
|
||||
|
||||
Ok(Self::from_parts_with_default_publisher(
|
||||
root,
|
||||
dataset,
|
||||
known_state,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
/// Open an existing repo's manifest.
|
||||
pub async fn open(root_uri: &str) -> Result<Self> {
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
let (dataset, known_state) = open_manifest_repo(root, None).await?;
|
||||
Ok(Self::from_parts_with_default_publisher(
|
||||
root,
|
||||
dataset,
|
||||
known_state,
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
/// Open an existing repo's manifest at a specific branch.
|
||||
pub async fn open_at_branch(root_uri: &str, branch: &str) -> Result<Self> {
|
||||
if branch == "main" {
|
||||
return Self::open(root_uri).await;
|
||||
}
|
||||
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
let (dataset, known_state) = open_manifest_repo(root, Some(branch)).await?;
|
||||
Ok(Self::from_parts_with_default_publisher(
|
||||
root,
|
||||
dataset,
|
||||
known_state,
|
||||
Some(branch.to_string()),
|
||||
))
|
||||
}
|
||||
|
||||
pub async fn snapshot_at(
|
||||
root_uri: &str,
|
||||
branch: Option<&str>,
|
||||
version: u64,
|
||||
) -> Result<Snapshot> {
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
Ok(Self::snapshot_from_state(
|
||||
root,
|
||||
snapshot_state_at(root, branch, version).await?,
|
||||
))
|
||||
}
|
||||
|
||||
/// Return a Snapshot from the known manifest state. No storage I/O.
|
||||
pub fn snapshot(&self) -> Snapshot {
|
||||
Self::snapshot_from_state(&self.root_uri, self.known_state.clone())
|
||||
}
|
||||
|
||||
/// Re-read manifest from storage to see other writers' commits.
|
||||
pub async fn refresh(&mut self) -> Result<()> {
|
||||
self.dataset = open_manifest_dataset(&self.root_uri, self.active_branch.as_deref()).await?;
|
||||
self.known_state = read_manifest_state(&self.dataset).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Commit updated sub-table versions to the manifest.
|
||||
///
|
||||
/// Atomically inserts one immutable `table_version` row per updated table.
|
||||
/// The merge-insert commit on `__manifest` is the graph-level publish point.
|
||||
pub async fn commit(&mut self, updates: &[SubTableUpdate]) -> Result<u64> {
|
||||
if updates.is_empty() {
|
||||
return Ok(self.version());
|
||||
}
|
||||
|
||||
self.dataset = self.publisher.publish(updates).await?;
|
||||
|
||||
self.known_state = read_manifest_state(&self.dataset).await?;
|
||||
Ok(self.version())
|
||||
}
|
||||
|
||||
/// Current manifest version.
|
||||
pub fn version(&self) -> u64 {
|
||||
self.dataset.version().version
|
||||
}
|
||||
|
||||
pub fn active_branch(&self) -> Option<&str> {
|
||||
self.active_branch.as_deref()
|
||||
}
|
||||
|
||||
pub async fn create_branch(&mut self, name: &str) -> Result<()> {
|
||||
let mut ds = self.dataset.clone();
|
||||
ds.create_branch(name, self.version(), None)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn delete_branch(&mut self, name: &str) -> Result<()> {
|
||||
let uri = manifest_uri(&self.root_uri);
|
||||
let mut ds = Dataset::open(&uri)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
ds.delete_branch(name)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.dataset = open_manifest_dataset(&self.root_uri, self.active_branch.as_deref()).await?;
|
||||
self.known_state = read_manifest_state(&self.dataset).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn list_branches(&self) -> Result<Vec<String>> {
|
||||
let branches = self
|
||||
.dataset
|
||||
.list_branches()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let mut names: Vec<String> = branches.into_keys().filter(|name| name != "main").collect();
|
||||
names.sort();
|
||||
let mut all = vec!["main".to_string()];
|
||||
all.extend(names);
|
||||
Ok(all)
|
||||
}
|
||||
|
||||
pub async fn descendant_branches(&self, name: &str) -> Result<Vec<String>> {
|
||||
let branches = self
|
||||
.dataset
|
||||
.list_branches()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let mut frontier = vec![name.to_string()];
|
||||
let mut descendants = Vec::new();
|
||||
let mut seen = HashSet::new();
|
||||
|
||||
while let Some(parent) = frontier.pop() {
|
||||
let mut children = branches
|
||||
.iter()
|
||||
.filter_map(|(branch, contents)| {
|
||||
(contents.parent_branch.as_deref() == Some(parent.as_str()))
|
||||
.then_some(branch.clone())
|
||||
})
|
||||
.collect::<Vec<_>>();
|
||||
children.sort();
|
||||
for child in children {
|
||||
if seen.insert(child.clone()) {
|
||||
frontier.push(child.clone());
|
||||
descendants.push(child);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Ok(descendants)
|
||||
}
|
||||
|
||||
/// Root URI of the repo.
|
||||
pub fn root_uri(&self) -> &str {
|
||||
&self.root_uri
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "manifest/tests.rs"]
|
||||
mod tests;
|
||||
74
crates/omnigraph/src/db/manifest/layout.rs
Normal file
74
crates/omnigraph/src/db/manifest/layout.rs
Normal file
|
|
@ -0,0 +1,74 @@
|
|||
use lance::Dataset;
|
||||
use lance_namespace::Error as LanceNamespaceError;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
use crate::storage::{StorageKind, join_uri, storage_kind_for_uri};
|
||||
|
||||
const MANIFEST_DIR: &str = "__manifest";
|
||||
|
||||
pub(super) fn type_name_hash(name: &str) -> String {
|
||||
let mut h: u64 = 0xcbf29ce484222325;
|
||||
for byte in name.as_bytes() {
|
||||
h ^= *byte as u64;
|
||||
h = h.wrapping_mul(0x100000001b3);
|
||||
}
|
||||
format!("{:016x}", h)
|
||||
}
|
||||
|
||||
pub(super) fn manifest_uri(root: &str) -> String {
|
||||
format!("{}/{}", root.trim_end_matches('/'), MANIFEST_DIR)
|
||||
}
|
||||
|
||||
pub(super) async fn open_manifest_dataset(root_uri: &str, branch: Option<&str>) -> Result<Dataset> {
|
||||
let dataset = Dataset::open(&manifest_uri(root_uri.trim_end_matches('/')))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
match branch {
|
||||
Some(branch) if branch != "main" => dataset
|
||||
.checkout_branch(branch)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string())),
|
||||
_ => Ok(dataset),
|
||||
}
|
||||
}
|
||||
|
||||
fn format_table_version(version: u64) -> String {
|
||||
format!("{version:020}")
|
||||
}
|
||||
|
||||
pub(super) fn version_object_id(table_key: &str, version: u64) -> String {
|
||||
format!("{}${}", table_key, format_table_version(version))
|
||||
}
|
||||
|
||||
pub(super) fn table_id_to_key(request_id: Option<&Vec<String>>) -> lance_namespace::Result<String> {
|
||||
match request_id {
|
||||
Some(request_id) if request_id.len() == 1 && !request_id[0].is_empty() => {
|
||||
Ok(request_id[0].clone())
|
||||
}
|
||||
Some(request_id) => Err(LanceNamespaceError::invalid_input(format!(
|
||||
"expected single table id component, got {:?}",
|
||||
request_id
|
||||
))),
|
||||
None => Err(LanceNamespaceError::invalid_input("table id is required")),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn table_uri_for_path(root_uri: &str, table_path: &str, branch: Option<&str>) -> String {
|
||||
let mut dataset_location = join_uri(root_uri, table_path);
|
||||
if let Some(branch) = branch.filter(|branch| *branch != "main") {
|
||||
dataset_location = join_uri(&dataset_location, "tree");
|
||||
for segment in branch.split('/') {
|
||||
dataset_location = join_uri(&dataset_location, segment);
|
||||
}
|
||||
}
|
||||
match storage_kind_for_uri(root_uri) {
|
||||
StorageKind::Local => url::Url::from_file_path(&dataset_location)
|
||||
.map(|uri| uri.to_string())
|
||||
.unwrap_or(dataset_location),
|
||||
StorageKind::S3 => dataset_location,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn namespace_internal_error(message: impl Into<String>) -> LanceNamespaceError {
|
||||
LanceNamespaceError::namespace_source(Box::new(std::io::Error::other(message.into())))
|
||||
}
|
||||
244
crates/omnigraph/src/db/manifest/metadata.rs
Normal file
244
crates/omnigraph/src/db/manifest/metadata.rs
Normal file
|
|
@ -0,0 +1,244 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use lance::Dataset;
|
||||
use lance_namespace::Error as LanceNamespaceError;
|
||||
use lance_namespace::models::{CreateTableVersionRequest, TableVersion};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
use crate::storage::{StorageKind, join_uri, storage_kind_for_uri};
|
||||
|
||||
use super::layout::table_id_to_key;
|
||||
|
||||
pub(super) const OMNIGRAPH_ROW_COUNT_KEY: &str = "omnigraph.row_count";
|
||||
const OMNIGRAPH_TABLE_BRANCH_KEY: &str = "omnigraph.table_branch";
|
||||
|
||||
pub(super) fn namespace_version_metadata(
|
||||
row_count: u64,
|
||||
table_branch: Option<&str>,
|
||||
) -> HashMap<String, String> {
|
||||
let mut metadata =
|
||||
HashMap::from([(OMNIGRAPH_ROW_COUNT_KEY.to_string(), row_count.to_string())]);
|
||||
if let Some(table_branch) = table_branch {
|
||||
metadata.insert(
|
||||
OMNIGRAPH_TABLE_BRANCH_KEY.to_string(),
|
||||
table_branch.to_string(),
|
||||
);
|
||||
}
|
||||
metadata
|
||||
}
|
||||
|
||||
pub(super) fn parse_namespace_version_request(
|
||||
request: &CreateTableVersionRequest,
|
||||
) -> lance_namespace::Result<(String, u64, u64, Option<String>, TableVersionMetadata)> {
|
||||
let table_key = table_id_to_key(request.id.as_ref())?;
|
||||
let version = u64::try_from(request.version)
|
||||
.map_err(|_| LanceNamespaceError::invalid_input("table version must be non-negative"))?;
|
||||
let metadata = request.metadata.as_ref().ok_or_else(|| {
|
||||
LanceNamespaceError::invalid_input("version metadata is required for Omnigraph rows")
|
||||
})?;
|
||||
let row_count = metadata
|
||||
.get(OMNIGRAPH_ROW_COUNT_KEY)
|
||||
.ok_or_else(|| {
|
||||
LanceNamespaceError::invalid_input("missing omnigraph.row_count in metadata")
|
||||
})?
|
||||
.parse::<u64>()
|
||||
.map_err(|e| {
|
||||
LanceNamespaceError::invalid_input(format!("invalid omnigraph.row_count value: {}", e))
|
||||
})?;
|
||||
let table_branch = metadata.get(OMNIGRAPH_TABLE_BRANCH_KEY).cloned();
|
||||
let version_metadata = TableVersionMetadata {
|
||||
manifest_path: request.manifest_path.clone(),
|
||||
manifest_size: request.manifest_size.map(|size| size as u64),
|
||||
e_tag: request.e_tag.clone(),
|
||||
naming_scheme: request.naming_scheme.clone(),
|
||||
};
|
||||
|
||||
Ok((
|
||||
table_key,
|
||||
version,
|
||||
row_count,
|
||||
table_branch,
|
||||
version_metadata,
|
||||
))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub(crate) struct TableVersionMetadata {
|
||||
manifest_path: String,
|
||||
manifest_size: Option<u64>,
|
||||
e_tag: Option<String>,
|
||||
naming_scheme: Option<String>,
|
||||
}
|
||||
|
||||
impl TableVersionMetadata {
|
||||
pub(crate) fn from_dataset(
|
||||
root_uri: &str,
|
||||
table_path: &str,
|
||||
dataset: &Dataset,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
manifest_path: full_manifest_object_store_path(
|
||||
root_uri,
|
||||
table_path,
|
||||
&dataset.manifest_location().path.to_string(),
|
||||
)?,
|
||||
manifest_size: dataset.manifest_location().size,
|
||||
e_tag: dataset.manifest_location().e_tag.clone(),
|
||||
naming_scheme: Some(format!("{:?}", dataset.manifest_location().naming_scheme)),
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn from_json_str(value: &str) -> Result<Self> {
|
||||
serde_json::from_str(value).map_err(|e| {
|
||||
OmniError::manifest_internal(format!("failed to decode manifest metadata: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn to_json_string(&self) -> Result<String> {
|
||||
serde_json::to_string(self).map_err(|e| {
|
||||
OmniError::manifest_internal(format!("failed to encode manifest metadata: {e}"))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn manifest_path(&self) -> &str {
|
||||
&self.manifest_path
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn manifest_size(&self) -> Option<u64> {
|
||||
self.manifest_size
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn e_tag(&self) -> Option<&str> {
|
||||
self.e_tag.as_deref()
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn naming_scheme(&self) -> Option<&str> {
|
||||
self.naming_scheme.as_deref()
|
||||
}
|
||||
|
||||
pub(crate) fn to_create_table_version_request(
|
||||
&self,
|
||||
table_key: &str,
|
||||
table_version: u64,
|
||||
row_count: u64,
|
||||
table_branch: Option<&str>,
|
||||
) -> CreateTableVersionRequest {
|
||||
let mut request =
|
||||
CreateTableVersionRequest::new(table_version as i64, self.manifest_path.clone());
|
||||
request.id = Some(vec![table_key.to_string()]);
|
||||
request.manifest_size = self.manifest_size.map(|size| size as i64);
|
||||
request.e_tag = self.e_tag.clone();
|
||||
request.naming_scheme = self.naming_scheme.clone();
|
||||
request.metadata = Some(namespace_version_metadata(row_count, table_branch));
|
||||
request
|
||||
}
|
||||
|
||||
pub(super) fn to_namespace_version(&self, version: u64) -> TableVersion {
|
||||
self.to_namespace_version_with_details(version, None, None)
|
||||
}
|
||||
|
||||
pub(super) fn to_namespace_version_with_details(
|
||||
&self,
|
||||
version: u64,
|
||||
timestamp_millis: Option<i64>,
|
||||
metadata: Option<HashMap<String, String>>,
|
||||
) -> TableVersion {
|
||||
let mut metadata = metadata.unwrap_or_default();
|
||||
if let Some(naming_scheme) = &self.naming_scheme {
|
||||
metadata.insert("naming_scheme".to_string(), naming_scheme.clone());
|
||||
}
|
||||
|
||||
TableVersion {
|
||||
version: version as i64,
|
||||
manifest_path: self.manifest_path.clone(),
|
||||
manifest_size: self.manifest_size.map(|size| size as i64),
|
||||
e_tag: self.e_tag.clone(),
|
||||
timestamp_millis,
|
||||
metadata: (!metadata.is_empty()).then_some(metadata),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn object_store_path_from_uri(uri: &str) -> Result<String> {
|
||||
match storage_kind_for_uri(uri) {
|
||||
StorageKind::Local => {
|
||||
if uri.strip_prefix("file://").is_some() {
|
||||
let path = url::Url::parse(uri)
|
||||
.map_err(|e| {
|
||||
OmniError::manifest_internal(format!("invalid file uri '{}': {}", uri, e))
|
||||
})?
|
||||
.to_file_path()
|
||||
.map_err(|_| {
|
||||
OmniError::manifest_internal(format!("invalid file uri '{}'", uri))
|
||||
})?;
|
||||
Ok(path.to_string_lossy().to_string())
|
||||
} else {
|
||||
Ok(uri.to_string())
|
||||
}
|
||||
}
|
||||
StorageKind::S3 => {
|
||||
let url = url::Url::parse(uri).map_err(|e| {
|
||||
OmniError::manifest_internal(format!("invalid s3 uri '{}': {}", uri, e))
|
||||
})?;
|
||||
Ok(url.path().trim_start_matches('/').to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn full_manifest_object_store_path(
|
||||
root_uri: &str,
|
||||
table_path: &str,
|
||||
manifest_path: &str,
|
||||
) -> Result<String> {
|
||||
if manifest_path.contains("://") {
|
||||
return object_store_path_from_uri(manifest_path);
|
||||
}
|
||||
|
||||
if manifest_path.contains(table_path) {
|
||||
return Ok(manifest_path.to_string());
|
||||
}
|
||||
|
||||
let dataset_uri = join_uri(root_uri, table_path);
|
||||
let dataset_path = object_store_path_from_uri(&dataset_uri)?;
|
||||
let manifest_path = manifest_path.trim_start_matches('/');
|
||||
|
||||
if manifest_path.is_empty() {
|
||||
return Ok(dataset_path);
|
||||
}
|
||||
|
||||
Ok(format!(
|
||||
"{}/{}",
|
||||
dataset_path.trim_end_matches('/'),
|
||||
manifest_path
|
||||
))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(super) async fn table_version_metadata_for_state(
|
||||
root_uri: &str,
|
||||
table_path: &str,
|
||||
branch: Option<&str>,
|
||||
version: u64,
|
||||
) -> Result<TableVersionMetadata> {
|
||||
let full_path = format!("{}/{}", root_uri.trim_end_matches('/'), table_path);
|
||||
let ds = Dataset::open(&full_path)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let ds = match branch {
|
||||
Some(branch) => ds
|
||||
.checkout_branch(branch)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?,
|
||||
None => ds,
|
||||
};
|
||||
let ds = ds
|
||||
.checkout_version(version)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
TableVersionMetadata::from_dataset(root_uri, table_path, &ds)
|
||||
}
|
||||
549
crates/omnigraph/src/db/manifest/namespace.rs
Normal file
549
crates/omnigraph/src/db/manifest/namespace.rs
Normal file
|
|
@ -0,0 +1,549 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::builder::DatasetBuilder;
|
||||
use lance_namespace::models::{
|
||||
CreateTableVersionRequest, CreateTableVersionResponse, DescribeTableRequest,
|
||||
DescribeTableResponse, DescribeTableVersionRequest, DescribeTableVersionResponse,
|
||||
ListTableVersionsRequest, ListTableVersionsResponse, TableExistsRequest, TableVersion,
|
||||
};
|
||||
use lance_namespace::{Error as LanceNamespaceError, LanceNamespace, NamespaceError};
|
||||
use lance_table::io::commit::ManifestNamingScheme;
|
||||
use object_store::{Error as ObjectStoreError, ObjectStore as _, PutMode, PutOptions, path::Path};
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
use super::layout::{
|
||||
namespace_internal_error, open_manifest_dataset, table_id_to_key, table_uri_for_path,
|
||||
};
|
||||
use super::metadata::{
|
||||
TableVersionMetadata, namespace_version_metadata, parse_namespace_version_request,
|
||||
};
|
||||
use super::publisher::GraphNamespacePublisher;
|
||||
use super::state::{ManifestState, SubTableEntry, read_manifest_entries, read_manifest_state};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct BranchManifestNamespace {
|
||||
root_uri: String,
|
||||
branch: Option<String>,
|
||||
}
|
||||
|
||||
impl BranchManifestNamespace {
|
||||
fn new(root_uri: &str, branch: Option<&str>) -> Self {
|
||||
Self {
|
||||
root_uri: root_uri.trim_end_matches('/').to_string(),
|
||||
branch: branch
|
||||
.filter(|branch| *branch != "main")
|
||||
.map(ToOwned::to_owned),
|
||||
}
|
||||
}
|
||||
|
||||
async fn dataset(&self) -> Result<Dataset> {
|
||||
open_manifest_dataset(&self.root_uri, self.branch.as_deref()).await
|
||||
}
|
||||
|
||||
async fn state(&self) -> Result<ManifestState> {
|
||||
let dataset = self.dataset().await?;
|
||||
read_manifest_state(&dataset).await
|
||||
}
|
||||
|
||||
async fn version_entries(&self) -> Result<Vec<SubTableEntry>> {
|
||||
let dataset = self.dataset().await?;
|
||||
read_manifest_entries(&dataset).await
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct StagedTableNamespace {
|
||||
root_uri: String,
|
||||
table_id: Vec<String>,
|
||||
table_path: String,
|
||||
branch: Option<String>,
|
||||
}
|
||||
|
||||
impl StagedTableNamespace {
|
||||
fn new(root_uri: &str, table_key: &str, table_path: &str, branch: Option<&str>) -> Self {
|
||||
Self {
|
||||
root_uri: root_uri.trim_end_matches('/').to_string(),
|
||||
table_id: vec![table_key.to_string()],
|
||||
table_path: table_path.to_string(),
|
||||
branch: branch
|
||||
.filter(|branch| *branch != "main")
|
||||
.map(ToOwned::to_owned),
|
||||
}
|
||||
}
|
||||
|
||||
fn table_key(&self) -> &str {
|
||||
&self.table_id[0]
|
||||
}
|
||||
|
||||
fn table_uri(&self) -> String {
|
||||
table_uri_for_path(&self.root_uri, &self.table_path, self.branch.as_deref())
|
||||
}
|
||||
|
||||
fn ensure_request_table(
|
||||
&self,
|
||||
request_id: Option<&Vec<String>>,
|
||||
) -> lance_namespace::Result<()> {
|
||||
match request_id {
|
||||
Some(request_id) if request_id == &self.table_id => Ok(()),
|
||||
Some(request_id) => Err(LanceNamespaceError::namespace_source(Box::new(
|
||||
NamespaceError::TableNotFound {
|
||||
message: format!("table {:?} not found", request_id),
|
||||
},
|
||||
))),
|
||||
None => Err(LanceNamespaceError::invalid_input("table id is required")),
|
||||
}
|
||||
}
|
||||
|
||||
async fn open_head(&self) -> Result<Dataset> {
|
||||
Dataset::open(&self.table_uri())
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
async fn open_version(&self, version: u64) -> Result<Dataset> {
|
||||
let ds = self.open_head().await?;
|
||||
if ds.version().version == version {
|
||||
Ok(ds)
|
||||
} else {
|
||||
ds.checkout_version(version)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
fn to_table_version(
|
||||
&self,
|
||||
dataset: &Dataset,
|
||||
version: &lance::dataset::Version,
|
||||
) -> Result<TableVersion> {
|
||||
let metadata =
|
||||
TableVersionMetadata::from_dataset(&self.root_uri, &self.table_path, dataset)?;
|
||||
Ok(metadata.to_namespace_version_with_details(
|
||||
version.version,
|
||||
Some(version.timestamp.timestamp_millis()),
|
||||
Some(
|
||||
version
|
||||
.metadata
|
||||
.iter()
|
||||
.map(|(k, v)| (k.clone(), v.clone()))
|
||||
.collect(),
|
||||
),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn branch_manifest_namespace(
|
||||
root_uri: &str,
|
||||
branch: Option<&str>,
|
||||
) -> Arc<dyn LanceNamespace> {
|
||||
Arc::new(BranchManifestNamespace::new(root_uri, branch))
|
||||
}
|
||||
|
||||
pub(crate) fn staged_table_namespace(
|
||||
root_uri: &str,
|
||||
table_key: &str,
|
||||
table_path: &str,
|
||||
branch: Option<&str>,
|
||||
) -> Arc<dyn LanceNamespace> {
|
||||
Arc::new(StagedTableNamespace::new(
|
||||
root_uri, table_key, table_path, branch,
|
||||
))
|
||||
}
|
||||
|
||||
async fn load_table_from_namespace(
|
||||
namespace: Arc<dyn LanceNamespace>,
|
||||
table_key: &str,
|
||||
branch: Option<&str>,
|
||||
version: Option<u64>,
|
||||
) -> Result<Dataset> {
|
||||
let builder = DatasetBuilder::from_namespace(namespace, vec![table_key.to_string()])
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let builder = match (branch, version) {
|
||||
(Some(branch), version) => builder.with_branch(branch, version),
|
||||
(None, Some(version)) => builder.with_version(version),
|
||||
(None, None) => builder,
|
||||
};
|
||||
builder
|
||||
.load()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub(crate) async fn open_table_at_version_from_manifest(
|
||||
root_uri: &str,
|
||||
table_key: &str,
|
||||
branch: Option<&str>,
|
||||
version: u64,
|
||||
) -> Result<Dataset> {
|
||||
load_table_from_namespace(
|
||||
branch_manifest_namespace(root_uri, branch),
|
||||
table_key,
|
||||
branch,
|
||||
Some(version),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LanceNamespace for BranchManifestNamespace {
|
||||
fn namespace_id(&self) -> String {
|
||||
"__manifest".to_string()
|
||||
}
|
||||
|
||||
async fn describe_table(
|
||||
&self,
|
||||
request: DescribeTableRequest,
|
||||
) -> lance_namespace::Result<DescribeTableResponse> {
|
||||
let table_key = table_id_to_key(request.id.as_ref())?;
|
||||
let state = self
|
||||
.state()
|
||||
.await
|
||||
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?;
|
||||
let entry = state
|
||||
.entries
|
||||
.into_iter()
|
||||
.find(|entry| entry.table_key == table_key);
|
||||
let entry = entry.ok_or_else(|| {
|
||||
LanceNamespaceError::namespace_source(Box::new(NamespaceError::TableNotFound {
|
||||
message: format!("table {} not found", table_key),
|
||||
}))
|
||||
})?;
|
||||
let table_uri = table_uri_for_path(
|
||||
&self.root_uri,
|
||||
&entry.table_path,
|
||||
entry.table_branch.as_deref(),
|
||||
);
|
||||
|
||||
Ok(DescribeTableResponse {
|
||||
table: Some(entry.table_key.clone()),
|
||||
namespace: Some(Vec::new()),
|
||||
version: Some(entry.table_version as i64),
|
||||
location: Some(table_uri.clone()),
|
||||
table_uri: request.with_table_uri.unwrap_or(false).then_some(table_uri),
|
||||
schema: None,
|
||||
storage_options: None,
|
||||
stats: None,
|
||||
metadata: None,
|
||||
properties: None,
|
||||
managed_versioning: Some(true),
|
||||
})
|
||||
}
|
||||
|
||||
async fn table_exists(&self, request: TableExistsRequest) -> lance_namespace::Result<()> {
|
||||
let table_key = table_id_to_key(request.id.as_ref())?;
|
||||
let state = self
|
||||
.state()
|
||||
.await
|
||||
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?;
|
||||
if state
|
||||
.entries
|
||||
.iter()
|
||||
.any(|entry| entry.table_key == table_key)
|
||||
{
|
||||
Ok(())
|
||||
} else {
|
||||
Err(LanceNamespaceError::namespace_source(Box::new(
|
||||
NamespaceError::TableNotFound {
|
||||
message: format!("table {} not found", table_key),
|
||||
},
|
||||
)))
|
||||
}
|
||||
}
|
||||
|
||||
async fn list_table_versions(
|
||||
&self,
|
||||
request: ListTableVersionsRequest,
|
||||
) -> lance_namespace::Result<ListTableVersionsResponse> {
|
||||
let table_key = table_id_to_key(request.id.as_ref())?;
|
||||
let mut versions: Vec<TableVersion> = self
|
||||
.version_entries()
|
||||
.await
|
||||
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?
|
||||
.into_iter()
|
||||
.filter(|entry| entry.table_key == table_key)
|
||||
.map(|entry| {
|
||||
entry
|
||||
.version_metadata
|
||||
.to_namespace_version(entry.table_version)
|
||||
})
|
||||
.collect();
|
||||
|
||||
if request.descending.unwrap_or(false) {
|
||||
versions.sort_by(|a, b| b.version.cmp(&a.version));
|
||||
} else {
|
||||
versions.sort_by(|a, b| a.version.cmp(&b.version));
|
||||
}
|
||||
if let Some(limit) = request.limit {
|
||||
versions.truncate(limit as usize);
|
||||
}
|
||||
|
||||
Ok(ListTableVersionsResponse {
|
||||
versions,
|
||||
page_token: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn describe_table_version(
|
||||
&self,
|
||||
request: DescribeTableVersionRequest,
|
||||
) -> lance_namespace::Result<DescribeTableVersionResponse> {
|
||||
let table_key = table_id_to_key(request.id.as_ref())?;
|
||||
let version = request
|
||||
.version
|
||||
.ok_or_else(|| LanceNamespaceError::invalid_input("table version is required"))?;
|
||||
let version = u64::try_from(version).map_err(|_| {
|
||||
LanceNamespaceError::invalid_input("table version must be non-negative")
|
||||
})?;
|
||||
let entry = self
|
||||
.version_entries()
|
||||
.await
|
||||
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?
|
||||
.into_iter()
|
||||
.find(|entry| entry.table_key == table_key && entry.table_version == version)
|
||||
.ok_or_else(|| {
|
||||
LanceNamespaceError::namespace_source(Box::new(
|
||||
NamespaceError::TableVersionNotFound {
|
||||
message: format!("table version {} not found for {}", version, table_key),
|
||||
},
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(DescribeTableVersionResponse::new(
|
||||
entry
|
||||
.version_metadata
|
||||
.to_namespace_version(entry.table_version),
|
||||
))
|
||||
}
|
||||
|
||||
async fn create_table_version(
|
||||
&self,
|
||||
request: CreateTableVersionRequest,
|
||||
) -> lance_namespace::Result<CreateTableVersionResponse> {
|
||||
let (table_key, table_version, row_count, table_branch, version_metadata) =
|
||||
parse_namespace_version_request(&request)?;
|
||||
GraphNamespacePublisher::new(&self.root_uri, self.branch.as_deref())
|
||||
.publish_requests(std::slice::from_ref(&request))
|
||||
.await
|
||||
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))?;
|
||||
let mut response = CreateTableVersionResponse::new();
|
||||
response.version = Some(Box::new(
|
||||
version_metadata.to_namespace_version_with_details(
|
||||
table_version,
|
||||
None,
|
||||
Some(namespace_version_metadata(
|
||||
row_count,
|
||||
table_branch.as_deref(),
|
||||
)),
|
||||
),
|
||||
));
|
||||
let _ = table_key;
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl LanceNamespace for StagedTableNamespace {
|
||||
fn namespace_id(&self) -> String {
|
||||
"__manifest".to_string()
|
||||
}
|
||||
|
||||
async fn describe_table(
|
||||
&self,
|
||||
request: DescribeTableRequest,
|
||||
) -> lance_namespace::Result<DescribeTableResponse> {
|
||||
self.ensure_request_table(request.id.as_ref())?;
|
||||
let ds = self
|
||||
.open_head()
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
let table_uri = self.table_uri();
|
||||
Ok(DescribeTableResponse {
|
||||
table: Some(self.table_key().to_string()),
|
||||
namespace: Some(Vec::new()),
|
||||
version: Some(ds.version().version as i64),
|
||||
location: Some(table_uri.clone()),
|
||||
table_uri: request.with_table_uri.unwrap_or(false).then_some(table_uri),
|
||||
schema: None,
|
||||
storage_options: None,
|
||||
stats: None,
|
||||
metadata: None,
|
||||
properties: None,
|
||||
managed_versioning: Some(true),
|
||||
})
|
||||
}
|
||||
|
||||
async fn table_exists(&self, request: TableExistsRequest) -> lance_namespace::Result<()> {
|
||||
self.ensure_request_table(request.id.as_ref())?;
|
||||
self.open_head()
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|e| LanceNamespaceError::namespace_source(Box::new(e)))
|
||||
}
|
||||
|
||||
async fn list_table_versions(
|
||||
&self,
|
||||
request: ListTableVersionsRequest,
|
||||
) -> lance_namespace::Result<ListTableVersionsResponse> {
|
||||
self.ensure_request_table(request.id.as_ref())?;
|
||||
if request.limit == Some(0) {
|
||||
return Ok(ListTableVersionsResponse {
|
||||
versions: Vec::new(),
|
||||
page_token: None,
|
||||
});
|
||||
}
|
||||
let head = self
|
||||
.open_head()
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
let dataset_versions = head
|
||||
.versions()
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
let mut versions = Vec::with_capacity(dataset_versions.len());
|
||||
for version in dataset_versions {
|
||||
let dataset = if version.version == head.version().version {
|
||||
head.clone()
|
||||
} else {
|
||||
head.checkout_version(version.version)
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?
|
||||
};
|
||||
versions.push(
|
||||
self.to_table_version(&dataset, &version)
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?,
|
||||
);
|
||||
}
|
||||
if request.descending.unwrap_or(false) {
|
||||
versions.sort_by(|a, b| b.version.cmp(&a.version));
|
||||
} else {
|
||||
versions.sort_by(|a, b| a.version.cmp(&b.version));
|
||||
}
|
||||
if let Some(limit) = request.limit {
|
||||
versions.truncate(limit as usize);
|
||||
}
|
||||
Ok(ListTableVersionsResponse {
|
||||
versions,
|
||||
page_token: None,
|
||||
})
|
||||
}
|
||||
|
||||
async fn describe_table_version(
|
||||
&self,
|
||||
request: DescribeTableVersionRequest,
|
||||
) -> lance_namespace::Result<DescribeTableVersionResponse> {
|
||||
self.ensure_request_table(request.id.as_ref())?;
|
||||
let version = request
|
||||
.version
|
||||
.ok_or_else(|| LanceNamespaceError::invalid_input("table version is required"))?;
|
||||
let version = u64::try_from(version).map_err(|_| {
|
||||
LanceNamespaceError::invalid_input("table version must be non-negative")
|
||||
})?;
|
||||
let ds = self
|
||||
.open_version(version)
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
let version_info = self
|
||||
.to_table_version(
|
||||
&ds,
|
||||
&lance::dataset::Version {
|
||||
version: ds.version().version,
|
||||
timestamp: ds.version().timestamp,
|
||||
metadata: ds.version().metadata,
|
||||
},
|
||||
)
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
Ok(DescribeTableVersionResponse::new(version_info))
|
||||
}
|
||||
|
||||
async fn create_table_version(
|
||||
&self,
|
||||
request: CreateTableVersionRequest,
|
||||
) -> lance_namespace::Result<CreateTableVersionResponse> {
|
||||
self.ensure_request_table(request.id.as_ref())?;
|
||||
let version = u64::try_from(request.version).map_err(|_| {
|
||||
LanceNamespaceError::invalid_input("table version must be non-negative")
|
||||
})?;
|
||||
let naming_scheme = match request.naming_scheme.as_deref() {
|
||||
Some("V1") => ManifestNamingScheme::V1,
|
||||
_ => ManifestNamingScheme::V2,
|
||||
};
|
||||
let (object_store, base_path, _) = DatasetBuilder::from_uri(&self.table_uri())
|
||||
.build_object_store()
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
let staging_path = Path::from(request.manifest_path.clone());
|
||||
let manifest_data = object_store
|
||||
.inner
|
||||
.get(&staging_path)
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
let final_path = naming_scheme.manifest_path(&base_path, version);
|
||||
object_store
|
||||
.inner
|
||||
.put_opts(
|
||||
&final_path,
|
||||
manifest_data.into(),
|
||||
PutOptions {
|
||||
mode: PutMode::Create,
|
||||
..Default::default()
|
||||
},
|
||||
)
|
||||
.await
|
||||
.map_err(|e| match e {
|
||||
ObjectStoreError::AlreadyExists { .. } | ObjectStoreError::Precondition { .. } => {
|
||||
LanceNamespaceError::namespace_source(Box::new(
|
||||
NamespaceError::ConcurrentModification {
|
||||
message: format!(
|
||||
"table version {} already exists for {}",
|
||||
version,
|
||||
self.table_key()
|
||||
),
|
||||
},
|
||||
))
|
||||
}
|
||||
other => namespace_internal_error(other.to_string()),
|
||||
})?;
|
||||
let meta = object_store
|
||||
.inner
|
||||
.head(&final_path)
|
||||
.await
|
||||
.map_err(|e| namespace_internal_error(e.to_string()))?;
|
||||
match object_store.inner.delete(&staging_path).await {
|
||||
Ok(_) | Err(ObjectStoreError::NotFound { .. }) => {}
|
||||
Err(e) => return Err(namespace_internal_error(e.to_string())),
|
||||
}
|
||||
|
||||
let mut response = CreateTableVersionResponse::new();
|
||||
response.version = Some(Box::new(TableVersion {
|
||||
version: version as i64,
|
||||
manifest_path: final_path.to_string(),
|
||||
manifest_size: Some(meta.size as i64),
|
||||
e_tag: meta.e_tag,
|
||||
timestamp_millis: None,
|
||||
metadata: request.metadata,
|
||||
}));
|
||||
Ok(response)
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn open_table_head_for_write(
|
||||
root_uri: &str,
|
||||
table_key: &str,
|
||||
table_path: &str,
|
||||
branch: Option<&str>,
|
||||
) -> Result<Dataset> {
|
||||
load_table_from_namespace(
|
||||
staged_table_namespace(root_uri, table_key, table_path, branch),
|
||||
table_key,
|
||||
branch,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
}
|
||||
236
crates/omnigraph/src/db/manifest/publisher.rs
Normal file
236
crates/omnigraph/src/db/manifest/publisher.rs
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
//! Graph-level batch publish over the namespace `__manifest` table.
|
||||
//!
|
||||
//! Lance now owns most of the table/version control plane for Omnigraph:
|
||||
//! table storage, table-local versioning, namespace lookup, and native table
|
||||
//! history. This module exists for the remaining graph-specific gap:
|
||||
//! Omnigraph needs one atomic publish point across multiple tables and the
|
||||
//! current Rust namespace surface does not expose a branch-aware
|
||||
//! `BatchCreateTableVersions` path for `DirectoryNamespace`.
|
||||
//!
|
||||
//! Until Lance exposes that operation directly, this publisher owns only:
|
||||
//! - validating batch publish invariants against the current `__manifest` state
|
||||
//! - atomically inserting immutable `table_version` rows into `__manifest`
|
||||
//! - returning the refreshed manifest dataset that defines the visible graph
|
||||
//!
|
||||
//! This module should disappear once Lance Rust can do branch-aware batch table
|
||||
//! version publication against a managed namespace manifest.
|
||||
|
||||
use async_trait::async_trait;
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::RecordBatchIterator;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::{MergeInsertBuilder, WhenMatched, WhenNotMatched};
|
||||
use lance_namespace::NamespaceError;
|
||||
use lance_namespace::models::CreateTableVersionRequest;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
use super::layout::{open_manifest_dataset, version_object_id};
|
||||
use super::metadata::parse_namespace_version_request;
|
||||
use super::state::{
|
||||
manifest_rows_batch, manifest_schema, read_manifest_entries, read_manifest_state,
|
||||
};
|
||||
use super::{OBJECT_TYPE_TABLE_VERSION, SubTableEntry, SubTableUpdate};
|
||||
|
||||
#[async_trait]
|
||||
pub(super) trait ManifestBatchPublisher: Send + Sync {
|
||||
async fn publish(&self, updates: &[SubTableUpdate]) -> Result<Dataset>;
|
||||
}
|
||||
|
||||
pub(super) struct GraphNamespacePublisher {
|
||||
root_uri: String,
|
||||
branch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
struct PendingVersionRow {
|
||||
object_id: String,
|
||||
metadata: Option<String>,
|
||||
table_key: String,
|
||||
table_version: Option<u64>,
|
||||
table_branch: Option<String>,
|
||||
row_count: Option<u64>,
|
||||
}
|
||||
|
||||
impl GraphNamespacePublisher {
|
||||
pub(super) fn new(root_uri: &str, branch: Option<&str>) -> Self {
|
||||
Self {
|
||||
root_uri: root_uri.trim_end_matches('/').to_string(),
|
||||
branch: branch
|
||||
.filter(|branch| *branch != "main")
|
||||
.map(ToOwned::to_owned),
|
||||
}
|
||||
}
|
||||
|
||||
async fn dataset(&self) -> Result<Dataset> {
|
||||
open_manifest_dataset(&self.root_uri, self.branch.as_deref()).await
|
||||
}
|
||||
|
||||
async fn load_publish_state(
|
||||
&self,
|
||||
) -> Result<(
|
||||
Dataset,
|
||||
HashMap<String, ()>,
|
||||
HashMap<(String, u64), SubTableEntry>,
|
||||
)> {
|
||||
let dataset = self.dataset().await?;
|
||||
let current = read_manifest_state(&dataset).await?;
|
||||
let existing_entries = read_manifest_entries(&dataset).await?;
|
||||
let known_tables = current
|
||||
.entries
|
||||
.iter()
|
||||
.map(|entry| (entry.table_key.clone(), ()))
|
||||
.collect();
|
||||
let existing_versions = existing_entries
|
||||
.iter()
|
||||
.map(|entry| {
|
||||
(
|
||||
(entry.table_key.clone(), entry.table_version),
|
||||
entry.clone(),
|
||||
)
|
||||
})
|
||||
.collect();
|
||||
Ok((dataset, known_tables, existing_versions))
|
||||
}
|
||||
|
||||
fn build_pending_rows(
|
||||
requests: &[CreateTableVersionRequest],
|
||||
known_tables: &HashMap<String, ()>,
|
||||
existing_versions: &HashMap<(String, u64), SubTableEntry>,
|
||||
) -> Result<Vec<PendingVersionRow>> {
|
||||
let mut request_versions = HashMap::<(String, u64), ()>::new();
|
||||
let mut rows = Vec::with_capacity(requests.len());
|
||||
|
||||
for request in requests {
|
||||
let (table_key, table_version, row_count, table_branch, version_metadata) =
|
||||
parse_namespace_version_request(request)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
if !known_tables.contains_key(table_key.as_str()) {
|
||||
return Err(OmniError::Lance(
|
||||
NamespaceError::TableNotFound {
|
||||
message: format!("table {} not found", table_key),
|
||||
}
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
if request_versions
|
||||
.insert((table_key.clone(), table_version), ())
|
||||
.is_some()
|
||||
{
|
||||
return Err(OmniError::Lance(
|
||||
NamespaceError::ConcurrentModification {
|
||||
message: format!(
|
||||
"table version {} already exists for {}",
|
||||
table_version, table_key
|
||||
),
|
||||
}
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
if let Some(existing) = existing_versions.get(&(table_key.clone(), table_version)) {
|
||||
let is_owner_branch_handoff =
|
||||
existing.row_count == row_count && existing.table_branch != table_branch;
|
||||
if !is_owner_branch_handoff {
|
||||
return Err(OmniError::Lance(
|
||||
NamespaceError::ConcurrentModification {
|
||||
message: format!(
|
||||
"table version {} already exists for {}",
|
||||
table_version, table_key
|
||||
),
|
||||
}
|
||||
.to_string(),
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
rows.push(PendingVersionRow {
|
||||
object_id: version_object_id(&table_key, table_version),
|
||||
metadata: Some(version_metadata.to_json_string()?),
|
||||
table_key,
|
||||
table_version: Some(table_version),
|
||||
table_branch,
|
||||
row_count: Some(row_count),
|
||||
});
|
||||
}
|
||||
|
||||
Ok(rows)
|
||||
}
|
||||
|
||||
fn pending_rows_to_batch(rows: Vec<PendingVersionRow>) -> Result<arrow_array::RecordBatch> {
|
||||
let mut object_ids = Vec::with_capacity(rows.len());
|
||||
let mut object_types = Vec::with_capacity(rows.len());
|
||||
let mut locations: Vec<Option<String>> = Vec::with_capacity(rows.len());
|
||||
let mut metadata = Vec::with_capacity(rows.len());
|
||||
let mut table_keys = Vec::with_capacity(rows.len());
|
||||
let mut table_versions: Vec<Option<u64>> = Vec::with_capacity(rows.len());
|
||||
let mut table_branches = Vec::with_capacity(rows.len());
|
||||
let mut row_counts: Vec<Option<u64>> = Vec::with_capacity(rows.len());
|
||||
|
||||
for row in rows {
|
||||
object_ids.push(row.object_id);
|
||||
object_types.push(OBJECT_TYPE_TABLE_VERSION.to_string());
|
||||
locations.push(None);
|
||||
metadata.push(row.metadata);
|
||||
table_keys.push(row.table_key);
|
||||
table_versions.push(row.table_version);
|
||||
table_branches.push(row.table_branch);
|
||||
row_counts.push(row.row_count);
|
||||
}
|
||||
|
||||
manifest_rows_batch(
|
||||
object_ids,
|
||||
object_types,
|
||||
locations,
|
||||
metadata,
|
||||
table_keys,
|
||||
table_versions,
|
||||
table_branches,
|
||||
row_counts,
|
||||
)
|
||||
}
|
||||
|
||||
async fn merge_rows(&self, dataset: Dataset, rows: Vec<PendingVersionRow>) -> Result<Dataset> {
|
||||
let batch = Self::pending_rows_to_batch(rows)?;
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], manifest_schema());
|
||||
let dataset = Arc::new(dataset);
|
||||
let mut merge_builder = MergeInsertBuilder::try_new(dataset, vec!["object_id".to_string()])
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
merge_builder.when_matched(WhenMatched::UpdateAll);
|
||||
merge_builder.when_not_matched(WhenNotMatched::InsertAll);
|
||||
merge_builder.conflict_retries(5);
|
||||
merge_builder.use_index(false);
|
||||
let (new_dataset, _stats) = merge_builder
|
||||
.try_build()
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.execute_reader(Box::new(reader))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
Ok(Arc::try_unwrap(new_dataset).unwrap_or_else(|arc| (*arc).clone()))
|
||||
}
|
||||
|
||||
pub(super) async fn publish_requests(
|
||||
&self,
|
||||
requests: &[CreateTableVersionRequest],
|
||||
) -> Result<Dataset> {
|
||||
if requests.is_empty() {
|
||||
return self.dataset().await;
|
||||
}
|
||||
|
||||
let (dataset, known_tables, existing_versions) = self.load_publish_state().await?;
|
||||
let rows = Self::build_pending_rows(requests, &known_tables, &existing_versions)?;
|
||||
self.merge_rows(dataset, rows).await
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl ManifestBatchPublisher for GraphNamespacePublisher {
|
||||
async fn publish(&self, updates: &[SubTableUpdate]) -> Result<Dataset> {
|
||||
let requests: Vec<CreateTableVersionRequest> = updates
|
||||
.iter()
|
||||
.map(SubTableUpdate::to_create_table_version_request)
|
||||
.collect();
|
||||
self.publish_requests(&requests).await
|
||||
}
|
||||
}
|
||||
133
crates/omnigraph/src/db/manifest/repo.rs
Normal file
133
crates/omnigraph/src/db/manifest/repo.rs
Normal file
|
|
@ -0,0 +1,133 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use arrow_array::{RecordBatch, RecordBatchIterator};
|
||||
use arrow_schema::SchemaRef;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
use lance_file::version::LanceFileVersion;
|
||||
use omnigraph_compiler::catalog::Catalog;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
use super::TABLE_VERSION_MANAGEMENT_KEY;
|
||||
use super::layout::{manifest_uri, open_manifest_dataset, type_name_hash};
|
||||
use super::metadata::TableVersionMetadata;
|
||||
use super::state::{
|
||||
ManifestState, SubTableEntry, entries_to_batch, manifest_schema, read_manifest_state,
|
||||
};
|
||||
|
||||
pub(super) async fn init_manifest_repo(
|
||||
root_uri: &str,
|
||||
catalog: &Catalog,
|
||||
) -> Result<(Dataset, ManifestState)> {
|
||||
let root = root_uri.trim_end_matches('/');
|
||||
let (entries, version_metadata) = build_initial_entries(root, catalog).await?;
|
||||
|
||||
let manifest_batch = entries_to_batch(&entries, &version_metadata)?;
|
||||
let schema = manifest_schema();
|
||||
let reader = RecordBatchIterator::new(vec![Ok(manifest_batch)], schema);
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
..Default::default()
|
||||
};
|
||||
let manifest_path = manifest_uri(root);
|
||||
let mut dataset = Dataset::write(reader, &manifest_path, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
dataset
|
||||
.update_config([(TABLE_VERSION_MANAGEMENT_KEY, Some("true"))])
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
let known_state = read_manifest_state(&dataset).await?;
|
||||
Ok((dataset, known_state))
|
||||
}
|
||||
|
||||
pub(super) async fn open_manifest_repo(
|
||||
root_uri: &str,
|
||||
branch: Option<&str>,
|
||||
) -> Result<(Dataset, ManifestState)> {
|
||||
let dataset = open_manifest_dataset(root_uri.trim_end_matches('/'), branch).await?;
|
||||
let known_state = read_manifest_state(&dataset).await?;
|
||||
Ok((dataset, known_state))
|
||||
}
|
||||
|
||||
pub(super) async fn snapshot_state_at(
|
||||
root_uri: &str,
|
||||
branch: Option<&str>,
|
||||
version: u64,
|
||||
) -> Result<ManifestState> {
|
||||
let dataset = open_manifest_dataset(root_uri.trim_end_matches('/'), branch).await?;
|
||||
let dataset = dataset
|
||||
.checkout_version(version)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
read_manifest_state(&dataset).await
|
||||
}
|
||||
|
||||
async fn build_initial_entries(
|
||||
root_uri: &str,
|
||||
catalog: &Catalog,
|
||||
) -> Result<(Vec<SubTableEntry>, HashMap<String, String>)> {
|
||||
let mut entries = Vec::new();
|
||||
let mut version_metadata = HashMap::new();
|
||||
|
||||
for (name, node_type) in &catalog.node_types {
|
||||
let hash = type_name_hash(name);
|
||||
let table_path = format!("nodes/{}", hash);
|
||||
let full_path = format!("{}/{}", root_uri, table_path);
|
||||
|
||||
let ds = create_empty_dataset(&full_path, &node_type.arrow_schema).await?;
|
||||
let table_key = format!("node:{}", name);
|
||||
let metadata = TableVersionMetadata::from_dataset(root_uri, &table_path, &ds)?;
|
||||
|
||||
entries.push(SubTableEntry {
|
||||
table_key: table_key.clone(),
|
||||
table_path: table_path.clone(),
|
||||
table_version: ds.version().version,
|
||||
table_branch: None,
|
||||
row_count: 0,
|
||||
version_metadata: metadata.clone(),
|
||||
});
|
||||
version_metadata.insert(table_key, metadata.to_json_string()?);
|
||||
}
|
||||
|
||||
for (name, edge_type) in &catalog.edge_types {
|
||||
let hash = type_name_hash(name);
|
||||
let table_path = format!("edges/{}", hash);
|
||||
let full_path = format!("{}/{}", root_uri, table_path);
|
||||
|
||||
let ds = create_empty_dataset(&full_path, &edge_type.arrow_schema).await?;
|
||||
let table_key = format!("edge:{}", name);
|
||||
let metadata = TableVersionMetadata::from_dataset(root_uri, &table_path, &ds)?;
|
||||
|
||||
entries.push(SubTableEntry {
|
||||
table_key: table_key.clone(),
|
||||
table_path: table_path.clone(),
|
||||
table_version: ds.version().version,
|
||||
table_branch: None,
|
||||
row_count: 0,
|
||||
version_metadata: metadata.clone(),
|
||||
});
|
||||
version_metadata.insert(table_key, metadata.to_json_string()?);
|
||||
}
|
||||
|
||||
Ok((entries, version_metadata))
|
||||
}
|
||||
|
||||
async fn create_empty_dataset(uri: &str, schema: &SchemaRef) -> Result<Dataset> {
|
||||
let batch = RecordBatch::new_empty(schema.clone());
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema.clone());
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
allow_external_blob_outside_bases: true,
|
||||
..Default::default()
|
||||
};
|
||||
Dataset::write(reader, uri, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
274
crates/omnigraph/src/db/manifest/state.rs
Normal file
274
crates/omnigraph/src/db/manifest/state.rs
Normal file
|
|
@ -0,0 +1,274 @@
|
|||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Array, RecordBatch, StringArray, UInt64Array, new_null_array};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef};
|
||||
use futures::TryStreamExt;
|
||||
use lance::Dataset;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
use super::layout::version_object_id;
|
||||
use super::metadata::TableVersionMetadata;
|
||||
use super::{OBJECT_TYPE_TABLE, OBJECT_TYPE_TABLE_VERSION};
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct SubTableEntry {
|
||||
pub table_key: String,
|
||||
pub table_path: String,
|
||||
pub table_version: u64,
|
||||
pub table_branch: Option<String>,
|
||||
pub row_count: u64,
|
||||
pub(crate) version_metadata: TableVersionMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub(super) struct ManifestState {
|
||||
pub(super) version: u64,
|
||||
pub(super) entries: Vec<SubTableEntry>,
|
||||
}
|
||||
|
||||
pub(super) fn manifest_schema() -> SchemaRef {
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("object_id", DataType::Utf8, false),
|
||||
Field::new("object_type", DataType::Utf8, false),
|
||||
Field::new("location", DataType::Utf8, true),
|
||||
Field::new("metadata", DataType::Utf8, true),
|
||||
Field::new(
|
||||
"base_objects",
|
||||
DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
|
||||
true,
|
||||
),
|
||||
Field::new("table_key", DataType::Utf8, false),
|
||||
Field::new("table_version", DataType::UInt64, true),
|
||||
Field::new("table_branch", DataType::Utf8, true),
|
||||
Field::new("row_count", DataType::UInt64, true),
|
||||
]))
|
||||
}
|
||||
|
||||
pub(super) async fn read_manifest_state(dataset: &Dataset) -> Result<ManifestState> {
|
||||
let version = dataset.version().version;
|
||||
let entries = read_manifest_entries(dataset).await?;
|
||||
let mut latest_versions = HashMap::<String, SubTableEntry>::new();
|
||||
|
||||
for entry in entries {
|
||||
match latest_versions.get(&entry.table_key) {
|
||||
Some(existing) if existing.table_version >= entry.table_version => {}
|
||||
_ => {
|
||||
latest_versions.insert(entry.table_key.clone(), entry);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut entries: Vec<SubTableEntry> = latest_versions.into_values().collect();
|
||||
entries.sort_by(|a, b| a.table_key.cmp(&b.table_key));
|
||||
|
||||
Ok(ManifestState { version, entries })
|
||||
}
|
||||
|
||||
pub(super) async fn read_manifest_entries(dataset: &Dataset) -> Result<Vec<SubTableEntry>> {
|
||||
let batches: Vec<RecordBatch> = dataset
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
let mut table_locations = HashMap::new();
|
||||
let mut version_entries = Vec::new();
|
||||
|
||||
for batch in &batches {
|
||||
let object_types = string_column(batch, "object_type")?;
|
||||
let locations = string_column(batch, "location")?;
|
||||
let metadata = string_column(batch, "metadata")?;
|
||||
let table_keys = string_column(batch, "table_key")?;
|
||||
let versions = u64_column(batch, "table_version")?;
|
||||
let branches = string_column(batch, "table_branch")?;
|
||||
let row_counts = u64_column(batch, "row_count")?;
|
||||
|
||||
for row in 0..batch.num_rows() {
|
||||
let table_key = table_keys.value(row).to_string();
|
||||
match object_types.value(row) {
|
||||
OBJECT_TYPE_TABLE => {
|
||||
if locations.is_null(row) {
|
||||
return Err(OmniError::manifest_internal(format!(
|
||||
"manifest table row missing location for {}",
|
||||
table_key
|
||||
)));
|
||||
}
|
||||
table_locations.insert(table_key, locations.value(row).to_string());
|
||||
}
|
||||
OBJECT_TYPE_TABLE_VERSION => {
|
||||
let table_version = required_u64(versions, row, "table_version")?;
|
||||
let row_count = required_u64(row_counts, row, "row_count")?;
|
||||
if metadata.is_null(row) {
|
||||
return Err(OmniError::manifest_internal(format!(
|
||||
"manifest table_version row missing metadata for {}",
|
||||
table_key
|
||||
)));
|
||||
}
|
||||
let table_branch = if branches.is_null(row) {
|
||||
None
|
||||
} else {
|
||||
Some(branches.value(row).to_string())
|
||||
};
|
||||
version_entries.push(SubTableEntry {
|
||||
table_key: table_key.clone(),
|
||||
table_path: String::new(),
|
||||
table_version,
|
||||
table_branch,
|
||||
row_count,
|
||||
version_metadata: TableVersionMetadata::from_json_str(metadata.value(row))?,
|
||||
});
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut entries = version_entries
|
||||
.into_iter()
|
||||
.map(|mut entry| {
|
||||
entry.table_path = table_locations
|
||||
.get(&entry.table_key)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!(
|
||||
"manifest missing table row for {}",
|
||||
entry.table_key
|
||||
))
|
||||
})?;
|
||||
Ok(entry)
|
||||
})
|
||||
.collect::<Result<Vec<_>>>()?;
|
||||
entries.sort_by(|a, b| {
|
||||
a.table_key
|
||||
.cmp(&b.table_key)
|
||||
.then(a.table_version.cmp(&b.table_version))
|
||||
});
|
||||
|
||||
Ok(entries)
|
||||
}
|
||||
|
||||
pub(super) fn entries_to_batch(
|
||||
entries: &[SubTableEntry],
|
||||
version_metadata: &HashMap<String, String>,
|
||||
) -> Result<RecordBatch> {
|
||||
let mut object_ids = Vec::with_capacity(entries.len() * 2);
|
||||
let mut object_types = Vec::with_capacity(entries.len() * 2);
|
||||
let mut locations = Vec::with_capacity(entries.len() * 2);
|
||||
let mut metadata = Vec::with_capacity(entries.len() * 2);
|
||||
let mut table_keys = Vec::with_capacity(entries.len() * 2);
|
||||
let mut table_versions = Vec::with_capacity(entries.len() * 2);
|
||||
let mut table_branches = Vec::with_capacity(entries.len() * 2);
|
||||
let mut row_counts = Vec::with_capacity(entries.len() * 2);
|
||||
|
||||
for entry in entries {
|
||||
object_ids.push(entry.table_key.clone());
|
||||
object_types.push(OBJECT_TYPE_TABLE.to_string());
|
||||
locations.push(Some(entry.table_path.clone()));
|
||||
metadata.push(None);
|
||||
table_keys.push(entry.table_key.clone());
|
||||
table_versions.push(None);
|
||||
table_branches.push(None);
|
||||
row_counts.push(None);
|
||||
|
||||
object_ids.push(version_object_id(&entry.table_key, entry.table_version));
|
||||
object_types.push(OBJECT_TYPE_TABLE_VERSION.to_string());
|
||||
locations.push(None);
|
||||
metadata.push(Some(
|
||||
version_metadata
|
||||
.get(&entry.table_key)
|
||||
.cloned()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!(
|
||||
"missing initial version metadata for {}",
|
||||
entry.table_key
|
||||
))
|
||||
})?,
|
||||
));
|
||||
table_keys.push(entry.table_key.clone());
|
||||
table_versions.push(Some(entry.table_version));
|
||||
table_branches.push(entry.table_branch.clone());
|
||||
row_counts.push(Some(entry.row_count));
|
||||
}
|
||||
|
||||
manifest_rows_batch(
|
||||
object_ids,
|
||||
object_types,
|
||||
locations,
|
||||
metadata,
|
||||
table_keys,
|
||||
table_versions,
|
||||
table_branches,
|
||||
row_counts,
|
||||
)
|
||||
}
|
||||
|
||||
pub(super) fn manifest_rows_batch(
|
||||
object_ids: Vec<String>,
|
||||
object_types: Vec<String>,
|
||||
locations: Vec<Option<String>>,
|
||||
metadata: Vec<Option<String>>,
|
||||
table_keys: Vec<String>,
|
||||
table_versions: Vec<Option<u64>>,
|
||||
table_branches: Vec<Option<String>>,
|
||||
row_counts: Vec<Option<u64>>,
|
||||
) -> Result<RecordBatch> {
|
||||
let len = object_ids.len();
|
||||
RecordBatch::try_new(
|
||||
manifest_schema(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(object_ids)),
|
||||
Arc::new(StringArray::from(object_types)),
|
||||
Arc::new(StringArray::from(locations)),
|
||||
Arc::new(StringArray::from(metadata)),
|
||||
new_null_array(
|
||||
&DataType::List(Arc::new(Field::new("item", DataType::Utf8, true))),
|
||||
len,
|
||||
),
|
||||
Arc::new(StringArray::from(table_keys)),
|
||||
Arc::new(UInt64Array::from(table_versions)),
|
||||
Arc::new(StringArray::from(table_branches)),
|
||||
Arc::new(UInt64Array::from(row_counts)),
|
||||
],
|
||||
)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub(super) fn string_column<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a StringArray> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("manifest batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("manifest column '{name}' is not Utf8"))
|
||||
})
|
||||
}
|
||||
|
||||
fn u64_column<'a>(batch: &'a RecordBatch, name: &str) -> Result<&'a UInt64Array> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("manifest batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("manifest column '{name}' is not UInt64"))
|
||||
})
|
||||
}
|
||||
|
||||
fn required_u64(column: &UInt64Array, row: usize, name: &str) -> Result<u64> {
|
||||
if column.is_null(row) {
|
||||
return Err(OmniError::manifest_internal(format!(
|
||||
"manifest column '{name}' is null at row {row}"
|
||||
)));
|
||||
}
|
||||
Ok(column.value(row))
|
||||
}
|
||||
1064
crates/omnigraph/src/db/manifest/tests.rs
Normal file
1064
crates/omnigraph/src/db/manifest/tests.rs
Normal file
File diff suppressed because it is too large
Load diff
13
crates/omnigraph/src/db/mod.rs
Normal file
13
crates/omnigraph/src/db/mod.rs
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
pub mod commit_graph;
|
||||
pub mod graph_coordinator;
|
||||
pub mod manifest;
|
||||
mod omnigraph;
|
||||
mod run_registry;
|
||||
mod schema_state;
|
||||
|
||||
pub use commit_graph::GraphCommit;
|
||||
pub use graph_coordinator::{GraphCoordinator, ReadTarget, ResolvedTarget, SnapshotId};
|
||||
pub use manifest::{Snapshot, SubTableEntry, SubTableUpdate};
|
||||
pub use omnigraph::{MergeOutcome, Omnigraph};
|
||||
pub(crate) use run_registry::is_internal_run_branch;
|
||||
pub use run_registry::{RunId, RunRecord, RunStatus};
|
||||
2636
crates/omnigraph/src/db/omnigraph.rs
Normal file
2636
crates/omnigraph/src/db/omnigraph.rs
Normal file
File diff suppressed because it is too large
Load diff
622
crates/omnigraph/src/db/run_registry.rs
Normal file
622
crates/omnigraph/src/db/run_registry.rs
Normal file
|
|
@ -0,0 +1,622 @@
|
|||
use std::collections::HashMap;
|
||||
use std::fmt;
|
||||
use std::sync::Arc;
|
||||
use std::time::{SystemTime, UNIX_EPOCH};
|
||||
|
||||
use arrow_array::{
|
||||
Array, RecordBatch, RecordBatchIterator, StringArray, TimestampMicrosecondArray, UInt64Array,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema, SchemaRef, TimeUnit};
|
||||
use futures::TryStreamExt;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
use lance_file::version::LanceFileVersion;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
const GRAPH_RUNS_DIR: &str = "_graph_runs.lance";
|
||||
const GRAPH_RUN_ACTORS_DIR: &str = "_graph_run_actors.lance";
|
||||
pub(crate) const INTERNAL_RUN_BRANCH_PREFIX: &str = "__run__";
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
pub struct RunId(String);
|
||||
|
||||
impl RunId {
|
||||
pub fn new(id: impl Into<String>) -> Self {
|
||||
Self(id.into())
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl fmt::Display for RunId {
|
||||
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||
self.0.fmt(f)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum RunStatus {
|
||||
Running,
|
||||
Published,
|
||||
Failed,
|
||||
Aborted,
|
||||
}
|
||||
|
||||
impl RunStatus {
|
||||
pub fn as_str(self) -> &'static str {
|
||||
match self {
|
||||
RunStatus::Running => "running",
|
||||
RunStatus::Published => "published",
|
||||
RunStatus::Failed => "failed",
|
||||
RunStatus::Aborted => "aborted",
|
||||
}
|
||||
}
|
||||
|
||||
fn parse(value: &str) -> Result<Self> {
|
||||
match value {
|
||||
"running" => Ok(Self::Running),
|
||||
"published" => Ok(Self::Published),
|
||||
"failed" => Ok(Self::Failed),
|
||||
"aborted" => Ok(Self::Aborted),
|
||||
other => Err(OmniError::manifest(format!(
|
||||
"invalid run status '{}'",
|
||||
other
|
||||
))),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct RunRecord {
|
||||
pub run_id: RunId,
|
||||
pub target_branch: String,
|
||||
pub run_branch: String,
|
||||
pub base_snapshot_id: String,
|
||||
pub base_manifest_version: u64,
|
||||
pub operation_hash: Option<String>,
|
||||
pub actor_id: Option<String>,
|
||||
pub status: RunStatus,
|
||||
pub published_snapshot_id: Option<String>,
|
||||
pub created_at: i64,
|
||||
pub updated_at: i64,
|
||||
}
|
||||
|
||||
impl RunRecord {
|
||||
pub fn new(
|
||||
target_branch: impl Into<String>,
|
||||
base_snapshot_id: impl Into<String>,
|
||||
base_manifest_version: u64,
|
||||
operation_hash: Option<String>,
|
||||
actor_id: Option<String>,
|
||||
) -> Result<Self> {
|
||||
let now = now_micros()?;
|
||||
let run_id = RunId::new(ulid::Ulid::new().to_string());
|
||||
Ok(Self {
|
||||
run_branch: internal_run_branch_name(&run_id),
|
||||
run_id,
|
||||
target_branch: target_branch.into(),
|
||||
base_snapshot_id: base_snapshot_id.into(),
|
||||
base_manifest_version,
|
||||
operation_hash,
|
||||
actor_id,
|
||||
status: RunStatus::Running,
|
||||
published_snapshot_id: None,
|
||||
created_at: now,
|
||||
updated_at: now,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn with_status(
|
||||
&self,
|
||||
status: RunStatus,
|
||||
published_snapshot_id: Option<String>,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
run_id: self.run_id.clone(),
|
||||
target_branch: self.target_branch.clone(),
|
||||
run_branch: self.run_branch.clone(),
|
||||
base_snapshot_id: self.base_snapshot_id.clone(),
|
||||
base_manifest_version: self.base_manifest_version,
|
||||
operation_hash: self.operation_hash.clone(),
|
||||
actor_id: self.actor_id.clone(),
|
||||
status,
|
||||
published_snapshot_id,
|
||||
created_at: self.created_at,
|
||||
updated_at: now_micros()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub struct RunRegistry {
|
||||
dataset: Dataset,
|
||||
actor_dataset: Option<Dataset>,
|
||||
latest_by_id: HashMap<String, RunRecord>,
|
||||
actor_by_run_id: HashMap<String, String>,
|
||||
root_uri: String,
|
||||
}
|
||||
|
||||
impl RunRegistry {
|
||||
pub async fn init(root_uri: &str) -> Result<Self> {
|
||||
let uri = graph_runs_uri(root_uri);
|
||||
let batch = RecordBatch::new_empty(run_registry_schema());
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
..Default::default()
|
||||
};
|
||||
let dataset = Dataset::write(reader, &uri as &str, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let actor_dataset = create_run_actor_dataset(root_uri).await?;
|
||||
Ok(Self {
|
||||
dataset,
|
||||
actor_dataset: Some(actor_dataset),
|
||||
latest_by_id: HashMap::new(),
|
||||
actor_by_run_id: HashMap::new(),
|
||||
root_uri: root_uri.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn open(root_uri: &str) -> Result<Self> {
|
||||
let dataset = Dataset::open(&graph_runs_uri(root_uri))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
let actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
|
||||
let actor_by_run_id = match &actor_dataset {
|
||||
Some(dataset) => load_run_actor_cache(dataset).await?,
|
||||
None => HashMap::new(),
|
||||
};
|
||||
let latest_by_id = load_run_cache(&dataset, &actor_by_run_id).await?;
|
||||
Ok(Self {
|
||||
dataset,
|
||||
actor_dataset,
|
||||
latest_by_id,
|
||||
actor_by_run_id,
|
||||
root_uri: root_uri.to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn refresh(&mut self, root_uri: &str) -> Result<()> {
|
||||
self.dataset = Dataset::open(&graph_runs_uri(root_uri))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.actor_dataset = Dataset::open(&graph_run_actors_uri(root_uri)).await.ok();
|
||||
self.actor_by_run_id = match &self.actor_dataset {
|
||||
Some(dataset) => load_run_actor_cache(dataset).await?,
|
||||
None => HashMap::new(),
|
||||
};
|
||||
self.latest_by_id = load_run_cache(&self.dataset, &self.actor_by_run_id).await?;
|
||||
self.root_uri = root_uri.to_string();
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn append_record(&mut self, record: &RunRecord) -> Result<()> {
|
||||
let batch = runs_to_batch(&[record.clone()])?;
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_registry_schema());
|
||||
let mut ds = self.dataset.clone();
|
||||
ds.append(reader, None)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.dataset = ds;
|
||||
if let Some(actor_id) = &record.actor_id {
|
||||
self.append_actor(record.run_id.as_str(), actor_id).await?;
|
||||
}
|
||||
let mut record = record.clone();
|
||||
if record.actor_id.is_none() {
|
||||
record.actor_id = self.actor_by_run_id.get(record.run_id.as_str()).cloned();
|
||||
}
|
||||
merge_latest_run(&mut self.latest_by_id, record);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn get_run(&self, run_id: &RunId) -> Result<Option<RunRecord>> {
|
||||
Ok(self.latest_by_id.get(run_id.as_str()).cloned())
|
||||
}
|
||||
|
||||
pub async fn list_runs(&self) -> Result<Vec<RunRecord>> {
|
||||
self.load_runs().await
|
||||
}
|
||||
|
||||
pub async fn load_runs(&self) -> Result<Vec<RunRecord>> {
|
||||
let mut runs = self.latest_by_id.values().cloned().collect::<Vec<_>>();
|
||||
runs.sort_by(|a, b| {
|
||||
a.created_at
|
||||
.cmp(&b.created_at)
|
||||
.then_with(|| a.run_id.as_str().cmp(b.run_id.as_str()))
|
||||
});
|
||||
Ok(runs)
|
||||
}
|
||||
|
||||
async fn append_actor(&mut self, run_id: &str, actor_id: &str) -> Result<()> {
|
||||
if self
|
||||
.actor_by_run_id
|
||||
.get(run_id)
|
||||
.is_some_and(|existing| existing == actor_id)
|
||||
{
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let record = RunActorRecord {
|
||||
run_id: run_id.to_string(),
|
||||
actor_id: actor_id.to_string(),
|
||||
created_at: now_micros()?,
|
||||
};
|
||||
let batch = run_actors_to_batch(&[record])?;
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
|
||||
let mut dataset = match self.actor_dataset.take() {
|
||||
Some(dataset) => dataset,
|
||||
None => create_run_actor_dataset(&self.root_uri).await?,
|
||||
};
|
||||
dataset
|
||||
.append(reader, None)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.actor_by_run_id
|
||||
.insert(run_id.to_string(), actor_id.to_string());
|
||||
self.actor_dataset = Some(dataset);
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn is_internal_run_branch(name: &str) -> bool {
|
||||
name.trim_start_matches('/')
|
||||
.starts_with(INTERNAL_RUN_BRANCH_PREFIX)
|
||||
}
|
||||
|
||||
pub(crate) fn internal_run_branch_name(run_id: &RunId) -> String {
|
||||
format!("{}{}", INTERNAL_RUN_BRANCH_PREFIX, run_id.as_str())
|
||||
}
|
||||
|
||||
pub(crate) fn graph_runs_uri(root_uri: &str) -> String {
|
||||
format!("{}/{}", root_uri.trim_end_matches('/'), GRAPH_RUNS_DIR)
|
||||
}
|
||||
|
||||
fn graph_run_actors_uri(root_uri: &str) -> String {
|
||||
format!(
|
||||
"{}/{}",
|
||||
root_uri.trim_end_matches('/'),
|
||||
GRAPH_RUN_ACTORS_DIR
|
||||
)
|
||||
}
|
||||
|
||||
fn run_registry_schema() -> SchemaRef {
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("run_id", DataType::Utf8, false),
|
||||
Field::new("target_branch", DataType::Utf8, false),
|
||||
Field::new("run_branch", DataType::Utf8, false),
|
||||
Field::new("base_snapshot_id", DataType::Utf8, false),
|
||||
Field::new("base_manifest_version", DataType::UInt64, false),
|
||||
Field::new("operation_hash", DataType::Utf8, true),
|
||||
Field::new("status", DataType::Utf8, false),
|
||||
Field::new("published_snapshot_id", DataType::Utf8, true),
|
||||
Field::new(
|
||||
"created_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"updated_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
]))
|
||||
}
|
||||
|
||||
fn run_actor_schema() -> SchemaRef {
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("run_id", DataType::Utf8, false),
|
||||
Field::new("actor_id", DataType::Utf8, false),
|
||||
Field::new(
|
||||
"created_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
]))
|
||||
}
|
||||
|
||||
async fn create_run_actor_dataset(root_uri: &str) -> Result<Dataset> {
|
||||
let batch = RecordBatch::new_empty(run_actor_schema());
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], run_actor_schema());
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
..Default::default()
|
||||
};
|
||||
Dataset::write(
|
||||
reader,
|
||||
&graph_run_actors_uri(root_uri) as &str,
|
||||
Some(params),
|
||||
)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
async fn load_run_cache(
|
||||
dataset: &Dataset,
|
||||
actor_by_run_id: &HashMap<String, String>,
|
||||
) -> Result<HashMap<String, RunRecord>> {
|
||||
let batches: Vec<RecordBatch> = dataset
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
let mut latest_by_id = HashMap::new();
|
||||
for mut record in load_runs_from_batches(&batches)? {
|
||||
record.actor_id = actor_by_run_id.get(record.run_id.as_str()).cloned();
|
||||
merge_latest_run(&mut latest_by_id, record);
|
||||
}
|
||||
Ok(latest_by_id)
|
||||
}
|
||||
|
||||
async fn load_run_actor_cache(dataset: &Dataset) -> Result<HashMap<String, String>> {
|
||||
let batches: Vec<RecordBatch> = dataset
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
let mut actors = HashMap::new();
|
||||
for batch in batches {
|
||||
let run_ids = string_column(&batch, "run_id", "run actor registry")?;
|
||||
let actor_ids = string_column(&batch, "actor_id", "run actor registry")?;
|
||||
for row in 0..batch.num_rows() {
|
||||
actors.insert(
|
||||
run_ids.value(row).to_string(),
|
||||
actor_ids.value(row).to_string(),
|
||||
);
|
||||
}
|
||||
}
|
||||
Ok(actors)
|
||||
}
|
||||
|
||||
fn load_runs_from_batches(batches: &[RecordBatch]) -> Result<Vec<RunRecord>> {
|
||||
let mut runs = Vec::new();
|
||||
for batch in batches {
|
||||
let run_ids = string_column(batch, "run_id", "run registry")?;
|
||||
let target_branches = string_column(batch, "target_branch", "run registry")?;
|
||||
let run_branches = string_column(batch, "run_branch", "run registry")?;
|
||||
let base_snapshot_ids = string_column(batch, "base_snapshot_id", "run registry")?;
|
||||
let base_manifest_versions = u64_column(batch, "base_manifest_version", "run registry")?;
|
||||
let operation_hashes = string_column(batch, "operation_hash", "run registry")?;
|
||||
let statuses = string_column(batch, "status", "run registry")?;
|
||||
let published_snapshot_ids = string_column(batch, "published_snapshot_id", "run registry")?;
|
||||
let created_ats = timestamp_micros_column(batch, "created_at", "run registry")?;
|
||||
let updated_ats = timestamp_micros_column(batch, "updated_at", "run registry")?;
|
||||
|
||||
for row in 0..batch.num_rows() {
|
||||
runs.push(RunRecord {
|
||||
run_id: RunId::new(run_ids.value(row)),
|
||||
target_branch: target_branches.value(row).to_string(),
|
||||
run_branch: run_branches.value(row).to_string(),
|
||||
base_snapshot_id: base_snapshot_ids.value(row).to_string(),
|
||||
base_manifest_version: base_manifest_versions.value(row),
|
||||
operation_hash: if operation_hashes.is_null(row) {
|
||||
None
|
||||
} else {
|
||||
Some(operation_hashes.value(row).to_string())
|
||||
},
|
||||
actor_id: None,
|
||||
status: RunStatus::parse(statuses.value(row))?,
|
||||
published_snapshot_id: if published_snapshot_ids.is_null(row) {
|
||||
None
|
||||
} else {
|
||||
Some(published_snapshot_ids.value(row).to_string())
|
||||
},
|
||||
created_at: created_ats.value(row),
|
||||
updated_at: updated_ats.value(row),
|
||||
});
|
||||
}
|
||||
}
|
||||
Ok(runs)
|
||||
}
|
||||
|
||||
fn merge_latest_run(latest_by_id: &mut HashMap<String, RunRecord>, record: RunRecord) {
|
||||
match latest_by_id.get(record.run_id.as_str()) {
|
||||
Some(existing)
|
||||
if existing.updated_at > record.updated_at
|
||||
|| (existing.updated_at == record.updated_at
|
||||
&& existing.created_at >= record.created_at) => {}
|
||||
_ => {
|
||||
latest_by_id.insert(record.run_id.as_str().to_string(), record);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn string_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a StringArray> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} column '{name}' is not Utf8"))
|
||||
})
|
||||
}
|
||||
|
||||
fn u64_column<'a>(batch: &'a RecordBatch, name: &str, context: &str) -> Result<&'a UInt64Array> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} column '{name}' is not UInt64"))
|
||||
})
|
||||
}
|
||||
|
||||
fn timestamp_micros_column<'a>(
|
||||
batch: &'a RecordBatch,
|
||||
name: &str,
|
||||
context: &str,
|
||||
) -> Result<&'a TimestampMicrosecondArray> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("{context} batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampMicrosecondArray>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!(
|
||||
"{context} column '{name}' is not Timestamp(Microsecond)"
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn runs_to_batch(records: &[RunRecord]) -> Result<RecordBatch> {
|
||||
let run_ids: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.run_id.as_str())
|
||||
.collect();
|
||||
let target_branches: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.target_branch.as_str())
|
||||
.collect();
|
||||
let run_branches: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.run_branch.as_str())
|
||||
.collect();
|
||||
let base_snapshot_ids: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.base_snapshot_id.as_str())
|
||||
.collect();
|
||||
let base_manifest_versions: Vec<u64> = records
|
||||
.iter()
|
||||
.map(|record| record.base_manifest_version)
|
||||
.collect();
|
||||
let operation_hashes: Vec<Option<&str>> = records
|
||||
.iter()
|
||||
.map(|record| record.operation_hash.as_deref())
|
||||
.collect();
|
||||
let statuses: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.status.as_str())
|
||||
.collect();
|
||||
let published_snapshot_ids: Vec<Option<&str>> = records
|
||||
.iter()
|
||||
.map(|record| record.published_snapshot_id.as_deref())
|
||||
.collect();
|
||||
let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
|
||||
let updated_ats: Vec<i64> = records.iter().map(|record| record.updated_at).collect();
|
||||
|
||||
RecordBatch::try_new(
|
||||
run_registry_schema(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(run_ids)),
|
||||
Arc::new(StringArray::from(target_branches)),
|
||||
Arc::new(StringArray::from(run_branches)),
|
||||
Arc::new(StringArray::from(base_snapshot_ids)),
|
||||
Arc::new(UInt64Array::from(base_manifest_versions)),
|
||||
Arc::new(StringArray::from(operation_hashes)),
|
||||
Arc::new(StringArray::from(statuses)),
|
||||
Arc::new(StringArray::from(published_snapshot_ids)),
|
||||
Arc::new(TimestampMicrosecondArray::from(created_ats)),
|
||||
Arc::new(TimestampMicrosecondArray::from(updated_ats)),
|
||||
],
|
||||
)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct RunActorRecord {
|
||||
run_id: String,
|
||||
actor_id: String,
|
||||
created_at: i64,
|
||||
}
|
||||
|
||||
fn run_actors_to_batch(records: &[RunActorRecord]) -> Result<RecordBatch> {
|
||||
let run_ids: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.run_id.as_str())
|
||||
.collect();
|
||||
let actor_ids: Vec<&str> = records
|
||||
.iter()
|
||||
.map(|record| record.actor_id.as_str())
|
||||
.collect();
|
||||
let created_ats: Vec<i64> = records.iter().map(|record| record.created_at).collect();
|
||||
|
||||
RecordBatch::try_new(
|
||||
run_actor_schema(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(run_ids)),
|
||||
Arc::new(StringArray::from(actor_ids)),
|
||||
Arc::new(TimestampMicrosecondArray::from(created_ats)),
|
||||
],
|
||||
)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
fn now_micros() -> Result<i64> {
|
||||
let duration = SystemTime::now()
|
||||
.duration_since(UNIX_EPOCH)
|
||||
.map_err(|e| OmniError::manifest(format!("system clock error: {}", e)))?;
|
||||
Ok(duration.as_micros() as i64)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn load_runs_from_batches_returns_error_for_bad_schema() {
|
||||
let batch = RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![
|
||||
Field::new("run_id", DataType::UInt64, false),
|
||||
Field::new("target_branch", DataType::Utf8, false),
|
||||
Field::new("run_branch", DataType::Utf8, false),
|
||||
Field::new("base_snapshot_id", DataType::Utf8, false),
|
||||
Field::new("base_manifest_version", DataType::UInt64, false),
|
||||
Field::new("operation_hash", DataType::Utf8, true),
|
||||
Field::new("status", DataType::Utf8, false),
|
||||
Field::new("published_snapshot_id", DataType::Utf8, true),
|
||||
Field::new(
|
||||
"created_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
Field::new(
|
||||
"updated_at",
|
||||
DataType::Timestamp(TimeUnit::Microsecond, None),
|
||||
false,
|
||||
),
|
||||
])),
|
||||
vec![
|
||||
Arc::new(UInt64Array::from(vec![1_u64])),
|
||||
Arc::new(StringArray::from(vec!["main"])),
|
||||
Arc::new(StringArray::from(vec!["__run__1"])),
|
||||
Arc::new(StringArray::from(vec!["snap-1"])),
|
||||
Arc::new(UInt64Array::from(vec![1_u64])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
Arc::new(StringArray::from(vec!["running"])),
|
||||
Arc::new(StringArray::from(vec![None::<&str>])),
|
||||
Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
|
||||
Arc::new(TimestampMicrosecondArray::from(vec![1_i64])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = load_runs_from_batches(&[batch]).unwrap_err();
|
||||
assert!(err.to_string().contains("run_id"));
|
||||
}
|
||||
}
|
||||
236
crates/omnigraph/src/db/schema_state.rs
Normal file
236
crates/omnigraph/src/db/schema_state.rs
Normal file
|
|
@ -0,0 +1,236 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use omnigraph_compiler::schema::parser::parse_schema;
|
||||
use omnigraph_compiler::{SchemaIR, build_schema_ir, schema_ir_hash, schema_ir_pretty_json};
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
use crate::storage::{StorageAdapter, join_uri};
|
||||
|
||||
pub(crate) const SCHEMA_SOURCE_FILENAME: &str = "_schema.pg";
|
||||
pub(crate) const SCHEMA_IR_FILENAME: &str = "_schema.ir.json";
|
||||
pub(crate) const SCHEMA_STATE_FILENAME: &str = "__schema_state.json";
|
||||
|
||||
const SCHEMA_STATE_FORMAT_VERSION: u32 = 1;
|
||||
const SCHEMA_IDENTITY_VERSION: u32 = 1;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
|
||||
pub(crate) struct SchemaState {
|
||||
pub(crate) format_version: u32,
|
||||
pub(crate) schema_ir_hash: String,
|
||||
pub(crate) schema_identity_version: u32,
|
||||
}
|
||||
|
||||
impl SchemaState {
|
||||
pub(crate) fn new(schema_ir_hash: String) -> Self {
|
||||
Self {
|
||||
format_version: SCHEMA_STATE_FORMAT_VERSION,
|
||||
schema_ir_hash,
|
||||
schema_identity_version: SCHEMA_IDENTITY_VERSION,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn load_or_bootstrap_schema_contract(
|
||||
root_uri: &str,
|
||||
storage: Arc<dyn StorageAdapter>,
|
||||
public_branches: &[String],
|
||||
current_source_ir: &SchemaIR,
|
||||
) -> Result<(SchemaIR, SchemaState)> {
|
||||
match read_schema_contract(root_uri, storage.as_ref()).await? {
|
||||
SchemaContractRead::Present { ir, state } => {
|
||||
validate_persisted_schema_contract(&ir, &state)?;
|
||||
validate_current_source_matches(&state, current_source_ir)?;
|
||||
Ok((ir, state))
|
||||
}
|
||||
SchemaContractRead::MissingAll => {
|
||||
let public_non_main = public_branches
|
||||
.iter()
|
||||
.filter(|branch| branch.as_str() != "main")
|
||||
.cloned()
|
||||
.collect::<Vec<_>>();
|
||||
if !public_non_main.is_empty() {
|
||||
return Err(schema_lock_conflict(format!(
|
||||
"repo is missing persisted schema state and has public branches ({}); public branches block schema evolution entirely",
|
||||
public_non_main.join(", ")
|
||||
)));
|
||||
}
|
||||
let state =
|
||||
write_schema_contract(root_uri, storage.as_ref(), current_source_ir).await?;
|
||||
Ok((current_source_ir.clone(), state))
|
||||
}
|
||||
SchemaContractRead::PartialMissing => Err(schema_lock_conflict(
|
||||
"repo schema state is incomplete (_schema.ir.json and __schema_state.json must either both exist or both be absent)",
|
||||
)),
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn validate_schema_contract(
|
||||
root_uri: &str,
|
||||
storage: Arc<dyn StorageAdapter>,
|
||||
) -> Result<()> {
|
||||
let current_source_ir = read_current_source_ir(root_uri, storage.as_ref()).await?;
|
||||
let (persisted_ir, state) = match read_schema_contract(root_uri, storage.as_ref()).await? {
|
||||
SchemaContractRead::Present { ir, state } => (ir, state),
|
||||
SchemaContractRead::MissingAll | SchemaContractRead::PartialMissing => {
|
||||
return Err(schema_lock_conflict(
|
||||
"repo is missing persisted schema state; manual coordination is required before schema changes are allowed",
|
||||
));
|
||||
}
|
||||
};
|
||||
|
||||
validate_persisted_schema_contract(&persisted_ir, &state)?;
|
||||
validate_current_source_matches(&state, ¤t_source_ir)
|
||||
}
|
||||
|
||||
pub(crate) async fn write_schema_contract(
|
||||
root_uri: &str,
|
||||
storage: &dyn StorageAdapter,
|
||||
schema_ir: &SchemaIR,
|
||||
) -> Result<SchemaState> {
|
||||
let ir_json = schema_ir_pretty_json(schema_ir)
|
||||
.map_err(|err| OmniError::manifest_internal(err.to_string()))?;
|
||||
let state = SchemaState::new(
|
||||
schema_ir_hash(schema_ir).map_err(|err| OmniError::manifest_internal(err.to_string()))?,
|
||||
);
|
||||
let state_json = serde_json::to_string_pretty(&state).map_err(|err| {
|
||||
OmniError::manifest_internal(format!("serialize schema state error: {}", err))
|
||||
})?;
|
||||
|
||||
storage
|
||||
.write_text(&schema_ir_uri(root_uri), &ir_json)
|
||||
.await?;
|
||||
storage
|
||||
.write_text(&schema_state_uri(root_uri), &state_json)
|
||||
.await?;
|
||||
Ok(state)
|
||||
}
|
||||
|
||||
pub(crate) async fn read_current_source_ir(
|
||||
root_uri: &str,
|
||||
storage: &dyn StorageAdapter,
|
||||
) -> Result<SchemaIR> {
|
||||
let source = storage.read_text(&schema_source_uri(root_uri)).await?;
|
||||
compile_schema_source(&source)
|
||||
}
|
||||
|
||||
pub(crate) async fn read_accepted_schema_ir(
|
||||
root_uri: &str,
|
||||
storage: Arc<dyn StorageAdapter>,
|
||||
) -> Result<SchemaIR> {
|
||||
match read_schema_contract(root_uri, storage.as_ref()).await? {
|
||||
SchemaContractRead::Present { ir, state } => {
|
||||
validate_persisted_schema_contract(&ir, &state)?;
|
||||
Ok(ir)
|
||||
}
|
||||
SchemaContractRead::MissingAll | SchemaContractRead::PartialMissing => {
|
||||
Err(schema_lock_conflict(
|
||||
"repo is missing persisted schema state; manual coordination is required before schema changes are allowed",
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn schema_source_uri(root_uri: &str) -> String {
|
||||
join_uri(root_uri, SCHEMA_SOURCE_FILENAME)
|
||||
}
|
||||
|
||||
pub(crate) fn schema_ir_uri(root_uri: &str) -> String {
|
||||
join_uri(root_uri, SCHEMA_IR_FILENAME)
|
||||
}
|
||||
|
||||
pub(crate) fn schema_state_uri(root_uri: &str) -> String {
|
||||
join_uri(root_uri, SCHEMA_STATE_FILENAME)
|
||||
}
|
||||
|
||||
enum SchemaContractRead {
|
||||
Present { ir: SchemaIR, state: SchemaState },
|
||||
MissingAll,
|
||||
PartialMissing,
|
||||
}
|
||||
|
||||
async fn read_schema_contract(
|
||||
root_uri: &str,
|
||||
storage: &dyn StorageAdapter,
|
||||
) -> Result<SchemaContractRead> {
|
||||
let ir_uri = schema_ir_uri(root_uri);
|
||||
let state_uri = schema_state_uri(root_uri);
|
||||
let ir_exists = storage.exists(&ir_uri).await?;
|
||||
let state_exists = storage.exists(&state_uri).await?;
|
||||
|
||||
match (ir_exists, state_exists) {
|
||||
(false, false) => Ok(SchemaContractRead::MissingAll),
|
||||
(true, true) => {
|
||||
let ir_json = storage.read_text(&ir_uri).await?;
|
||||
let state_json = storage.read_text(&state_uri).await?;
|
||||
let ir = serde_json::from_str::<SchemaIR>(&ir_json).map_err(|err| {
|
||||
schema_lock_conflict(format!(
|
||||
"accepted compiled schema contract in {} is invalid: {}",
|
||||
SCHEMA_IR_FILENAME, err
|
||||
))
|
||||
})?;
|
||||
let state = serde_json::from_str::<SchemaState>(&state_json).map_err(|err| {
|
||||
schema_lock_conflict(format!(
|
||||
"repo schema state in {} is invalid: {}",
|
||||
SCHEMA_STATE_FILENAME, err
|
||||
))
|
||||
})?;
|
||||
Ok(SchemaContractRead::Present { ir, state })
|
||||
}
|
||||
_ => Ok(SchemaContractRead::PartialMissing),
|
||||
}
|
||||
}
|
||||
|
||||
fn validate_persisted_schema_contract(ir: &SchemaIR, state: &SchemaState) -> Result<()> {
|
||||
if state.format_version != SCHEMA_STATE_FORMAT_VERSION {
|
||||
return Err(schema_lock_conflict(format!(
|
||||
"repo schema state format {} is unsupported",
|
||||
state.format_version
|
||||
)));
|
||||
}
|
||||
|
||||
let actual_hash = schema_ir_hash(ir).map_err(|err| schema_lock_conflict(err.to_string()))?;
|
||||
if actual_hash != state.schema_ir_hash {
|
||||
return Err(schema_lock_conflict(
|
||||
"accepted compiled schema does not match the recorded schema state",
|
||||
));
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn validate_current_source_matches(
|
||||
state: &SchemaState,
|
||||
current_source_ir: &SchemaIR,
|
||||
) -> Result<()> {
|
||||
let current_hash =
|
||||
schema_ir_hash(current_source_ir).map_err(|err| schema_lock_conflict(err.to_string()))?;
|
||||
if current_hash != state.schema_ir_hash {
|
||||
return Err(schema_lock_conflict(
|
||||
"current _schema.pg no longer matches the accepted compiled schema",
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn compile_schema_source(source: &str) -> Result<SchemaIR> {
|
||||
let schema = parse_schema(source).map_err(|err| {
|
||||
schema_lock_conflict(format!(
|
||||
"current _schema.pg is not a valid accepted schema definition: {}",
|
||||
err
|
||||
))
|
||||
})?;
|
||||
build_schema_ir(&schema).map_err(|err| {
|
||||
schema_lock_conflict(format!(
|
||||
"current _schema.pg could not be compiled into the accepted schema contract: {}",
|
||||
err
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn schema_lock_conflict(detail: impl Into<String>) -> OmniError {
|
||||
OmniError::manifest_conflict(format!(
|
||||
"schema evolution is locked down in phase 1: {}; manual coordination is required",
|
||||
detail.into()
|
||||
))
|
||||
}
|
||||
489
crates/omnigraph/src/embedding.rs
Normal file
489
crates/omnigraph/src/embedding.rs
Normal file
|
|
@ -0,0 +1,489 @@
|
|||
use std::future::Future;
|
||||
use std::time::Duration;
|
||||
|
||||
use reqwest::Client;
|
||||
use serde::Deserialize;
|
||||
use serde_json::{Value, json};
|
||||
use tokio::time::sleep;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
const GEMINI_EMBED_MODEL: &str = "gemini-embedding-2-preview";
|
||||
const DEFAULT_GEMINI_BASE_URL: &str = "https://generativelanguage.googleapis.com/v1beta";
|
||||
const DEFAULT_TIMEOUT_MS: u64 = 30_000;
|
||||
const DEFAULT_RETRY_ATTEMPTS: usize = 4;
|
||||
const DEFAULT_RETRY_BACKOFF_MS: u64 = 200;
|
||||
const QUERY_TASK_TYPE: &str = "RETRIEVAL_QUERY";
|
||||
const DOCUMENT_TASK_TYPE: &str = "RETRIEVAL_DOCUMENT";
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
enum EmbeddingTransport {
|
||||
Mock,
|
||||
Gemini {
|
||||
api_key: String,
|
||||
base_url: String,
|
||||
http: Client,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct EmbeddingClient {
|
||||
retry_attempts: usize,
|
||||
retry_backoff_ms: u64,
|
||||
transport: EmbeddingTransport,
|
||||
}
|
||||
|
||||
struct EmbedCallError {
|
||||
message: String,
|
||||
retryable: bool,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GeminiEmbedResponse {
|
||||
embedding: GeminiContentEmbedding,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GeminiContentEmbedding {
|
||||
values: Vec<f32>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleErrorEnvelope {
|
||||
error: GoogleErrorBody,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct GoogleErrorBody {
|
||||
message: String,
|
||||
}
|
||||
|
||||
impl EmbeddingClient {
|
||||
pub fn from_env() -> Result<Self> {
|
||||
let retry_attempts =
|
||||
parse_env_usize("OMNIGRAPH_EMBED_RETRY_ATTEMPTS", DEFAULT_RETRY_ATTEMPTS);
|
||||
let retry_backoff_ms =
|
||||
parse_env_u64("OMNIGRAPH_EMBED_RETRY_BACKOFF_MS", DEFAULT_RETRY_BACKOFF_MS);
|
||||
|
||||
if env_flag("OMNIGRAPH_EMBEDDINGS_MOCK") {
|
||||
return Ok(Self {
|
||||
retry_attempts,
|
||||
retry_backoff_ms,
|
||||
transport: EmbeddingTransport::Mock,
|
||||
});
|
||||
}
|
||||
|
||||
let api_key = std::env::var("GEMINI_API_KEY")
|
||||
.ok()
|
||||
.map(|v| v.trim().to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(
|
||||
"GEMINI_API_KEY is required when nearest() needs a string embedding",
|
||||
)
|
||||
})?;
|
||||
let base_url = std::env::var("OMNIGRAPH_GEMINI_BASE_URL")
|
||||
.ok()
|
||||
.map(|v| v.trim_end_matches('/').to_string())
|
||||
.filter(|v| !v.is_empty())
|
||||
.unwrap_or_else(|| DEFAULT_GEMINI_BASE_URL.to_string());
|
||||
let timeout_ms = parse_env_u64("OMNIGRAPH_EMBED_TIMEOUT_MS", DEFAULT_TIMEOUT_MS);
|
||||
let http = Client::builder()
|
||||
.timeout(Duration::from_millis(timeout_ms))
|
||||
.build()
|
||||
.map_err(|e| {
|
||||
OmniError::manifest_internal(format!("failed to initialize HTTP client: {}", e))
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
retry_attempts,
|
||||
retry_backoff_ms,
|
||||
transport: EmbeddingTransport::Gemini {
|
||||
api_key,
|
||||
base_url,
|
||||
http,
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
fn mock_for_tests() -> Self {
|
||||
Self {
|
||||
retry_attempts: DEFAULT_RETRY_ATTEMPTS,
|
||||
retry_backoff_ms: DEFAULT_RETRY_BACKOFF_MS,
|
||||
transport: EmbeddingTransport::Mock,
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn embed_query_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
|
||||
self.embed_text(input, expected_dim, QUERY_TASK_TYPE).await
|
||||
}
|
||||
|
||||
pub async fn embed_document_text(&self, input: &str, expected_dim: usize) -> Result<Vec<f32>> {
|
||||
self.embed_text(input, expected_dim, DOCUMENT_TASK_TYPE)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn embed_text(
|
||||
&self,
|
||||
input: &str,
|
||||
expected_dim: usize,
|
||||
task_type: &'static str,
|
||||
) -> Result<Vec<f32>> {
|
||||
if expected_dim == 0 {
|
||||
return Err(OmniError::manifest_internal(
|
||||
"embedding dimension must be greater than zero",
|
||||
));
|
||||
}
|
||||
|
||||
match &self.transport {
|
||||
EmbeddingTransport::Mock => Ok(mock_embedding(input, expected_dim)),
|
||||
EmbeddingTransport::Gemini { .. } => {
|
||||
self.with_retry(|| self.embed_text_gemini_once(input, expected_dim, task_type))
|
||||
.await
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn with_retry<T, F, Fut>(&self, mut operation: F) -> Result<T>
|
||||
where
|
||||
F: FnMut() -> Fut,
|
||||
Fut: Future<Output = std::result::Result<T, EmbedCallError>>,
|
||||
{
|
||||
let max_attempt = self.retry_attempts.max(1);
|
||||
let mut attempt = 0usize;
|
||||
loop {
|
||||
attempt += 1;
|
||||
match operation().await {
|
||||
Ok(value) => return Ok(value),
|
||||
Err(err) => {
|
||||
if !err.retryable || attempt >= max_attempt {
|
||||
return Err(OmniError::manifest_internal(err.message));
|
||||
}
|
||||
let shift = (attempt - 1).min(10) as u32;
|
||||
let delay = self.retry_backoff_ms.saturating_mul(1u64 << shift);
|
||||
sleep(Duration::from_millis(delay)).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn embed_text_gemini_once(
|
||||
&self,
|
||||
input: &str,
|
||||
expected_dim: usize,
|
||||
task_type: &'static str,
|
||||
) -> std::result::Result<Vec<f32>, EmbedCallError> {
|
||||
let (api_key, base_url, http) = match &self.transport {
|
||||
EmbeddingTransport::Gemini {
|
||||
api_key,
|
||||
base_url,
|
||||
http,
|
||||
} => (api_key, base_url, http),
|
||||
EmbeddingTransport::Mock => unreachable!("mock transport should not call Gemini"),
|
||||
};
|
||||
|
||||
let response = http
|
||||
.post(gemini_endpoint(base_url))
|
||||
.header("x-goog-api-key", api_key)
|
||||
.json(&build_gemini_request(input, expected_dim, task_type))
|
||||
.send()
|
||||
.await;
|
||||
let response = match response {
|
||||
Ok(response) => response,
|
||||
Err(err) => {
|
||||
let retryable = err.is_timeout() || err.is_connect() || err.is_request();
|
||||
return Err(EmbedCallError {
|
||||
message: format!("embedding request failed: {}", err),
|
||||
retryable,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
let status = response.status();
|
||||
let body = match response.text().await {
|
||||
Ok(body) => body,
|
||||
Err(err) => {
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding response read failed (status {}): {}",
|
||||
status, err
|
||||
),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
if !status.is_success() {
|
||||
let message = parse_google_error_message(&body).unwrap_or(body);
|
||||
return Err(EmbedCallError {
|
||||
message: format!(
|
||||
"embedding request failed with status {}: {}",
|
||||
status, message
|
||||
),
|
||||
retryable: status.is_server_error() || status.as_u16() == 429,
|
||||
});
|
||||
}
|
||||
|
||||
let parsed: GeminiEmbedResponse =
|
||||
serde_json::from_str(&body).map_err(|err| EmbedCallError {
|
||||
message: format!("embedding response decode failed: {}", err),
|
||||
retryable: false,
|
||||
})?;
|
||||
|
||||
validate_and_normalize_embedding(parsed.embedding.values, expected_dim).map_err(|message| {
|
||||
EmbedCallError {
|
||||
message,
|
||||
retryable: false,
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
fn gemini_endpoint(base_url: &str) -> String {
|
||||
format!(
|
||||
"{}/models/{}:embedContent",
|
||||
base_url.trim_end_matches('/'),
|
||||
GEMINI_EMBED_MODEL
|
||||
)
|
||||
}
|
||||
|
||||
fn build_gemini_request(input: &str, expected_dim: usize, task_type: &'static str) -> Value {
|
||||
json!({
|
||||
"model": format!("models/{}", GEMINI_EMBED_MODEL),
|
||||
"content": {
|
||||
"parts": [
|
||||
{
|
||||
"text": input
|
||||
}
|
||||
]
|
||||
},
|
||||
"taskType": task_type,
|
||||
"outputDimensionality": expected_dim,
|
||||
})
|
||||
}
|
||||
|
||||
fn validate_and_normalize_embedding(
|
||||
values: Vec<f32>,
|
||||
expected_dim: usize,
|
||||
) -> std::result::Result<Vec<f32>, String> {
|
||||
if values.len() != expected_dim {
|
||||
return Err(format!(
|
||||
"embedding dimension mismatch: expected {}, got {}",
|
||||
expected_dim,
|
||||
values.len()
|
||||
));
|
||||
}
|
||||
Ok(normalize_vector(values))
|
||||
}
|
||||
|
||||
fn normalize_vector(mut values: Vec<f32>) -> Vec<f32> {
|
||||
let norm = values
|
||||
.iter()
|
||||
.map(|v| (*v as f64) * (*v as f64))
|
||||
.sum::<f64>()
|
||||
.sqrt() as f32;
|
||||
if norm > f32::EPSILON {
|
||||
for value in &mut values {
|
||||
*value /= norm;
|
||||
}
|
||||
}
|
||||
values
|
||||
}
|
||||
|
||||
fn parse_google_error_message(body: &str) -> Option<String> {
|
||||
serde_json::from_str::<GoogleErrorEnvelope>(body)
|
||||
.ok()
|
||||
.map(|e| e.error.message)
|
||||
.filter(|msg| !msg.trim().is_empty())
|
||||
}
|
||||
|
||||
fn parse_env_usize(name: &str, default: usize) -> usize {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<usize>().ok())
|
||||
.filter(|v| *v > 0)
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn parse_env_u64(name: &str, default: u64) -> u64 {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.and_then(|v| v.parse::<u64>().ok())
|
||||
.filter(|v| *v > 0)
|
||||
.unwrap_or(default)
|
||||
}
|
||||
|
||||
fn env_flag(name: &str) -> bool {
|
||||
std::env::var(name)
|
||||
.ok()
|
||||
.map(|v| {
|
||||
let s = v.trim().to_ascii_lowercase();
|
||||
s == "1" || s == "true" || s == "yes" || s == "on"
|
||||
})
|
||||
.unwrap_or(false)
|
||||
}
|
||||
|
||||
fn mock_embedding(input: &str, dim: usize) -> Vec<f32> {
|
||||
let mut seed = fnv1a64(input.as_bytes());
|
||||
let mut out = Vec::with_capacity(dim);
|
||||
for _ in 0..dim {
|
||||
seed = xorshift64(seed);
|
||||
let ratio = (seed as f64 / u64::MAX as f64) as f32;
|
||||
out.push((ratio * 2.0) - 1.0);
|
||||
}
|
||||
normalize_vector(out)
|
||||
}
|
||||
|
||||
fn fnv1a64(bytes: &[u8]) -> u64 {
|
||||
let mut hash = 14695981039346656037u64;
|
||||
for byte in bytes {
|
||||
hash ^= *byte as u64;
|
||||
hash = hash.wrapping_mul(1099511628211u64);
|
||||
}
|
||||
hash
|
||||
}
|
||||
|
||||
fn xorshift64(mut x: u64) -> u64 {
|
||||
x ^= x << 13;
|
||||
x ^= x >> 7;
|
||||
x ^= x << 17;
|
||||
x
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::{AtomicUsize, Ordering};
|
||||
|
||||
use serial_test::serial;
|
||||
|
||||
use super::*;
|
||||
|
||||
struct EnvGuard {
|
||||
saved: Vec<(&'static str, Option<String>)>,
|
||||
}
|
||||
|
||||
impl EnvGuard {
|
||||
fn set(vars: &[(&'static str, Option<&str>)]) -> Self {
|
||||
let saved = vars
|
||||
.iter()
|
||||
.map(|(name, _)| (*name, std::env::var(name).ok()))
|
||||
.collect::<Vec<_>>();
|
||||
for (name, value) in vars {
|
||||
unsafe {
|
||||
match value {
|
||||
Some(value) => std::env::set_var(name, value),
|
||||
None => std::env::remove_var(name),
|
||||
}
|
||||
}
|
||||
}
|
||||
Self { saved }
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for EnvGuard {
|
||||
fn drop(&mut self) {
|
||||
for (name, value) in self.saved.drain(..) {
|
||||
unsafe {
|
||||
match value {
|
||||
Some(value) => std::env::set_var(name, value),
|
||||
None => std::env::remove_var(name),
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn mock_embeddings_are_deterministic() {
|
||||
let client = EmbeddingClient::mock_for_tests();
|
||||
let a = client.embed_query_text("alpha", 8).await.unwrap();
|
||||
let b = client.embed_query_text("alpha", 8).await.unwrap();
|
||||
let c = client.embed_query_text("beta", 8).await.unwrap();
|
||||
assert_eq!(a, b);
|
||||
assert_ne!(a, c);
|
||||
assert_eq!(a.len(), 8);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemini_request_uses_preview_model_retrieval_query_and_dimension() {
|
||||
let request = build_gemini_request("alpha", 4, QUERY_TASK_TYPE);
|
||||
assert_eq!(request["model"], "models/gemini-embedding-2-preview");
|
||||
assert_eq!(request["taskType"], QUERY_TASK_TYPE);
|
||||
assert_eq!(request["outputDimensionality"], 4);
|
||||
assert_eq!(request["content"]["parts"][0]["text"], "alpha");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn gemini_document_request_uses_retrieval_document_task_type() {
|
||||
let request = build_gemini_request("alpha", 4, DOCUMENT_TASK_TYPE);
|
||||
assert_eq!(request["taskType"], DOCUMENT_TASK_TYPE);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn validate_and_normalize_embedding_enforces_dimension() {
|
||||
let normalized = validate_and_normalize_embedding(vec![3.0, 4.0], 2).unwrap();
|
||||
assert!((normalized[0] - 0.6).abs() < 1e-6);
|
||||
assert!((normalized[1] - 0.8).abs() < 1e-6);
|
||||
|
||||
let err = validate_and_normalize_embedding(vec![1.0, 2.0], 3).unwrap_err();
|
||||
assert!(err.contains("expected 3, got 2"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn with_retry_retries_retryable_failures() {
|
||||
let client = EmbeddingClient::mock_for_tests();
|
||||
let attempts = Arc::new(AtomicUsize::new(0));
|
||||
let attempts_for_call = Arc::clone(&attempts);
|
||||
|
||||
let value = client
|
||||
.with_retry(|| {
|
||||
let attempts_for_call = Arc::clone(&attempts_for_call);
|
||||
async move {
|
||||
let attempt = attempts_for_call.fetch_add(1, Ordering::SeqCst);
|
||||
if attempt == 0 {
|
||||
Err(EmbedCallError {
|
||||
message: "retry me".to_string(),
|
||||
retryable: true,
|
||||
})
|
||||
} else {
|
||||
Ok("ok")
|
||||
}
|
||||
}
|
||||
})
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(value, "ok");
|
||||
assert_eq!(attempts.load(Ordering::SeqCst), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn with_retry_stops_on_non_retryable_failures() {
|
||||
let client = EmbeddingClient::mock_for_tests();
|
||||
let err = client
|
||||
.with_retry(|| async {
|
||||
Err::<(), _>(EmbedCallError {
|
||||
message: "do not retry".to_string(),
|
||||
retryable: false,
|
||||
})
|
||||
})
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert!(err.to_string().contains("do not retry"));
|
||||
}
|
||||
|
||||
#[test]
|
||||
#[serial]
|
||||
fn from_env_requires_gemini_api_key_when_not_mocking() {
|
||||
let _guard = EnvGuard::set(&[
|
||||
("OMNIGRAPH_EMBEDDINGS_MOCK", None),
|
||||
("GEMINI_API_KEY", None),
|
||||
]);
|
||||
|
||||
let err = EmbeddingClient::from_env().unwrap_err();
|
||||
assert!(err.to_string().contains("GEMINI_API_KEY"));
|
||||
}
|
||||
}
|
||||
80
crates/omnigraph/src/error.rs
Normal file
80
crates/omnigraph/src/error.rs
Normal file
|
|
@ -0,0 +1,80 @@
|
|||
use thiserror::Error;
|
||||
|
||||
pub type Result<T> = std::result::Result<T, OmniError>;
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum ManifestErrorKind {
|
||||
BadRequest,
|
||||
NotFound,
|
||||
Conflict,
|
||||
Internal,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Error)]
|
||||
#[error("{message}")]
|
||||
pub struct ManifestError {
|
||||
pub kind: ManifestErrorKind,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
impl ManifestError {
|
||||
pub fn new(kind: ManifestErrorKind, message: impl Into<String>) -> Self {
|
||||
Self {
|
||||
kind,
|
||||
message: message.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct MergeConflict {
|
||||
pub table_key: String,
|
||||
pub row_id: Option<String>,
|
||||
pub kind: MergeConflictKind,
|
||||
pub message: String,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum MergeConflictKind {
|
||||
DivergentInsert,
|
||||
DivergentUpdate,
|
||||
DeleteVsUpdate,
|
||||
OrphanEdge,
|
||||
UniqueViolation,
|
||||
CardinalityViolation,
|
||||
ValueConstraintViolation,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
pub enum OmniError {
|
||||
#[error("{0}")]
|
||||
Compiler(#[from] omnigraph_compiler::error::NanoError),
|
||||
#[error("storage: {0}")]
|
||||
Lance(String),
|
||||
#[error("query: {0}")]
|
||||
DataFusion(String),
|
||||
#[error("io: {0}")]
|
||||
Io(#[from] std::io::Error),
|
||||
#[error("{0}")]
|
||||
Manifest(ManifestError),
|
||||
#[error("merge conflicts: {0:?}")]
|
||||
MergeConflicts(Vec<MergeConflict>),
|
||||
}
|
||||
|
||||
impl OmniError {
|
||||
pub fn manifest(message: impl Into<String>) -> Self {
|
||||
Self::Manifest(ManifestError::new(ManifestErrorKind::BadRequest, message))
|
||||
}
|
||||
|
||||
pub fn manifest_not_found(message: impl Into<String>) -> Self {
|
||||
Self::Manifest(ManifestError::new(ManifestErrorKind::NotFound, message))
|
||||
}
|
||||
|
||||
pub fn manifest_conflict(message: impl Into<String>) -> Self {
|
||||
Self::Manifest(ManifestError::new(ManifestErrorKind::Conflict, message))
|
||||
}
|
||||
|
||||
pub fn manifest_internal(message: impl Into<String>) -> Self {
|
||||
Self::Manifest(ManifestError::new(ManifestErrorKind::Internal, message))
|
||||
}
|
||||
}
|
||||
4011
crates/omnigraph/src/exec/mod.rs
Normal file
4011
crates/omnigraph/src/exec/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
37
crates/omnigraph/src/failpoints.rs
Normal file
37
crates/omnigraph/src/failpoints.rs
Normal file
|
|
@ -0,0 +1,37 @@
|
|||
use crate::error::Result;
|
||||
|
||||
pub(crate) fn maybe_fail(_name: &str) -> Result<()> {
|
||||
#[cfg(feature = "failpoints")]
|
||||
{
|
||||
let name = _name;
|
||||
fail::fail_point!(name, |_| {
|
||||
return Err(crate::error::OmniError::manifest(format!(
|
||||
"injected failpoint triggered: {}",
|
||||
name
|
||||
)));
|
||||
});
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(feature = "failpoints")]
|
||||
pub struct ScopedFailPoint {
|
||||
name: String,
|
||||
}
|
||||
|
||||
#[cfg(feature = "failpoints")]
|
||||
impl ScopedFailPoint {
|
||||
pub fn new(name: &str, action: &str) -> Self {
|
||||
fail::cfg(name, action).expect("configure failpoint");
|
||||
Self {
|
||||
name: name.to_string(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(feature = "failpoints")]
|
||||
impl Drop for ScopedFailPoint {
|
||||
fn drop(&mut self) {
|
||||
fail::remove(&self.name);
|
||||
}
|
||||
}
|
||||
315
crates/omnigraph/src/graph_index/mod.rs
Normal file
315
crates/omnigraph/src/graph_index/mod.rs
Normal file
|
|
@ -0,0 +1,315 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use arrow_array::StringArray;
|
||||
use futures::TryStreamExt;
|
||||
|
||||
use crate::db::Snapshot;
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
/// Dense u32 mapping for a single node type: String ID ↔ dense index.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TypeIndex {
|
||||
id_to_dense: HashMap<String, u32>,
|
||||
dense_to_id: Vec<String>,
|
||||
}
|
||||
|
||||
impl TypeIndex {
|
||||
pub(crate) fn new() -> Self {
|
||||
Self {
|
||||
id_to_dense: HashMap::new(),
|
||||
dense_to_id: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get or insert a string ID, returning its dense index.
|
||||
pub(crate) fn get_or_insert(&mut self, id: &str) -> u32 {
|
||||
if let Some(&idx) = self.id_to_dense.get(id) {
|
||||
return idx;
|
||||
}
|
||||
let idx = self.dense_to_id.len() as u32;
|
||||
self.dense_to_id.push(id.to_string());
|
||||
self.id_to_dense.insert(id.to_string(), idx);
|
||||
idx
|
||||
}
|
||||
|
||||
pub fn to_dense(&self, id: &str) -> Option<u32> {
|
||||
self.id_to_dense.get(id).copied()
|
||||
}
|
||||
|
||||
pub fn to_id(&self, dense: u32) -> Option<&str> {
|
||||
self.dense_to_id.get(dense as usize).map(|s| s.as_str())
|
||||
}
|
||||
|
||||
pub fn len(&self) -> usize {
|
||||
self.dense_to_id.len()
|
||||
}
|
||||
}
|
||||
|
||||
/// CSR (Compressed Sparse Row) adjacency index.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct CsrIndex {
|
||||
/// offsets[i] .. offsets[i+1] gives the neighbor range for node i.
|
||||
offsets: Vec<u32>,
|
||||
/// Dense indices of destination nodes.
|
||||
targets: Vec<u32>,
|
||||
}
|
||||
|
||||
impl CsrIndex {
|
||||
pub(crate) fn build(num_nodes: usize, edges: &[(u32, u32)]) -> Self {
|
||||
// Count outgoing edges per source
|
||||
let mut counts = vec![0u32; num_nodes];
|
||||
for &(src, _) in edges {
|
||||
counts[src as usize] += 1;
|
||||
}
|
||||
|
||||
// Build offset array (prefix sum)
|
||||
let mut offsets = Vec::with_capacity(num_nodes + 1);
|
||||
offsets.push(0);
|
||||
for &c in &counts {
|
||||
offsets.push(offsets.last().unwrap() + c);
|
||||
}
|
||||
|
||||
// Fill targets
|
||||
let mut targets = vec![0u32; edges.len()];
|
||||
let mut cursors = vec![0u32; num_nodes];
|
||||
for &(src, dst) in edges {
|
||||
let s = src as usize;
|
||||
let pos = offsets[s] + cursors[s];
|
||||
targets[pos as usize] = dst;
|
||||
cursors[s] += 1;
|
||||
}
|
||||
|
||||
Self { offsets, targets }
|
||||
}
|
||||
|
||||
/// Return the dense indices of neighbors for a given dense node index.
|
||||
pub fn neighbors(&self, node: u32) -> &[u32] {
|
||||
let start = self.offsets[node as usize] as usize;
|
||||
let end = self.offsets[node as usize + 1] as usize;
|
||||
&self.targets[start..end]
|
||||
}
|
||||
|
||||
/// Check if a node has any outgoing edges. O(1), no allocation.
|
||||
pub fn has_neighbors(&self, node: u32) -> bool {
|
||||
let n = node as usize;
|
||||
self.offsets[n + 1] > self.offsets[n]
|
||||
}
|
||||
}
|
||||
|
||||
/// Topology-only graph index. No node data cached — just adjacency.
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct GraphIndex {
|
||||
/// Dense index per node type (built from edge src/dst columns).
|
||||
type_indices: HashMap<String, TypeIndex>,
|
||||
/// Outgoing adjacency per edge type.
|
||||
csr: HashMap<String, CsrIndex>,
|
||||
/// Incoming adjacency per edge type.
|
||||
csc: HashMap<String, CsrIndex>,
|
||||
}
|
||||
|
||||
impl GraphIndex {
|
||||
/// Build a graph index by scanning edge sub-tables from a snapshot.
|
||||
pub async fn build(
|
||||
snapshot: &Snapshot,
|
||||
edge_types: &HashMap<String, (String, String)>, // edge_name → (from_type, to_type)
|
||||
) -> Result<Self> {
|
||||
let mut type_indices: HashMap<String, TypeIndex> = HashMap::new();
|
||||
let mut csr = HashMap::new();
|
||||
let mut csc = HashMap::new();
|
||||
|
||||
// Phase 1: Scan all edges, build TypeIndices and collect edge pairs
|
||||
let mut edge_pairs: HashMap<String, Vec<(u32, u32)>> = HashMap::new();
|
||||
|
||||
for (edge_name, (from_type, to_type)) in edge_types {
|
||||
let table_key = format!("edge:{}", edge_name);
|
||||
if snapshot.entry(&table_key).is_none() {
|
||||
continue;
|
||||
}
|
||||
|
||||
let ds = snapshot.open(&table_key).await?;
|
||||
|
||||
let batches: Vec<arrow_array::RecordBatch> = ds
|
||||
.scan()
|
||||
.project(&["src", "dst"])
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
type_indices
|
||||
.entry(from_type.clone())
|
||||
.or_insert_with(TypeIndex::new);
|
||||
type_indices
|
||||
.entry(to_type.clone())
|
||||
.or_insert_with(TypeIndex::new);
|
||||
|
||||
let mut edges: Vec<(u32, u32)> = Vec::new();
|
||||
for batch in &batches {
|
||||
let srcs = string_column(batch, "src")?;
|
||||
let dsts = string_column(batch, "dst")?;
|
||||
|
||||
for i in 0..batch.num_rows() {
|
||||
let src_dense = type_indices
|
||||
.get_mut(from_type)
|
||||
.unwrap()
|
||||
.get_or_insert(srcs.value(i));
|
||||
let dst_dense = type_indices
|
||||
.get_mut(to_type)
|
||||
.unwrap()
|
||||
.get_or_insert(dsts.value(i));
|
||||
edges.push((src_dense, dst_dense));
|
||||
}
|
||||
}
|
||||
edge_pairs.insert(edge_name.clone(), edges);
|
||||
}
|
||||
|
||||
// Phase 2: Build CSR/CSC using final TypeIndex sizes
|
||||
for (edge_name, (from_type, to_type)) in edge_types {
|
||||
let Some(edges) = edge_pairs.get(edge_name) else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let src_count = type_indices[from_type].len();
|
||||
let dst_count = type_indices[to_type].len();
|
||||
|
||||
csr.insert(edge_name.clone(), CsrIndex::build(src_count, edges));
|
||||
|
||||
let reversed: Vec<(u32, u32)> = edges.iter().map(|&(s, d)| (d, s)).collect();
|
||||
csc.insert(edge_name.clone(), CsrIndex::build(dst_count, &reversed));
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
type_indices,
|
||||
csr,
|
||||
csc,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn type_index(&self, type_name: &str) -> Option<&TypeIndex> {
|
||||
self.type_indices.get(type_name)
|
||||
}
|
||||
|
||||
pub fn csr(&self, edge_type: &str) -> Option<&CsrIndex> {
|
||||
self.csr.get(edge_type)
|
||||
}
|
||||
|
||||
pub fn csc(&self, edge_type: &str) -> Option<&CsrIndex> {
|
||||
self.csc.get(edge_type)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn empty_for_test() -> Self {
|
||||
Self {
|
||||
type_indices: HashMap::new(),
|
||||
csr: HashMap::new(),
|
||||
csc: HashMap::new(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn string_column<'a>(batch: &'a arrow_array::RecordBatch, name: &str) -> Result<&'a StringArray> {
|
||||
batch
|
||||
.column_by_name(name)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("graph index batch missing '{name}' column"))
|
||||
})?
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!("graph index column '{name}' is not Utf8"))
|
||||
})
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::UInt64Array;
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn type_index_round_trip() {
|
||||
let mut idx = TypeIndex::new();
|
||||
let a = idx.get_or_insert("Alice");
|
||||
let b = idx.get_or_insert("Bob");
|
||||
let c = idx.get_or_insert("Charlie");
|
||||
|
||||
assert_eq!(idx.to_dense("Alice"), Some(a));
|
||||
assert_eq!(idx.to_dense("Bob"), Some(b));
|
||||
assert_eq!(idx.to_dense("Charlie"), Some(c));
|
||||
|
||||
assert_eq!(idx.to_id(a), Some("Alice"));
|
||||
assert_eq!(idx.to_id(b), Some("Bob"));
|
||||
assert_eq!(idx.to_id(c), Some("Charlie"));
|
||||
assert_eq!(idx.len(), 3);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_index_idempotent_insert() {
|
||||
let mut idx = TypeIndex::new();
|
||||
let a1 = idx.get_or_insert("Alice");
|
||||
let a2 = idx.get_or_insert("Alice");
|
||||
assert_eq!(a1, a2);
|
||||
assert_eq!(idx.len(), 1);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn type_index_unknown_returns_none() {
|
||||
let idx = TypeIndex::new();
|
||||
assert_eq!(idx.to_dense("unknown"), None);
|
||||
assert_eq!(idx.to_id(999), None);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_neighbors_correct() {
|
||||
// Graph: 0→1, 0→2, 1→2
|
||||
let edges = vec![(0, 1), (0, 2), (1, 2)];
|
||||
let csr = CsrIndex::build(3, &edges);
|
||||
|
||||
let mut n0: Vec<u32> = csr.neighbors(0).to_vec();
|
||||
n0.sort();
|
||||
assert_eq!(n0, vec![1, 2]);
|
||||
|
||||
assert_eq!(csr.neighbors(1), &[2]);
|
||||
assert_eq!(csr.neighbors(2), &[] as &[u32]);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_empty_graph() {
|
||||
let csr = CsrIndex::build(3, &[]);
|
||||
assert_eq!(csr.neighbors(0), &[] as &[u32]);
|
||||
assert_eq!(csr.neighbors(1), &[] as &[u32]);
|
||||
assert_eq!(csr.neighbors(2), &[] as &[u32]);
|
||||
assert!(!csr.has_neighbors(0));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn csr_has_neighbors() {
|
||||
// 0→1, 1→2
|
||||
let csr = CsrIndex::build(3, &[(0, 1), (1, 2)]);
|
||||
assert!(csr.has_neighbors(0));
|
||||
assert!(csr.has_neighbors(1));
|
||||
assert!(!csr.has_neighbors(2));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn string_column_returns_error_for_bad_schema() {
|
||||
let batch = arrow_array::RecordBatch::try_new(
|
||||
Arc::new(Schema::new(vec![Field::new(
|
||||
"src",
|
||||
DataType::UInt64,
|
||||
false,
|
||||
)])),
|
||||
vec![Arc::new(UInt64Array::from(vec![1_u64]))],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let err = string_column(&batch, "src").unwrap_err();
|
||||
assert!(err.to_string().contains("src"));
|
||||
}
|
||||
}
|
||||
11
crates/omnigraph/src/lib.rs
Normal file
11
crates/omnigraph/src/lib.rs
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
pub mod changes;
|
||||
pub mod db;
|
||||
pub mod embedding;
|
||||
pub mod error;
|
||||
mod exec;
|
||||
pub mod failpoints;
|
||||
pub mod graph_index;
|
||||
pub mod loader;
|
||||
pub mod runtime_cache;
|
||||
pub mod storage;
|
||||
pub mod table_store;
|
||||
476
crates/omnigraph/src/loader/constraints.rs
Normal file
476
crates/omnigraph/src/loader/constraints.rs
Normal file
|
|
@ -0,0 +1,476 @@
|
|||
use std::collections::HashMap;
|
||||
#[cfg(test)]
|
||||
use std::collections::HashSet;
|
||||
|
||||
use arrow_array::{
|
||||
Array, ArrayRef, BooleanArray, Date32Array, Date64Array, Float32Array, Float64Array,
|
||||
Int32Array, Int64Array, StringArray, UInt32Array, UInt64Array,
|
||||
};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
|
||||
use crate::catalog::schema_ir::SchemaIR;
|
||||
use crate::error::{NanoError, Result};
|
||||
|
||||
use super::super::graph::DatasetAccumulator;
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub(crate) struct NodeConstraintAnnotations {
|
||||
pub(crate) key_props: HashMap<String, String>,
|
||||
pub(crate) unique_props: HashMap<String, Vec<String>>,
|
||||
}
|
||||
|
||||
pub(crate) fn load_node_constraint_annotations(
|
||||
schema_ir: &SchemaIR,
|
||||
) -> Result<NodeConstraintAnnotations> {
|
||||
let mut constraints = NodeConstraintAnnotations::default();
|
||||
|
||||
for node in schema_ir.node_types() {
|
||||
let mut node_key_prop: Option<String> = None;
|
||||
let mut node_unique_props: Vec<String> = Vec::new();
|
||||
|
||||
for prop in &node.properties {
|
||||
if prop.key && node_key_prop.replace(prop.name.clone()).is_some() {
|
||||
return Err(NanoError::Storage(format!(
|
||||
"node type {} has multiple @key properties; only one is currently supported",
|
||||
node.name
|
||||
)));
|
||||
}
|
||||
if prop.unique {
|
||||
node_unique_props.push(prop.name.clone());
|
||||
}
|
||||
}
|
||||
|
||||
if let Some(prop_name) = node_key_prop {
|
||||
if !node_unique_props.contains(&prop_name) {
|
||||
node_unique_props.push(prop_name.clone());
|
||||
}
|
||||
constraints.key_props.insert(node.name.clone(), prop_name);
|
||||
}
|
||||
if !node_unique_props.is_empty() {
|
||||
node_unique_props.sort();
|
||||
node_unique_props.dedup();
|
||||
constraints
|
||||
.unique_props
|
||||
.insert(node.name.clone(), node_unique_props);
|
||||
}
|
||||
}
|
||||
|
||||
Ok(constraints)
|
||||
}
|
||||
|
||||
pub(crate) fn enforce_node_unique_constraints(
|
||||
storage: &DatasetAccumulator,
|
||||
unique_props: &HashMap<String, Vec<String>>,
|
||||
) -> Result<()> {
|
||||
for (type_name, properties) in unique_props {
|
||||
let Some(batch) = storage.get_all_nodes(type_name)? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
for property in properties {
|
||||
let prop_idx =
|
||||
node_property_index(batch.schema().as_ref(), property).ok_or_else(|| {
|
||||
NanoError::Storage(format!(
|
||||
"node type {} missing @unique property {}",
|
||||
type_name, property
|
||||
))
|
||||
})?;
|
||||
let arr = batch.column(prop_idx);
|
||||
let mut seen: HashMap<String, usize> = HashMap::new();
|
||||
for row in 0..batch.num_rows() {
|
||||
let Some(value) = unique_value_string(arr, row, type_name, property)? else {
|
||||
continue;
|
||||
};
|
||||
if let Some(prev_row) = seen.insert(value.clone(), row) {
|
||||
return Err(NanoError::UniqueConstraint {
|
||||
type_name: type_name.clone(),
|
||||
property: property.clone(),
|
||||
value,
|
||||
first_row: prev_row,
|
||||
second_row: row,
|
||||
});
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub(crate) fn collect_incoming_node_types(data_source: &str) -> Result<HashSet<String>> {
|
||||
let mut node_types = HashSet::new();
|
||||
for line in data_source.lines() {
|
||||
let line = line.trim();
|
||||
if line.is_empty() || line.starts_with("//") {
|
||||
continue;
|
||||
}
|
||||
|
||||
let obj: serde_json::Value = serde_json::from_str(line)
|
||||
.map_err(|e| NanoError::Storage(format!("JSON parse error: {}", e)))?;
|
||||
if let Some(type_name) = obj.get("type").and_then(|v| v.as_str()) {
|
||||
node_types.insert(type_name.to_string());
|
||||
}
|
||||
}
|
||||
Ok(node_types)
|
||||
}
|
||||
|
||||
pub(crate) fn build_name_seed_for_keyed_load(
|
||||
storage: &DatasetAccumulator,
|
||||
key_props: &HashMap<String, String>,
|
||||
) -> Result<HashMap<(String, String), u64>> {
|
||||
let mut seed = HashMap::new();
|
||||
|
||||
for (type_name, key_prop) in key_props {
|
||||
let Some(batch) = storage.get_all_nodes(type_name)? else {
|
||||
continue;
|
||||
};
|
||||
|
||||
let key_idx = node_property_index(batch.schema().as_ref(), key_prop).ok_or_else(|| {
|
||||
NanoError::Storage(format!(
|
||||
"node type {} missing @key property {}",
|
||||
type_name, key_prop
|
||||
))
|
||||
})?;
|
||||
let key_arr = batch.column(key_idx).clone();
|
||||
let id_arr = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.ok_or_else(|| {
|
||||
NanoError::Storage(format!("node type {} has non-UInt64 id column", type_name))
|
||||
})?;
|
||||
|
||||
for row in 0..batch.num_rows() {
|
||||
let key = key_value_string(&key_arr, row, key_prop)?;
|
||||
seed.insert((type_name.clone(), key), id_arr.value(row));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(seed)
|
||||
}
|
||||
|
||||
pub(crate) fn build_name_seed_for_append(
|
||||
storage: &DatasetAccumulator,
|
||||
key_props: &HashMap<String, String>,
|
||||
) -> Result<HashMap<(String, String), u64>> {
|
||||
build_name_seed_for_keyed_load(storage, key_props)
|
||||
}
|
||||
|
||||
pub(crate) fn node_property_index(schema: &Schema, prop_name: &str) -> Option<usize> {
|
||||
schema
|
||||
.fields()
|
||||
.iter()
|
||||
.enumerate()
|
||||
.skip(1)
|
||||
.find_map(|(idx, field)| (field.name() == prop_name).then_some(idx))
|
||||
}
|
||||
|
||||
pub(crate) fn node_property_field<'a>(schema: &'a Schema, prop_name: &str) -> Option<&'a Field> {
|
||||
node_property_index(schema, prop_name).map(|idx| schema.field(idx))
|
||||
}
|
||||
|
||||
pub(crate) fn key_value_string(array: &ArrayRef, row: usize, prop_name: &str) -> Result<String> {
|
||||
let value = scalar_value_string(array, row, "key", None, prop_name)?;
|
||||
if let Some(value) = value {
|
||||
return Ok(value);
|
||||
}
|
||||
Err(NanoError::Storage(format!(
|
||||
"@key property {} cannot be null",
|
||||
prop_name
|
||||
)))
|
||||
}
|
||||
|
||||
fn unique_value_string(
|
||||
array: &ArrayRef,
|
||||
row: usize,
|
||||
type_name: &str,
|
||||
prop_name: &str,
|
||||
) -> Result<Option<String>> {
|
||||
scalar_value_string(array, row, "unique", Some(type_name), prop_name)
|
||||
}
|
||||
|
||||
fn scalar_value_string(
|
||||
array: &ArrayRef,
|
||||
row: usize,
|
||||
annotation: &str,
|
||||
type_name: Option<&str>,
|
||||
prop_name: &str,
|
||||
) -> Result<Option<String>> {
|
||||
if array.is_null(row) {
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let value = match array.data_type() {
|
||||
DataType::Utf8 => array
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::Boolean => array
|
||||
.as_any()
|
||||
.downcast_ref::<BooleanArray>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::Int32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::Int64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Int64Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::UInt32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<UInt32Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::UInt64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::Float32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Float32Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::Float64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Float64Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::Date32 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Date32Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
DataType::Date64 => array
|
||||
.as_any()
|
||||
.downcast_ref::<Date64Array>()
|
||||
.map(|a| a.value(row).to_string()),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
let value = value.ok_or_else(|| {
|
||||
let target = match type_name {
|
||||
Some(name) => format!("{}.{}", name, prop_name),
|
||||
None => prop_name.to_string(),
|
||||
};
|
||||
NanoError::Storage(format!(
|
||||
"unsupported @{} data type {:?} for {}",
|
||||
annotation,
|
||||
array.data_type(),
|
||||
target
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(Some(value))
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
use arrow_array::StringArray;
|
||||
|
||||
use crate::catalog::schema_ir::{build_catalog_from_ir, build_schema_ir};
|
||||
use crate::schema::parser::parse_schema;
|
||||
|
||||
use super::super::jsonl::load_jsonl_data;
|
||||
use super::*;
|
||||
|
||||
fn build_schema_ir_and_storage(schema_src: &str) -> (SchemaIR, DatasetAccumulator) {
|
||||
let schema = parse_schema(schema_src).unwrap();
|
||||
let ir = build_schema_ir(&schema).unwrap();
|
||||
let catalog = build_catalog_from_ir(&ir).unwrap();
|
||||
(ir, DatasetAccumulator::new(catalog))
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn load_node_constraint_annotations_collects_key_and_unique() {
|
||||
let schema = r#"node Person {
|
||||
name: String @key
|
||||
email: String @unique
|
||||
alias: String? @unique
|
||||
}"#;
|
||||
let (ir, _) = build_schema_ir_and_storage(schema);
|
||||
let annotations = load_node_constraint_annotations(&ir).unwrap();
|
||||
|
||||
assert_eq!(annotations.key_props.get("Person").unwrap(), "name");
|
||||
assert_eq!(
|
||||
annotations.unique_props.get("Person").unwrap(),
|
||||
&vec!["alias".to_string(), "email".to_string(), "name".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn collect_incoming_node_types_ignores_comments_and_blanks() {
|
||||
let data = r#"
|
||||
// comment
|
||||
{"type":"Person","data":{"name":"Alice"}}
|
||||
|
||||
{"edge":"Knows","from":"Alice","to":"Bob"}
|
||||
{"type":"Company","data":{"name":"Acme"}}
|
||||
"#;
|
||||
let types = collect_incoming_node_types(data).unwrap();
|
||||
assert_eq!(
|
||||
types,
|
||||
HashSet::from(["Person".to_string(), "Company".to_string()])
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforce_node_unique_constraints_detects_duplicate_non_null() {
|
||||
let schema = r#"node Person {
|
||||
name: String
|
||||
email: String? @unique
|
||||
}"#;
|
||||
let (_, mut storage) = build_schema_ir_and_storage(schema);
|
||||
let key_props = HashMap::new();
|
||||
load_jsonl_data(
|
||||
&mut storage,
|
||||
r#"{"type":"Person","data":{"name":"Alice","email":"dupe@example.com"}}
|
||||
{"type":"Person","data":{"name":"Bob","email":"dupe@example.com"}}"#,
|
||||
&key_props,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let unique_props = HashMap::from([("Person".to_string(), vec!["email".to_string()])]);
|
||||
let err = enforce_node_unique_constraints(&storage, &unique_props).unwrap_err();
|
||||
match err {
|
||||
NanoError::UniqueConstraint {
|
||||
type_name,
|
||||
property,
|
||||
value,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(type_name, "Person");
|
||||
assert_eq!(property, "email");
|
||||
assert_eq!(value, "dupe@example.com");
|
||||
}
|
||||
other => panic!("expected UniqueConstraint, got {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforce_node_unique_constraints_allows_multiple_nulls() {
|
||||
let schema = r#"node Person {
|
||||
name: String
|
||||
nick: String? @unique
|
||||
}"#;
|
||||
let (_, mut storage) = build_schema_ir_and_storage(schema);
|
||||
let key_props = HashMap::new();
|
||||
load_jsonl_data(
|
||||
&mut storage,
|
||||
r#"{"type":"Person","data":{"name":"Alice","nick":null}}
|
||||
{"type":"Person","data":{"name":"Bob","nick":null}}"#,
|
||||
&key_props,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let unique_props = HashMap::from([("Person".to_string(), vec!["nick".to_string()])]);
|
||||
enforce_node_unique_constraints(&storage, &unique_props).unwrap();
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn enforce_node_unique_constraints_uses_user_property_named_id() {
|
||||
let schema = r#"node Person {
|
||||
id: String @unique
|
||||
name: String
|
||||
}"#;
|
||||
let (_, mut storage) = build_schema_ir_and_storage(schema);
|
||||
let key_props = HashMap::new();
|
||||
load_jsonl_data(
|
||||
&mut storage,
|
||||
r#"{"type":"Person","data":{"id":"user-1","name":"Alice"}}
|
||||
{"type":"Person","data":{"id":"user-1","name":"Bob"}}"#,
|
||||
&key_props,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let unique_props = HashMap::from([("Person".to_string(), vec!["id".to_string()])]);
|
||||
let err = enforce_node_unique_constraints(&storage, &unique_props).unwrap_err();
|
||||
match err {
|
||||
NanoError::UniqueConstraint {
|
||||
type_name,
|
||||
property,
|
||||
value,
|
||||
..
|
||||
} => {
|
||||
assert_eq!(type_name, "Person");
|
||||
assert_eq!(property, "id");
|
||||
assert_eq!(value, "user-1");
|
||||
}
|
||||
other => panic!("expected UniqueConstraint, got {other}"),
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_name_seed_for_keyed_load_uses_declared_key_property() {
|
||||
let schema = r#"node Person {
|
||||
uid: String @key
|
||||
name: String
|
||||
}
|
||||
node Company {
|
||||
name: String
|
||||
}"#;
|
||||
let (_, mut storage) = build_schema_ir_and_storage(schema);
|
||||
let key_props = HashMap::from([("Person".to_string(), "uid".to_string())]);
|
||||
load_jsonl_data(
|
||||
&mut storage,
|
||||
r#"{"type":"Person","data":{"uid":"u1","name":"Alice"}}
|
||||
{"type":"Company","data":{"name":"Acme"}}"#,
|
||||
&key_props,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let seed = build_name_seed_for_keyed_load(&storage, &key_props).unwrap();
|
||||
|
||||
assert!(seed.contains_key(&("Person".to_string(), "u1".to_string())));
|
||||
assert!(!seed.contains_key(&("Company".to_string(), "Acme".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_name_seed_for_keyed_load_uses_user_property_named_id() {
|
||||
let schema = r#"node Person {
|
||||
id: String @key
|
||||
name: String
|
||||
}"#;
|
||||
let (_, mut storage) = build_schema_ir_and_storage(schema);
|
||||
let key_props = HashMap::from([("Person".to_string(), "id".to_string())]);
|
||||
load_jsonl_data(
|
||||
&mut storage,
|
||||
r#"{"type":"Person","data":{"id":"user-1","name":"Alice"}}"#,
|
||||
&key_props,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let seed = build_name_seed_for_keyed_load(&storage, &key_props).unwrap();
|
||||
assert!(seed.contains_key(&("Person".to_string(), "user-1".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn build_name_seed_for_append_keeps_all_existing_keyed_nodes() {
|
||||
let schema = r#"node Person {
|
||||
uid: String @key
|
||||
name: String
|
||||
}
|
||||
node Company {
|
||||
name: String
|
||||
}"#;
|
||||
let (_, mut storage) = build_schema_ir_and_storage(schema);
|
||||
let key_props = HashMap::from([("Person".to_string(), "uid".to_string())]);
|
||||
load_jsonl_data(
|
||||
&mut storage,
|
||||
r#"{"type":"Person","data":{"uid":"u1","name":"Alice"}}
|
||||
{"type":"Company","data":{"name":"Acme"}}"#,
|
||||
&key_props,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let seed = build_name_seed_for_append(&storage, &key_props).unwrap();
|
||||
assert!(seed.contains_key(&("Person".to_string(), "u1".to_string())));
|
||||
assert!(!seed.contains_key(&("Company".to_string(), "Acme".to_string())));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn key_value_string_rejects_null() {
|
||||
let arr: ArrayRef = std::sync::Arc::new(StringArray::from(vec![Some("x"), None]));
|
||||
assert_eq!(key_value_string(&arr, 0, "name").unwrap(), "x");
|
||||
let err = key_value_string(&arr, 1, "name").unwrap_err();
|
||||
assert!(err.to_string().contains("cannot be null"));
|
||||
}
|
||||
}
|
||||
1732
crates/omnigraph/src/loader/embeddings.rs
Normal file
1732
crates/omnigraph/src/loader/embeddings.rs
Normal file
File diff suppressed because it is too large
Load diff
1532
crates/omnigraph/src/loader/jsonl.rs
Normal file
1532
crates/omnigraph/src/loader/jsonl.rs
Normal file
File diff suppressed because it is too large
Load diff
1631
crates/omnigraph/src/loader/mod.rs
Normal file
1631
crates/omnigraph/src/loader/mod.rs
Normal file
File diff suppressed because it is too large
Load diff
159
crates/omnigraph/src/runtime_cache.rs
Normal file
159
crates/omnigraph/src/runtime_cache.rs
Normal file
|
|
@ -0,0 +1,159 @@
|
|||
use std::collections::{HashMap, VecDeque};
|
||||
use std::sync::Arc;
|
||||
|
||||
use omnigraph_compiler::catalog::Catalog;
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use crate::db::ResolvedTarget;
|
||||
use crate::error::Result;
|
||||
use crate::graph_index::GraphIndex;
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct GraphIndexCacheKey {
|
||||
snapshot_id: String,
|
||||
edge_tables: Vec<GraphIndexTableState>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
|
||||
struct GraphIndexTableState {
|
||||
table_key: String,
|
||||
table_version: u64,
|
||||
table_branch: Option<String>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct RuntimeCache {
|
||||
graph_indices: Mutex<GraphIndexCache>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
struct GraphIndexCache {
|
||||
entries: HashMap<GraphIndexCacheKey, Arc<GraphIndex>>,
|
||||
lru: VecDeque<GraphIndexCacheKey>,
|
||||
}
|
||||
|
||||
impl RuntimeCache {
|
||||
pub async fn invalidate_all(&self) {
|
||||
let mut cache = self.graph_indices.lock().await;
|
||||
cache.entries.clear();
|
||||
cache.lru.clear();
|
||||
}
|
||||
|
||||
pub async fn graph_index(
|
||||
&self,
|
||||
resolved: &ResolvedTarget,
|
||||
catalog: &Catalog,
|
||||
) -> Result<Arc<GraphIndex>> {
|
||||
let key = graph_index_cache_key(resolved, catalog);
|
||||
{
|
||||
let mut cache = self.graph_indices.lock().await;
|
||||
if let Some(index) = cache.entries.get(&key).cloned() {
|
||||
cache.touch(key.clone());
|
||||
return Ok(index);
|
||||
}
|
||||
}
|
||||
|
||||
let edge_types = catalog
|
||||
.edge_types
|
||||
.iter()
|
||||
.map(|(name, et)| (name.clone(), (et.from_type.clone(), et.to_type.clone())))
|
||||
.collect();
|
||||
|
||||
let index = Arc::new(GraphIndex::build(&resolved.snapshot, &edge_types).await?);
|
||||
let mut cache = self.graph_indices.lock().await;
|
||||
if let Some(existing) = cache.entries.get(&key).cloned() {
|
||||
cache.touch(key);
|
||||
return Ok(existing);
|
||||
}
|
||||
cache.insert(key, Arc::clone(&index));
|
||||
Ok(index)
|
||||
}
|
||||
}
|
||||
|
||||
impl GraphIndexCache {
|
||||
fn insert(&mut self, key: GraphIndexCacheKey, value: Arc<GraphIndex>) {
|
||||
self.entries.insert(key.clone(), value);
|
||||
self.touch(key);
|
||||
while self.entries.len() > 8 {
|
||||
let Some(oldest) = self.lru.pop_front() else {
|
||||
break;
|
||||
};
|
||||
if self.entries.remove(&oldest).is_some() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn touch(&mut self, key: GraphIndexCacheKey) {
|
||||
self.lru.retain(|existing| existing != &key);
|
||||
self.lru.push_back(key);
|
||||
}
|
||||
}
|
||||
|
||||
fn graph_index_cache_key(resolved: &ResolvedTarget, catalog: &Catalog) -> GraphIndexCacheKey {
|
||||
let mut edge_tables: Vec<GraphIndexTableState> = catalog
|
||||
.edge_types
|
||||
.keys()
|
||||
.filter_map(|edge_name| {
|
||||
let table_key = format!("edge:{}", edge_name);
|
||||
resolved
|
||||
.snapshot
|
||||
.entry(&table_key)
|
||||
.map(|entry| GraphIndexTableState {
|
||||
table_key,
|
||||
table_version: entry.table_version,
|
||||
table_branch: entry.table_branch.clone(),
|
||||
})
|
||||
})
|
||||
.collect();
|
||||
edge_tables.sort_by(|a, b| a.table_key.cmp(&b.table_key));
|
||||
|
||||
GraphIndexCacheKey {
|
||||
snapshot_id: resolved.snapshot_id.as_str().to_string(),
|
||||
edge_tables,
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use std::sync::Arc;
|
||||
|
||||
use super::*;
|
||||
|
||||
fn key(id: usize) -> GraphIndexCacheKey {
|
||||
GraphIndexCacheKey {
|
||||
snapshot_id: format!("snap-{id}"),
|
||||
edge_tables: Vec::new(),
|
||||
}
|
||||
}
|
||||
|
||||
fn empty_index() -> Arc<GraphIndex> {
|
||||
Arc::new(GraphIndex::empty_for_test())
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn graph_index_cache_evicts_oldest_entry() {
|
||||
let mut cache = GraphIndexCache::default();
|
||||
for idx in 0..9 {
|
||||
cache.insert(key(idx), empty_index());
|
||||
}
|
||||
|
||||
assert_eq!(cache.entries.len(), 8);
|
||||
assert!(!cache.entries.contains_key(&key(0)));
|
||||
assert!(cache.entries.contains_key(&key(8)));
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn graph_index_cache_touch_keeps_recent_entry() {
|
||||
let mut cache = GraphIndexCache::default();
|
||||
for idx in 0..8 {
|
||||
cache.insert(key(idx), empty_index());
|
||||
}
|
||||
|
||||
cache.touch(key(0));
|
||||
cache.insert(key(8), empty_index());
|
||||
|
||||
assert!(cache.entries.contains_key(&key(0)));
|
||||
assert!(!cache.entries.contains_key(&key(1)));
|
||||
}
|
||||
}
|
||||
325
crates/omnigraph/src/storage.rs
Normal file
325
crates/omnigraph/src/storage.rs
Normal file
|
|
@ -0,0 +1,325 @@
|
|||
use std::env;
|
||||
use std::fmt::Debug;
|
||||
use std::path::{Path, PathBuf};
|
||||
use std::sync::Arc;
|
||||
|
||||
use async_trait::async_trait;
|
||||
use futures::TryStreamExt;
|
||||
use object_store::aws::AmazonS3Builder;
|
||||
use object_store::path::Path as ObjectPath;
|
||||
use object_store::{DynObjectStore, ObjectStore, PutPayload};
|
||||
use url::Url;
|
||||
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
const FILE_SCHEME_PREFIX: &str = "file://";
|
||||
const S3_SCHEME_PREFIX: &str = "s3://";
|
||||
|
||||
#[async_trait]
|
||||
pub trait StorageAdapter: Debug + Send + Sync {
|
||||
async fn read_text(&self, uri: &str) -> Result<String>;
|
||||
async fn write_text(&self, uri: &str, contents: &str) -> Result<()>;
|
||||
async fn exists(&self, uri: &str) -> Result<bool>;
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||
pub enum StorageKind {
|
||||
Local,
|
||||
S3,
|
||||
}
|
||||
|
||||
#[derive(Debug, Default)]
|
||||
pub struct LocalStorageAdapter;
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct S3StorageAdapter {
|
||||
bucket: String,
|
||||
store: Arc<DynObjectStore>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
struct S3Location {
|
||||
bucket: String,
|
||||
key: String,
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StorageAdapter for LocalStorageAdapter {
|
||||
async fn read_text(&self, uri: &str) -> Result<String> {
|
||||
let path = local_path_from_uri(uri)?;
|
||||
Ok(tokio::fs::read_to_string(&path).await?)
|
||||
}
|
||||
|
||||
async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
|
||||
let path = local_path_from_uri(uri)?;
|
||||
tokio::fs::write(&path, contents).await?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn exists(&self, uri: &str) -> Result<bool> {
|
||||
Ok(local_path_from_uri(uri)?.exists())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl StorageAdapter for S3StorageAdapter {
|
||||
async fn read_text(&self, uri: &str) -> Result<String> {
|
||||
let location = self.object_path(uri)?;
|
||||
let bytes = self
|
||||
.store
|
||||
.get(&location)
|
||||
.await
|
||||
.map_err(|err| storage_backend_error("read", uri, err))?
|
||||
.bytes()
|
||||
.await
|
||||
.map_err(|err| storage_backend_error("read", uri, err))?;
|
||||
|
||||
String::from_utf8(bytes.to_vec()).map_err(|err| {
|
||||
OmniError::manifest_internal(format!("storage read failed for '{}': {}", uri, err))
|
||||
})
|
||||
}
|
||||
|
||||
async fn write_text(&self, uri: &str, contents: &str) -> Result<()> {
|
||||
let location = self.object_path(uri)?;
|
||||
self.store
|
||||
.put(&location, PutPayload::from(contents.as_bytes().to_vec()))
|
||||
.await
|
||||
.map_err(|err| storage_backend_error("write", uri, err))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn exists(&self, uri: &str) -> Result<bool> {
|
||||
let location = self.object_path(uri)?;
|
||||
match self.store.head(&location).await {
|
||||
Ok(_) => Ok(true),
|
||||
Err(object_store::Error::NotFound { .. }) => {
|
||||
let mut entries = self.store.list(Some(&location));
|
||||
let has_prefix_entries = entries
|
||||
.try_next()
|
||||
.await
|
||||
.map_err(|err| storage_backend_error("exists", uri, err))?
|
||||
.is_some();
|
||||
Ok(has_prefix_entries)
|
||||
}
|
||||
Err(err) => Err(storage_backend_error("exists", uri, err)),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl S3StorageAdapter {
|
||||
fn from_root_uri(root_uri: &str) -> Result<Self> {
|
||||
let location = parse_s3_uri(root_uri)?;
|
||||
let mut builder = AmazonS3Builder::from_env().with_bucket_name(&location.bucket);
|
||||
|
||||
if let Some(endpoint) = env::var("AWS_ENDPOINT_URL_S3")
|
||||
.ok()
|
||||
.or_else(|| env::var("AWS_ENDPOINT_URL").ok())
|
||||
{
|
||||
builder = builder.with_endpoint(&endpoint);
|
||||
if endpoint.starts_with("http://") || env_var_truthy("AWS_ALLOW_HTTP") {
|
||||
builder = builder.with_allow_http(true);
|
||||
}
|
||||
}
|
||||
|
||||
if env_var_truthy("AWS_S3_FORCE_PATH_STYLE") {
|
||||
builder = builder.with_virtual_hosted_style_request(false);
|
||||
}
|
||||
|
||||
let store = builder.build().map_err(|err| {
|
||||
OmniError::manifest_internal(format!(
|
||||
"failed to initialize s3 storage for '{}': {}",
|
||||
root_uri, err
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
bucket: location.bucket,
|
||||
store: Arc::new(store),
|
||||
})
|
||||
}
|
||||
|
||||
fn object_path(&self, uri: &str) -> Result<ObjectPath> {
|
||||
let location = parse_s3_uri(uri)?;
|
||||
if location.bucket != self.bucket {
|
||||
return Err(OmniError::manifest_internal(format!(
|
||||
"s3 storage bucket mismatch for '{}': expected '{}', found '{}'",
|
||||
uri, self.bucket, location.bucket
|
||||
)));
|
||||
}
|
||||
if location.key.is_empty() {
|
||||
return Err(OmniError::manifest_internal(format!(
|
||||
"s3 storage path is empty for '{}'",
|
||||
uri
|
||||
)));
|
||||
}
|
||||
ObjectPath::parse(&location.key).map_err(|err| {
|
||||
OmniError::manifest_internal(format!("invalid s3 object path for '{}': {}", uri, err))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage_kind_for_uri(uri: &str) -> StorageKind {
|
||||
if uri.starts_with(S3_SCHEME_PREFIX) {
|
||||
StorageKind::S3
|
||||
} else {
|
||||
StorageKind::Local
|
||||
}
|
||||
}
|
||||
|
||||
pub fn storage_for_uri(uri: &str) -> Result<Arc<dyn StorageAdapter>> {
|
||||
match storage_kind_for_uri(uri) {
|
||||
StorageKind::Local => Ok(Arc::new(LocalStorageAdapter)),
|
||||
StorageKind::S3 => Ok(Arc::new(S3StorageAdapter::from_root_uri(uri)?)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn normalize_root_uri(uri: &str) -> Result<String> {
|
||||
match storage_kind_for_uri(uri) {
|
||||
StorageKind::Local => {
|
||||
let path = local_path_from_uri(uri)?;
|
||||
Ok(normalize_local_path(&path))
|
||||
}
|
||||
StorageKind::S3 => Ok(trim_trailing_slashes(uri)),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn join_uri(root_uri: &str, relative_path: &str) -> String {
|
||||
let relative_path = relative_path.trim_start_matches('/');
|
||||
match storage_kind_for_uri(root_uri) {
|
||||
StorageKind::S3 => {
|
||||
let root = trim_trailing_slashes(root_uri);
|
||||
if root.is_empty() {
|
||||
relative_path.to_string()
|
||||
} else {
|
||||
format!("{}/{}", root, relative_path)
|
||||
}
|
||||
}
|
||||
StorageKind::Local => {
|
||||
let root = if root_uri.starts_with(FILE_SCHEME_PREFIX) {
|
||||
local_path_from_file_uri(root_uri)
|
||||
.map(|path| normalize_local_path(&path))
|
||||
.unwrap_or_else(|_| trim_trailing_slashes(root_uri))
|
||||
} else {
|
||||
normalize_local_path(Path::new(root_uri))
|
||||
};
|
||||
let joined = Path::new(&root).join(relative_path);
|
||||
normalize_local_path(&joined)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn local_path_from_uri(uri: &str) -> Result<PathBuf> {
|
||||
if uri.starts_with(FILE_SCHEME_PREFIX) {
|
||||
return local_path_from_file_uri(uri);
|
||||
}
|
||||
Ok(PathBuf::from(uri))
|
||||
}
|
||||
|
||||
fn local_path_from_file_uri(uri: &str) -> Result<PathBuf> {
|
||||
let url = Url::parse(uri).map_err(|err| {
|
||||
OmniError::manifest_internal(format!("invalid file uri '{}': {}", uri, err))
|
||||
})?;
|
||||
url.to_file_path()
|
||||
.map_err(|_| OmniError::manifest_internal(format!("invalid file uri '{}'", uri)))
|
||||
}
|
||||
|
||||
fn parse_s3_uri(uri: &str) -> Result<S3Location> {
|
||||
let url = Url::parse(uri).map_err(|err| {
|
||||
OmniError::manifest_internal(format!("invalid s3 uri '{}': {}", uri, err))
|
||||
})?;
|
||||
if url.scheme() != "s3" {
|
||||
return Err(OmniError::manifest_internal(format!(
|
||||
"unsupported s3 uri '{}'",
|
||||
uri
|
||||
)));
|
||||
}
|
||||
let bucket = url
|
||||
.host_str()
|
||||
.ok_or_else(|| OmniError::manifest_internal(format!("missing s3 bucket in '{}'", uri)))?;
|
||||
Ok(S3Location {
|
||||
bucket: bucket.to_string(),
|
||||
key: url.path().trim_start_matches('/').to_string(),
|
||||
})
|
||||
}
|
||||
|
||||
fn storage_backend_error(action: &str, uri: &str, err: impl std::fmt::Display) -> OmniError {
|
||||
OmniError::manifest_internal(format!("storage {} failed for '{}': {}", action, uri, err))
|
||||
}
|
||||
|
||||
fn normalize_local_path(path: &Path) -> String {
|
||||
let raw = path.as_os_str().to_string_lossy();
|
||||
if raw == "/" {
|
||||
return raw.to_string();
|
||||
}
|
||||
trim_trailing_slashes(&raw)
|
||||
}
|
||||
|
||||
fn trim_trailing_slashes(value: &str) -> String {
|
||||
let trimmed = value.trim_end_matches('/');
|
||||
if trimmed.is_empty() {
|
||||
value.to_string()
|
||||
} else {
|
||||
trimmed.to_string()
|
||||
}
|
||||
}
|
||||
|
||||
fn env_var_truthy(key: &str) -> bool {
|
||||
matches!(
|
||||
env::var(key).ok().as_deref(),
|
||||
Some("1" | "true" | "TRUE" | "True" | "yes" | "YES" | "on" | "ON")
|
||||
)
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
mod tests {
|
||||
use super::*;
|
||||
|
||||
#[test]
|
||||
fn storage_backend_selection_is_scheme_aware() {
|
||||
assert_eq!(storage_kind_for_uri("/tmp/repo"), StorageKind::Local);
|
||||
assert_eq!(storage_kind_for_uri("file:///tmp/repo"), StorageKind::Local);
|
||||
assert_eq!(
|
||||
storage_kind_for_uri("s3://omnigraph-preview/repo"),
|
||||
StorageKind::S3
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn normalize_root_uri_preserves_local_and_s3_shapes() {
|
||||
assert_eq!(
|
||||
normalize_root_uri("/tmp/omnigraph/").unwrap(),
|
||||
"/tmp/omnigraph"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_root_uri("file:///tmp/omnigraph/").unwrap(),
|
||||
"/tmp/omnigraph"
|
||||
);
|
||||
assert_eq!(
|
||||
normalize_root_uri("s3://bucket/prefix/").unwrap(),
|
||||
"s3://bucket/prefix"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn join_uri_handles_local_file_and_s3_roots() {
|
||||
assert_eq!(
|
||||
join_uri("/tmp/omnigraph", "_schema.pg"),
|
||||
"/tmp/omnigraph/_schema.pg"
|
||||
);
|
||||
assert_eq!(
|
||||
join_uri("file:///tmp/omnigraph", "_schema.pg"),
|
||||
"/tmp/omnigraph/_schema.pg"
|
||||
);
|
||||
assert_eq!(
|
||||
join_uri("s3://bucket/prefix", "_schema.pg"),
|
||||
"s3://bucket/prefix/_schema.pg"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn parse_s3_uri_splits_bucket_and_key() {
|
||||
let location = parse_s3_uri("s3://bucket/repo/_schema.pg").unwrap();
|
||||
assert_eq!(location.bucket, "bucket");
|
||||
assert_eq!(location.key, "repo/_schema.pg");
|
||||
}
|
||||
}
|
||||
603
crates/omnigraph/src/table_store.rs
Normal file
603
crates/omnigraph/src/table_store.rs
Normal file
|
|
@ -0,0 +1,603 @@
|
|||
use arrow_array::{RecordBatch, UInt64Array};
|
||||
use arrow_schema::SchemaRef;
|
||||
use arrow_select::concat::concat_batches;
|
||||
use futures::TryStreamExt;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::scanner::{ColumnOrdering, DatasetRecordBatchStream, Scanner};
|
||||
use lance::dataset::{MergeInsertBuilder, WhenMatched, WhenNotMatched, WriteMode, WriteParams};
|
||||
use lance::datatypes::BlobHandling;
|
||||
use lance::index::scalar::IndexDetails;
|
||||
use lance_file::version::LanceFileVersion;
|
||||
use lance_index::scalar::{InvertedIndexParams, ScalarIndexParams};
|
||||
use lance_index::{DatasetIndexExt, IndexType, is_system_index};
|
||||
use lance_linalg::distance::MetricType;
|
||||
use lance_table::format::IndexMetadata;
|
||||
use std::sync::Arc;
|
||||
|
||||
use crate::db::manifest::{TableVersionMetadata, open_table_head_for_write};
|
||||
use crate::db::{Snapshot, SubTableEntry};
|
||||
use crate::error::{OmniError, Result};
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct TableState {
|
||||
pub version: u64,
|
||||
pub row_count: u64,
|
||||
pub(crate) version_metadata: TableVersionMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct DeleteState {
|
||||
pub version: u64,
|
||||
pub row_count: u64,
|
||||
pub deleted_rows: usize,
|
||||
pub(crate) version_metadata: TableVersionMetadata,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
pub struct TableStore {
|
||||
root_uri: String,
|
||||
}
|
||||
|
||||
impl TableStore {
|
||||
pub fn new(root_uri: &str) -> Self {
|
||||
Self {
|
||||
root_uri: root_uri.trim_end_matches('/').to_string(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn root_uri(&self) -> &str {
|
||||
&self.root_uri
|
||||
}
|
||||
|
||||
pub fn dataset_uri(&self, table_path: &str) -> String {
|
||||
format!("{}/{}", self.root_uri, table_path)
|
||||
}
|
||||
|
||||
fn table_path_from_dataset_uri(&self, dataset_uri: &str) -> Result<String> {
|
||||
let prefix = format!("{}/", self.root_uri.trim_end_matches('/'));
|
||||
let table_path = dataset_uri
|
||||
.strip_prefix(&prefix)
|
||||
.map(|path| path.to_string())
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!(
|
||||
"dataset uri '{}' is not under root '{}'",
|
||||
dataset_uri, self.root_uri
|
||||
))
|
||||
})?;
|
||||
Ok(table_path
|
||||
.split_once("/tree/")
|
||||
.map(|(path, _)| path.to_string())
|
||||
.unwrap_or(table_path))
|
||||
}
|
||||
|
||||
fn dataset_version_metadata(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
ds: &Dataset,
|
||||
) -> Result<TableVersionMetadata> {
|
||||
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
|
||||
TableVersionMetadata::from_dataset(&self.root_uri, &table_path, ds)
|
||||
}
|
||||
|
||||
pub async fn open_snapshot_table(
|
||||
&self,
|
||||
snapshot: &Snapshot,
|
||||
table_key: &str,
|
||||
) -> Result<Dataset> {
|
||||
snapshot.open(table_key).await
|
||||
}
|
||||
|
||||
pub async fn open_at_entry(&self, entry: &SubTableEntry) -> Result<Dataset> {
|
||||
entry.open(&self.root_uri).await
|
||||
}
|
||||
|
||||
pub async fn open_dataset_head(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
branch: Option<&str>,
|
||||
) -> Result<Dataset> {
|
||||
let ds = Dataset::open(dataset_uri)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
match branch {
|
||||
Some(branch) if branch != "main" => ds
|
||||
.checkout_branch(branch)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string())),
|
||||
_ => Ok(ds),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn open_dataset_head_for_write(
|
||||
&self,
|
||||
table_key: &str,
|
||||
dataset_uri: &str,
|
||||
branch: Option<&str>,
|
||||
) -> Result<Dataset> {
|
||||
let table_path = self.table_path_from_dataset_uri(dataset_uri)?;
|
||||
open_table_head_for_write(&self.root_uri, table_key, &table_path, branch).await
|
||||
}
|
||||
|
||||
pub async fn delete_branch(&self, dataset_uri: &str, branch: &str) -> Result<()> {
|
||||
let mut ds = Dataset::open(dataset_uri)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
ds.delete_branch(branch)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn open_dataset_at_state(
|
||||
&self,
|
||||
table_path: &str,
|
||||
branch: Option<&str>,
|
||||
version: u64,
|
||||
) -> Result<Dataset> {
|
||||
let ds = self
|
||||
.open_dataset_head(&self.dataset_uri(table_path), branch)
|
||||
.await?;
|
||||
ds.checkout_version(version)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn ensure_expected_version(
|
||||
&self,
|
||||
ds: &Dataset,
|
||||
table_key: &str,
|
||||
expected_version: u64,
|
||||
) -> Result<()> {
|
||||
if ds.version().version != expected_version {
|
||||
return Err(OmniError::manifest_conflict(format!(
|
||||
"version drift on {}: snapshot pinned v{} but dataset is at v{} — call sync_branch() and retry",
|
||||
table_key,
|
||||
expected_version,
|
||||
ds.version().version
|
||||
)));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub async fn reopen_for_mutation(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
branch: Option<&str>,
|
||||
table_key: &str,
|
||||
expected_version: u64,
|
||||
) -> Result<Dataset> {
|
||||
let ds = self
|
||||
.open_dataset_head_for_write(table_key, dataset_uri, branch)
|
||||
.await?;
|
||||
self.ensure_expected_version(&ds, table_key, expected_version)?;
|
||||
Ok(ds)
|
||||
}
|
||||
|
||||
pub async fn fork_branch_from_state(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
source_branch: Option<&str>,
|
||||
table_key: &str,
|
||||
source_version: u64,
|
||||
target_branch: &str,
|
||||
) -> Result<Dataset> {
|
||||
let mut source_ds = self
|
||||
.open_dataset_head(dataset_uri, source_branch)
|
||||
.await?
|
||||
.checkout_version(source_version)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.ensure_expected_version(&source_ds, table_key, source_version)?;
|
||||
|
||||
match source_ds
|
||||
.create_branch(target_branch, source_version, None)
|
||||
.await
|
||||
{
|
||||
Ok(_) => {}
|
||||
Err(create_err) => match self
|
||||
.open_dataset_head(dataset_uri, Some(target_branch))
|
||||
.await
|
||||
{
|
||||
Ok(ds) => {
|
||||
self.ensure_expected_version(&ds, table_key, source_version)?;
|
||||
return Ok(ds);
|
||||
}
|
||||
Err(_) => return Err(OmniError::Lance(create_err.to_string())),
|
||||
},
|
||||
}
|
||||
|
||||
let ds = self
|
||||
.open_dataset_head(dataset_uri, Some(target_branch))
|
||||
.await?;
|
||||
self.ensure_expected_version(&ds, table_key, source_version)?;
|
||||
Ok(ds)
|
||||
}
|
||||
|
||||
pub async fn scan_batches(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
|
||||
self.scan(ds, None, None, None).await
|
||||
}
|
||||
|
||||
pub async fn scan_batches_for_rewrite(&self, ds: &Dataset) -> Result<Vec<RecordBatch>> {
|
||||
let has_blob_columns = ds.schema().fields_pre_order().any(|field| field.is_blob());
|
||||
if !has_blob_columns {
|
||||
return self.scan_batches(ds).await;
|
||||
}
|
||||
|
||||
let mut scanner = ds.scan();
|
||||
scanner.blob_handling(BlobHandling::AllBinary);
|
||||
scanner
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn scan_stream(
|
||||
ds: &Dataset,
|
||||
projection: Option<&[&str]>,
|
||||
filter: Option<&str>,
|
||||
order_by: Option<Vec<ColumnOrdering>>,
|
||||
with_row_id: bool,
|
||||
) -> Result<DatasetRecordBatchStream> {
|
||||
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, |_| Ok(())).await
|
||||
}
|
||||
|
||||
pub async fn scan_stream_with<F>(
|
||||
ds: &Dataset,
|
||||
projection: Option<&[&str]>,
|
||||
filter: Option<&str>,
|
||||
order_by: Option<Vec<ColumnOrdering>>,
|
||||
with_row_id: bool,
|
||||
configure: F,
|
||||
) -> Result<DatasetRecordBatchStream>
|
||||
where
|
||||
F: FnOnce(&mut Scanner) -> Result<()>,
|
||||
{
|
||||
let mut scanner = ds.scan();
|
||||
if with_row_id {
|
||||
scanner.with_row_id();
|
||||
}
|
||||
if let Some(columns) = projection {
|
||||
scanner
|
||||
.project(columns)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
}
|
||||
if let Some(filter_sql) = filter {
|
||||
scanner
|
||||
.filter(filter_sql)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
}
|
||||
if let Some(ordering) = order_by {
|
||||
scanner
|
||||
.order_by(Some(ordering))
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
}
|
||||
configure(&mut scanner)?;
|
||||
scanner
|
||||
.try_into_stream()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn scan(
|
||||
&self,
|
||||
ds: &Dataset,
|
||||
projection: Option<&[&str]>,
|
||||
filter: Option<&str>,
|
||||
order_by: Option<Vec<ColumnOrdering>>,
|
||||
) -> Result<Vec<RecordBatch>> {
|
||||
Self::scan_stream(ds, projection, filter, order_by, false)
|
||||
.await?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn scan_with<F>(
|
||||
&self,
|
||||
ds: &Dataset,
|
||||
projection: Option<&[&str]>,
|
||||
filter: Option<&str>,
|
||||
order_by: Option<Vec<ColumnOrdering>>,
|
||||
with_row_id: bool,
|
||||
configure: F,
|
||||
) -> Result<Vec<RecordBatch>>
|
||||
where
|
||||
F: FnOnce(&mut Scanner) -> Result<()>,
|
||||
{
|
||||
Self::scan_stream_with(ds, projection, filter, order_by, with_row_id, configure)
|
||||
.await?
|
||||
.try_collect()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn count_rows(&self, ds: &Dataset, filter: Option<String>) -> Result<usize> {
|
||||
ds.count_rows(filter)
|
||||
.await
|
||||
.map(|count| count as usize)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub fn dataset_version(&self, ds: &Dataset) -> u64 {
|
||||
ds.version().version
|
||||
}
|
||||
|
||||
pub async fn table_state(&self, dataset_uri: &str, ds: &Dataset) -> Result<TableState> {
|
||||
Ok(TableState {
|
||||
version: self.dataset_version(ds),
|
||||
row_count: self.count_rows(ds, None).await? as u64,
|
||||
version_metadata: self.dataset_version_metadata(dataset_uri, ds)?,
|
||||
})
|
||||
}
|
||||
|
||||
pub async fn append_batch(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
ds: &mut Dataset,
|
||||
batch: RecordBatch,
|
||||
) -> Result<TableState> {
|
||||
if batch.num_rows() == 0 {
|
||||
return self.table_state(dataset_uri, ds).await;
|
||||
}
|
||||
let schema = batch.schema();
|
||||
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Append,
|
||||
allow_external_blob_outside_bases: true,
|
||||
..Default::default()
|
||||
};
|
||||
ds.append(reader, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.table_state(dataset_uri, ds).await
|
||||
}
|
||||
|
||||
pub async fn append_or_create_batch(
|
||||
dataset_uri: &str,
|
||||
dataset: Option<Dataset>,
|
||||
batch: RecordBatch,
|
||||
) -> Result<Dataset> {
|
||||
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
|
||||
match dataset {
|
||||
Some(mut ds) => {
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Append,
|
||||
allow_external_blob_outside_bases: true,
|
||||
..Default::default()
|
||||
};
|
||||
ds.append(reader, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
Ok(ds)
|
||||
}
|
||||
None => {
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
allow_external_blob_outside_bases: true,
|
||||
..Default::default()
|
||||
};
|
||||
Dataset::write(reader, dataset_uri, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn overwrite_batch(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
ds: &mut Dataset,
|
||||
batch: RecordBatch,
|
||||
) -> Result<TableState> {
|
||||
ds.truncate_table()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.append_batch(dataset_uri, ds, batch).await
|
||||
}
|
||||
|
||||
pub async fn merge_insert_batch(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
ds: Dataset,
|
||||
batch: RecordBatch,
|
||||
key_columns: Vec<String>,
|
||||
when_matched: WhenMatched,
|
||||
when_not_matched: WhenNotMatched,
|
||||
) -> Result<TableState> {
|
||||
if batch.num_rows() == 0 {
|
||||
return self.table_state(dataset_uri, &ds).await;
|
||||
}
|
||||
|
||||
// TODO(lance-upstream): MergeInsertBuilder does not accept WriteParams,
|
||||
// so allow_external_blob_outside_bases cannot be set here. External URI
|
||||
// blobs via merge_insert (LoadMode::Merge, mutations) are unsupported
|
||||
// until Lance exposes WriteParams on MergeInsertBuilder.
|
||||
let ds = Arc::new(ds);
|
||||
let job = MergeInsertBuilder::try_new(ds, key_columns)
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
.when_matched(when_matched)
|
||||
.when_not_matched(when_not_matched)
|
||||
.try_build()
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
|
||||
let schema = batch.schema();
|
||||
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let (new_ds, _stats) = job
|
||||
.execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
self.table_state(dataset_uri, &new_ds).await
|
||||
}
|
||||
|
||||
pub async fn merge_insert_batches(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
ds: Dataset,
|
||||
batches: Vec<RecordBatch>,
|
||||
key_columns: Vec<String>,
|
||||
when_matched: WhenMatched,
|
||||
when_not_matched: WhenNotMatched,
|
||||
) -> Result<TableState> {
|
||||
if batches.is_empty() {
|
||||
return self.table_state(dataset_uri, &ds).await;
|
||||
}
|
||||
let batch = if batches.len() == 1 {
|
||||
batches.into_iter().next().unwrap()
|
||||
} else {
|
||||
let schema = batches[0].schema();
|
||||
concat_batches(&schema, &batches).map_err(|e| OmniError::Lance(e.to_string()))?
|
||||
};
|
||||
self.merge_insert_batch(
|
||||
dataset_uri,
|
||||
ds,
|
||||
batch,
|
||||
key_columns,
|
||||
when_matched,
|
||||
when_not_matched,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn delete_where(
|
||||
&self,
|
||||
dataset_uri: &str,
|
||||
ds: &mut Dataset,
|
||||
filter: &str,
|
||||
) -> Result<DeleteState> {
|
||||
let delete_result = ds
|
||||
.delete(filter)
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
Ok(DeleteState {
|
||||
version: delete_result.new_dataset.version().version,
|
||||
row_count: self.count_rows(&delete_result.new_dataset, None).await? as u64,
|
||||
deleted_rows: delete_result.num_deleted_rows as usize,
|
||||
version_metadata: self
|
||||
.dataset_version_metadata(dataset_uri, &delete_result.new_dataset)?,
|
||||
})
|
||||
}
|
||||
|
||||
async fn user_indices_for_column(
|
||||
&self,
|
||||
ds: &Dataset,
|
||||
column: &str,
|
||||
) -> Result<Vec<IndexMetadata>> {
|
||||
let field_id = ds
|
||||
.schema()
|
||||
.field(column)
|
||||
.map(|field| field.id)
|
||||
.ok_or_else(|| {
|
||||
OmniError::manifest_internal(format!(
|
||||
"dataset is missing expected index column '{}'",
|
||||
column
|
||||
))
|
||||
})?;
|
||||
let indices = ds
|
||||
.load_indices()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
Ok(indices
|
||||
.iter()
|
||||
.filter(|index| !is_system_index(index))
|
||||
.filter(|index| index.fields.len() == 1 && index.fields[0] == field_id)
|
||||
.cloned()
|
||||
.collect())
|
||||
}
|
||||
|
||||
pub async fn has_btree_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
|
||||
let indices = self.user_indices_for_column(ds, column).await?;
|
||||
Ok(indices.iter().any(|index| {
|
||||
index
|
||||
.index_details
|
||||
.as_ref()
|
||||
.map(|details| details.type_url.ends_with("BTreeIndexDetails"))
|
||||
.unwrap_or(false)
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn has_fts_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
|
||||
let indices = self.user_indices_for_column(ds, column).await?;
|
||||
Ok(indices.iter().any(|index| {
|
||||
index
|
||||
.index_details
|
||||
.as_ref()
|
||||
.map(|details| IndexDetails(details.clone()).supports_fts())
|
||||
.unwrap_or(false)
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn has_vector_index(&self, ds: &Dataset, column: &str) -> Result<bool> {
|
||||
let indices = self.user_indices_for_column(ds, column).await?;
|
||||
Ok(indices.iter().any(|index| {
|
||||
index
|
||||
.index_details
|
||||
.as_ref()
|
||||
.map(|details| IndexDetails(details.clone()).is_vector())
|
||||
.unwrap_or(false)
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn create_btree_index(&self, ds: &mut Dataset, columns: &[&str]) -> Result<()> {
|
||||
let params = ScalarIndexParams::default();
|
||||
ds.create_index_builder(columns, IndexType::BTree, ¶ms)
|
||||
.replace(true)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn create_inverted_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
|
||||
let params = InvertedIndexParams::default();
|
||||
ds.create_index_builder(&[column], IndexType::Inverted, ¶ms)
|
||||
.replace(true)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn create_vector_index(&self, ds: &mut Dataset, column: &str) -> Result<()> {
|
||||
let params = lance::index::vector::VectorIndexParams::ivf_flat(1, MetricType::L2);
|
||||
ds.create_index_builder(&[column], IndexType::Vector, ¶ms)
|
||||
.replace(true)
|
||||
.await
|
||||
.map(|_| ())
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
|
||||
pub async fn create_empty_dataset(dataset_uri: &str, schema: &SchemaRef) -> Result<Dataset> {
|
||||
let batch = RecordBatch::new_empty(schema.clone());
|
||||
Self::write_dataset(dataset_uri, batch).await
|
||||
}
|
||||
|
||||
pub async fn first_row_id_for_filter(&self, ds: &Dataset, filter: &str) -> Result<Option<u64>> {
|
||||
let batches = Self::scan_stream(ds, Some(&["id"]), Some(filter), None, true)
|
||||
.await?
|
||||
.try_collect::<Vec<RecordBatch>>()
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))?;
|
||||
Ok(batches.iter().find_map(|batch| {
|
||||
batch
|
||||
.column_by_name("_rowid")
|
||||
.and_then(|col| col.as_any().downcast_ref::<UInt64Array>())
|
||||
.and_then(|arr| (arr.len() > 0).then(|| arr.value(0)))
|
||||
}))
|
||||
}
|
||||
|
||||
pub async fn write_dataset(dataset_uri: &str, batch: RecordBatch) -> Result<Dataset> {
|
||||
let reader = arrow_array::RecordBatchIterator::new(vec![Ok(batch.clone())], batch.schema());
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
allow_external_blob_outside_bases: true,
|
||||
..Default::default()
|
||||
};
|
||||
Dataset::write(reader, dataset_uri, Some(params))
|
||||
.await
|
||||
.map_err(|e| OmniError::Lance(e.to_string()))
|
||||
}
|
||||
}
|
||||
1481
crates/omnigraph/tests/branching.rs
Normal file
1481
crates/omnigraph/tests/branching.rs
Normal file
File diff suppressed because it is too large
Load diff
677
crates/omnigraph/tests/changes.rs
Normal file
677
crates/omnigraph/tests/changes.rs
Normal file
|
|
@ -0,0 +1,677 @@
|
|||
mod helpers;
|
||||
|
||||
use omnigraph::changes::{ChangeFilter, ChangeOp, EntityKind};
|
||||
use omnigraph::db::commit_graph::CommitGraph;
|
||||
use omnigraph::db::{MergeOutcome, Omnigraph, ReadTarget};
|
||||
|
||||
use helpers::*;
|
||||
|
||||
async fn head_commit_id(uri: &str, branch: Option<&str>) -> String {
|
||||
let commit_graph = match branch {
|
||||
Some(branch) => CommitGraph::open_at_branch(uri, branch).await.unwrap(),
|
||||
None => CommitGraph::open(uri).await.unwrap(),
|
||||
};
|
||||
commit_graph.head_commit_id().await.unwrap().unwrap()
|
||||
}
|
||||
|
||||
fn change_tuples(change_set: &omnigraph::changes::ChangeSet) -> Vec<(String, String, ChangeOp)> {
|
||||
let mut tuples: Vec<_> = change_set
|
||||
.changes
|
||||
.iter()
|
||||
.map(|change| (change.table_key.clone(), change.id.clone(), change.op))
|
||||
.collect();
|
||||
tuples.sort_by(|a, b| {
|
||||
a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)).then_with(|| {
|
||||
let a_op = match a.2 {
|
||||
ChangeOp::Insert => 0,
|
||||
ChangeOp::Update => 1,
|
||||
ChangeOp::Delete => 2,
|
||||
};
|
||||
let b_op = match b.2 {
|
||||
ChangeOp::Insert => 0,
|
||||
ChangeOp::Update => 1,
|
||||
ChangeOp::Delete => 2,
|
||||
};
|
||||
a_op.cmp(&b_op)
|
||||
})
|
||||
});
|
||||
tuples
|
||||
}
|
||||
|
||||
// ─── Same-branch diff tests ────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_empty_when_nothing_changed() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db = init_and_load(&dir).await;
|
||||
let v = snapshot_id(&db, "main").await.unwrap();
|
||||
let cs = db
|
||||
.diff_between(
|
||||
ReadTarget::Snapshot(v.clone()),
|
||||
ReadTarget::Snapshot(v),
|
||||
&ChangeFilter::default(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(cs.changes.is_empty());
|
||||
assert_eq!(cs.stats.inserts, 0);
|
||||
assert_eq!(cs.stats.updates, 0);
|
||||
assert_eq!(cs.stats.deletes, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_detects_node_insert() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = snapshot_id(&db, "main").await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let inserts: Vec<_> = cs
|
||||
.changes
|
||||
.iter()
|
||||
.filter(|c| c.op == ChangeOp::Insert && c.table_key == "node:Person")
|
||||
.collect();
|
||||
assert!(
|
||||
!inserts.is_empty(),
|
||||
"Should detect the Person insert. Got changes: {:?}",
|
||||
cs.changes
|
||||
.iter()
|
||||
.map(|c| (&c.table_key, &c.id, c.op))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
assert!(
|
||||
inserts.iter().any(|c| c.id == "Eve"),
|
||||
"Insert should contain Eve. Got: {:?}",
|
||||
inserts.iter().map(|c| &c.id).collect::<Vec<_>>()
|
||||
);
|
||||
assert_eq!(inserts[0].kind, EntityKind::Node);
|
||||
assert_eq!(inserts[0].endpoints, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_detects_node_update() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = snapshot_id(&db, "main").await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 99)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let updates: Vec<_> = cs
|
||||
.changes
|
||||
.iter()
|
||||
.filter(|c| c.op == ChangeOp::Update && c.table_key == "node:Person")
|
||||
.collect();
|
||||
assert!(
|
||||
!updates.is_empty(),
|
||||
"Should detect the Person update. Got changes: {:?}",
|
||||
cs.changes
|
||||
.iter()
|
||||
.map(|c| (&c.table_key, &c.id, c.op))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_detects_node_delete_with_cascade() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = snapshot_id(&db, "main").await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"remove_person",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should have node:Person delete
|
||||
let person_deletes: Vec<_> = cs
|
||||
.changes
|
||||
.iter()
|
||||
.filter(|c| c.op == ChangeOp::Delete && c.table_key == "node:Person")
|
||||
.collect();
|
||||
assert!(
|
||||
!person_deletes.is_empty(),
|
||||
"Should detect Person delete. Changes: {:?}",
|
||||
cs.changes
|
||||
.iter()
|
||||
.map(|c| (&c.table_key, &c.id, c.op))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
// Should also have edge:Knows cascade deletes
|
||||
let edge_deletes: Vec<_> = cs
|
||||
.changes
|
||||
.iter()
|
||||
.filter(|c| c.op == ChangeOp::Delete && c.table_key == "edge:Knows")
|
||||
.collect();
|
||||
assert!(
|
||||
!edge_deletes.is_empty(),
|
||||
"Should detect cascaded Knows edge deletes. Changes: {:?}",
|
||||
cs.changes
|
||||
.iter()
|
||||
.map(|c| (&c.table_key, &c.id, c.op))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
// Cascaded edge deletes should have endpoints
|
||||
for edge_del in &edge_deletes {
|
||||
assert!(
|
||||
edge_del.endpoints.is_some(),
|
||||
"Deleted edge should have endpoint context"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_detects_edge_insert_with_endpoints() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = snapshot_id(&db, "main").await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Bob"), ("$to", "Charlie")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let cs = diff_since_branch(&db, "main", v_before, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let edge_inserts: Vec<_> = cs
|
||||
.changes
|
||||
.iter()
|
||||
.filter(|c| c.op == ChangeOp::Insert && c.table_key == "edge:Knows")
|
||||
.collect();
|
||||
assert!(
|
||||
!edge_inserts.is_empty(),
|
||||
"Should detect Knows edge insert. Changes: {:?}",
|
||||
cs.changes
|
||||
.iter()
|
||||
.map(|c| (&c.table_key, &c.id, c.op))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
let e = &edge_inserts[0];
|
||||
assert_eq!(e.kind, EntityKind::Edge);
|
||||
let ep = e
|
||||
.endpoints
|
||||
.as_ref()
|
||||
.expect("Edge insert should have endpoints");
|
||||
assert!(!ep.src.is_empty(), "src should not be empty");
|
||||
assert!(!ep.dst.is_empty(), "dst should not be empty");
|
||||
}
|
||||
|
||||
// ─── Filter tests ──────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn filter_by_type_name_skips_non_matching() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = snapshot_id(&db, "main").await.unwrap();
|
||||
|
||||
// Insert a person (node:Person) and add a friend (edge:Knows)
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "FilterTest")], &[("$age", 30)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Filter to Company only — should not see Person changes
|
||||
let filter = ChangeFilter {
|
||||
type_names: Some(vec!["Company".to_string()]),
|
||||
..Default::default()
|
||||
};
|
||||
let cs = diff_since_branch(&db, "main", v_before, &filter)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
cs.changes.is_empty(),
|
||||
"Filter to Company should skip Person changes. Got: {:?}",
|
||||
cs.changes
|
||||
.iter()
|
||||
.map(|c| (&c.table_key, &c.id, c.op))
|
||||
.collect::<Vec<_>>()
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn filter_by_op_skips_unwanted_operations() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = snapshot_id(&db, "main").await.unwrap();
|
||||
|
||||
// Insert Eve, update Bob, delete Alice
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 99)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Filter to Insert only
|
||||
let filter = ChangeFilter {
|
||||
ops: Some(vec![ChangeOp::Insert]),
|
||||
..Default::default()
|
||||
};
|
||||
let cs = diff_since_branch(&db, "main", v_before, &filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should only have inserts, no updates or deletes
|
||||
for c in &cs.changes {
|
||||
assert_eq!(
|
||||
c.op,
|
||||
ChangeOp::Insert,
|
||||
"Filter for Insert-only should not include {:?} for {} ({})",
|
||||
c.op,
|
||||
c.table_key,
|
||||
c.id
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Cross-branch diff tests ──────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_after_merge_reports_actual_changes() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
main.ensure_indices().await.unwrap();
|
||||
let v_before_branch = snapshot_id(&main, "main").await.unwrap();
|
||||
|
||||
main.branch_create("feature").await.unwrap();
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
|
||||
// Main updates Bob
|
||||
mutate_main(
|
||||
&mut main,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 26)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Feature inserts Eve
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let outcome = main.branch_merge("feature", "main").await.unwrap();
|
||||
assert_eq!(outcome, MergeOutcome::Merged);
|
||||
|
||||
// Diff from pre-branch to post-merge on main
|
||||
let cs = diff_since_branch(&main, "main", v_before_branch, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should have:
|
||||
// - Person insert (Eve) — from the merge
|
||||
// - Person update (Bob) — from the main write
|
||||
// Should NOT have: all original persons re-reported as inserts
|
||||
let person_changes: Vec<_> = cs
|
||||
.changes
|
||||
.iter()
|
||||
.filter(|c| c.table_key == "node:Person")
|
||||
.collect();
|
||||
|
||||
let person_inserts: Vec<_> = person_changes
|
||||
.iter()
|
||||
.filter(|c| c.op == ChangeOp::Insert)
|
||||
.collect();
|
||||
let person_updates: Vec<_> = person_changes
|
||||
.iter()
|
||||
.filter(|c| c.op == ChangeOp::Update)
|
||||
.collect();
|
||||
|
||||
// There should be exactly 1 insert (Eve) not all persons
|
||||
assert!(
|
||||
person_inserts.len() <= 2,
|
||||
"After surgical merge, should not re-report all persons as inserts. \
|
||||
Got {} inserts: {:?}",
|
||||
person_inserts.len(),
|
||||
person_inserts.iter().map(|c| &c.id).collect::<Vec<_>>()
|
||||
);
|
||||
|
||||
// Bob's update should be detected
|
||||
assert!(
|
||||
!person_updates.is_empty() || person_inserts.len() > 0,
|
||||
"Should detect Bob's age update or Eve's insert"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_commits_resolves_feature_commit_from_main_handle() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
main.branch_create("feature").await.unwrap();
|
||||
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let main_head = CommitGraph::open(uri)
|
||||
.await
|
||||
.unwrap()
|
||||
.head_commit()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.graph_commit_id;
|
||||
let feature_head = CommitGraph::open_at_branch(uri, "feature")
|
||||
.await
|
||||
.unwrap()
|
||||
.head_commit()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.graph_commit_id;
|
||||
|
||||
let cs = main
|
||||
.diff_commits(&main_head, &feature_head, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(
|
||||
cs.changes
|
||||
.iter()
|
||||
.any(|change| change.op == ChangeOp::Insert && change.id == "Eve"),
|
||||
"expected feature-only insert to be diffable from a main handle"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cross_branch_diff_honors_insert_only_filter() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
main.branch_create("feature").await.unwrap();
|
||||
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let main_head = CommitGraph::open(uri)
|
||||
.await
|
||||
.unwrap()
|
||||
.head_commit()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.graph_commit_id;
|
||||
let feature_head = CommitGraph::open_at_branch(uri, "feature")
|
||||
.await
|
||||
.unwrap()
|
||||
.head_commit()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap()
|
||||
.graph_commit_id;
|
||||
|
||||
let filter = ChangeFilter {
|
||||
ops: Some(vec![ChangeOp::Insert]),
|
||||
..Default::default()
|
||||
};
|
||||
let cs = main
|
||||
.diff_commits(&main_head, &feature_head, &filter)
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(!cs.changes.is_empty());
|
||||
assert!(
|
||||
cs.changes
|
||||
.iter()
|
||||
.all(|change| change.op == ChangeOp::Insert)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_commits_resolves_commits_across_branches_from_any_handle() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
let base_commit = head_commit_id(uri, None).await;
|
||||
|
||||
main.branch_create("feature").await.unwrap();
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let feature_commit = head_commit_id(uri, Some("feature")).await;
|
||||
|
||||
let from_main = main
|
||||
.diff_commits(&base_commit, &feature_commit, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
let from_feature = feature
|
||||
.diff_commits(&base_commit, &feature_commit, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(change_tuples(&from_main), change_tuples(&from_feature));
|
||||
assert!(from_main.changes.iter().any(|change| {
|
||||
change.table_key == "node:Person" && change.id == "Eve" && change.op == ChangeOp::Insert
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cross_lineage_diff_honors_delete_only_filter() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
main.branch_create("feature").await.unwrap();
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
let before = snapshot_id(&feature, "feature").await.unwrap();
|
||||
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 99)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"remove_person",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let filter = ChangeFilter {
|
||||
ops: Some(vec![ChangeOp::Delete]),
|
||||
..Default::default()
|
||||
};
|
||||
let change_set = diff_since_branch(&feature, "feature", before, &filter)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
!change_set.changes.is_empty(),
|
||||
"expected delete changes after removing Alice"
|
||||
);
|
||||
assert!(
|
||||
change_set
|
||||
.changes
|
||||
.iter()
|
||||
.all(|change| change.op == ChangeOp::Delete)
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn same_branch_diff_across_first_lazy_fork_detects_update() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
main.branch_create("feature").await.unwrap();
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
let before = snapshot_id(&feature, "feature").await.unwrap();
|
||||
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 77)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let change_set = diff_since_branch(&feature, "feature", before, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
assert!(change_set.changes.iter().any(|change| {
|
||||
change.table_key == "node:Person" && change.id == "Bob" && change.op == ChangeOp::Update
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_commits_cross_branch_reports_property_only_updates() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
let base_commit = head_commit_id(uri, None).await;
|
||||
|
||||
main.branch_create("feature").await.unwrap();
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 55)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let feature_commit = head_commit_id(uri, Some("feature")).await;
|
||||
|
||||
let change_set = main
|
||||
.diff_commits(&base_commit, &feature_commit, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(change_set.changes.iter().any(|change| {
|
||||
change.table_key == "node:Person" && change.id == "Bob" && change.op == ChangeOp::Update
|
||||
}));
|
||||
assert!(!change_set.changes.iter().any(|change| {
|
||||
change.table_key == "node:Person" && change.id == "Bob" && change.op == ChangeOp::Insert
|
||||
}));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn diff_commits_ignores_row_version_only_differences() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
|
||||
main.branch_create("feature").await.unwrap();
|
||||
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 55)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let feature_commit = head_commit_id(uri, Some("feature")).await;
|
||||
|
||||
mutate_main(
|
||||
&mut main,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 55)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let main_commit = head_commit_id(uri, None).await;
|
||||
|
||||
let change_set = main
|
||||
.diff_commits(&main_commit, &feature_commit, &ChangeFilter::default())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(
|
||||
change_set.changes.is_empty(),
|
||||
"identical user-visible state should not produce diff entries: {:?}",
|
||||
change_set.changes
|
||||
);
|
||||
}
|
||||
574
crates/omnigraph/tests/consistency.rs
Normal file
574
crates/omnigraph/tests/consistency.rs
Normal file
|
|
@ -0,0 +1,574 @@
|
|||
mod helpers;
|
||||
|
||||
use arrow_array::{Array, Date32Array, Int32Array, StringArray};
|
||||
use futures::TryStreamExt;
|
||||
|
||||
use omnigraph::db::Omnigraph;
|
||||
use omnigraph::loader::{LoadMode, load_jsonl};
|
||||
use omnigraph_compiler::ir::ParamMap;
|
||||
use omnigraph_compiler::query::ast::Literal;
|
||||
|
||||
use helpers::*;
|
||||
|
||||
// ─── Snapshot data-level isolation ──────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn snapshot_returns_stale_data_after_write() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Snapshot BEFORE mutation
|
||||
let snap_before = snapshot_main(&db).await.unwrap();
|
||||
|
||||
// Insert a new person
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Snapshot AFTER mutation
|
||||
let snap_after = snapshot_main(&db).await.unwrap();
|
||||
|
||||
// Old snapshot should still see 4 persons
|
||||
let ds_before = snap_before.open("node:Person").await.unwrap();
|
||||
assert_eq!(ds_before.count_rows(None).await.unwrap(), 4);
|
||||
|
||||
// New snapshot should see 5 persons
|
||||
let ds_after = snap_after.open("node:Person").await.unwrap();
|
||||
assert_eq!(ds_after.count_rows(None).await.unwrap(), 5);
|
||||
|
||||
// Verify Eve is NOT in old snapshot's data
|
||||
let batches_before: Vec<arrow_array::RecordBatch> = ds_before
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect()
|
||||
.await
|
||||
.unwrap();
|
||||
let ids_before = collect_column_strings(&batches_before, "id");
|
||||
assert!(!ids_before.contains(&"Eve".to_string()));
|
||||
|
||||
// Verify Eve IS in new snapshot's data
|
||||
let batches_after: Vec<arrow_array::RecordBatch> = ds_after
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect()
|
||||
.await
|
||||
.unwrap();
|
||||
let ids_after = collect_column_strings(&batches_after, "id");
|
||||
assert!(ids_after.contains(&"Eve".to_string()));
|
||||
}
|
||||
|
||||
// ─── LoadMode::Merge ────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn load_merge_upserts_existing_and_inserts_new() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||
|
||||
// Load Alice(30) and Bob(25) via Overwrite
|
||||
let initial = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
|
||||
{"type": "Person", "data": {"name": "Bob", "age": 25}}"#;
|
||||
load_jsonl(&mut db, initial, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(count_rows(&db, "node:Person").await, 2);
|
||||
|
||||
// Merge: Alice updated to age=31, Charlie is new
|
||||
let merge_data = r#"{"type": "Person", "data": {"name": "Alice", "age": 31}}
|
||||
{"type": "Person", "data": {"name": "Charlie", "age": 35}}"#;
|
||||
load_jsonl(&mut db, merge_data, LoadMode::Merge)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Should have 3 persons total (not 4)
|
||||
assert_eq!(count_rows(&db, "node:Person").await, 3);
|
||||
|
||||
// Verify individual values
|
||||
let batches = read_table(&db, "node:Person").await;
|
||||
let batch = &batches[0];
|
||||
let ids = batch
|
||||
.column_by_name("id")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let ages = batch
|
||||
.column_by_name("age")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
|
||||
for i in 0..batch.num_rows() {
|
||||
match ids.value(i) {
|
||||
"Alice" => assert_eq!(ages.value(i), 31, "Alice should be updated to 31"),
|
||||
"Bob" => assert_eq!(ages.value(i), 25, "Bob should be unchanged"),
|
||||
"Charlie" => assert_eq!(ages.value(i), 35, "Charlie should be inserted"),
|
||||
other => panic!("unexpected person: {}", other),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cross_type_traversal_deduplicates_duplicate_edges() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let schema = r#"
|
||||
node Person { name: String @key }
|
||||
node Company { name: String @key }
|
||||
edge WorksAt: Person -> Company
|
||||
"#;
|
||||
let data = r#"{"type":"Person","data":{"name":"Alice"}}
|
||||
{"type":"Company","data":{"name":"Acme"}}
|
||||
{"edge":"WorksAt","from":"Alice","to":"Acme"}
|
||||
{"edge":"WorksAt","from":"Alice","to":"Acme"}"#;
|
||||
let query = r#"
|
||||
query company($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
$p worksAt $c
|
||||
}
|
||||
return { $c.name }
|
||||
}
|
||||
"#;
|
||||
|
||||
let mut db = Omnigraph::init(uri, schema).await.unwrap();
|
||||
load_jsonl(&mut db, data, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = query_main(&mut db, query, "company", ¶ms(&[("$name", "Alice")]))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.num_rows(), 1);
|
||||
}
|
||||
|
||||
// ─── Multi-writer refresh ───────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn explicit_target_query_sees_other_writer_commits_without_refresh() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let _db = init_and_load(&dir).await;
|
||||
drop(_db);
|
||||
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
|
||||
// Two independent handles to the same repo
|
||||
let mut db1 = Omnigraph::open(uri).await.unwrap();
|
||||
let mut db2 = Omnigraph::open(uri).await.unwrap();
|
||||
|
||||
// Writer 1 inserts Eve
|
||||
mutate_main(
|
||||
&mut db1,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Explicit-target reads resolve the latest branch head and should see Eve
|
||||
let qr = query_main(
|
||||
&mut db2,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(qr.num_rows(), 1, "explicit target reads should see Eve");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn explicit_target_query_rebuilds_graph_index_after_external_edge_write() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let _db = init_and_load(&dir).await;
|
||||
drop(_db);
|
||||
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db1 = Omnigraph::open(uri).await.unwrap();
|
||||
let mut db2 = Omnigraph::open(uri).await.unwrap();
|
||||
|
||||
let warm = query_main(
|
||||
&mut db2,
|
||||
TEST_QUERIES,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(warm.num_rows(), 2);
|
||||
|
||||
mutate_main(
|
||||
&mut db1,
|
||||
MUTATION_QUERIES,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Alice"), ("$to", "Diana")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let refreshed = query_main(
|
||||
&mut db2,
|
||||
TEST_QUERIES,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
refreshed.num_rows(),
|
||||
3,
|
||||
"explicit target reads should rebuild topology after edge change"
|
||||
);
|
||||
|
||||
let batch = refreshed.concat_batches().unwrap();
|
||||
let names = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let values: Vec<&str> = (0..names.len()).map(|i| names.value(i)).collect();
|
||||
assert!(values.contains(&"Bob"));
|
||||
assert!(values.contains(&"Diana"));
|
||||
}
|
||||
|
||||
// ─── Null handling ──────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn null_values_in_filter_and_projection() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||
|
||||
// Load data: Alice has age, Bob has null age, Charlie has age
|
||||
let data = r#"{"type": "Person", "data": {"name": "Alice", "age": 30}}
|
||||
{"type": "Person", "data": {"name": "Bob"}}
|
||||
{"type": "Person", "data": {"name": "Charlie", "age": 35}}"#;
|
||||
load_jsonl(&mut db, data, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Filter: age > 30 should exclude Bob (null) and Alice (30), keep Charlie (35)
|
||||
let queries = r#"
|
||||
query older_than_30() {
|
||||
match {
|
||||
$p: Person
|
||||
$p.age > 30
|
||||
}
|
||||
return { $p.name, $p.age }
|
||||
order { $p.age desc }
|
||||
}
|
||||
|
||||
query all_persons() {
|
||||
match { $p: Person }
|
||||
return { $p.name, $p.age }
|
||||
order { $p.age desc }
|
||||
}
|
||||
"#;
|
||||
|
||||
let result = query_main(&mut db, queries, "older_than_30", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.num_rows(), 1);
|
||||
let batch = &result.batches()[0];
|
||||
let names = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
assert_eq!(names.value(0), "Charlie");
|
||||
|
||||
// Projection: Bob's age should be null
|
||||
let all = query_main(&mut db, queries, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let batch = &all.batches()[0];
|
||||
let ids = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let ages = batch
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
|
||||
for i in 0..batch.num_rows() {
|
||||
if ids.value(i) == "Bob" {
|
||||
assert!(ages.is_null(i), "Bob's age should be null");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// ─── Graph index after node+edge insert ─────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn traversal_works_after_node_then_edge_insert() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Warm up the graph index cache by running a traversal
|
||||
let _ = query_main(
|
||||
&mut db,
|
||||
TEST_QUERIES,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert a new node (does NOT invalidate graph index)
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Frank")], &[("$age", 40)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert an edge from Frank → Alice (DOES invalidate graph index)
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Frank"), ("$to", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Traversal should work: Frank → Alice
|
||||
let result = query_main(
|
||||
&mut db,
|
||||
TEST_QUERIES,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Frank")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.num_rows(), 1);
|
||||
let batch = result.concat_batches().unwrap();
|
||||
let names = batch
|
||||
.column(0)
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
assert_eq!(names.value(0), "Alice");
|
||||
}
|
||||
|
||||
// ─── Edge property insert ───────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn insert_edge_with_property() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// Knows has `since: Date?` property
|
||||
let queries = r#"
|
||||
query add_friend_since($from: String, $to: String, $since: Date) {
|
||||
insert Knows { from: $from, to: $to, since: $since }
|
||||
}
|
||||
"#;
|
||||
let mut p = params(&[("$from", "Diana"), ("$to", "Bob")]);
|
||||
p.insert("since".to_string(), Literal::Date("2024-06-15".to_string()));
|
||||
|
||||
let result = mutate_main(&mut db, queries, "add_friend_since", &p)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.affected_edges, 1);
|
||||
|
||||
// Verify the edge property was stored
|
||||
let batches = read_table(&db, "edge:Knows").await;
|
||||
let mut found = false;
|
||||
for batch in &batches {
|
||||
let srcs = batch
|
||||
.column_by_name("src")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let dsts = batch
|
||||
.column_by_name("dst")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let since = batch
|
||||
.column_by_name("since")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Date32Array>()
|
||||
.unwrap();
|
||||
for i in 0..batch.num_rows() {
|
||||
if srcs.value(i) == "Diana" && dsts.value(i) == "Bob" {
|
||||
assert!(!since.is_null(i), "since should not be null");
|
||||
found = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
assert!(found, "should find Diana→Bob edge");
|
||||
}
|
||||
|
||||
// ─── Update / delete no-match ───────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn update_nonexistent_returns_zero_affected() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
let result = mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Nobody")], &[("$age", 99)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.affected_nodes, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn delete_nonexistent_returns_zero_affected() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
let result = mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"remove_person",
|
||||
¶ms(&[("$name", "Nobody")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(result.affected_nodes, 0);
|
||||
assert_eq!(result.affected_edges, 0);
|
||||
|
||||
// All 4 persons still intact
|
||||
assert_eq!(count_rows(&db, "node:Person").await, 4);
|
||||
}
|
||||
|
||||
// ─── Large batch load ───────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn large_batch_load_and_query() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
|
||||
let schema = r#"
|
||||
node Item {
|
||||
name: String @key
|
||||
value: I32
|
||||
}
|
||||
"#;
|
||||
let mut db = Omnigraph::init(uri, schema).await.unwrap();
|
||||
|
||||
// Generate 500 items
|
||||
let mut lines = Vec::with_capacity(500);
|
||||
for i in 0..500 {
|
||||
lines.push(format!(
|
||||
r#"{{"type": "Item", "data": {{"name": "item_{:04}", "value": {}}}}}"#,
|
||||
i, i
|
||||
));
|
||||
}
|
||||
let data = lines.join("\n");
|
||||
load_jsonl(&mut db, &data, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(count_rows(&db, "node:Item").await, 500);
|
||||
|
||||
// Query with filter — value > 490
|
||||
let queries = r#"
|
||||
query high_value() {
|
||||
match {
|
||||
$i: Item
|
||||
$i.value > 490
|
||||
}
|
||||
return { $i.name, $i.value }
|
||||
order { $i.value asc }
|
||||
}
|
||||
"#;
|
||||
let result = query_main(&mut db, queries, "high_value", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Items 491..499 = 9 items
|
||||
assert_eq!(result.num_rows(), 9);
|
||||
let batch = &result.batches()[0];
|
||||
let values = batch
|
||||
.column(1)
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
assert_eq!(values.value(0), 491);
|
||||
assert_eq!(values.value(8), 499);
|
||||
}
|
||||
|
||||
// ─── Regression: public mutation on stale handle still applies to latest head ──────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn stale_handle_public_mutation_uses_latest_target_head() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let _db = init_and_load(&dir).await;
|
||||
drop(_db);
|
||||
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db1 = Omnigraph::open(uri).await.unwrap();
|
||||
let mut db2 = Omnigraph::open(uri).await.unwrap();
|
||||
|
||||
// Writer 1 inserts — advances the Person sub-table version
|
||||
mutate_main(
|
||||
&mut db1,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Writer 2 (stale) mutates through the public transactional path.
|
||||
// It should stage from the latest target head rather than replaying a stale write.
|
||||
mutate_main(
|
||||
&mut db2,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Alice")], &[("$age", 99)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let result = query_main(
|
||||
&mut db2,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.num_rows(), 1);
|
||||
assert_eq!(result.to_rust_json()[0]["p.age"], serde_json::json!(99));
|
||||
|
||||
let eve = query_main(
|
||||
&mut db2,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(eve.num_rows(), 1, "concurrent insert should be preserved");
|
||||
}
|
||||
1831
crates/omnigraph/tests/end_to_end.rs
Normal file
1831
crates/omnigraph/tests/end_to_end.rs
Normal file
File diff suppressed because it is too large
Load diff
183
crates/omnigraph/tests/export.rs
Normal file
183
crates/omnigraph/tests/export.rs
Normal file
|
|
@ -0,0 +1,183 @@
|
|||
mod helpers;
|
||||
|
||||
use arrow_array::{Array, StringArray};
|
||||
|
||||
use omnigraph::db::{Omnigraph, ReadTarget};
|
||||
use omnigraph::loader::{LoadMode, load_jsonl};
|
||||
|
||||
use helpers::*;
|
||||
|
||||
const EXPORT_MUTATIONS: &str = r#"
|
||||
query insert_person($name: String, $age: I32) {
|
||||
insert Person { name: $name, age: $age }
|
||||
}
|
||||
|
||||
query add_friend($from: String, $to: String) {
|
||||
insert Knows { from: $from, to: $to }
|
||||
}
|
||||
"#;
|
||||
|
||||
const NOTE_SCHEMA: &str = r#"
|
||||
node Note {
|
||||
text: String
|
||||
}
|
||||
|
||||
edge References: Note -> Note
|
||||
"#;
|
||||
|
||||
const NOTE_DATA: &str = r#"
|
||||
{"type":"Note","data":{"id":"note-1","text":"Alpha"}}
|
||||
{"type":"Note","data":{"id":"note-2","text":"Beta"}}
|
||||
{"edge":"References","from":"note-1","to":"note-2","data":{"id":"edge-1"}}
|
||||
"#;
|
||||
|
||||
#[tokio::test]
|
||||
async fn export_jsonl_round_trips_branch_snapshot() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
db.branch_create_from(ReadTarget::branch("main"), "feature")
|
||||
.await
|
||||
.unwrap();
|
||||
db.mutate(
|
||||
"feature",
|
||||
EXPORT_MUTATIONS,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 29)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.mutate(
|
||||
"feature",
|
||||
EXPORT_MUTATIONS,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Eve"), ("$to", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let main_jsonl = db.export_jsonl("main", &[], &[]).await.unwrap();
|
||||
let feature_jsonl = db.export_jsonl("feature", &[], &[]).await.unwrap();
|
||||
|
||||
let imported_main_dir = tempfile::tempdir().unwrap();
|
||||
let imported_feature_dir = tempfile::tempdir().unwrap();
|
||||
let mut imported_main =
|
||||
Omnigraph::init(imported_main_dir.path().to_str().unwrap(), TEST_SCHEMA)
|
||||
.await
|
||||
.unwrap();
|
||||
let mut imported_feature =
|
||||
Omnigraph::init(imported_feature_dir.path().to_str().unwrap(), TEST_SCHEMA)
|
||||
.await
|
||||
.unwrap();
|
||||
load_jsonl(&mut imported_main, &main_jsonl, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
load_jsonl(&mut imported_feature, &feature_jsonl, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(count_rows(&db, "node:Person").await, 4);
|
||||
assert_eq!(count_rows_branch(&db, "feature", "node:Person").await, 5);
|
||||
assert_eq!(count_rows(&imported_main, "node:Person").await, 4);
|
||||
assert_eq!(count_rows(&imported_feature, "node:Person").await, 5);
|
||||
assert_eq!(count_rows(&imported_main, "edge:Knows").await, 3);
|
||||
assert_eq!(count_rows(&imported_feature, "edge:Knows").await, 4);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn export_jsonl_preserves_explicit_ids_for_non_key_graphs() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = Omnigraph::init(dir.path().to_str().unwrap(), NOTE_SCHEMA)
|
||||
.await
|
||||
.unwrap();
|
||||
load_jsonl(&mut db, NOTE_DATA, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let exported = db.export_jsonl("main", &[], &[]).await.unwrap();
|
||||
|
||||
let imported_dir = tempfile::tempdir().unwrap();
|
||||
let mut imported = Omnigraph::init(imported_dir.path().to_str().unwrap(), NOTE_SCHEMA)
|
||||
.await
|
||||
.unwrap();
|
||||
load_jsonl(&mut imported, &exported, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let node_batches = read_table(&imported, "node:Note").await;
|
||||
let node_ids = collect_column_strings(&node_batches, "id");
|
||||
assert_eq!(node_ids, vec!["note-1".to_string(), "note-2".to_string()]);
|
||||
|
||||
let edge_batches = read_table(&imported, "edge:References").await;
|
||||
let edge_ids = collect_column_strings(&edge_batches, "id");
|
||||
assert_eq!(edge_ids, vec!["edge-1".to_string()]);
|
||||
|
||||
let srcs = edge_batches[0]
|
||||
.column_by_name("src")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let dsts = edge_batches[0]
|
||||
.column_by_name("dst")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
assert_eq!(srcs.value(0), "note-1");
|
||||
assert_eq!(dsts.value(0), "note-2");
|
||||
}
|
||||
|
||||
// ─── Regression: export with blob columns ────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn export_jsonl_with_blob_type() {
|
||||
// Regression: export on types with blob columns failed with
|
||||
// "Schema error: Can not append column _rowaddr on schema" because
|
||||
// Lance 4's take_blobs duplicated _rowaddr on the unsorted path.
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
|
||||
const BLOB_SCHEMA: &str = r#"
|
||||
node Document {
|
||||
title: String @key
|
||||
content: Blob?
|
||||
}
|
||||
"#;
|
||||
|
||||
let mut db = Omnigraph::init(uri, BLOB_SCHEMA).await.unwrap();
|
||||
let data = concat!(
|
||||
"{\"type\": \"Document\", \"data\": {\"title\": \"readme\", \"content\": \"base64:SGVsbG8=\"}}\n",
|
||||
"{\"type\": \"Document\", \"data\": {\"title\": \"empty\"}}\n",
|
||||
);
|
||||
load_jsonl(&mut db, data, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Export should succeed
|
||||
let exported = db.export_jsonl("main", &[], &[]).await.unwrap();
|
||||
assert!(
|
||||
exported.contains("readme"),
|
||||
"export should contain readme doc"
|
||||
);
|
||||
|
||||
// Verify blob value is in the export
|
||||
assert!(
|
||||
exported.contains("base64:") || exported.contains("SGVsbG8"),
|
||||
"export should contain blob data as base64"
|
||||
);
|
||||
|
||||
// Round-trip: re-import and verify blob data survives
|
||||
let imported_dir = tempfile::tempdir().unwrap();
|
||||
let imported_uri = imported_dir.path().to_str().unwrap();
|
||||
let mut imported = Omnigraph::init(imported_uri, BLOB_SCHEMA).await.unwrap();
|
||||
load_jsonl(&mut imported, &exported, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let blob = imported
|
||||
.read_blob("Document", "readme", "content")
|
||||
.await
|
||||
.unwrap();
|
||||
let bytes = blob.read().await.unwrap();
|
||||
assert_eq!(&bytes[..], b"Hello");
|
||||
}
|
||||
47
crates/omnigraph/tests/failpoints.rs
Normal file
47
crates/omnigraph/tests/failpoints.rs
Normal file
|
|
@ -0,0 +1,47 @@
|
|||
#![cfg(feature = "failpoints")]
|
||||
|
||||
mod helpers;
|
||||
|
||||
use fail::FailScenario;
|
||||
use omnigraph::db::Omnigraph;
|
||||
use omnigraph::failpoints::ScopedFailPoint;
|
||||
|
||||
use helpers::{MUTATION_QUERIES, mixed_params};
|
||||
|
||||
#[tokio::test]
|
||||
async fn branch_create_failpoint_triggers() {
|
||||
let _scenario = FailScenario::setup();
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, helpers::TEST_SCHEMA).await.unwrap();
|
||||
let _failpoint = ScopedFailPoint::new("branch_create.after_manifest_branch_create", "return");
|
||||
|
||||
let err = db.branch_create("feature").await.unwrap_err();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("injected failpoint triggered: branch_create.after_manifest_branch_create")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn graph_publish_failpoint_triggers_before_commit_append() {
|
||||
let _scenario = FailScenario::setup();
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = Omnigraph::init(dir.path().to_str().unwrap(), helpers::TEST_SCHEMA)
|
||||
.await
|
||||
.unwrap();
|
||||
let _failpoint = ScopedFailPoint::new("graph_publish.before_commit_append", "return");
|
||||
|
||||
let err = mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
assert!(
|
||||
err.to_string()
|
||||
.contains("injected failpoint triggered: graph_publish.before_commit_append")
|
||||
);
|
||||
}
|
||||
13
crates/omnigraph/tests/fixtures/context.jsonl
vendored
Normal file
13
crates/omnigraph/tests/fixtures/context.jsonl
vendored
Normal file
|
|
@ -0,0 +1,13 @@
|
|||
{"type": "Actor", "data": {"slug": "aaron", "name": "Aaron"}}
|
||||
{"type": "Actor", "data": {"slug": "bruno", "name": "Bruno"}}
|
||||
{"type": "Actor", "data": {"slug": "jorge", "name": "Jorge"}}
|
||||
{"type": "Actor", "data": {"slug": "muneeb", "name": "Muneeb"}}
|
||||
{"type": "Actor", "data": {"slug": "ragnor", "name": "Ragnor"}}
|
||||
{"type": "Actor", "data": {"slug": "andrew", "name": "Andrew"}}
|
||||
{"type": "Signal", "data": {"slug": "zylon-private-ai-platform", "title": "Zylon.ai positions as complete on-premise enterprise AI platform for regulated industries", "body": "Zylon.ai positions itself as a complete, fully private (100% on-premise) enterprise AI platform built explicitly for regulated industries, emphasizing air-gapped deployability, data sovereignty (no external cloud dependency), and predictable fixed-cost economics (no per-token pricing).", "category": "competitor", "strength": "strong", "observed_at": "2026-03-27", "source": "https://www.zylon.ai/"}}
|
||||
{"type": "Decision", "data": {"slug": "create-360-ai-infra", "title": "Create 360 AI Infra offering", "body": "Build a comprehensive 360-degree AI infrastructure offering in response to competitors like Zylon positioning complete on-premise AI platforms for regulated industries.", "status": "proposed", "urgency": "high", "decided_at": "2026-03-27"}}
|
||||
{"type": "Trace", "data": {"slug": "jorge-spots-zylon", "title": "Jorge spots Zylon.ai competitor signal", "body": "Jorge identified Zylon.ai as a new competitor positioning a fully private enterprise AI platform targeting regulated industries with air-gapped deployment and fixed-cost pricing.", "kind": "note", "recorded_at": "2026-03-27", "source": "https://www.zylon.ai/"}}
|
||||
{"edge": "OwnedBy", "from": "create-360-ai-infra", "to": "andrew"}
|
||||
{"edge": "RecordedBy", "from": "jorge-spots-zylon", "to": "jorge"}
|
||||
{"edge": "Triggered", "from": "zylon-private-ai-platform", "to": "create-360-ai-infra"}
|
||||
{"edge": "Supports", "from": "jorge-spots-zylon", "to": "create-360-ai-infra"}
|
||||
78
crates/omnigraph/tests/fixtures/context.pg
vendored
Normal file
78
crates/omnigraph/tests/fixtures/context.pg
vendored
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
// Context graph: decisions, the people behind them,
|
||||
// the evidence trail, and market signals that inform them.
|
||||
|
||||
// ── Nodes ────────────────────────────────────────────
|
||||
|
||||
node Actor {
|
||||
slug: String @key
|
||||
name: String
|
||||
email: String? @unique
|
||||
}
|
||||
|
||||
node Decision {
|
||||
slug: String @key
|
||||
title: String @index
|
||||
body: String?
|
||||
status: enum(proposed, accepted, rejected, superseded)
|
||||
urgency: enum(low, normal, high, critical)
|
||||
decided_at: Date?
|
||||
}
|
||||
|
||||
node Trace {
|
||||
slug: String @key
|
||||
title: String @index
|
||||
body: String?
|
||||
kind: enum(note, discussion, experiment, review, meeting, document)
|
||||
recorded_at: Date
|
||||
source: String?
|
||||
}
|
||||
|
||||
node Signal {
|
||||
slug: String @key
|
||||
title: String @index
|
||||
body: String?
|
||||
category: enum(competitor, market, regulatory, technology, customer)
|
||||
strength: enum(strong, moderate, weak)
|
||||
observed_at: Date
|
||||
source: String?
|
||||
}
|
||||
|
||||
node Artifact {
|
||||
slug: String @key
|
||||
title: String @index
|
||||
kind: enum(doc, presentation, proposal, spec, report, memo)
|
||||
url: String?
|
||||
created_at: Date
|
||||
}
|
||||
|
||||
// ── Ownership / participation ────────────────────────
|
||||
|
||||
edge OwnedBy: Decision -> Actor @card(1..1)
|
||||
|
||||
edge ParticipatedIn: Actor -> Decision
|
||||
|
||||
edge RecordedBy: Trace -> Actor @card(1..1)
|
||||
|
||||
edge AuthoredBy: Artifact -> Actor @card(1..1)
|
||||
|
||||
// ── Evidence trail ───────────────────────────────────
|
||||
|
||||
edge Supports: Trace -> Decision
|
||||
|
||||
edge Attached: Artifact -> Decision
|
||||
|
||||
edge CitedIn: Artifact -> Trace
|
||||
|
||||
// ── Signal linkage ───────────────────────────────────
|
||||
|
||||
edge Triggered: Signal -> Decision
|
||||
|
||||
edge Correlates: Signal -> Signal {
|
||||
@unique(src, dst)
|
||||
}
|
||||
|
||||
// ── Decision lineage ─────────────────────────────────
|
||||
|
||||
edge Supersedes: Decision -> Decision {
|
||||
@unique(src, dst)
|
||||
}
|
||||
48
crates/omnigraph/tests/fixtures/revops_large_signal.md
vendored
Normal file
48
crates/omnigraph/tests/fixtures/revops_large_signal.md
vendored
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
# Enterprise Procurement Risk Memo
|
||||
|
||||
## Situation
|
||||
The buyer entered procurement for annual renewal and asked for a bundled proposal that combines platform licensing, implementation support, and SLA uplift.
|
||||
Legal requested two rounds of redlines and now requires explicit language for data residency, deletion timelines, and subprocessors.
|
||||
Security asked for the full questionnaire, pen-test summary, SOC evidence package, and a named escalation owner for incident response coordination.
|
||||
Finance requested price hold terms through quarter close and requires a clean net amount with no conditional side letters.
|
||||
|
||||
## Current Friction
|
||||
The commercial owner reports that each team is operating on a different timeline.
|
||||
Procurement prefers a single consolidated response packet, but legal and security are still updating separate drafts.
|
||||
The buyer champion is supportive but cannot route final approval until the redline and security sections are complete.
|
||||
Two approvers are out next week, which introduces a calendar risk for final signoff.
|
||||
|
||||
## Evidence From Recent Calls
|
||||
- Buyer said the risk is not product fit, it is internal process load.
|
||||
- Procurement requested one owner for all responses to avoid thread drift.
|
||||
- Legal asked for an explicit breach notification interval in the MSA.
|
||||
- Security flagged third-party dependency disclosure as incomplete.
|
||||
- Finance asked for forecast certainty before they release PO authority.
|
||||
|
||||
## Operational Notes
|
||||
1. The account team should treat this as a coordination problem, not a persuasion problem.
|
||||
2. Every open item needs owner, due date, and blocking dependency.
|
||||
3. Replies should be centralized in one tracker to avoid inconsistent statements.
|
||||
4. Escalation should happen early when legal language depends on security attestations.
|
||||
5. The champion should get a concise status summary after each workday.
|
||||
|
||||
## Risk Register
|
||||
- **Timeline risk:** medium-high due to calendar compression and approver availability.
|
||||
- **Compliance risk:** medium due to unresolved security questionnaire fields.
|
||||
- **Commercial risk:** medium because procurement is requesting fixed pricing through quarter end.
|
||||
- **Execution risk:** high if response ownership remains fragmented across teams.
|
||||
|
||||
## Recommended Plan
|
||||
Create a single response packet and assign one coordinator.
|
||||
Pre-fill all known legal and security answers from existing templates.
|
||||
Schedule a thirty-minute cross-functional triage with legal, security, and sales operations.
|
||||
Lock a daily cutoff time for updates and send one canonical status note to stakeholders.
|
||||
Escalate unresolved blockers to leadership forty-eight hours before target sign date.
|
||||
|
||||
## Success Criteria
|
||||
- Security questionnaire submitted with no unresolved critical fields.
|
||||
- Redline package accepted or narrowed to non-blocking items.
|
||||
- Pricing terms approved by finance for the requested window.
|
||||
- Purchase order process initiated before the internal close date.
|
||||
- Owner confirms all blocker tickets are either resolved or explicitly waived.
|
||||
|
||||
44
crates/omnigraph/tests/fixtures/search.gq
vendored
Normal file
44
crates/omnigraph/tests/fixtures/search.gq
vendored
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
query text_search($q: String) {
|
||||
match {
|
||||
$d: Doc
|
||||
search($d.title, $q)
|
||||
}
|
||||
return { $d.slug, $d.title }
|
||||
}
|
||||
|
||||
query fuzzy_search($q: String) {
|
||||
match {
|
||||
$d: Doc
|
||||
fuzzy($d.title, $q, 2)
|
||||
}
|
||||
return { $d.slug, $d.title }
|
||||
}
|
||||
|
||||
query phrase_search($q: String) {
|
||||
match {
|
||||
$d: Doc
|
||||
match_text($d.body, $q)
|
||||
}
|
||||
return { $d.slug, $d.title }
|
||||
}
|
||||
|
||||
query vector_search($q: Vector(4)) {
|
||||
match { $d: Doc }
|
||||
return { $d.slug, $d.title }
|
||||
order { nearest($d.embedding, $q) }
|
||||
limit 3
|
||||
}
|
||||
|
||||
query bm25_search($q: String) {
|
||||
match { $d: Doc }
|
||||
return { $d.slug, $d.title }
|
||||
order { bm25($d.title, $q) }
|
||||
limit 3
|
||||
}
|
||||
|
||||
query hybrid_search($vq: Vector(4), $tq: String) {
|
||||
match { $d: Doc }
|
||||
return { $d.slug, $d.title }
|
||||
order { rrf(nearest($d.embedding, $vq), bm25($d.title, $tq)) }
|
||||
limit 3
|
||||
}
|
||||
5
crates/omnigraph/tests/fixtures/search.jsonl
vendored
Normal file
5
crates/omnigraph/tests/fixtures/search.jsonl
vendored
Normal file
|
|
@ -0,0 +1,5 @@
|
|||
{"type": "Doc", "data": {"slug": "ml-intro", "title": "Introduction to Machine Learning", "body": "Machine learning is a subset of artificial intelligence that focuses on algorithms", "embedding": [0.1, 0.2, 0.3, 0.4]}}
|
||||
{"type": "Doc", "data": {"slug": "dl-basics", "title": "Deep Learning Basics", "body": "Deep learning uses neural networks with many layers to learn representations", "embedding": [0.5, 0.6, 0.7, 0.8]}}
|
||||
{"type": "Doc", "data": {"slug": "nlp-guide", "title": "Natural Language Processing Guide", "body": "NLP applies machine learning to understand and generate human language", "embedding": [0.2, 0.3, 0.4, 0.5]}}
|
||||
{"type": "Doc", "data": {"slug": "cv-overview", "title": "Computer Vision Overview", "body": "Computer vision enables machines to interpret visual information from images", "embedding": [0.8, 0.7, 0.6, 0.5]}}
|
||||
{"type": "Doc", "data": {"slug": "rl-intro", "title": "Reinforcement Learning Introduction", "body": "Reinforcement learning trains agents through reward and punishment signals", "embedding": [0.3, 0.4, 0.5, 0.6]}}
|
||||
6
crates/omnigraph/tests/fixtures/search.pg
vendored
Normal file
6
crates/omnigraph/tests/fixtures/search.pg
vendored
Normal file
|
|
@ -0,0 +1,6 @@
|
|||
node Doc {
|
||||
slug: String @key
|
||||
title: String @index
|
||||
body: String @index
|
||||
embedding: Vector(4)
|
||||
}
|
||||
46
crates/omnigraph/tests/fixtures/signals.jsonl
vendored
Normal file
46
crates/omnigraph/tests/fixtures/signals.jsonl
vendored
Normal file
|
|
@ -0,0 +1,46 @@
|
|||
{"type": "Company", "data": {"slug": "aws", "name": "AWS", "sector": "hyperscaler"}}
|
||||
{"type": "Company", "data": {"slug": "cerebras", "name": "Cerebras", "sector": "chipmaker"}}
|
||||
{"type": "Company", "data": {"slug": "vast", "name": "VAST Data", "sector": "startup"}}
|
||||
{"type": "Company", "data": {"slug": "oracle", "name": "Oracle", "sector": "hyperscaler"}}
|
||||
{"type": "Company", "data": {"slug": "benchmark", "name": "Benchmark", "sector": "investor"}}
|
||||
{"type": "Company", "data": {"slug": "xai", "name": "xAI", "sector": "lab"}}
|
||||
{"type": "Company", "data": {"slug": "openai", "name": "OpenAI", "sector": "lab"}}
|
||||
{"type": "Company", "data": {"slug": "anthropic", "name": "Anthropic", "sector": "lab"}}
|
||||
{"type": "Company", "data": {"slug": "nvidia", "name": "NVIDIA", "sector": "chipmaker"}}
|
||||
{"type": "Tech", "data": {"slug": "cs3", "name": "CS-3", "kind": "infra", "tier": "growth"}}
|
||||
{"type": "Tech", "data": {"slug": "trainium", "name": "Trainium", "kind": "infra", "tier": "growth"}}
|
||||
{"type": "Tech", "data": {"slug": "grok", "name": "Grok", "kind": "model", "tier": "emerging"}}
|
||||
{"type": "Tech", "data": {"slug": "vast-ai-os", "name": "VAST AI OS", "kind": "platform", "tier": "growth"}}
|
||||
{"type": "Signal", "data": {"slug": "aws-cerebras-inference", "title": "AWS and Cerebras collaborate on disaggregated inference: Trainium for prefill, CS-3 for decode, 5x capacity", "source": "press-release", "strength": "strong", "observed": "2026-03-13"}}
|
||||
{"type": "Signal", "data": {"slug": "vast-1b-raise", "title": "VAST Data raises $1B at $30B, unveils AI OS bundling storage, compute, and agent runtimes into one stack", "source": "funding-round", "strength": "strong", "observed": "2026-03-12"}}
|
||||
{"type": "Signal", "data": {"slug": "oracle-cerebras-mention", "title": "Oracle names Cerebras alongside NVIDIA and AMD as enterprise AI chip option", "source": "earnings-call", "strength": "moderate", "observed": "2026-03-10"}}
|
||||
{"type": "Signal", "data": {"slug": "cerebras-23b-round", "title": "Cerebras valued at $23B after $1B round; Benchmark raises $225M in special vehicles to double down", "source": "funding-round", "strength": "strong", "observed": "2026-02-04"}}
|
||||
{"type": "Signal", "data": {"slug": "xai-field-engineers", "title": "xAI deploys engineers on-site at enterprise clients to win deals from OpenAI and Anthropic", "source": "reporting", "strength": "strong", "observed": "2026-03-20"}}
|
||||
{"type": "Pattern", "data": {"slug": "rise-of-fde", "name": "Rise of FDE", "category": "expansion"}}
|
||||
{"type": "Pattern", "data": {"slug": "alt-chip-breakout", "name": "Alt-chip breakout", "category": "adoption"}}
|
||||
{"type": "Pattern", "data": {"slug": "stack-collapse", "name": "Stack collapse", "category": "convergence"}}
|
||||
{"type": "Pattern", "data": {"slug": "inference-specialization", "name": "Inference specialization", "category": "convergence"}}
|
||||
{"edge": "Builds", "from": "cerebras", "to": "cs3"}
|
||||
{"edge": "Builds", "from": "aws", "to": "trainium"}
|
||||
{"edge": "Builds", "from": "xai", "to": "grok"}
|
||||
{"edge": "Builds", "from": "vast", "to": "vast-ai-os"}
|
||||
{"edge": "FundedBy", "from": "cerebras", "to": "benchmark", "data": {"amount": "$225M"}}
|
||||
{"edge": "PartnersWith", "from": "aws", "to": "cerebras"}
|
||||
{"edge": "PartnersWith", "from": "cerebras", "to": "openai"}
|
||||
{"edge": "Mentions", "from": "aws-cerebras-inference", "to": "cs3"}
|
||||
{"edge": "Mentions", "from": "aws-cerebras-inference", "to": "trainium"}
|
||||
{"edge": "Mentions", "from": "vast-1b-raise", "to": "vast-ai-os"}
|
||||
{"edge": "Mentions", "from": "cerebras-23b-round", "to": "cs3"}
|
||||
{"edge": "Mentions", "from": "xai-field-engineers", "to": "grok"}
|
||||
{"edge": "Indicates", "from": "aws-cerebras-inference", "to": "alt-chip-breakout"}
|
||||
{"edge": "Indicates", "from": "aws-cerebras-inference", "to": "inference-specialization"}
|
||||
{"edge": "Indicates", "from": "vast-1b-raise", "to": "stack-collapse"}
|
||||
{"edge": "Indicates", "from": "oracle-cerebras-mention", "to": "alt-chip-breakout"}
|
||||
{"edge": "Indicates", "from": "cerebras-23b-round", "to": "alt-chip-breakout"}
|
||||
{"edge": "Indicates", "from": "xai-field-engineers", "to": "rise-of-fde"}
|
||||
{"edge": "Involves", "from": "rise-of-fde", "to": "grok"}
|
||||
{"edge": "Involves", "from": "alt-chip-breakout", "to": "cs3"}
|
||||
{"edge": "Involves", "from": "alt-chip-breakout", "to": "trainium"}
|
||||
{"edge": "Involves", "from": "stack-collapse", "to": "vast-ai-os"}
|
||||
{"edge": "Involves", "from": "inference-specialization", "to": "cs3"}
|
||||
{"edge": "Involves", "from": "inference-specialization", "to": "trainium"}
|
||||
44
crates/omnigraph/tests/fixtures/signals.pg
vendored
Normal file
44
crates/omnigraph/tests/fixtures/signals.pg
vendored
Normal file
|
|
@ -0,0 +1,44 @@
|
|||
// Industry signals around AI models and major players.
|
||||
// Branch use case: main tracks confirmed signals, branches
|
||||
// model speculative or emerging interpretations.
|
||||
|
||||
node Signal {
|
||||
slug: String @key
|
||||
title: String
|
||||
source: String
|
||||
strength: enum(strong, moderate, weak)
|
||||
observed: Date?
|
||||
}
|
||||
|
||||
node Pattern {
|
||||
slug: String @key
|
||||
name: String
|
||||
category: enum(adoption, churn, expansion, contraction, convergence)
|
||||
}
|
||||
|
||||
node Tech {
|
||||
slug: String @key
|
||||
name: String
|
||||
kind: enum(model, platform, infra, framework, tool)
|
||||
tier: enum(emerging, growth, mature, declining)
|
||||
}
|
||||
|
||||
node Company {
|
||||
slug: String @key
|
||||
name: String
|
||||
sector: enum(lab, hyperscaler, chipmaker, investor, startup)
|
||||
}
|
||||
|
||||
edge Indicates: Signal -> Pattern
|
||||
|
||||
edge Mentions: Signal -> Tech
|
||||
|
||||
edge Involves: Pattern -> Tech
|
||||
|
||||
edge Builds: Company -> Tech
|
||||
|
||||
edge FundedBy: Company -> Company {
|
||||
amount: String?
|
||||
}
|
||||
|
||||
edge PartnersWith: Company -> Company
|
||||
78
crates/omnigraph/tests/fixtures/test.gq
vendored
Normal file
78
crates/omnigraph/tests/fixtures/test.gq
vendored
Normal file
|
|
@ -0,0 +1,78 @@
|
|||
// Basic: find person by name
|
||||
query get_person($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
}
|
||||
return { $p.name, $p.age }
|
||||
}
|
||||
|
||||
// Filter by age
|
||||
query adults() {
|
||||
match {
|
||||
$p: Person
|
||||
$p.age > 30
|
||||
}
|
||||
return { $p.name, $p.age }
|
||||
order { $p.age desc }
|
||||
}
|
||||
|
||||
// One hop traversal
|
||||
query friends_of($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
$p knows $f
|
||||
}
|
||||
return { $f.name, $f.age }
|
||||
}
|
||||
|
||||
// Reverse traversal: who works at a company
|
||||
query employees_of($company: String) {
|
||||
match {
|
||||
$c: Company { name: $company }
|
||||
$p worksAt $c
|
||||
}
|
||||
return { $p.name }
|
||||
}
|
||||
|
||||
// Two hop: friends of friends
|
||||
query friends_of_friends($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
$p knows $mid
|
||||
$mid knows $fof
|
||||
}
|
||||
return { $fof.name }
|
||||
}
|
||||
|
||||
// Negation: people who don't work anywhere
|
||||
query unemployed() {
|
||||
match {
|
||||
$p: Person
|
||||
not { $p worksAt $_ }
|
||||
}
|
||||
return { $p.name }
|
||||
}
|
||||
|
||||
// Aggregation: friend count
|
||||
query friend_counts() {
|
||||
match {
|
||||
$p: Person
|
||||
$p knows $f
|
||||
}
|
||||
return {
|
||||
$p.name
|
||||
count($f) as friends
|
||||
}
|
||||
order { friends desc }
|
||||
limit 20
|
||||
}
|
||||
|
||||
// Order and limit
|
||||
query top_by_age() {
|
||||
match {
|
||||
$p: Person
|
||||
}
|
||||
return { $p.name, $p.age }
|
||||
order { $p.age desc }
|
||||
limit 2
|
||||
}
|
||||
11
crates/omnigraph/tests/fixtures/test.jsonl
vendored
Normal file
11
crates/omnigraph/tests/fixtures/test.jsonl
vendored
Normal file
|
|
@ -0,0 +1,11 @@
|
|||
{"type": "Person", "data": {"name": "Alice", "age": 30}}
|
||||
{"type": "Person", "data": {"name": "Bob", "age": 25}}
|
||||
{"type": "Person", "data": {"name": "Charlie", "age": 35}}
|
||||
{"type": "Person", "data": {"name": "Diana", "age": 28}}
|
||||
{"type": "Company", "data": {"name": "Acme"}}
|
||||
{"type": "Company", "data": {"name": "Globex"}}
|
||||
{"edge": "Knows", "from": "Alice", "to": "Bob"}
|
||||
{"edge": "Knows", "from": "Alice", "to": "Charlie"}
|
||||
{"edge": "Knows", "from": "Bob", "to": "Diana"}
|
||||
{"edge": "WorksAt", "from": "Alice", "to": "Acme"}
|
||||
{"edge": "WorksAt", "from": "Bob", "to": "Globex"}
|
||||
14
crates/omnigraph/tests/fixtures/test.pg
vendored
Normal file
14
crates/omnigraph/tests/fixtures/test.pg
vendored
Normal file
|
|
@ -0,0 +1,14 @@
|
|||
node Person {
|
||||
name: String @key
|
||||
age: I32?
|
||||
}
|
||||
|
||||
node Company {
|
||||
name: String @key
|
||||
}
|
||||
|
||||
edge Knows: Person -> Person {
|
||||
since: Date?
|
||||
}
|
||||
|
||||
edge WorksAt: Person -> Company
|
||||
256
crates/omnigraph/tests/helpers/mod.rs
Normal file
256
crates/omnigraph/tests/helpers/mod.rs
Normal file
|
|
@ -0,0 +1,256 @@
|
|||
#![allow(dead_code)]
|
||||
|
||||
use arrow_array::{Array, RecordBatch, StringArray};
|
||||
use futures::TryStreamExt;
|
||||
|
||||
use omnigraph::changes::{ChangeFilter, ChangeSet};
|
||||
use omnigraph::db::{Omnigraph, ReadTarget, Snapshot, SnapshotId};
|
||||
use omnigraph::error::Result;
|
||||
use omnigraph::loader::{LoadMode, load_jsonl};
|
||||
use omnigraph_compiler::ir::ParamMap;
|
||||
use omnigraph_compiler::query::ast::Literal;
|
||||
use omnigraph_compiler::result::{MutationResult, QueryResult};
|
||||
|
||||
pub const TEST_SCHEMA: &str = include_str!("../fixtures/test.pg");
|
||||
pub const TEST_DATA: &str = include_str!("../fixtures/test.jsonl");
|
||||
pub const TEST_QUERIES: &str = include_str!("../fixtures/test.gq");
|
||||
|
||||
pub const MUTATION_QUERIES: &str = r#"
|
||||
query insert_person($name: String, $age: I32) {
|
||||
insert Person { name: $name, age: $age }
|
||||
}
|
||||
|
||||
query add_friend($from: String, $to: String) {
|
||||
insert Knows { from: $from, to: $to }
|
||||
}
|
||||
|
||||
query set_age($name: String, $age: I32) {
|
||||
update Person set { age: $age } where name = $name
|
||||
}
|
||||
|
||||
query remove_person($name: String) {
|
||||
delete Person where name = $name
|
||||
}
|
||||
|
||||
query remove_friendship($from: String) {
|
||||
delete Knows where from = $from
|
||||
}
|
||||
"#;
|
||||
|
||||
/// Init a repo and load the standard test data.
|
||||
pub async fn init_and_load(dir: &tempfile::TempDir) -> Omnigraph {
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
db
|
||||
}
|
||||
|
||||
/// Read all rows from a sub-table by table_key.
|
||||
pub async fn read_table(db: &Omnigraph, table_key: &str) -> Vec<RecordBatch> {
|
||||
let snap = snapshot_main(db).await.unwrap();
|
||||
let ds = snap.open(table_key).await.unwrap();
|
||||
ds.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Read all rows from a branch-local sub-table by table_key.
|
||||
pub async fn read_table_branch(db: &Omnigraph, branch: &str, table_key: &str) -> Vec<RecordBatch> {
|
||||
let snap = snapshot_branch(db, branch).await.unwrap();
|
||||
let ds = snap.open(table_key).await.unwrap();
|
||||
ds.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect()
|
||||
.await
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
/// Count rows in a sub-table.
|
||||
pub async fn count_rows(db: &Omnigraph, table_key: &str) -> usize {
|
||||
let snap = snapshot_main(db).await.unwrap();
|
||||
let ds = snap.open(table_key).await.unwrap();
|
||||
ds.count_rows(None).await.unwrap()
|
||||
}
|
||||
|
||||
/// Count rows in a branch-local sub-table.
|
||||
pub async fn count_rows_branch(db: &Omnigraph, branch: &str, table_key: &str) -> usize {
|
||||
let snap = snapshot_branch(db, branch).await.unwrap();
|
||||
let ds = snap.open(table_key).await.unwrap();
|
||||
ds.count_rows(None).await.unwrap()
|
||||
}
|
||||
|
||||
/// Collect all string values from a named column across batches.
|
||||
pub fn collect_column_strings(batches: &[RecordBatch], col: &str) -> Vec<String> {
|
||||
let mut out = Vec::new();
|
||||
for batch in batches {
|
||||
let arr = batch
|
||||
.column_by_name(col)
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
for i in 0..arr.len() {
|
||||
if !arr.is_null(i) {
|
||||
out.push(arr.value(i).to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
out
|
||||
}
|
||||
|
||||
pub async fn query_main(
|
||||
db: &mut Omnigraph,
|
||||
query_source: &str,
|
||||
query_name: &str,
|
||||
params: &ParamMap,
|
||||
) -> Result<QueryResult> {
|
||||
db.query(ReadTarget::branch("main"), query_source, query_name, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn query_branch(
|
||||
db: &mut Omnigraph,
|
||||
branch: &str,
|
||||
query_source: &str,
|
||||
query_name: &str,
|
||||
params: &ParamMap,
|
||||
) -> Result<QueryResult> {
|
||||
db.query(ReadTarget::branch(branch), query_source, query_name, params)
|
||||
.await
|
||||
}
|
||||
|
||||
pub async fn mutate_main(
|
||||
db: &mut Omnigraph,
|
||||
query_source: &str,
|
||||
query_name: &str,
|
||||
params: &ParamMap,
|
||||
) -> Result<MutationResult> {
|
||||
db.mutate("main", query_source, query_name, params).await
|
||||
}
|
||||
|
||||
pub async fn mutate_branch(
|
||||
db: &mut Omnigraph,
|
||||
branch: &str,
|
||||
query_source: &str,
|
||||
query_name: &str,
|
||||
params: &ParamMap,
|
||||
) -> Result<MutationResult> {
|
||||
db.mutate(branch, query_source, query_name, params).await
|
||||
}
|
||||
|
||||
pub async fn snapshot_main(db: &Omnigraph) -> Result<Snapshot> {
|
||||
db.snapshot_of(ReadTarget::branch("main")).await
|
||||
}
|
||||
|
||||
pub async fn snapshot_branch(db: &Omnigraph, branch: &str) -> Result<Snapshot> {
|
||||
db.snapshot_of(ReadTarget::branch(branch)).await
|
||||
}
|
||||
|
||||
pub async fn version_main(db: &Omnigraph) -> Result<u64> {
|
||||
db.version_of(ReadTarget::branch("main")).await
|
||||
}
|
||||
|
||||
pub async fn version_branch(db: &Omnigraph, branch: &str) -> Result<u64> {
|
||||
db.version_of(ReadTarget::branch(branch)).await
|
||||
}
|
||||
|
||||
pub async fn sync_main(db: &mut Omnigraph) -> Result<()> {
|
||||
db.sync_branch("main").await
|
||||
}
|
||||
|
||||
pub async fn sync_named_branch(db: &mut Omnigraph, branch: &str) -> Result<()> {
|
||||
db.sync_branch(branch).await
|
||||
}
|
||||
|
||||
pub async fn snapshot_id(db: &Omnigraph, branch: &str) -> Result<SnapshotId> {
|
||||
db.resolve_snapshot(branch).await
|
||||
}
|
||||
|
||||
pub async fn diff_since_branch(
|
||||
db: &Omnigraph,
|
||||
branch: &str,
|
||||
from_snapshot: SnapshotId,
|
||||
filter: &ChangeFilter,
|
||||
) -> Result<ChangeSet> {
|
||||
db.diff_between(
|
||||
ReadTarget::Snapshot(from_snapshot),
|
||||
ReadTarget::branch(branch),
|
||||
filter,
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
/// Build a ParamMap from string key-value pairs.
|
||||
pub fn params(pairs: &[(&str, &str)]) -> ParamMap {
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
let key = k.strip_prefix('$').unwrap_or(k);
|
||||
(key.to_string(), Literal::String(v.to_string()))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build a ParamMap from integer key-value pairs.
|
||||
pub fn int_params(pairs: &[(&str, i64)]) -> ParamMap {
|
||||
pairs
|
||||
.iter()
|
||||
.map(|(k, v)| {
|
||||
let key = k.strip_prefix('$').unwrap_or(k);
|
||||
(key.to_string(), Literal::Integer(*v))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
/// Build a ParamMap from mixed string + integer pairs.
|
||||
pub fn mixed_params(str_pairs: &[(&str, &str)], int_pairs: &[(&str, i64)]) -> ParamMap {
|
||||
let mut map = params(str_pairs);
|
||||
for (k, v) in int_pairs {
|
||||
let key = k.strip_prefix('$').unwrap_or(k);
|
||||
map.insert(key.to_string(), Literal::Integer(*v));
|
||||
}
|
||||
map
|
||||
}
|
||||
|
||||
/// Build a ParamMap with a single vector parameter.
|
||||
pub fn vector_param(name: &str, values: &[f32]) -> ParamMap {
|
||||
let key = name.strip_prefix('$').unwrap_or(name).to_string();
|
||||
let lit = Literal::List(values.iter().map(|v| Literal::Float(*v as f64)).collect());
|
||||
let mut map = ParamMap::new();
|
||||
map.insert(key, lit);
|
||||
map
|
||||
}
|
||||
|
||||
/// Build a ParamMap with a vector param and a string param.
|
||||
pub fn vector_and_string_params(
|
||||
vec_name: &str,
|
||||
vec_values: &[f32],
|
||||
str_name: &str,
|
||||
str_value: &str,
|
||||
) -> ParamMap {
|
||||
let mut map = vector_param(vec_name, vec_values);
|
||||
let key = str_name.strip_prefix('$').unwrap_or(str_name).to_string();
|
||||
map.insert(key, Literal::String(str_value.to_string()));
|
||||
map
|
||||
}
|
||||
|
||||
pub fn s3_test_repo_uri(suite: &str) -> Option<String> {
|
||||
let bucket = std::env::var("OMNIGRAPH_S3_TEST_BUCKET").ok()?;
|
||||
let prefix = std::env::var("OMNIGRAPH_S3_TEST_PREFIX")
|
||||
.ok()
|
||||
.filter(|value| !value.trim().is_empty())
|
||||
.unwrap_or_else(|| "omnigraph-itests".to_string());
|
||||
let unique = std::time::SystemTime::now()
|
||||
.duration_since(std::time::UNIX_EPOCH)
|
||||
.ok()?
|
||||
.as_nanos();
|
||||
Some(format!("s3://{}/{}/{}/{}", bucket, prefix, suite, unique))
|
||||
}
|
||||
268
crates/omnigraph/tests/lance_version_columns.rs
Normal file
268
crates/omnigraph/tests/lance_version_columns.rs
Normal file
|
|
@ -0,0 +1,268 @@
|
|||
/// Investigation test: understand how Lance stamps `_row_created_at_version` and
|
||||
/// `_row_last_updated_at_version` for different write modes (append, merge_insert new,
|
||||
/// merge_insert update).
|
||||
use std::sync::Arc;
|
||||
|
||||
use arrow_array::{Array, Int32Array, RecordBatch, RecordBatchIterator, StringArray, UInt64Array};
|
||||
use arrow_schema::{DataType, Field, Schema};
|
||||
use futures::TryStreamExt;
|
||||
use lance::Dataset;
|
||||
use lance::dataset::{WriteMode, WriteParams};
|
||||
use lance_file::version::LanceFileVersion;
|
||||
|
||||
async fn create_test_dataset(uri: &str) -> Dataset {
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Utf8, false),
|
||||
Field::new("value", DataType::Int32, false),
|
||||
]));
|
||||
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["alice", "bob"])),
|
||||
Arc::new(Int32Array::from(vec![1, 2])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let params = WriteParams {
|
||||
mode: WriteMode::Create,
|
||||
enable_stable_row_ids: true,
|
||||
data_storage_version: Some(LanceFileVersion::V2_2),
|
||||
..Default::default()
|
||||
};
|
||||
Dataset::write(reader, uri, Some(params)).await.unwrap()
|
||||
}
|
||||
|
||||
fn read_version_columns(batches: &[RecordBatch]) -> Vec<(String, i32, u64, u64)> {
|
||||
let mut rows = Vec::new();
|
||||
for batch in batches {
|
||||
let ids = batch
|
||||
.column_by_name("id")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let vals = batch
|
||||
.column_by_name("value")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
let created = batch
|
||||
.column_by_name("_row_created_at_version")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap();
|
||||
let updated = batch
|
||||
.column_by_name("_row_last_updated_at_version")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<UInt64Array>()
|
||||
.unwrap();
|
||||
for i in 0..ids.len() {
|
||||
rows.push((
|
||||
ids.value(i).to_string(),
|
||||
vals.value(i),
|
||||
created.value(i),
|
||||
updated.value(i),
|
||||
));
|
||||
}
|
||||
}
|
||||
rows.sort_by(|a, b| a.0.cmp(&b.0));
|
||||
rows
|
||||
}
|
||||
|
||||
async fn scan_with_versions(ds: &Dataset) -> Vec<(String, i32, u64, u64)> {
|
||||
let mut scanner = ds.scan();
|
||||
scanner
|
||||
.project(&[
|
||||
"id",
|
||||
"value",
|
||||
"_row_created_at_version",
|
||||
"_row_last_updated_at_version",
|
||||
])
|
||||
.unwrap();
|
||||
let batches: Vec<RecordBatch> = scanner
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect()
|
||||
.await
|
||||
.unwrap();
|
||||
read_version_columns(&batches)
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lance_append_stamps_created_at_version_correctly() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().join("test.lance");
|
||||
let uri_str = uri.to_str().unwrap();
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Utf8, false),
|
||||
Field::new("value", DataType::Int32, false),
|
||||
]));
|
||||
|
||||
let ds = create_test_dataset(uri_str).await;
|
||||
let v1 = ds.version().version;
|
||||
|
||||
// Append a new row
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["charlie"])),
|
||||
Arc::new(Int32Array::from(vec![3])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let mut ds = ds;
|
||||
ds.append(reader, None).await.unwrap();
|
||||
let v2 = ds.version().version;
|
||||
|
||||
let rows = scan_with_versions(&ds).await;
|
||||
eprintln!("After append (v1={}, v2={}):", v1, v2);
|
||||
for (id, val, created, updated) in &rows {
|
||||
eprintln!(
|
||||
" id={:<10} val={:<4} created_v={:<4} updated_v={}",
|
||||
id, val, created, updated
|
||||
);
|
||||
}
|
||||
|
||||
// Alice and Bob: created at v1
|
||||
let alice = rows.iter().find(|r| r.0 == "alice").unwrap();
|
||||
assert_eq!(alice.2, v1, "alice created_at should be v1");
|
||||
|
||||
// Charlie: created at v2 (the append version)
|
||||
let charlie = rows.iter().find(|r| r.0 == "charlie").unwrap();
|
||||
assert_eq!(
|
||||
charlie.2, v2,
|
||||
"charlie created_at should be v2 (append version)"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lance_merge_insert_new_row_stamps_created_at_version() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().join("test.lance");
|
||||
let uri_str = uri.to_str().unwrap();
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Utf8, false),
|
||||
Field::new("value", DataType::Int32, false),
|
||||
]));
|
||||
|
||||
let ds = create_test_dataset(uri_str).await;
|
||||
let v1 = ds.version().version;
|
||||
|
||||
// merge_insert a NEW row (eve doesn't exist)
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["eve"])),
|
||||
Arc::new(Int32Array::from(vec![4])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let ds_arc = Arc::new(ds);
|
||||
let job = lance::dataset::MergeInsertBuilder::try_new(ds_arc, vec!["id".to_string()])
|
||||
.unwrap()
|
||||
.when_matched(lance::dataset::WhenMatched::UpdateAll)
|
||||
.when_not_matched(lance::dataset::WhenNotMatched::InsertAll)
|
||||
.try_build()
|
||||
.unwrap();
|
||||
let (new_ds, _) = job
|
||||
.execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
|
||||
.await
|
||||
.unwrap();
|
||||
let v2 = new_ds.version().version;
|
||||
|
||||
let rows = scan_with_versions(&new_ds).await;
|
||||
eprintln!("After merge_insert NEW eve (v1={}, v2={}):", v1, v2);
|
||||
for (id, val, created, updated) in &rows {
|
||||
eprintln!(
|
||||
" id={:<10} val={:<4} created_v={:<4} updated_v={}",
|
||||
id, val, created, updated
|
||||
);
|
||||
}
|
||||
|
||||
let eve = rows.iter().find(|r| r.0 == "eve").unwrap();
|
||||
eprintln!("Eve: created_at_version={}, v1={}, v2={}", eve.2, v1, v2);
|
||||
|
||||
// Lance behavior (as of 3.0.1): merge_insert stamps new rows with
|
||||
// _row_created_at_version = dataset_creation_version (v1), NOT the
|
||||
// merge_insert commit version (v2). This is why Omnigraph's change
|
||||
// detection uses _row_last_updated_at_version + ID set membership
|
||||
// to classify inserts vs updates, not _row_created_at_version alone.
|
||||
assert_eq!(
|
||||
eve.2, v1,
|
||||
"Lance merge_insert stamps new rows with created_at = dataset creation version, not commit version"
|
||||
);
|
||||
assert_eq!(
|
||||
eve.3, v2,
|
||||
"Lance merge_insert stamps new rows with last_updated_at = commit version"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn lance_merge_insert_update_preserves_created_at_version() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().join("test.lance");
|
||||
let uri_str = uri.to_str().unwrap();
|
||||
let schema = Arc::new(Schema::new(vec![
|
||||
Field::new("id", DataType::Utf8, false),
|
||||
Field::new("value", DataType::Int32, false),
|
||||
]));
|
||||
|
||||
let ds = create_test_dataset(uri_str).await;
|
||||
let v1 = ds.version().version;
|
||||
|
||||
// merge_insert an EXISTING row (update bob's value)
|
||||
let batch = RecordBatch::try_new(
|
||||
schema.clone(),
|
||||
vec![
|
||||
Arc::new(StringArray::from(vec!["bob"])),
|
||||
Arc::new(Int32Array::from(vec![99])),
|
||||
],
|
||||
)
|
||||
.unwrap();
|
||||
let reader = RecordBatchIterator::new(vec![Ok(batch)], schema);
|
||||
let ds_arc = Arc::new(ds);
|
||||
let job = lance::dataset::MergeInsertBuilder::try_new(ds_arc, vec!["id".to_string()])
|
||||
.unwrap()
|
||||
.when_matched(lance::dataset::WhenMatched::UpdateAll)
|
||||
.when_not_matched(lance::dataset::WhenNotMatched::InsertAll)
|
||||
.try_build()
|
||||
.unwrap();
|
||||
let (new_ds, _) = job
|
||||
.execute(lance_datafusion::utils::reader_to_stream(Box::new(reader)))
|
||||
.await
|
||||
.unwrap();
|
||||
let v2 = new_ds.version().version;
|
||||
|
||||
let rows = scan_with_versions(&new_ds).await;
|
||||
eprintln!("After merge_insert UPDATE bob (v1={}, v2={}):", v1, v2);
|
||||
for (id, val, created, updated) in &rows {
|
||||
eprintln!(
|
||||
" id={:<10} val={:<4} created_v={:<4} updated_v={}",
|
||||
id, val, created, updated
|
||||
);
|
||||
}
|
||||
|
||||
let alice = rows.iter().find(|r| r.0 == "alice").unwrap();
|
||||
let bob = rows.iter().find(|r| r.0 == "bob").unwrap();
|
||||
|
||||
// Alice: untouched, should keep original versions
|
||||
assert_eq!(alice.2, v1, "alice created_at should still be v1");
|
||||
assert_eq!(alice.3, v1, "alice updated_at should still be v1");
|
||||
|
||||
// Bob: updated via merge_insert
|
||||
// created_at should be preserved (v1), updated_at should be bumped (v2)
|
||||
eprintln!(
|
||||
"Bob: created_at={}, updated_at={}, v1={}, v2={}",
|
||||
bob.2, bob.3, v1, v2
|
||||
);
|
||||
assert_eq!(bob.1, 99, "bob's value should be updated to 99");
|
||||
}
|
||||
736
crates/omnigraph/tests/point_in_time.rs
Normal file
736
crates/omnigraph/tests/point_in_time.rs
Normal file
|
|
@ -0,0 +1,736 @@
|
|||
mod helpers;
|
||||
|
||||
use arrow_array::{Array, Int32Array};
|
||||
use helpers::*;
|
||||
use omnigraph::db::Omnigraph;
|
||||
use omnigraph_compiler::ir::ParamMap;
|
||||
|
||||
// ─── Inline queries for point-in-time tests ─────────────────────────────────
|
||||
|
||||
const ALL_PERSONS_QUERY: &str = r#"
|
||||
query all_persons() {
|
||||
match {
|
||||
$p: Person
|
||||
}
|
||||
return { $p.name, $p.age }
|
||||
order { $p.name asc }
|
||||
}
|
||||
"#;
|
||||
|
||||
const FRIENDS_QUERY: &str = r#"
|
||||
query friends_of($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
$p knows $f
|
||||
}
|
||||
return { $f.name }
|
||||
order { $f.name asc }
|
||||
}
|
||||
"#;
|
||||
|
||||
const UNEMPLOYED_QUERY: &str = r#"
|
||||
query unemployed() {
|
||||
match {
|
||||
$p: Person
|
||||
not { $p worksAt $_ }
|
||||
}
|
||||
return { $p.name }
|
||||
order { $p.name asc }
|
||||
}
|
||||
"#;
|
||||
|
||||
const FILTERED_QUERY: &str = r#"
|
||||
query older_than($min_age: I32) {
|
||||
match {
|
||||
$p: Person
|
||||
$p.age > $min_age
|
||||
}
|
||||
return { $p.name, $p.age }
|
||||
order { $p.name asc }
|
||||
}
|
||||
"#;
|
||||
|
||||
const GET_PERSON_QUERY: &str = r#"
|
||||
query get_person($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
}
|
||||
return { $p.name, $p.age }
|
||||
}
|
||||
"#;
|
||||
|
||||
// ─── Morphological matrix ───────────────────────────────────────────────────
|
||||
//
|
||||
// Dimensions:
|
||||
// Query type: Tabular | Traversal | Negation (AntiJoin) | Filtered | Aggregation
|
||||
// Mutation: Insert | Update | Delete node | Delete edge
|
||||
// Branch: Main | Named branch
|
||||
// Result shape: Empty→non-empty | Non-empty→empty | Count changes | Value changes
|
||||
//
|
||||
// Existing coverage (4 tests):
|
||||
// Tabular × Insert × Main (returns_historical_data)
|
||||
// Traversal × Insert × Main (traversal_uses_historical_graph_index)
|
||||
// Tabular × Update × Main (multiple_versions_sees_correct_state)
|
||||
// Error case (snapshot_at_version_fails_for_nonexistent_version)
|
||||
//
|
||||
// New coverage (9 tests below):
|
||||
// Tabular × Delete node × Main → non-empty becomes smaller
|
||||
// Traversal × Delete edge × Main → edge disappears from historical
|
||||
// Negation × Insert edge × Main → anti-join result shrinks after insert
|
||||
// Negation × Delete edge × Main → anti-join result grows after delete
|
||||
// Filtered × Update × Main → entity enters/exits filter after age change
|
||||
// Multi-hop × Insert(n+e) × Main → friends-of-friends grows after new path
|
||||
// Traversal × Delete node × Main → cascade removes edges from traversal
|
||||
// Tabular × Insert × Branch → branch isolation for point-in-time
|
||||
// Tabular × Multi-step × Main → 4-version chain: insert, update, delete
|
||||
|
||||
// ─── Original tests ────────────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_query_at_returns_historical_data() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let historical = db
|
||||
.run_query_at(v_before, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(historical.num_rows(), 4, "historical should have 4 persons");
|
||||
assert_eq!(current.num_rows(), 5, "current should have 5 persons");
|
||||
|
||||
let historical_names = collect_column_strings(historical.batches(), "p.name");
|
||||
assert!(!historical_names.contains(&"Eve".to_string()));
|
||||
|
||||
let current_names = collect_column_strings(current.batches(), "p.name");
|
||||
assert!(current_names.contains(&"Eve".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_query_at_traversal_uses_historical_graph_index() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Eve"), ("$to", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let historical = db
|
||||
.run_query_at(
|
||||
v_before,
|
||||
FRIENDS_QUERY,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let current = query_main(
|
||||
&mut db,
|
||||
FRIENDS_QUERY,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(historical.num_rows(), 2);
|
||||
let hist_names = collect_column_strings(historical.batches(), "f.name");
|
||||
assert!(hist_names.contains(&"Bob".to_string()));
|
||||
assert!(hist_names.contains(&"Charlie".to_string()));
|
||||
|
||||
assert_eq!(current.num_rows(), 1);
|
||||
let cur_names = collect_column_strings(current.batches(), "f.name");
|
||||
assert!(cur_names.contains(&"Alice".to_string()));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn snapshot_at_version_fails_for_nonexistent_version() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let db = init_and_load(&dir).await;
|
||||
|
||||
let result = db.snapshot_at_version(99999).await;
|
||||
assert!(result.is_err(), "non-existent version should return error");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn run_query_at_multiple_versions_sees_correct_state() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let v1 = version_main(&db).await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Alice")], &[("$age", 99)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let v2 = version_main(&db).await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Frank")], &[("$age", 40)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let at_v1 = db
|
||||
.run_query_at(v1, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(at_v1.num_rows(), 4, "v1 should have 4 persons");
|
||||
let v1_names = collect_column_strings(at_v1.batches(), "p.name");
|
||||
assert!(!v1_names.contains(&"Frank".to_string()));
|
||||
|
||||
let at_v2 = db
|
||||
.run_query_at(v2, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(at_v2.num_rows(), 4, "v2 should have 4 persons");
|
||||
let v2_names = collect_column_strings(at_v2.batches(), "p.name");
|
||||
assert!(!v2_names.contains(&"Frank".to_string()));
|
||||
|
||||
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(current.num_rows(), 5, "current should have 5 persons");
|
||||
let cur_names = collect_column_strings(current.batches(), "p.name");
|
||||
assert!(cur_names.contains(&"Frank".to_string()));
|
||||
}
|
||||
|
||||
// ─── Tabular × Delete node ─────────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn tabular_delete_node_invisible_at_historical_version() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
// Fixture: Alice, Bob, Charlie, Diana
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"remove_person",
|
||||
¶ms(&[("$name", "Charlie")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Historical: Charlie still exists
|
||||
let historical = db
|
||||
.run_query_at(v_before, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let hist_names = collect_column_strings(historical.batches(), "p.name");
|
||||
assert_eq!(historical.num_rows(), 4);
|
||||
assert!(hist_names.contains(&"Charlie".to_string()));
|
||||
|
||||
// Current: Charlie is gone
|
||||
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let cur_names = collect_column_strings(current.batches(), "p.name");
|
||||
assert_eq!(current.num_rows(), 3);
|
||||
assert!(!cur_names.contains(&"Charlie".to_string()));
|
||||
}
|
||||
|
||||
// ─── Traversal × Delete edge ───────────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn traversal_delete_edge_invisible_at_historical_version() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
// Fixture: Alice knows Bob, Alice knows Charlie
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
// Remove all Knows edges FROM Alice
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"remove_friendship",
|
||||
¶ms(&[("$from", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Historical traversal: Alice's friends at v_before = Bob, Charlie
|
||||
let historical = db
|
||||
.run_query_at(
|
||||
v_before,
|
||||
FRIENDS_QUERY,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(historical.num_rows(), 2);
|
||||
let hist_names = collect_column_strings(historical.batches(), "f.name");
|
||||
assert!(hist_names.contains(&"Bob".to_string()));
|
||||
assert!(hist_names.contains(&"Charlie".to_string()));
|
||||
|
||||
// Current: Alice has no friends (edges deleted)
|
||||
let current = query_main(
|
||||
&mut db,
|
||||
FRIENDS_QUERY,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(
|
||||
current.num_rows(),
|
||||
0,
|
||||
"Alice should have no friends after edge deletion"
|
||||
);
|
||||
}
|
||||
|
||||
// ─── Negation (AntiJoin) × Insert ──────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn negation_insert_shrinks_antijoin_result() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
// Fixture: Alice worksAt Acme, Bob worksAt Globex
|
||||
// Unemployed: Charlie, Diana
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
// Give Charlie a job
|
||||
mutate_main(
|
||||
&mut db,
|
||||
r#"
|
||||
query hire($from: String, $to: String) {
|
||||
insert WorksAt { from: $from, to: $to }
|
||||
}
|
||||
"#,
|
||||
"hire",
|
||||
¶ms(&[("$from", "Charlie"), ("$to", "Acme")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Historical: Charlie and Diana were unemployed
|
||||
let historical = db
|
||||
.run_query_at(v_before, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let hist_names = collect_column_strings(historical.batches(), "p.name");
|
||||
assert_eq!(historical.num_rows(), 2);
|
||||
assert!(hist_names.contains(&"Charlie".to_string()));
|
||||
assert!(hist_names.contains(&"Diana".to_string()));
|
||||
|
||||
// Current: only Diana is unemployed
|
||||
let current = query_main(&mut db, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
let cur_names = collect_column_strings(current.batches(), "p.name");
|
||||
assert_eq!(current.num_rows(), 1);
|
||||
assert!(cur_names.contains(&"Diana".to_string()));
|
||||
assert!(!cur_names.contains(&"Charlie".to_string()));
|
||||
}
|
||||
|
||||
// ─── Negation (AntiJoin) × Delete edge ─────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn negation_delete_edge_grows_antijoin_result() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
// Fixture: Alice worksAt Acme, Bob worksAt Globex
|
||||
// Unemployed at start: Charlie, Diana
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
// Fire Alice (delete WorksAt edge)
|
||||
mutate_main(
|
||||
&mut db,
|
||||
r#"
|
||||
query fire($from: String) {
|
||||
delete WorksAt where from = $from
|
||||
}
|
||||
"#,
|
||||
"fire",
|
||||
¶ms(&[("$from", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Historical: 2 unemployed (Charlie, Diana)
|
||||
let historical = db
|
||||
.run_query_at(v_before, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(historical.num_rows(), 2);
|
||||
let hist_names = collect_column_strings(historical.batches(), "p.name");
|
||||
assert!(!hist_names.contains(&"Alice".to_string()));
|
||||
|
||||
// Current: 3 unemployed (Alice, Charlie, Diana)
|
||||
let current = query_main(&mut db, UNEMPLOYED_QUERY, "unemployed", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(current.num_rows(), 3);
|
||||
let cur_names = collect_column_strings(current.batches(), "p.name");
|
||||
assert!(cur_names.contains(&"Alice".to_string()));
|
||||
}
|
||||
|
||||
// ─── Filtered × Update (value enters/exits filter) ─────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn filtered_update_entity_crosses_filter_boundary() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
// Fixture: Alice(30), Bob(25), Charlie(35), Diana(28)
|
||||
// older_than(30): Charlie(35) only
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
// Update Bob's age from 25 to 40 → enters the filter
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Bob")], &[("$age", 40)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Historical: only Charlie is older than 30
|
||||
let historical = db
|
||||
.run_query_at(
|
||||
v_before,
|
||||
FILTERED_QUERY,
|
||||
"older_than",
|
||||
&int_params(&[("$min_age", 30)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(historical.num_rows(), 1);
|
||||
let hist_names = collect_column_strings(historical.batches(), "p.name");
|
||||
assert_eq!(hist_names, vec!["Charlie"]);
|
||||
|
||||
// Current: Bob(40) and Charlie(35) are older than 30
|
||||
let current = query_main(
|
||||
&mut db,
|
||||
FILTERED_QUERY,
|
||||
"older_than",
|
||||
&int_params(&[("$min_age", 30)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(current.num_rows(), 2);
|
||||
let cur_names = collect_column_strings(current.batches(), "p.name");
|
||||
assert!(cur_names.contains(&"Bob".to_string()));
|
||||
assert!(cur_names.contains(&"Charlie".to_string()));
|
||||
}
|
||||
|
||||
// ─── Multi-hop traversal × Insert ──────────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn multi_hop_traversal_historical_version() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
// Fixture: Alice→Bob, Alice→Charlie, Bob→Diana
|
||||
// friends_of_friends(Alice) = Diana (Alice→Bob→Diana)
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
// Insert Eve and edge: Charlie→Eve
|
||||
// Now friends_of_friends(Alice) = Diana + Eve (Alice→Charlie→Eve)
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Charlie"), ("$to", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let fof_query = r#"
|
||||
query fof($name: String) {
|
||||
match {
|
||||
$p: Person { name: $name }
|
||||
$p knows $mid
|
||||
$mid knows $f
|
||||
}
|
||||
return { $f.name }
|
||||
order { $f.name asc }
|
||||
}
|
||||
"#;
|
||||
|
||||
// Historical: friends-of-friends of Alice = Diana only
|
||||
let historical = db
|
||||
.run_query_at(v_before, fof_query, "fof", ¶ms(&[("$name", "Alice")]))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(historical.num_rows(), 1);
|
||||
let hist_names = collect_column_strings(historical.batches(), "f.name");
|
||||
assert_eq!(hist_names, vec!["Diana"]);
|
||||
|
||||
// Current: friends-of-friends of Alice = Diana + Eve
|
||||
let current = query_main(&mut db, fof_query, "fof", ¶ms(&[("$name", "Alice")]))
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(current.num_rows(), 2);
|
||||
let cur_names = collect_column_strings(current.batches(), "f.name");
|
||||
assert!(cur_names.contains(&"Diana".to_string()));
|
||||
assert!(cur_names.contains(&"Eve".to_string()));
|
||||
}
|
||||
|
||||
// ─── Traversal × Delete node (cascade removes edges) ───────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn traversal_delete_node_cascade_removes_edges() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
// Fixture: Alice knows Bob, Alice knows Charlie, Bob knows Diana
|
||||
let v_before = version_main(&db).await.unwrap();
|
||||
|
||||
// Delete Bob → cascades to Knows edges involving Bob
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"remove_person",
|
||||
¶ms(&[("$name", "Bob")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Historical: Alice's friends = Bob, Charlie
|
||||
let historical = db
|
||||
.run_query_at(
|
||||
v_before,
|
||||
FRIENDS_QUERY,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(historical.num_rows(), 2);
|
||||
let hist_names = collect_column_strings(historical.batches(), "f.name");
|
||||
assert!(hist_names.contains(&"Bob".to_string()));
|
||||
assert!(hist_names.contains(&"Charlie".to_string()));
|
||||
|
||||
// Current: Alice's friends = Charlie only (Bob was deleted, edge cascaded)
|
||||
let current = query_main(
|
||||
&mut db,
|
||||
FRIENDS_QUERY,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(current.num_rows(), 1);
|
||||
let cur_names = collect_column_strings(current.batches(), "f.name");
|
||||
assert_eq!(cur_names, vec!["Charlie"]);
|
||||
}
|
||||
|
||||
// ─── Branch isolation for point-in-time ────────────────────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn branch_point_in_time_isolated_from_main() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut main = init_and_load(&dir).await;
|
||||
|
||||
main.branch_create("feature").await.unwrap();
|
||||
let v_main_before = version_main(&main).await.unwrap();
|
||||
|
||||
// Insert Eve on main
|
||||
mutate_main(
|
||||
&mut main,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Insert Frank on feature branch
|
||||
let mut feature = Omnigraph::open(uri).await.unwrap();
|
||||
mutate_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Frank")], &[("$age", 33)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// Historical main at v_main_before: 4 persons, no Eve, no Frank
|
||||
let hist_main = main
|
||||
.run_query_at(
|
||||
v_main_before,
|
||||
ALL_PERSONS_QUERY,
|
||||
"all_persons",
|
||||
&ParamMap::new(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(hist_main.num_rows(), 4);
|
||||
let hist_names = collect_column_strings(hist_main.batches(), "p.name");
|
||||
assert!(!hist_names.contains(&"Eve".to_string()));
|
||||
assert!(!hist_names.contains(&"Frank".to_string()));
|
||||
|
||||
// Current main: 5 persons (Eve present, Frank not visible on main)
|
||||
let cur_main = query_main(
|
||||
&mut main,
|
||||
ALL_PERSONS_QUERY,
|
||||
"all_persons",
|
||||
&ParamMap::new(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(cur_main.num_rows(), 5);
|
||||
let cur_names = collect_column_strings(cur_main.batches(), "p.name");
|
||||
assert!(cur_names.contains(&"Eve".to_string()));
|
||||
assert!(!cur_names.contains(&"Frank".to_string()));
|
||||
|
||||
// Feature branch: 5 persons (Frank present, Eve not visible on feature)
|
||||
let cur_feature = query_branch(
|
||||
&mut feature,
|
||||
"feature",
|
||||
ALL_PERSONS_QUERY,
|
||||
"all_persons",
|
||||
&ParamMap::new(),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(cur_feature.num_rows(), 5);
|
||||
let feat_names = collect_column_strings(cur_feature.batches(), "p.name");
|
||||
assert!(feat_names.contains(&"Frank".to_string()));
|
||||
assert!(!feat_names.contains(&"Eve".to_string()));
|
||||
}
|
||||
|
||||
// ─── Multi-step version chain: insert → update → delete ────────────────────
|
||||
|
||||
#[tokio::test]
|
||||
async fn four_version_chain_insert_update_delete() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
// v1: baseline (Alice=30, Bob=25, Charlie=35, Diana=28)
|
||||
let v1 = version_main(&db).await.unwrap();
|
||||
|
||||
// v2: insert Eve(22)
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let v2 = version_main(&db).await.unwrap();
|
||||
|
||||
// v3: update Eve's age to 50
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 50)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let v3 = version_main(&db).await.unwrap();
|
||||
|
||||
// v4: delete Eve
|
||||
mutate_main(
|
||||
&mut db,
|
||||
MUTATION_QUERIES,
|
||||
"remove_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
// v1: no Eve, 4 persons
|
||||
let at_v1 = db
|
||||
.run_query_at(v1, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(at_v1.num_rows(), 4);
|
||||
let v1_names = collect_column_strings(at_v1.batches(), "p.name");
|
||||
assert!(!v1_names.contains(&"Eve".to_string()));
|
||||
|
||||
// v2: Eve exists with age 22, 5 persons
|
||||
let at_v2 = db
|
||||
.run_query_at(
|
||||
v2,
|
||||
GET_PERSON_QUERY,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(at_v2.num_rows(), 1);
|
||||
let v2_batch = at_v2.concat_batches().unwrap();
|
||||
let v2_ages = v2_batch
|
||||
.column_by_name("p.age")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
assert_eq!(v2_ages.value(0), 22);
|
||||
|
||||
// v3: Eve exists with age 50
|
||||
let at_v3 = db
|
||||
.run_query_at(
|
||||
v3,
|
||||
GET_PERSON_QUERY,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(at_v3.num_rows(), 1);
|
||||
let v3_batch = at_v3.concat_batches().unwrap();
|
||||
let v3_ages = v3_batch
|
||||
.column_by_name("p.age")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<Int32Array>()
|
||||
.unwrap();
|
||||
assert_eq!(v3_ages.value(0), 50);
|
||||
|
||||
// v4 (current): Eve is gone, back to 4
|
||||
let current = query_main(&mut db, ALL_PERSONS_QUERY, "all_persons", &ParamMap::new())
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(current.num_rows(), 4);
|
||||
let cur_names = collect_column_strings(current.batches(), "p.name");
|
||||
assert!(!cur_names.contains(&"Eve".to_string()));
|
||||
}
|
||||
533
crates/omnigraph/tests/runs.rs
Normal file
533
crates/omnigraph/tests/runs.rs
Normal file
|
|
@ -0,0 +1,533 @@
|
|||
mod helpers;
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use arrow_array::{Array, RecordBatch, StringArray, TimestampMicrosecondArray};
|
||||
use futures::TryStreamExt;
|
||||
use lance::Dataset;
|
||||
|
||||
use omnigraph::db::commit_graph::CommitGraph;
|
||||
use omnigraph::db::{Omnigraph, ReadTarget, RunStatus};
|
||||
use omnigraph::error::OmniError;
|
||||
use omnigraph::loader::{LoadMode, load_jsonl};
|
||||
|
||||
use helpers::*;
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct PersistedRun {
|
||||
run_id: String,
|
||||
target_branch: String,
|
||||
run_branch: String,
|
||||
status: String,
|
||||
updated_at: i64,
|
||||
}
|
||||
|
||||
async fn latest_runs(uri: &str) -> Vec<PersistedRun> {
|
||||
let runs_uri = format!("{}/_graph_runs.lance", uri);
|
||||
let ds = Dataset::open(&runs_uri).await.unwrap();
|
||||
let batches: Vec<RecordBatch> = ds
|
||||
.scan()
|
||||
.try_into_stream()
|
||||
.await
|
||||
.unwrap()
|
||||
.try_collect()
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut latest: HashMap<String, PersistedRun> = HashMap::new();
|
||||
for batch in batches {
|
||||
let run_ids = batch
|
||||
.column_by_name("run_id")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let target_branches = batch
|
||||
.column_by_name("target_branch")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let run_branches = batch
|
||||
.column_by_name("run_branch")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let statuses = batch
|
||||
.column_by_name("status")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<StringArray>()
|
||||
.unwrap();
|
||||
let updated_ats = batch
|
||||
.column_by_name("updated_at")
|
||||
.unwrap()
|
||||
.as_any()
|
||||
.downcast_ref::<TimestampMicrosecondArray>()
|
||||
.unwrap();
|
||||
|
||||
for row in 0..batch.num_rows() {
|
||||
let record = PersistedRun {
|
||||
run_id: run_ids.value(row).to_string(),
|
||||
target_branch: target_branches.value(row).to_string(),
|
||||
run_branch: run_branches.value(row).to_string(),
|
||||
status: statuses.value(row).to_string(),
|
||||
updated_at: updated_ats.value(row),
|
||||
};
|
||||
match latest.get(record.run_id.as_str()) {
|
||||
Some(existing) if existing.updated_at >= record.updated_at => {}
|
||||
_ => {
|
||||
latest.insert(record.run_id.clone(), record);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mut records = latest.into_values().collect::<Vec<_>>();
|
||||
records.sort_by(|a, b| a.run_id.cmp(&b.run_id));
|
||||
records
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn begin_run_creates_hidden_internal_branch_and_isolates_writes() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let base_snapshot = db.resolve_snapshot("main").await.unwrap();
|
||||
|
||||
let run = db.begin_run("main", Some("test-load")).await.unwrap();
|
||||
|
||||
assert!(run.run_branch.starts_with("__run__"));
|
||||
assert_eq!(run.target_branch, "main");
|
||||
assert_eq!(run.base_snapshot_id, base_snapshot.as_str());
|
||||
assert_eq!(run.status, RunStatus::Running);
|
||||
assert_eq!(db.branch_list().await.unwrap(), vec!["main"]);
|
||||
|
||||
db.load(
|
||||
&run.run_branch,
|
||||
r#"{"type":"Person","data":{"name":"Eve","age":22}}"#,
|
||||
LoadMode::Append,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let main_qr = db
|
||||
.query(
|
||||
ReadTarget::branch("main"),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(main_qr.num_rows(), 0);
|
||||
|
||||
let run_qr = db
|
||||
.query(
|
||||
ReadTarget::branch(run.run_branch.as_str()),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(run_qr.num_rows(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn publish_run_merges_internal_branch_into_target_and_marks_record() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let run = db.begin_run("main", Some("publish-test")).await.unwrap();
|
||||
|
||||
db.load(
|
||||
&run.run_branch,
|
||||
r#"{"type":"Person","data":{"name":"Eve","age":22}}"#,
|
||||
LoadMode::Append,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let published_snapshot = db.publish_run(&run.run_id).await.unwrap();
|
||||
let record = db.get_run(&run.run_id).await.unwrap();
|
||||
|
||||
assert_eq!(record.status, RunStatus::Published);
|
||||
assert_eq!(
|
||||
record.published_snapshot_id.as_deref(),
|
||||
Some(published_snapshot.as_str())
|
||||
);
|
||||
|
||||
let main_qr = db
|
||||
.query(
|
||||
ReadTarget::branch("main"),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(main_qr.num_rows(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn abort_run_keeps_target_unchanged_and_preserves_hidden_branch_for_inspection() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let run = db.begin_run("main", Some("abort-test")).await.unwrap();
|
||||
|
||||
db.load(
|
||||
&run.run_branch,
|
||||
r#"{"type":"Person","data":{"name":"Eve","age":22}}"#,
|
||||
LoadMode::Append,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let aborted = db.abort_run(&run.run_id).await.unwrap();
|
||||
assert_eq!(aborted.status, RunStatus::Aborted);
|
||||
|
||||
let main_qr = db
|
||||
.query(
|
||||
ReadTarget::branch("main"),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(main_qr.num_rows(), 0);
|
||||
|
||||
let run_qr = db
|
||||
.query(
|
||||
ReadTarget::branch(run.run_branch.as_str()),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(run_qr.num_rows(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn public_branch_apis_reject_internal_run_refs() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
let run = db.begin_run("main", Some("guard-test")).await.unwrap();
|
||||
|
||||
let merge_err = db.branch_merge(&run.run_branch, "main").await.unwrap_err();
|
||||
match merge_err {
|
||||
OmniError::Manifest(message) => assert!(message.message.contains("internal run refs")),
|
||||
other => panic!("unexpected error: {}", other),
|
||||
}
|
||||
|
||||
let create_err = db.branch_create(&run.run_branch).await.unwrap_err();
|
||||
match create_err {
|
||||
OmniError::Manifest(message) => assert!(message.message.contains("internal run ref")),
|
||||
other => panic!("unexpected error: {}", other),
|
||||
}
|
||||
|
||||
let delete_err = db.branch_delete(&run.run_branch).await.unwrap_err();
|
||||
match delete_err {
|
||||
OmniError::Manifest(message) => assert!(message.message.contains("internal run ref")),
|
||||
other => panic!("unexpected error: {}", other),
|
||||
}
|
||||
|
||||
let fork_err = db
|
||||
.branch_create_from(ReadTarget::branch(run.run_branch.as_str()), "child")
|
||||
.await
|
||||
.unwrap_err();
|
||||
match fork_err {
|
||||
OmniError::Manifest(message) => assert!(message.message.contains("internal run ref")),
|
||||
other => panic!("unexpected error: {}", other),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn branch_delete_rejects_target_branches_with_active_runs() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
db.branch_create("feature").await.unwrap();
|
||||
let run = db.begin_run("feature", Some("delete-guard")).await.unwrap();
|
||||
|
||||
let err = db.branch_delete("feature").await.unwrap_err();
|
||||
assert!(err.to_string().contains(run.run_id.as_str()));
|
||||
assert!(err.to_string().contains("targeting it is running"));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn public_load_uses_hidden_transactional_run_and_publishes_it() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||
|
||||
let result = load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.nodes_loaded.len(), 2);
|
||||
assert_eq!(result.edges_loaded.len(), 2);
|
||||
assert_eq!(db.branch_list().await.unwrap(), vec!["main"]);
|
||||
|
||||
let runs = latest_runs(uri).await;
|
||||
assert_eq!(runs.len(), 1);
|
||||
assert_eq!(runs[0].target_branch, "main");
|
||||
assert_eq!(runs[0].status, "published");
|
||||
assert!(runs[0].run_branch.starts_with("__run__"));
|
||||
|
||||
let qr = db
|
||||
.query(
|
||||
ReadTarget::branch("main"),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(qr.num_rows(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn public_load_preserves_staged_edge_ids_on_publish() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||
|
||||
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let runs = latest_runs(uri).await;
|
||||
let run_branch = runs[0].run_branch.clone();
|
||||
|
||||
let mut main_ids = collect_column_strings(&read_table(&db, "edge:Knows").await, "id");
|
||||
let mut run_ids = collect_column_strings(
|
||||
&read_table_branch(&db, run_branch.as_str(), "edge:Knows").await,
|
||||
"id",
|
||||
);
|
||||
main_ids.sort();
|
||||
run_ids.sort();
|
||||
assert_eq!(main_ids, run_ids);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn failed_public_load_marks_run_failed_and_leaves_target_unchanged() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = Omnigraph::init(uri, TEST_SCHEMA).await.unwrap();
|
||||
|
||||
let bad = r#"{"type":"Person","data":{"name":"Alice","age":30}}
|
||||
{"edge":"Knows","from":"Alice","to":"Missing"}"#;
|
||||
let err = load_jsonl(&mut db, bad, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap_err();
|
||||
match err {
|
||||
OmniError::Manifest(message) => assert!(message.message.contains("not found in Person")),
|
||||
other => panic!("unexpected error: {}", other),
|
||||
}
|
||||
|
||||
let runs = latest_runs(uri).await;
|
||||
assert_eq!(runs.len(), 1);
|
||||
assert_eq!(runs[0].status, "failed");
|
||||
assert!(runs[0].run_branch.starts_with("__run__"));
|
||||
|
||||
let snap = snapshot_main(&db).await.unwrap();
|
||||
let person_count = snap
|
||||
.open("node:Person")
|
||||
.await
|
||||
.unwrap()
|
||||
.count_rows(None)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(person_count, 0);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn public_mutation_uses_hidden_transactional_run_and_publishes_it() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
let result = db
|
||||
.mutate(
|
||||
"main",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(result.affected_nodes, 1);
|
||||
assert_eq!(result.affected_edges, 0);
|
||||
|
||||
let runs = latest_runs(uri).await;
|
||||
assert!(!runs.is_empty());
|
||||
let latest = runs.last().unwrap();
|
||||
assert_eq!(latest.target_branch, "main");
|
||||
assert_eq!(latest.status, "published");
|
||||
assert!(latest.run_branch.starts_with("__run__"));
|
||||
|
||||
let qr = db
|
||||
.query(
|
||||
ReadTarget::branch("main"),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(qr.num_rows(), 1);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn public_mutation_preserves_staged_edge_ids_on_publish() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
db.mutate(
|
||||
"main",
|
||||
MUTATION_QUERIES,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Alice"), ("$to", "Diana")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let runs = latest_runs(uri).await;
|
||||
let latest = runs.last().unwrap();
|
||||
|
||||
let mut main_ids = collect_column_strings(&read_table(&db, "edge:Knows").await, "id");
|
||||
let mut run_ids = collect_column_strings(
|
||||
&read_table_branch(&db, latest.run_branch.as_str(), "edge:Knows").await,
|
||||
"id",
|
||||
);
|
||||
main_ids.sort();
|
||||
run_ids.sort();
|
||||
assert_eq!(main_ids, run_ids);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn failed_public_mutation_marks_run_failed_and_leaves_target_unchanged() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
let err = db
|
||||
.mutate(
|
||||
"main",
|
||||
MUTATION_QUERIES,
|
||||
"add_friend",
|
||||
¶ms(&[("$from", "Alice"), ("$to", "Missing")]),
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
match err {
|
||||
OmniError::Manifest(message) => assert!(message.message.contains("not found")),
|
||||
other => panic!("unexpected error: {}", other),
|
||||
}
|
||||
|
||||
let runs = latest_runs(uri).await;
|
||||
assert!(!runs.is_empty());
|
||||
let latest = runs.last().unwrap();
|
||||
assert_eq!(latest.status, "failed");
|
||||
assert!(latest.run_branch.starts_with("__run__"));
|
||||
|
||||
let qr = db
|
||||
.query(
|
||||
ReadTarget::branch("main"),
|
||||
TEST_QUERIES,
|
||||
"friends_of",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(qr.num_rows(), 2);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn concurrent_conflicting_run_publish_fails_cleanly() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
let run_a = db.begin_run("main", Some("conflict-a")).await.unwrap();
|
||||
let run_b = db.begin_run("main", Some("conflict-b")).await.unwrap();
|
||||
|
||||
db.mutate(
|
||||
run_a.run_branch.as_str(),
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Alice")], &[("$age", 31)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
db.mutate(
|
||||
run_b.run_branch.as_str(),
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Alice")], &[("$age", 32)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db.publish_run(&run_a.run_id).await.unwrap();
|
||||
let publish_b = db.publish_run(&run_b.run_id).await;
|
||||
assert!(publish_b.is_err(), "second conflicting publish should fail");
|
||||
let err = publish_b.unwrap_err().to_string();
|
||||
assert!(
|
||||
err.contains("conflict") || err.contains("divergent") || err.contains("Alice"),
|
||||
"unexpected conflict error: {}",
|
||||
err
|
||||
);
|
||||
|
||||
let alice = db
|
||||
.query(
|
||||
ReadTarget::branch("main"),
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
let rows = alice.to_rust_json();
|
||||
assert_eq!(alice.num_rows(), 1);
|
||||
assert_eq!(rows[0]["p.age"], serde_json::json!(31));
|
||||
|
||||
let run_a_record = db.get_run(&run_a.run_id).await.unwrap();
|
||||
assert_eq!(run_a_record.status, RunStatus::Published);
|
||||
let run_b_record = db.get_run(&run_b.run_id).await.unwrap();
|
||||
assert_eq!(run_b_record.status, RunStatus::Running);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn public_mutation_records_actor_on_run_and_published_commit() {
|
||||
let dir = tempfile::tempdir().unwrap();
|
||||
let uri = dir.path().to_str().unwrap();
|
||||
let mut db = init_and_load(&dir).await;
|
||||
|
||||
db.mutate_as(
|
||||
"main",
|
||||
MUTATION_QUERIES,
|
||||
"set_age",
|
||||
&mixed_params(&[("$name", "Alice")], &[("$age", 31)]),
|
||||
Some("act-andrew"),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let runs = db.list_runs().await.unwrap();
|
||||
let run = runs
|
||||
.iter()
|
||||
.find(|run| run.operation_hash.as_deref() == Some("mutation:set_age:branch=main"))
|
||||
.expect("published mutation run should exist");
|
||||
assert_eq!(run.actor_id.as_deref(), Some("act-andrew"));
|
||||
assert_eq!(run.status, RunStatus::Published);
|
||||
|
||||
let head = CommitGraph::open(uri)
|
||||
.await
|
||||
.unwrap()
|
||||
.head_commit()
|
||||
.await
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
assert_eq!(head.actor_id.as_deref(), Some("act-andrew"));
|
||||
}
|
||||
187
crates/omnigraph/tests/s3_storage.rs
Normal file
187
crates/omnigraph/tests/s3_storage.rs
Normal file
|
|
@ -0,0 +1,187 @@
|
|||
mod helpers;
|
||||
|
||||
use omnigraph::db::MergeOutcome;
|
||||
use omnigraph::db::Omnigraph;
|
||||
use omnigraph::loader::{LoadMode, load_jsonl};
|
||||
|
||||
use helpers::*;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn s3_compatible_repo_lifecycle_works() {
|
||||
let Some(uri) = s3_test_repo_uri("omnigraph-runtime") else {
|
||||
eprintln!("skipping s3 runtime test: OMNIGRAPH_S3_TEST_BUCKET is not set");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut db = Omnigraph::init(&uri, TEST_SCHEMA).await.unwrap();
|
||||
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let mut reopened = Omnigraph::open(&uri).await.unwrap();
|
||||
let snapshot = reopened.snapshot_of("main").await.unwrap();
|
||||
assert!(snapshot.entry("node:Person").is_some());
|
||||
assert!(snapshot.entry("edge:Knows").is_some());
|
||||
|
||||
let alice = query_main(
|
||||
&mut reopened,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Alice")]),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.to_rust_json();
|
||||
assert_eq!(alice[0]["p.name"], "Alice");
|
||||
|
||||
reopened
|
||||
.mutate(
|
||||
"main",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "RustFS-Eve")], &[("$age", 29)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let run = reopened
|
||||
.begin_run("main", Some("s3-runtime-run"))
|
||||
.await
|
||||
.unwrap();
|
||||
reopened
|
||||
.load(
|
||||
&run.run_branch,
|
||||
r#"{"type":"Person","data":{"name":"RunOnly","age":31}}"#,
|
||||
LoadMode::Append,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
reopened.publish_run(&run.run_id).await.unwrap();
|
||||
|
||||
let runs = reopened.list_runs().await.unwrap();
|
||||
assert!(
|
||||
runs.iter()
|
||||
.any(|record| { record.run_id == run.run_id && record.status.as_str() == "published" }),
|
||||
"expected published run record in {:?}",
|
||||
runs
|
||||
);
|
||||
|
||||
let mut reopened_again = Omnigraph::open(&uri).await.unwrap();
|
||||
let eve = query_main(
|
||||
&mut reopened_again,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "RustFS-Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.to_rust_json();
|
||||
assert_eq!(eve[0]["p.name"], "RustFS-Eve");
|
||||
|
||||
let run_only = query_main(
|
||||
&mut reopened_again,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "RunOnly")]),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.to_rust_json();
|
||||
assert_eq!(run_only[0]["p.name"], "RunOnly");
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn s3_branch_change_merge_flow_works() {
|
||||
let Some(uri) = s3_test_repo_uri("omnigraph-branching") else {
|
||||
eprintln!("skipping s3 branch test: OMNIGRAPH_S3_TEST_BUCKET is not set");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut main = Omnigraph::init(&uri, TEST_SCHEMA).await.unwrap();
|
||||
load_jsonl(&mut main, TEST_DATA, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
main.branch_create("feature").await.unwrap();
|
||||
|
||||
let mut feature = Omnigraph::open(&uri).await.unwrap();
|
||||
feature
|
||||
.mutate(
|
||||
"feature",
|
||||
MUTATION_QUERIES,
|
||||
"insert_person",
|
||||
&mixed_params(&[("$name", "Feature-Eve")], &[("$age", 22)]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let before_merge = query_main(
|
||||
&mut main,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Feature-Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
assert_eq!(before_merge.num_rows(), 0);
|
||||
|
||||
let outcome = main.branch_merge("feature", "main").await.unwrap();
|
||||
assert_eq!(outcome, MergeOutcome::FastForward);
|
||||
|
||||
let mut reopened = Omnigraph::open(&uri).await.unwrap();
|
||||
let after_merge = query_main(
|
||||
&mut reopened,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Feature-Eve")]),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.to_rust_json();
|
||||
assert_eq!(after_merge[0]["p.name"], "Feature-Eve");
|
||||
assert_eq!(
|
||||
reopened.branch_list().await.unwrap(),
|
||||
vec!["main".to_string(), "feature".to_string()]
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test(flavor = "multi_thread")]
|
||||
async fn s3_public_load_uses_hidden_run_and_publishes() {
|
||||
let Some(uri) = s3_test_repo_uri("omnigraph-public-load") else {
|
||||
eprintln!("skipping s3 public load test: OMNIGRAPH_S3_TEST_BUCKET is not set");
|
||||
return;
|
||||
};
|
||||
|
||||
let mut db = Omnigraph::init(&uri, TEST_SCHEMA).await.unwrap();
|
||||
load_jsonl(&mut db, TEST_DATA, LoadMode::Overwrite)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
db.load(
|
||||
"main",
|
||||
r#"{"type":"Person","data":{"name":"Loaded-Over-S3","age":34}}"#,
|
||||
LoadMode::Append,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
let runs = db.list_runs().await.unwrap();
|
||||
assert!(
|
||||
runs.iter().any(|record| {
|
||||
record.target_branch == "main" && record.status.as_str() == "published"
|
||||
}),
|
||||
"expected published transactional run in {:?}",
|
||||
runs
|
||||
);
|
||||
|
||||
let mut reopened = Omnigraph::open(&uri).await.unwrap();
|
||||
let loaded = query_main(
|
||||
&mut reopened,
|
||||
TEST_QUERIES,
|
||||
"get_person",
|
||||
¶ms(&[("$name", "Loaded-Over-S3")]),
|
||||
)
|
||||
.await
|
||||
.unwrap()
|
||||
.to_rust_json();
|
||||
assert_eq!(loaded[0]["p.name"], "Loaded-Over-S3");
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue