mirror of
https://github.com/katanemo/plano.git
synced 2026-06-26 15:39:40 +02:00
Merge branch 'main' into adil/agents_framework
This commit is contained in:
commit
660f8d433f
26 changed files with 2692 additions and 93 deletions
|
|
@ -363,6 +363,31 @@ properties:
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
|
state_storage:
|
||||||
|
type: object
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
type: string
|
||||||
|
enum:
|
||||||
|
- memory
|
||||||
|
- postgres
|
||||||
|
connection_string:
|
||||||
|
type: string
|
||||||
|
description: Required when type is postgres. Supports environment variable substitution using $VAR or ${VAR} syntax.
|
||||||
|
additionalProperties: false
|
||||||
|
required:
|
||||||
|
- type
|
||||||
|
# Note: connection_string is conditionally required based on type
|
||||||
|
# If type is 'postgres', connection_string must be provided
|
||||||
|
# If type is 'memory', connection_string is not needed
|
||||||
|
allOf:
|
||||||
|
- if:
|
||||||
|
properties:
|
||||||
|
type:
|
||||||
|
const: postgres
|
||||||
|
then:
|
||||||
|
required:
|
||||||
|
- connection_string
|
||||||
prompt_guards:
|
prompt_guards:
|
||||||
type: object
|
type: object
|
||||||
properties:
|
properties:
|
||||||
|
|
|
||||||
|
|
@ -2,7 +2,7 @@
|
||||||
nodaemon=true
|
nodaemon=true
|
||||||
|
|
||||||
[program:brightstaff]
|
[program:brightstaff]
|
||||||
command=sh -c "RUST_LOG=info /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
|
command=sh -c "envsubst < /app/arch_config_rendered.yaml > /app/arch_config_rendered.env_sub.yaml && RUST_LOG=debug ARCH_CONFIG_PATH_RENDERED=/app/arch_config_rendered.env_sub.yaml /app/brightstaff 2>&1 | tee /var/log/brightstaff.log | while IFS= read -r line; do echo '[brightstaff]' \"$line\"; done"
|
||||||
stdout_logfile=/dev/stdout
|
stdout_logfile=/dev/stdout
|
||||||
redirect_stderr=true
|
redirect_stderr=true
|
||||||
stdout_logfile_maxbytes=0
|
stdout_logfile_maxbytes=0
|
||||||
|
|
|
||||||
|
|
@ -152,6 +152,24 @@ def get_llm_provider_access_keys(arch_config_file):
|
||||||
if access_key is not None:
|
if access_key is not None:
|
||||||
access_key_list.append(access_key)
|
access_key_list.append(access_key)
|
||||||
|
|
||||||
|
# Extract environment variables from state_storage.connection_string
|
||||||
|
state_storage = arch_config_yaml.get("state_storage_v1_responses")
|
||||||
|
if state_storage:
|
||||||
|
connection_string = state_storage.get("connection_string")
|
||||||
|
if connection_string and isinstance(connection_string, str):
|
||||||
|
# Extract all $VAR and ${VAR} patterns from connection string
|
||||||
|
import re
|
||||||
|
|
||||||
|
# Match both $VAR and ${VAR} patterns
|
||||||
|
pattern = r"\$\{?([A-Z_][A-Z0-9_]*)\}?"
|
||||||
|
matches = re.findall(pattern, connection_string)
|
||||||
|
for var in matches:
|
||||||
|
access_key_list.append(f"${var}")
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
"Invalid connection string received in state_storage_v1_responses"
|
||||||
|
)
|
||||||
|
|
||||||
return access_key_list
|
return access_key_list
|
||||||
|
|
||||||
|
|
||||||
|
|
|
||||||
197
crates/Cargo.lock
generated
197
crates/Cargo.lock
generated
|
|
@ -308,11 +308,13 @@ name = "brightstaff"
|
||||||
version = "0.1.0"
|
version = "0.1.0"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"async-openai",
|
"async-openai",
|
||||||
|
"async-trait",
|
||||||
"bytes",
|
"bytes",
|
||||||
"chrono",
|
"chrono",
|
||||||
"common",
|
"common",
|
||||||
"eventsource-client",
|
"eventsource-client",
|
||||||
"eventsource-stream",
|
"eventsource-stream",
|
||||||
|
"flate2",
|
||||||
"futures",
|
"futures",
|
||||||
"futures-util",
|
"futures-util",
|
||||||
"hermesllm",
|
"hermesllm",
|
||||||
|
|
@ -336,6 +338,7 @@ dependencies = [
|
||||||
"thiserror 2.0.12",
|
"thiserror 2.0.12",
|
||||||
"time",
|
"time",
|
||||||
"tokio",
|
"tokio",
|
||||||
|
"tokio-postgres",
|
||||||
"tokio-stream",
|
"tokio-stream",
|
||||||
"tracing",
|
"tracing",
|
||||||
"tracing-opentelemetry",
|
"tracing-opentelemetry",
|
||||||
|
|
@ -360,6 +363,12 @@ version = "3.18.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee"
|
checksum = "793db76d6187cd04dff33004d8e6c9cc4e05cd330500379d2394209271b4aeee"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "byteorder"
|
||||||
|
version = "1.5.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "bytes"
|
name = "bytes"
|
||||||
version = "1.10.1"
|
version = "1.10.1"
|
||||||
|
|
@ -604,6 +613,7 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"block-buffer",
|
"block-buffer",
|
||||||
"crypto-common",
|
"crypto-common",
|
||||||
|
"subtle",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -691,6 +701,12 @@ dependencies = [
|
||||||
"pin-project-lite",
|
"pin-project-lite",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "fallible-iterator"
|
||||||
|
version = "0.2.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "4443176a9f2c162692bd3d352d745ef9413eec5782a80d8fd6f8a1ac692a07f7"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fancy-regex"
|
name = "fancy-regex"
|
||||||
version = "0.12.0"
|
version = "0.12.0"
|
||||||
|
|
@ -707,6 +723,16 @@ version = "2.3.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "flate2"
|
||||||
|
version = "1.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "bfe33edd8e85a12a67454e37f8c75e730830d83e313556ab9ebf9ee7fbeb3bfb"
|
||||||
|
dependencies = [
|
||||||
|
"crc32fast",
|
||||||
|
"miniz_oxide",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "fnv"
|
name = "fnv"
|
||||||
version = "1.0.7"
|
version = "1.0.7"
|
||||||
|
|
@ -986,6 +1012,15 @@ version = "0.4.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
checksum = "7f24254aa9a54b5c858eaee2f5bccdb46aaf0e486a595ed5fd8f86ba55232a70"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "hmac"
|
||||||
|
version = "0.12.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e"
|
||||||
|
dependencies = [
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "http"
|
name = "http"
|
||||||
version = "0.2.12"
|
version = "0.2.12"
|
||||||
|
|
@ -1420,6 +1455,17 @@ version = "0.2.172"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
|
checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "libredox"
|
||||||
|
version = "0.1.10"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "416f7e718bdb06000964960ffa43b4335ad4012ae8b99060261aa4a8088d5ccb"
|
||||||
|
dependencies = [
|
||||||
|
"bitflags",
|
||||||
|
"libc",
|
||||||
|
"redox_syscall",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "linux-raw-sys"
|
name = "linux-raw-sys"
|
||||||
version = "0.9.4"
|
version = "0.9.4"
|
||||||
|
|
@ -1492,6 +1538,16 @@ version = "0.7.3"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
checksum = "0e7465ac9959cc2b1404e8e2367b43684a6d13790fe23056cc8c6c5a6b7bcb94"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "md-5"
|
||||||
|
version = "0.10.6"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "d89e7ee0cfbedfc4da3340218492196241d89eefb6dab27de5df917a6d2e78cf"
|
||||||
|
dependencies = [
|
||||||
|
"cfg-if",
|
||||||
|
"digest",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "md5"
|
name = "md5"
|
||||||
version = "0.7.0"
|
version = "0.7.0"
|
||||||
|
|
@ -1533,6 +1589,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a"
|
checksum = "3be647b768db090acb35d5ec5db2b0e1f1de11133ca123b9eacf5137868f892a"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"adler2",
|
"adler2",
|
||||||
|
"simd-adler32",
|
||||||
]
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
|
|
@ -1836,6 +1893,24 @@ version = "2.3.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "phf"
|
||||||
|
version = "0.11.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "1fd6780a80ae0c52cc120a26a1a42c1ae51b247a253e4e06113d23d2c2edd078"
|
||||||
|
dependencies = [
|
||||||
|
"phf_shared",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "phf_shared"
|
||||||
|
version = "0.11.3"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "67eabc2ef2a60eb7faa00097bd1ffdb5bd28e62bf39990626a582201b7a754e5"
|
||||||
|
dependencies = [
|
||||||
|
"siphasher",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "pin-project"
|
name = "pin-project"
|
||||||
version = "1.1.10"
|
version = "1.1.10"
|
||||||
|
|
@ -1880,6 +1955,37 @@ version = "1.11.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
|
checksum = "350e9b48cbc6b0e028b0473b114454c6316e57336ee184ceab6e53f72c178b3e"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "postgres-protocol"
|
||||||
|
version = "0.6.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "fbef655056b916eb868048276cfd5d6a7dea4f81560dfd047f97c8c6fe3fcfd4"
|
||||||
|
dependencies = [
|
||||||
|
"base64 0.22.1",
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"fallible-iterator",
|
||||||
|
"hmac",
|
||||||
|
"md-5",
|
||||||
|
"memchr",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"sha2",
|
||||||
|
"stringprep",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "postgres-types"
|
||||||
|
version = "0.2.9"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "613283563cd90e1dfc3518d548caee47e0e725455ed619881f5cf21f36de4b48"
|
||||||
|
dependencies = [
|
||||||
|
"bytes",
|
||||||
|
"fallible-iterator",
|
||||||
|
"postgres-protocol",
|
||||||
|
"serde",
|
||||||
|
"serde_json",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "potential_utf"
|
name = "potential_utf"
|
||||||
version = "0.1.2"
|
version = "0.1.2"
|
||||||
|
|
@ -2109,9 +2215,9 @@ dependencies = [
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "redox_syscall"
|
name = "redox_syscall"
|
||||||
version = "0.5.12"
|
version = "0.5.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "928fca9cf2aa042393a8325b9ead81d2f0df4cb12e1e24cef072922ccd99c5af"
|
checksum = "ed2bf2547551a7053d6fdfafda3f938979645c44812fbfcda098faae3f1a362d"
|
||||||
dependencies = [
|
dependencies = [
|
||||||
"bitflags",
|
"bitflags",
|
||||||
]
|
]
|
||||||
|
|
@ -2650,12 +2756,24 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "simd-adler32"
|
||||||
|
version = "0.3.8"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "e320a6c5ad31d271ad523dcf3ad13e2767ad8b1cb8f047f75a8aeaf8da139da2"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "similar"
|
name = "similar"
|
||||||
version = "2.7.0"
|
version = "2.7.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
checksum = "bbbb5d9659141646ae647b42fe094daf6c6192d1620870b449d9557f748b2daa"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "siphasher"
|
||||||
|
version = "1.0.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "56199f7ddabf13fe5074ce809e7d3f42b42ae711800501b5b16ea82ad029c39d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "slab"
|
name = "slab"
|
||||||
version = "0.4.9"
|
version = "0.4.9"
|
||||||
|
|
@ -2696,6 +2814,17 @@ version = "1.2.0"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "stringprep"
|
||||||
|
version = "0.1.5"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7b4df3d392d81bd458a8a621b8bffbd2302a12ffe288a9d931670948749463b1"
|
||||||
|
dependencies = [
|
||||||
|
"unicode-bidi",
|
||||||
|
"unicode-normalization",
|
||||||
|
"unicode-properties",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "strsim"
|
name = "strsim"
|
||||||
version = "0.11.1"
|
version = "0.11.1"
|
||||||
|
|
@ -2954,6 +3083,32 @@ dependencies = [
|
||||||
"tokio",
|
"tokio",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "tokio-postgres"
|
||||||
|
version = "0.7.13"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "6c95d533c83082bb6490e0189acaa0bbeef9084e60471b696ca6988cd0541fb0"
|
||||||
|
dependencies = [
|
||||||
|
"async-trait",
|
||||||
|
"byteorder",
|
||||||
|
"bytes",
|
||||||
|
"fallible-iterator",
|
||||||
|
"futures-channel",
|
||||||
|
"futures-util",
|
||||||
|
"log",
|
||||||
|
"parking_lot",
|
||||||
|
"percent-encoding",
|
||||||
|
"phf",
|
||||||
|
"pin-project-lite",
|
||||||
|
"postgres-protocol",
|
||||||
|
"postgres-types",
|
||||||
|
"rand 0.9.2",
|
||||||
|
"socket2",
|
||||||
|
"tokio",
|
||||||
|
"tokio-util",
|
||||||
|
"whoami",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "tokio-rustls"
|
name = "tokio-rustls"
|
||||||
version = "0.24.1"
|
version = "0.24.1"
|
||||||
|
|
@ -3189,12 +3344,33 @@ version = "2.8.1"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
|
checksum = "75b844d17643ee918803943289730bec8aac480150456169e647ed0b576ba539"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-bidi"
|
||||||
|
version = "0.3.18"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5c1cb5db39152898a79168971543b1cb5020dff7fe43c8dc468b0885f5e29df5"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unicode-ident"
|
name = "unicode-ident"
|
||||||
version = "1.0.18"
|
version = "1.0.18"
|
||||||
source = "registry+https://github.com/rust-lang/crates.io-index"
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512"
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-normalization"
|
||||||
|
version = "0.1.25"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5fd4f6878c9cb28d874b009da9e8d183b5abc80117c40bbd187a1fde336be6e8"
|
||||||
|
dependencies = [
|
||||||
|
"tinyvec",
|
||||||
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "unicode-properties"
|
||||||
|
version = "0.1.4"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "7df058c713841ad818f1dc5d3fd88063241cc61f49f5fbea4b951e8cf5a8d71d"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "unsafe-libyaml"
|
name = "unsafe-libyaml"
|
||||||
version = "0.2.11"
|
version = "0.2.11"
|
||||||
|
|
@ -3290,6 +3466,12 @@ dependencies = [
|
||||||
"wit-bindgen-rt",
|
"wit-bindgen-rt",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "wasite"
|
||||||
|
version = "0.1.0"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "b8dad83b4f25e74f184f64c43b150b91efe7647395b42289f38e50566d82855b"
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "wasm-bindgen"
|
name = "wasm-bindgen"
|
||||||
version = "0.2.100"
|
version = "0.2.100"
|
||||||
|
|
@ -3394,6 +3576,17 @@ dependencies = [
|
||||||
"wasm-bindgen",
|
"wasm-bindgen",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
[[package]]
|
||||||
|
name = "whoami"
|
||||||
|
version = "1.6.1"
|
||||||
|
source = "registry+https://github.com/rust-lang/crates.io-index"
|
||||||
|
checksum = "5d4a4db5077702ca3015d3d02d74974948aba2ad9e12ab7df718ee64ccd7e97d"
|
||||||
|
dependencies = [
|
||||||
|
"libredox",
|
||||||
|
"wasite",
|
||||||
|
"web-sys",
|
||||||
|
]
|
||||||
|
|
||||||
[[package]]
|
[[package]]
|
||||||
name = "winapi"
|
name = "winapi"
|
||||||
version = "0.3.9"
|
version = "0.3.9"
|
||||||
|
|
|
||||||
|
|
@ -5,11 +5,13 @@ edition = "2021"
|
||||||
|
|
||||||
[dependencies]
|
[dependencies]
|
||||||
async-openai = "0.30.1"
|
async-openai = "0.30.1"
|
||||||
|
async-trait = "0.1"
|
||||||
bytes = "1.10.1"
|
bytes = "1.10.1"
|
||||||
chrono = "0.4"
|
chrono = "0.4"
|
||||||
common = { version = "0.1.0", path = "../common", features = ["trace-collection"] }
|
common = { version = "0.1.0", path = "../common", features = ["trace-collection"] }
|
||||||
eventsource-client = "0.15.0"
|
eventsource-client = "0.15.0"
|
||||||
eventsource-stream = "0.2.3"
|
eventsource-stream = "0.2.3"
|
||||||
|
flate2 = "1.0"
|
||||||
futures = "0.3.31"
|
futures = "0.3.31"
|
||||||
futures-util = "0.3.31"
|
futures-util = "0.3.31"
|
||||||
hermesllm = { version = "0.1.0", path = "../hermesllm" }
|
hermesllm = { version = "0.1.0", path = "../hermesllm" }
|
||||||
|
|
@ -31,6 +33,7 @@ serde_with = "3.13.0"
|
||||||
serde_yaml = "0.9.34"
|
serde_yaml = "0.9.34"
|
||||||
thiserror = "2.0.12"
|
thiserror = "2.0.12"
|
||||||
tokio = { version = "1.44.2", features = ["full"] }
|
tokio = { version = "1.44.2", features = ["full"] }
|
||||||
|
tokio-postgres = { version = "0.7", features = ["with-serde_json-1"] }
|
||||||
tokio-stream = "0.1"
|
tokio-stream = "0.1"
|
||||||
time = { version = "0.3", features = ["formatting", "macros"] }
|
time = { version = "0.3", features = ["formatting", "macros"] }
|
||||||
tracing = "0.1"
|
tracing = "0.1"
|
||||||
|
|
|
||||||
|
|
@ -1,8 +1,9 @@
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{LlmProvider, ModelAlias};
|
use common::configuration::{LlmProvider, ModelAlias};
|
||||||
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER};
|
use common::consts::{ARCH_IS_STREAMING_HEADER, ARCH_PROVIDER_HINT_HEADER, REQUEST_ID_HEADER, TRACE_PARENT_HEADER};
|
||||||
use common::traces::TraceCollector;
|
use common::traces::TraceCollector;
|
||||||
use hermesllm::clients::SupportedAPIsFromClient;
|
use hermesllm::apis::openai_responses::InputParam;
|
||||||
|
use hermesllm::clients::{SupportedAPIsFromClient, SupportedUpstreamAPIs};
|
||||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||||
use http_body_util::combinators::BoxBody;
|
use http_body_util::combinators::BoxBody;
|
||||||
use http_body_util::{BodyExt, Full};
|
use http_body_util::{BodyExt, Full};
|
||||||
|
|
@ -11,11 +12,16 @@ use hyper::{Request, Response, StatusCode};
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
use tokio::sync::RwLock;
|
use tokio::sync::RwLock;
|
||||||
use tracing::{debug, warn};
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
use crate::router::llm_router::RouterService;
|
use crate::router::llm_router::RouterService;
|
||||||
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
|
use crate::handlers::utils::{create_streaming_response, ObservableStreamProcessor, truncate_message};
|
||||||
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
use crate::handlers::router_chat::router_chat_get_upstream_model;
|
||||||
|
use crate::state::response_state_processor::ResponsesStateProcessor;
|
||||||
|
use crate::state::{
|
||||||
|
StateStorage, StateStorageError,
|
||||||
|
extract_input_items, retrieve_and_combine_input
|
||||||
|
};
|
||||||
use crate::tracing::operation_component;
|
use crate::tracing::operation_component;
|
||||||
|
|
||||||
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
fn full<T: Into<Bytes>>(chunk: T) -> BoxBody<Bytes, hyper::Error> {
|
||||||
|
|
@ -31,14 +37,20 @@ pub async fn llm_chat(
|
||||||
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
model_aliases: Arc<Option<HashMap<String, ModelAlias>>>,
|
||||||
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
|
llm_providers: Arc<RwLock<Vec<LlmProvider>>>,
|
||||||
trace_collector: Arc<TraceCollector>,
|
trace_collector: Arc<TraceCollector>,
|
||||||
|
state_storage: Option<Arc<dyn StateStorage>>,
|
||||||
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, hyper::Error> {
|
||||||
|
|
||||||
let request_path = request.uri().path().to_string();
|
let request_path = request.uri().path().to_string();
|
||||||
let request_headers = request.headers().clone();
|
let request_headers = request.headers().clone();
|
||||||
|
let request_id = request_headers
|
||||||
|
.get(REQUEST_ID_HEADER)
|
||||||
|
.and_then(|h| h.to_str().ok())
|
||||||
|
.map(|s| s.to_string())
|
||||||
|
.unwrap_or_else(|| "unknown".to_string());
|
||||||
|
|
||||||
// Extract or generate traceparent - this establishes the trace context for all spans
|
// Extract or generate traceparent - this establishes the trace context for all spans
|
||||||
let traceparent: String = request_headers
|
let traceparent: String = request_headers
|
||||||
.get("traceparent")
|
.get(TRACE_PARENT_HEADER)
|
||||||
.and_then(|h| h.to_str().ok())
|
.and_then(|h| h.to_str().ok())
|
||||||
.map(|s| s.to_string())
|
.map(|s| s.to_string())
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
|
|
@ -51,7 +63,8 @@ pub async fn llm_chat(
|
||||||
let chat_request_bytes = request.collect().await?.to_bytes();
|
let chat_request_bytes = request.collect().await?.to_bytes();
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"Received request body (raw utf8): {}",
|
"[PLANO_REQ_ID:{}] | REQUEST_BODY (UTF8): {}",
|
||||||
|
request_id,
|
||||||
String::from_utf8_lossy(&chat_request_bytes)
|
String::from_utf8_lossy(&chat_request_bytes)
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -61,14 +74,19 @@ pub async fn llm_chat(
|
||||||
)) {
|
)) {
|
||||||
Ok(request) => request,
|
Ok(request) => request,
|
||||||
Err(err) => {
|
Err(err) => {
|
||||||
warn!("Failed to parse request as ProviderRequestType: {}", err);
|
warn!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request as ProviderRequestType: {}", request_id, err);
|
||||||
let err_msg = format!("Failed to parse request: {}", err);
|
let err_msg = format!("[PLANO_REQ_ID:{}] | FAILURE | Failed to parse request: {}", request_id, err);
|
||||||
let mut bad_request = Response::new(full(err_msg));
|
let mut bad_request = Response::new(full(err_msg));
|
||||||
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
*bad_request.status_mut() = StatusCode::BAD_REQUEST;
|
||||||
return Ok(bad_request);
|
return Ok(bad_request);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
// === v1/responses state management: Extract input items early ===
|
||||||
|
let mut original_input_items = Vec::new();
|
||||||
|
let client_api = SupportedAPIsFromClient::from_endpoint(request_path.as_str());
|
||||||
|
let is_responses_api_client = matches!(client_api, Some(SupportedAPIsFromClient::OpenAIResponsesAPI(_)));
|
||||||
|
|
||||||
// Model alias resolution: update model field in client_request immediately
|
// Model alias resolution: update model field in client_request immediately
|
||||||
// This ensures all downstream objects use the resolved model
|
// This ensures all downstream objects use the resolved model
|
||||||
let model_from_request = client_request.model().to_string();
|
let model_from_request = client_request.model().to_string();
|
||||||
|
|
@ -83,9 +101,77 @@ pub async fn llm_chat(
|
||||||
|
|
||||||
client_request.set_model(resolved_model.clone());
|
client_request.set_model(resolved_model.clone());
|
||||||
if client_request.remove_metadata_key("archgw_preference_config") {
|
if client_request.remove_metadata_key("archgw_preference_config") {
|
||||||
debug!("Removed archgw_preference_config from metadata");
|
debug!("[PLANO_REQ_ID:{}] Removed archgw_preference_config from metadata", request_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// === v1/responses state management: Determine upstream API and combine input if needed ===
|
||||||
|
// Do this BEFORE routing since routing consumes the request
|
||||||
|
// Only process state if state_storage is configured
|
||||||
|
let mut should_manage_state = false;
|
||||||
|
if is_responses_api_client && state_storage.is_some() {
|
||||||
|
if let ProviderRequestType::ResponsesAPIRequest(ref mut responses_req) = client_request {
|
||||||
|
// Extract original input once
|
||||||
|
original_input_items = extract_input_items(&responses_req.input);
|
||||||
|
|
||||||
|
// Get the upstream path and check if it's ResponsesAPI
|
||||||
|
let upstream_path = get_upstream_path(
|
||||||
|
&llm_providers,
|
||||||
|
&resolved_model,
|
||||||
|
&request_path,
|
||||||
|
&resolved_model,
|
||||||
|
is_streaming_request,
|
||||||
|
).await;
|
||||||
|
|
||||||
|
let upstream_api = SupportedUpstreamAPIs::from_endpoint(&upstream_path);
|
||||||
|
|
||||||
|
// Only manage state if upstream is NOT OpenAIResponsesAPI (needs translation)
|
||||||
|
should_manage_state = !matches!(upstream_api, Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(_)));
|
||||||
|
|
||||||
|
if should_manage_state {
|
||||||
|
// Retrieve and combine conversation history if previous_response_id exists
|
||||||
|
if let Some(ref prev_resp_id) = responses_req.previous_response_id {
|
||||||
|
match retrieve_and_combine_input(
|
||||||
|
state_storage.as_ref().unwrap().clone(),
|
||||||
|
prev_resp_id,
|
||||||
|
original_input_items, // Pass ownership instead of cloning
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
{
|
||||||
|
Ok(combined_input) => {
|
||||||
|
// Update both the request and original_input_items
|
||||||
|
responses_req.input = InputParam::Items(combined_input.clone());
|
||||||
|
original_input_items = combined_input;
|
||||||
|
info!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Updated request with conversation history ({} items)", request_id, original_input_items.len());
|
||||||
|
}
|
||||||
|
Err(StateStorageError::NotFound(_)) => {
|
||||||
|
// Return 409 Conflict when previous_response_id not found
|
||||||
|
warn!("[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Previous response_id not found: {}", request_id, prev_resp_id);
|
||||||
|
let err_msg = format!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Conversation state not found for previous_response_id: {}",
|
||||||
|
request_id, prev_resp_id
|
||||||
|
);
|
||||||
|
let mut conflict_response = Response::new(full(err_msg));
|
||||||
|
*conflict_response.status_mut() = StatusCode::CONFLICT;
|
||||||
|
return Ok(conflict_response);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Log warning but continue on other storage errors
|
||||||
|
warn!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to retrieve conversation state for {}: {}",
|
||||||
|
request_id, prev_resp_id, e
|
||||||
|
);
|
||||||
|
// Restore original_input_items since we passed ownership
|
||||||
|
original_input_items = extract_input_items(&responses_req.input);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
debug!("[PLANO_REQ_ID:{}] | BRIGHT_STAFF | Upstream supports ResponsesAPI natively.", request_id);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Serialize request for upstream BEFORE router consumes it
|
||||||
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
let client_request_bytes_for_upstream = ProviderRequestType::to_bytes(&client_request).unwrap();
|
||||||
|
|
||||||
// Determine routing using the dedicated router_chat module
|
// Determine routing using the dedicated router_chat module
|
||||||
|
|
@ -110,8 +196,8 @@ pub async fn llm_chat(
|
||||||
let model_name = routing_result.model_name;
|
let model_name = routing_result.model_name;
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"[ARCH_ROUTER] URL: {}, Resolved Model: {}",
|
"[PLANO_REQ_ID:{}] | ARCH_ROUTER URL | {}, Resolved Model: {}",
|
||||||
full_qualified_llm_provider_url, model_name
|
request_id, full_qualified_llm_provider_url, model_name
|
||||||
);
|
);
|
||||||
|
|
||||||
request_headers.insert(
|
request_headers.insert(
|
||||||
|
|
@ -173,15 +259,40 @@ pub async fn llm_chat(
|
||||||
&llm_providers,
|
&llm_providers,
|
||||||
).await;
|
).await;
|
||||||
|
|
||||||
// Use PassthroughProcessor to track streaming metrics and finalize the span
|
// Create base processor for metrics and tracing
|
||||||
let processor = ObservableStreamProcessor::new(
|
let base_processor = ObservableStreamProcessor::new(
|
||||||
trace_collector,
|
trace_collector,
|
||||||
operation_component::LLM,
|
operation_component::LLM,
|
||||||
llm_span,
|
llm_span,
|
||||||
request_start_time,
|
request_start_time,
|
||||||
);
|
);
|
||||||
|
|
||||||
let streaming_response = create_streaming_response(byte_stream, processor, 16);
|
// === v1/responses state management: Wrap with ResponsesStateProcessor ===
|
||||||
|
// Only wrap if we need to manage state (client is ResponsesAPI AND upstream is NOT ResponsesAPI AND state_storage is configured)
|
||||||
|
let streaming_response = if should_manage_state && !original_input_items.is_empty() && state_storage.is_some() {
|
||||||
|
// Extract Content-Encoding header to handle decompression for state parsing
|
||||||
|
let content_encoding = response_headers
|
||||||
|
.get("content-encoding")
|
||||||
|
.and_then(|v| v.to_str().ok())
|
||||||
|
.map(|s| s.to_string());
|
||||||
|
|
||||||
|
// Wrap with state management processor to store state after response completes
|
||||||
|
let state_processor = ResponsesStateProcessor::new(
|
||||||
|
base_processor,
|
||||||
|
state_storage.unwrap(),
|
||||||
|
original_input_items,
|
||||||
|
resolved_model.clone(),
|
||||||
|
model_name.clone(),
|
||||||
|
is_streaming_request,
|
||||||
|
false, // Not OpenAI upstream since should_manage_state is true
|
||||||
|
content_encoding,
|
||||||
|
request_id.clone(),
|
||||||
|
);
|
||||||
|
create_streaming_response(byte_stream, state_processor, 16)
|
||||||
|
} else {
|
||||||
|
// Use base processor without state management
|
||||||
|
create_streaming_response(byte_stream, base_processor, 16)
|
||||||
|
};
|
||||||
|
|
||||||
match response.body(streaming_response.body) {
|
match response.body(streaming_response.body) {
|
||||||
Ok(response) => Ok(response),
|
Ok(response) => Ok(response),
|
||||||
|
|
@ -301,35 +412,7 @@ async fn get_upstream_path(
|
||||||
resolved_model: &str,
|
resolved_model: &str,
|
||||||
is_streaming: bool,
|
is_streaming: bool,
|
||||||
) -> String {
|
) -> String {
|
||||||
let providers_lock = llm_providers.read().await;
|
let (provider_id, base_url_path_prefix) = get_provider_info(llm_providers, model_name).await;
|
||||||
|
|
||||||
// First, try to find by model name or provider name
|
|
||||||
let provider = providers_lock.iter().find(|p| {
|
|
||||||
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|
|
||||||
|| p.name == model_name
|
|
||||||
});
|
|
||||||
|
|
||||||
let (provider_id, base_url_path_prefix) = if let Some(provider) = provider {
|
|
||||||
let provider_id = provider.provider_interface.to_provider_id();
|
|
||||||
let prefix = provider.base_url_path_prefix.clone();
|
|
||||||
(provider_id, prefix)
|
|
||||||
} else {
|
|
||||||
let default_provider = providers_lock.iter().find(|p| {
|
|
||||||
p.default.unwrap_or(false)
|
|
||||||
});
|
|
||||||
|
|
||||||
if let Some(provider) = default_provider {
|
|
||||||
let provider_id = provider.provider_interface.to_provider_id();
|
|
||||||
let prefix = provider.base_url_path_prefix.clone();
|
|
||||||
(provider_id, prefix)
|
|
||||||
} else {
|
|
||||||
// Last resort: use OpenAI as hardcoded fallback
|
|
||||||
warn!("No default provider found, falling back to OpenAI");
|
|
||||||
(hermesllm::ProviderId::OpenAI, None)
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
drop(providers_lock);
|
|
||||||
|
|
||||||
// Calculate the upstream path using the proper API
|
// Calculate the upstream path using the proper API
|
||||||
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
|
let client_api = SupportedAPIsFromClient::from_endpoint(request_path)
|
||||||
|
|
@ -343,3 +426,37 @@ async fn get_upstream_path(
|
||||||
base_url_path_prefix.as_deref(),
|
base_url_path_prefix.as_deref(),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/// Helper function to get provider info (ProviderId and base_url_path_prefix)
|
||||||
|
async fn get_provider_info(
|
||||||
|
llm_providers: &Arc<RwLock<Vec<LlmProvider>>>,
|
||||||
|
model_name: &str,
|
||||||
|
) -> (hermesllm::ProviderId, Option<String>) {
|
||||||
|
let providers_lock = llm_providers.read().await;
|
||||||
|
|
||||||
|
// First, try to find by model name or provider name
|
||||||
|
let provider = providers_lock.iter().find(|p| {
|
||||||
|
p.model.as_ref().map(|m| m == model_name).unwrap_or(false)
|
||||||
|
|| p.name == model_name
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(provider) = provider {
|
||||||
|
let provider_id = provider.provider_interface.to_provider_id();
|
||||||
|
let prefix = provider.base_url_path_prefix.clone();
|
||||||
|
return (provider_id, prefix);
|
||||||
|
}
|
||||||
|
|
||||||
|
let default_provider = providers_lock.iter().find(|p| {
|
||||||
|
p.default.unwrap_or(false)
|
||||||
|
});
|
||||||
|
|
||||||
|
if let Some(provider) = default_provider {
|
||||||
|
let provider_id = provider.provider_interface.to_provider_id();
|
||||||
|
let prefix = provider.base_url_path_prefix.clone();
|
||||||
|
(provider_id, prefix)
|
||||||
|
} else {
|
||||||
|
// Last resort: use OpenAI as hardcoded fallback
|
||||||
|
warn!("No default provider found, falling back to OpenAI");
|
||||||
|
(hermesllm::ProviderId::OpenAI, None)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
use common::configuration::ModelUsagePreference;
|
use common::configuration::ModelUsagePreference;
|
||||||
|
use common::consts::{REQUEST_ID_HEADER};
|
||||||
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
|
use common::traces::{TraceCollector, SpanKind, SpanBuilder, parse_traceparent};
|
||||||
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
use hermesllm::clients::endpoints::SupportedUpstreamAPIs;
|
||||||
use hermesllm::{ProviderRequest, ProviderRequestType};
|
use hermesllm::{ProviderRequest, ProviderRequestType};
|
||||||
|
|
@ -43,6 +44,10 @@ pub async fn router_chat_get_upstream_model(
|
||||||
) -> Result<RoutingResult, RoutingError> {
|
) -> Result<RoutingResult, RoutingError> {
|
||||||
// Clone metadata for routing before converting (which consumes client_request)
|
// Clone metadata for routing before converting (which consumes client_request)
|
||||||
let routing_metadata = client_request.metadata().clone();
|
let routing_metadata = client_request.metadata().clone();
|
||||||
|
let request_id = request_headers
|
||||||
|
.get(REQUEST_ID_HEADER)
|
||||||
|
.and_then(|value| value.to_str().ok())
|
||||||
|
.unwrap_or("unknown");
|
||||||
|
|
||||||
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
// Convert to ChatCompletionsRequest for routing (regardless of input type)
|
||||||
let chat_request = match ProviderRequestType::try_from((
|
let chat_request = match ProviderRequestType::try_from((
|
||||||
|
|
@ -73,7 +78,8 @@ pub async fn router_chat_get_upstream_model(
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!(
|
debug!(
|
||||||
"[ARCH_ROUTER REQ]: {}",
|
"[PLANO_REQ_ID: {}]: ROUTER_REQ: {}",
|
||||||
|
request_id,
|
||||||
&serde_json::to_string(&chat_request).unwrap()
|
&serde_json::to_string(&chat_request).unwrap()
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
@ -114,14 +120,13 @@ pub async fn router_chat_get_upstream_model(
|
||||||
};
|
};
|
||||||
|
|
||||||
info!(
|
info!(
|
||||||
"request received, request type: chat_completion, usage preferences from request: {}, request path: {}, latest message: {}",
|
"[PLANO_REQ_ID: {}] | ROUTER_REQ | Usage preferences from request: {}, request_path: {}, latest message: {}",
|
||||||
|
request_id,
|
||||||
usage_preferences.is_some(),
|
usage_preferences.is_some(),
|
||||||
request_path,
|
request_path,
|
||||||
latest_message_for_log
|
latest_message_for_log
|
||||||
);
|
);
|
||||||
|
|
||||||
debug!("usage preferences from request: {:?}", usage_preferences);
|
|
||||||
|
|
||||||
// Capture start time for routing span
|
// Capture start time for routing span
|
||||||
let routing_start_time = std::time::Instant::now();
|
let routing_start_time = std::time::Instant::now();
|
||||||
let routing_start_system_time = std::time::SystemTime::now();
|
let routing_start_system_time = std::time::SystemTime::now();
|
||||||
|
|
@ -153,7 +158,8 @@ pub async fn router_chat_get_upstream_model(
|
||||||
None => {
|
None => {
|
||||||
// No route determined, use default model from request
|
// No route determined, use default model from request
|
||||||
info!(
|
info!(
|
||||||
"No route determined, using default model from request: {}",
|
"[PLANO_REQ_ID: {}] | ROUTER_REQ | No route determined, using default model from request: {}",
|
||||||
|
request_id,
|
||||||
chat_request.model
|
chat_request.model
|
||||||
);
|
);
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -1,4 +1,5 @@
|
||||||
pub mod handlers;
|
pub mod handlers;
|
||||||
pub mod router;
|
pub mod router;
|
||||||
|
pub mod state;
|
||||||
pub mod tracing;
|
pub mod tracing;
|
||||||
pub mod utils;
|
pub mod utils;
|
||||||
|
|
|
||||||
|
|
@ -3,6 +3,9 @@ use brightstaff::handlers::function_calling::function_calling_chat_handler;
|
||||||
use brightstaff::handlers::llm::llm_chat;
|
use brightstaff::handlers::llm::llm_chat;
|
||||||
use brightstaff::handlers::models::list_models;
|
use brightstaff::handlers::models::list_models;
|
||||||
use brightstaff::router::llm_router::RouterService;
|
use brightstaff::router::llm_router::RouterService;
|
||||||
|
use brightstaff::state::StateStorage;
|
||||||
|
use brightstaff::state::postgresql::PostgreSQLConversationStorage;
|
||||||
|
use brightstaff::state::memory::MemoryConversationalStorage;
|
||||||
use brightstaff::utils::tracing::init_tracer;
|
use brightstaff::utils::tracing::init_tracer;
|
||||||
use bytes::Bytes;
|
use bytes::Bytes;
|
||||||
use common::configuration::{Agent, Configuration};
|
use common::configuration::{Agent, Configuration};
|
||||||
|
|
@ -113,6 +116,38 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let trace_collector = Arc::new(TraceCollector::new(tracing_enabled));
|
let trace_collector = Arc::new(TraceCollector::new(tracing_enabled));
|
||||||
let _flusher_handle = trace_collector.clone().start_background_flusher();
|
let _flusher_handle = trace_collector.clone().start_background_flusher();
|
||||||
|
|
||||||
|
// Initialize conversation state storage for v1/responses
|
||||||
|
// Configurable via arch_config.yaml state_storage section
|
||||||
|
// If not configured, state management is disabled
|
||||||
|
// Environment variables are substituted by envsubst before config is read
|
||||||
|
let state_storage: Option<Arc<dyn StateStorage>> = if let Some(storage_config) = &arch_config.state_storage {
|
||||||
|
let storage: Arc<dyn StateStorage> = match storage_config.storage_type {
|
||||||
|
common::configuration::StateStorageType::Memory => {
|
||||||
|
info!("Initialized conversation state storage: Memory");
|
||||||
|
Arc::new(MemoryConversationalStorage::new())
|
||||||
|
}
|
||||||
|
common::configuration::StateStorageType::Postgres => {
|
||||||
|
let connection_string = storage_config
|
||||||
|
.connection_string
|
||||||
|
.as_ref()
|
||||||
|
.expect("connection_string is required for postgres state_storage");
|
||||||
|
|
||||||
|
debug!("Postgres connection string (full): {}", connection_string);
|
||||||
|
info!("Initializing conversation state storage: Postgres");
|
||||||
|
Arc::new(
|
||||||
|
PostgreSQLConversationStorage::new(connection_string.clone())
|
||||||
|
.await
|
||||||
|
.expect("Failed to initialize Postgres state storage"),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
};
|
||||||
|
Some(storage)
|
||||||
|
} else {
|
||||||
|
info!("No state_storage configured - conversation state management disabled");
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
|
|
||||||
loop {
|
loop {
|
||||||
let (stream, _) = listener.accept().await?;
|
let (stream, _) = listener.accept().await?;
|
||||||
let peer_addr = stream.peer_addr()?;
|
let peer_addr = stream.peer_addr()?;
|
||||||
|
|
@ -128,6 +163,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let agents_list = combined_agents_filters_list.clone();
|
let agents_list = combined_agents_filters_list.clone();
|
||||||
let listeners = listeners.clone();
|
let listeners = listeners.clone();
|
||||||
let trace_collector = trace_collector.clone();
|
let trace_collector = trace_collector.clone();
|
||||||
|
let state_storage = state_storage.clone();
|
||||||
let service = service_fn(move |req| {
|
let service = service_fn(move |req| {
|
||||||
let router_service = Arc::clone(&router_service);
|
let router_service = Arc::clone(&router_service);
|
||||||
let parent_cx = extract_context_from_request(&req);
|
let parent_cx = extract_context_from_request(&req);
|
||||||
|
|
@ -137,9 +173,9 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
let agents_list = agents_list.clone();
|
let agents_list = agents_list.clone();
|
||||||
let listeners = listeners.clone();
|
let listeners = listeners.clone();
|
||||||
let trace_collector = trace_collector.clone();
|
let trace_collector = trace_collector.clone();
|
||||||
|
let state_storage = state_storage.clone();
|
||||||
|
|
||||||
async move {
|
async move {
|
||||||
let path = req.uri().path();
|
|
||||||
|
|
||||||
// Check if path starts with /agents
|
// Check if path starts with /agents
|
||||||
if path.starts_with("/agents") {
|
if path.starts_with("/agents") {
|
||||||
|
|
@ -162,26 +198,14 @@ async fn main() -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
|
||||||
.await;
|
.await;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
match (req.method(), req.uri().path()) {
|
||||||
match (req.method(), path) {
|
(&Method::POST, CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH) => {
|
||||||
(
|
|
||||||
&Method::POST,
|
|
||||||
CHAT_COMPLETIONS_PATH | MESSAGES_PATH | OPENAI_RESPONSES_API_PATH,
|
|
||||||
) => {
|
|
||||||
let fully_qualified_url =
|
let fully_qualified_url =
|
||||||
format!("{}{}", llm_provider_url, req.uri().path());
|
format!("{}{}", llm_provider_url, req.uri().path());
|
||||||
llm_chat(
|
llm_chat(req, router_service, fully_qualified_url, model_aliases, llm_providers, trace_collector)
|
||||||
req,
|
.with_context(parent_cx)
|
||||||
router_service,
|
.await
|
||||||
fully_qualified_url,
|
|
||||||
model_aliases,
|
|
||||||
llm_providers,
|
|
||||||
trace_collector,
|
|
||||||
)
|
|
||||||
.with_context(parent_cx)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
||||||
(&Method::POST, "/function_calling") => {
|
(&Method::POST, "/function_calling") => {
|
||||||
let fully_qualified_url =
|
let fully_qualified_url =
|
||||||
format!("{}{}", llm_provider_url, "/v1/chat/completions");
|
format!("{}{}", llm_provider_url, "/v1/chat/completions");
|
||||||
|
|
|
||||||
611
crates/brightstaff/src/state/memory.rs
Normal file
611
crates/brightstaff/src/state/memory.rs
Normal file
|
|
@ -0,0 +1,611 @@
|
||||||
|
use super::{OpenAIConversationState, StateStorage, StateStorageError};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use std::collections::HashMap;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::RwLock;
|
||||||
|
use tracing::{debug, warn};
|
||||||
|
|
||||||
|
/// In-memory storage backend for conversation state
|
||||||
|
/// Uses a HashMap wrapped in Arc<RwLock<>> for thread-safe access
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct MemoryConversationalStorage {
|
||||||
|
storage: Arc<RwLock<HashMap<String, OpenAIConversationState>>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl MemoryConversationalStorage {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
Self {
|
||||||
|
storage: Arc::new(RwLock::new(HashMap::new())),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for MemoryConversationalStorage {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self::new()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StateStorage for MemoryConversationalStorage {
|
||||||
|
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
|
||||||
|
let response_id = state.response_id.clone();
|
||||||
|
let mut storage = self.storage.write().await;
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Storing conversation state: model={}, provider={}, input_items={}",
|
||||||
|
response_id, state.model, state.provider, state.input_items.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
storage.insert(response_id, state);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
|
||||||
|
let storage = self.storage.read().await;
|
||||||
|
|
||||||
|
match storage.get(response_id) {
|
||||||
|
Some(state) => {
|
||||||
|
debug!(
|
||||||
|
"[PLANO | MEMORY_STORAGE | RESP_ID:{} | Retrieved conversation state: input_items={}",
|
||||||
|
response_id, state.input_items.len()
|
||||||
|
);
|
||||||
|
Ok(state.clone())
|
||||||
|
}
|
||||||
|
None => {
|
||||||
|
warn!(
|
||||||
|
"[PLANO_RESP_ID:{} | MEMORY_STORAGE | Conversation state not found",
|
||||||
|
response_id
|
||||||
|
);
|
||||||
|
Err(StateStorageError::NotFound(response_id.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
|
||||||
|
let storage = self.storage.read().await;
|
||||||
|
Ok(storage.contains_key(response_id))
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
|
||||||
|
let mut storage = self.storage.write().await;
|
||||||
|
|
||||||
|
if storage.remove(response_id).is_some() {
|
||||||
|
debug!(
|
||||||
|
"[PLANO | BRIGHTSTAFF | MEMORY_STORAGE] RESP_ID:{} | Deleted conversation state",
|
||||||
|
response_id
|
||||||
|
);
|
||||||
|
Ok(())
|
||||||
|
} else {
|
||||||
|
Err(StateStorageError::NotFound(response_id.to_string()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use hermesllm::apis::openai_responses::{InputItem, InputMessage, MessageRole, InputContent, MessageContent};
|
||||||
|
|
||||||
|
fn create_test_state(response_id: &str, num_messages: usize) -> OpenAIConversationState {
|
||||||
|
let mut input_items = Vec::new();
|
||||||
|
for i in 0..num_messages {
|
||||||
|
input_items.push(InputItem::Message(InputMessage {
|
||||||
|
role: if i % 2 == 0 { MessageRole::User } else { MessageRole::Assistant },
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: format!("Message {}", i),
|
||||||
|
}]),
|
||||||
|
}));
|
||||||
|
}
|
||||||
|
|
||||||
|
OpenAIConversationState {
|
||||||
|
response_id: response_id.to_string(),
|
||||||
|
input_items,
|
||||||
|
created_at: 1234567890,
|
||||||
|
model: "claude-3".to_string(),
|
||||||
|
provider: "anthropic".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_put_and_get_success() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
let state: OpenAIConversationState = create_test_state("resp_001", 3);
|
||||||
|
|
||||||
|
// Store
|
||||||
|
storage.put(state.clone()).await.unwrap();
|
||||||
|
|
||||||
|
// Retrieve
|
||||||
|
let retrieved = storage.get("resp_001").await.unwrap();
|
||||||
|
assert_eq!(retrieved.response_id, state.response_id);
|
||||||
|
assert_eq!(retrieved.model, state.model);
|
||||||
|
assert_eq!(retrieved.provider, state.provider);
|
||||||
|
assert_eq!(retrieved.input_items.len(), 3);
|
||||||
|
assert_eq!(retrieved.created_at, state.created_at);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_put_overwrites_existing() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
// First state
|
||||||
|
let state1 = create_test_state("resp_002", 2);
|
||||||
|
storage.put(state1).await.unwrap();
|
||||||
|
|
||||||
|
// Overwrite with new state
|
||||||
|
let state2 = OpenAIConversationState {
|
||||||
|
response_id: "resp_002".to_string(),
|
||||||
|
input_items: vec![],
|
||||||
|
created_at: 9999999999,
|
||||||
|
model: "gpt-4".to_string(),
|
||||||
|
provider: "openai".to_string(),
|
||||||
|
};
|
||||||
|
storage.put(state2.clone()).await.unwrap();
|
||||||
|
|
||||||
|
// Should retrieve the new state
|
||||||
|
let retrieved = storage.get("resp_002").await.unwrap();
|
||||||
|
assert_eq!(retrieved.model, "gpt-4");
|
||||||
|
assert_eq!(retrieved.provider, "openai");
|
||||||
|
assert_eq!(retrieved.input_items.len(), 0);
|
||||||
|
assert_eq!(retrieved.created_at, 9999999999);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_get_not_found() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
let result = storage.get("nonexistent").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
match result.unwrap_err() {
|
||||||
|
StateStorageError::NotFound(id) => {
|
||||||
|
assert_eq!(id, "nonexistent");
|
||||||
|
}
|
||||||
|
_ => panic!("Expected NotFound error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_exists_returns_false_for_nonexistent() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
assert!(!storage.exists("resp_003").await.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_exists_returns_true_after_put() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
let state = create_test_state("resp_004", 1);
|
||||||
|
|
||||||
|
assert!(!storage.exists("resp_004").await.unwrap());
|
||||||
|
storage.put(state).await.unwrap();
|
||||||
|
assert!(storage.exists("resp_004").await.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_success() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
let state = create_test_state("resp_005", 2);
|
||||||
|
|
||||||
|
storage.put(state).await.unwrap();
|
||||||
|
assert!(storage.exists("resp_005").await.unwrap());
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
storage.delete("resp_005").await.unwrap();
|
||||||
|
|
||||||
|
// Should no longer exist
|
||||||
|
assert!(!storage.exists("resp_005").await.unwrap());
|
||||||
|
assert!(storage.get("resp_005").await.is_err());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_delete_not_found() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
let result = storage.delete("nonexistent").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
|
||||||
|
match result.unwrap_err() {
|
||||||
|
StateStorageError::NotFound(id) => {
|
||||||
|
assert_eq!(id, "nonexistent");
|
||||||
|
}
|
||||||
|
_ => panic!("Expected NotFound error"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_combines_inputs() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
// Create a previous state with 2 messages
|
||||||
|
let prev_state = create_test_state("resp_006", 2);
|
||||||
|
|
||||||
|
// Create current input with 1 message
|
||||||
|
let current_input = vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "New message".to_string(),
|
||||||
|
}]),
|
||||||
|
})];
|
||||||
|
|
||||||
|
// Merge
|
||||||
|
let merged = storage.merge(&prev_state, current_input);
|
||||||
|
|
||||||
|
// Should have 3 messages total (2 from prev + 1 current)
|
||||||
|
assert_eq!(merged.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_preserves_order() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
// Previous state has messages 0 and 1
|
||||||
|
let prev_state = create_test_state("resp_007", 2);
|
||||||
|
|
||||||
|
// Current input has message 2
|
||||||
|
let current_input = vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Message 2".to_string(),
|
||||||
|
}]),
|
||||||
|
})];
|
||||||
|
|
||||||
|
let merged = storage.merge(&prev_state, current_input);
|
||||||
|
|
||||||
|
// Verify order: prev messages first, then current
|
||||||
|
let InputItem::Message(msg) = &merged[0] else { panic!("Expected Message") };
|
||||||
|
match &msg.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => assert_eq!(text, "Message 0"),
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let InputItem::Message(msg) = &merged[2] else { panic!("Expected Message") };
|
||||||
|
match &msg.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => assert_eq!(text, "Message 2"),
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_with_empty_current_input() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
let prev_state = create_test_state("resp_008", 3);
|
||||||
|
|
||||||
|
let merged = storage.merge(&prev_state, vec![]);
|
||||||
|
|
||||||
|
// Should just have the previous state's items
|
||||||
|
assert_eq!(merged.len(), 3);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_with_empty_previous_state() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
let prev_state = OpenAIConversationState {
|
||||||
|
response_id: "resp_009".to_string(),
|
||||||
|
input_items: vec![],
|
||||||
|
created_at: 1234567890,
|
||||||
|
model: "gpt-4".to_string(),
|
||||||
|
provider: "openai".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
let current_input = vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Only message".to_string(),
|
||||||
|
}]),
|
||||||
|
})];
|
||||||
|
|
||||||
|
let merged = storage.merge(&prev_state, current_input);
|
||||||
|
|
||||||
|
// Should just have the current input
|
||||||
|
assert_eq!(merged.len(), 1);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_concurrent_access() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
// Spawn multiple tasks that write concurrently
|
||||||
|
let mut handles = vec![];
|
||||||
|
|
||||||
|
for i in 0..10 {
|
||||||
|
let storage_clone = storage.clone();
|
||||||
|
let handle = tokio::spawn(async move {
|
||||||
|
let state = create_test_state(&format!("resp_{}", i), i % 3);
|
||||||
|
storage_clone.put(state).await.unwrap();
|
||||||
|
});
|
||||||
|
handles.push(handle);
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for all tasks
|
||||||
|
for handle in handles {
|
||||||
|
handle.await.unwrap();
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify all states were stored
|
||||||
|
for i in 0..10 {
|
||||||
|
assert!(storage.exists(&format!("resp_{}", i)).await.unwrap());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_multiple_operations_on_same_id() {
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
let state = create_test_state("resp_010", 1);
|
||||||
|
|
||||||
|
// Put
|
||||||
|
storage.put(state.clone()).await.unwrap();
|
||||||
|
|
||||||
|
// Get
|
||||||
|
let retrieved = storage.get("resp_010").await.unwrap();
|
||||||
|
assert_eq!(retrieved.response_id, "resp_010");
|
||||||
|
|
||||||
|
// Exists
|
||||||
|
assert!(storage.exists("resp_010").await.unwrap());
|
||||||
|
|
||||||
|
// Put again (overwrite)
|
||||||
|
let new_state = create_test_state("resp_010", 5);
|
||||||
|
storage.put(new_state).await.unwrap();
|
||||||
|
|
||||||
|
// Get updated
|
||||||
|
let updated = storage.get("resp_010").await.unwrap();
|
||||||
|
assert_eq!(updated.input_items.len(), 5);
|
||||||
|
|
||||||
|
// Delete
|
||||||
|
storage.delete("resp_010").await.unwrap();
|
||||||
|
|
||||||
|
// Should not exist
|
||||||
|
assert!(!storage.exists("resp_010").await.unwrap());
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_with_tool_call_flow() {
|
||||||
|
// This test simulates a realistic tool call conversation flow:
|
||||||
|
// 1. User sends message: "What's the weather?"
|
||||||
|
// 2. Model responds with function call (converted to assistant message)
|
||||||
|
// 3. User sends function call output in next request with previous_response_id
|
||||||
|
// The merge should combine: user message + assistant function call + function output
|
||||||
|
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
// Step 1: Previous state contains the initial exchange
|
||||||
|
// - User message: "What's the weather in SF?"
|
||||||
|
// - Assistant message (converted from FunctionCall): "Called function: get_weather..."
|
||||||
|
let prev_state = OpenAIConversationState {
|
||||||
|
response_id: "resp_tool_001".to_string(),
|
||||||
|
input_items: vec![
|
||||||
|
// Original user message
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "What's the weather in San Francisco?".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
// Assistant's function call (converted from OutputItem::FunctionCall)
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Called function: get_weather with arguments: {\"location\":\"San Francisco, CA\"}".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
created_at: 1234567890,
|
||||||
|
model: "claude-3".to_string(),
|
||||||
|
provider: "anthropic".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Step 2: Current request includes function call output
|
||||||
|
let current_input = vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Function result: {\"temperature\": 72, \"condition\": \"sunny\"}".to_string(),
|
||||||
|
}]),
|
||||||
|
})];
|
||||||
|
|
||||||
|
// Step 3: Merge should combine all conversation history
|
||||||
|
let merged = storage.merge(&prev_state, current_input);
|
||||||
|
|
||||||
|
// Should have 3 items: user question + assistant function call + function output
|
||||||
|
assert_eq!(merged.len(), 3);
|
||||||
|
|
||||||
|
// Verify the order and content
|
||||||
|
let InputItem::Message(msg1) = &merged[0] else { panic!("Expected Message") };
|
||||||
|
assert!(matches!(msg1.role, MessageRole::User));
|
||||||
|
match &msg1.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => {
|
||||||
|
assert!(text.contains("weather in San Francisco"));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let InputItem::Message(msg2) = &merged[1] else { panic!("Expected Message") };
|
||||||
|
assert!(matches!(msg2.role, MessageRole::Assistant));
|
||||||
|
match &msg2.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => {
|
||||||
|
assert!(text.contains("get_weather"));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let InputItem::Message(msg3) = &merged[2] else { panic!("Expected Message") };
|
||||||
|
assert!(matches!(msg3.role, MessageRole::User));
|
||||||
|
match &msg3.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => {
|
||||||
|
assert!(text.contains("Function result"));
|
||||||
|
assert!(text.contains("temperature"));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_with_multiple_tool_calls() {
|
||||||
|
// Test a more complex scenario with multiple tool calls
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
// Previous state has: user message + 2 function calls from assistant
|
||||||
|
let prev_state = OpenAIConversationState {
|
||||||
|
response_id: "resp_tool_002".to_string(),
|
||||||
|
input_items: vec![
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "What's the weather and time in SF?".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Called function: get_weather with arguments: {\"location\":\"SF\"}".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Called function: get_time with arguments: {\"timezone\":\"America/Los_Angeles\"}".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
created_at: 1234567890,
|
||||||
|
model: "gpt-4".to_string(),
|
||||||
|
provider: "openai".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Current input: function outputs for both calls
|
||||||
|
let current_input = vec![
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Weather result: {\"temp\": 68}".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Time result: {\"time\": \"14:30\"}".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
];
|
||||||
|
|
||||||
|
let merged = storage.merge(&prev_state, current_input);
|
||||||
|
|
||||||
|
// Should have 5 items total: 1 user + 2 assistant calls + 2 function outputs
|
||||||
|
assert_eq!(merged.len(), 5);
|
||||||
|
|
||||||
|
// Verify first item is original user message
|
||||||
|
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||||
|
assert!(matches!(first.role, MessageRole::User));
|
||||||
|
|
||||||
|
// Verify last two are function outputs
|
||||||
|
let InputItem::Message(second_last) = &merged[3] else { panic!("Expected Message") };
|
||||||
|
assert!(matches!(second_last.role, MessageRole::User));
|
||||||
|
match &second_last.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => assert!(text.contains("Weather result")),
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||||
|
assert!(matches!(last.role, MessageRole::User));
|
||||||
|
match &last.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => assert!(text.contains("Time result")),
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_merge_preserves_conversation_context_for_multi_turn() {
|
||||||
|
// Simulate a multi-turn conversation with tool calls
|
||||||
|
let storage = MemoryConversationalStorage::new();
|
||||||
|
|
||||||
|
// Previous state: full conversation history up to this point
|
||||||
|
let prev_state = OpenAIConversationState {
|
||||||
|
response_id: "resp_tool_003".to_string(),
|
||||||
|
input_items: vec![
|
||||||
|
// Turn 1: User asks about weather
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "What's the weather?".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
// Turn 1: Assistant calls get_weather
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Called function: get_weather".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
// Turn 2: User provides function output
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Weather: sunny, 72°F".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
// Turn 2: Assistant responds with text
|
||||||
|
InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "It's sunny and 72°F in San Francisco today!".to_string(),
|
||||||
|
}]),
|
||||||
|
}),
|
||||||
|
],
|
||||||
|
created_at: 1234567890,
|
||||||
|
model: "claude-3".to_string(),
|
||||||
|
provider: "anthropic".to_string(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Turn 3: User asks follow-up question
|
||||||
|
let current_input = vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Should I bring an umbrella?".to_string(),
|
||||||
|
}]),
|
||||||
|
})];
|
||||||
|
|
||||||
|
let merged = storage.merge(&prev_state, current_input);
|
||||||
|
|
||||||
|
// Should have all 5 messages in order
|
||||||
|
assert_eq!(merged.len(), 5);
|
||||||
|
|
||||||
|
// Verify the entire conversation flow is preserved
|
||||||
|
let InputItem::Message(first) = &merged[0] else { panic!("Expected Message") };
|
||||||
|
match &first.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => assert!(text.contains("What's the weather")),
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
|
||||||
|
let InputItem::Message(last) = &merged[4] else { panic!("Expected Message") };
|
||||||
|
match &last.content {
|
||||||
|
MessageContent::Items(items) => match &items[0] {
|
||||||
|
InputContent::InputText { text } => assert!(text.contains("umbrella")),
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
},
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
147
crates/brightstaff/src/state/mod.rs
Normal file
147
crates/brightstaff/src/state/mod.rs
Normal file
|
|
@ -0,0 +1,147 @@
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use hermesllm::apis::openai_responses::{InputItem, InputMessage, InputContent, MessageContent, MessageRole, InputParam};
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
use std::error::Error;
|
||||||
|
use std::fmt;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{debug};
|
||||||
|
|
||||||
|
pub mod memory;
|
||||||
|
pub mod response_state_processor;
|
||||||
|
pub mod postgresql;
|
||||||
|
|
||||||
|
/// Represents the conversational state for a v1/responses request
|
||||||
|
/// Contains the complete input/output history that can be restored
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct OpenAIConversationState {
|
||||||
|
/// The response ID this state is associated with
|
||||||
|
pub response_id: String,
|
||||||
|
|
||||||
|
/// The complete input history (original input + accumulated outputs)
|
||||||
|
/// This is what gets prepended to new requests via previous_response_id
|
||||||
|
pub input_items: Vec<InputItem>,
|
||||||
|
|
||||||
|
/// Timestamp when this state was created
|
||||||
|
pub created_at: i64,
|
||||||
|
|
||||||
|
/// Model used for this response
|
||||||
|
pub model: String,
|
||||||
|
|
||||||
|
/// Provider that generated this response (e.g., "anthropic", "openai")
|
||||||
|
pub provider: String,
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Error types for state storage operations
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum StateStorageError {
|
||||||
|
/// State not found for given response_id
|
||||||
|
NotFound(String),
|
||||||
|
|
||||||
|
/// Storage backend error (network, database, etc.)
|
||||||
|
StorageError(String),
|
||||||
|
|
||||||
|
/// Serialization/deserialization error
|
||||||
|
SerializationError(String),
|
||||||
|
}
|
||||||
|
|
||||||
|
impl fmt::Display for StateStorageError {
|
||||||
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
||||||
|
match self {
|
||||||
|
StateStorageError::NotFound(id) => write!(f, "Conversation state not found for response_id: {}", id),
|
||||||
|
StateStorageError::StorageError(msg) => write!(f, "Storage error: {}", msg),
|
||||||
|
StateStorageError::SerializationError(msg) => write!(f, "Serialization error: {}", msg),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Error for StateStorageError {}
|
||||||
|
|
||||||
|
/// Trait for conversation state storage backends
|
||||||
|
#[async_trait]
|
||||||
|
pub trait StateStorage: Send + Sync {
|
||||||
|
/// Store conversation state for a response
|
||||||
|
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError>;
|
||||||
|
|
||||||
|
/// Retrieve conversation state by response_id
|
||||||
|
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError>;
|
||||||
|
|
||||||
|
/// Check if state exists for a response_id
|
||||||
|
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError>;
|
||||||
|
|
||||||
|
/// Delete state for a response_id (optional, for cleanup)
|
||||||
|
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError>;
|
||||||
|
|
||||||
|
fn merge(
|
||||||
|
&self,
|
||||||
|
prev_state: &OpenAIConversationState,
|
||||||
|
current_input: Vec<InputItem>,
|
||||||
|
) -> Vec<InputItem> {
|
||||||
|
// Default implementation: prepend previous input, append current
|
||||||
|
let prev_count = prev_state.input_items.len();
|
||||||
|
let current_count = current_input.len();
|
||||||
|
|
||||||
|
let mut combined_input = prev_state.input_items.clone();
|
||||||
|
combined_input.extend(current_input);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"PLANO | BRIGHTSTAFF | STATE_STORAGE | RESP_ID:{} | Merged state: prev_items={}, current_items={}, total_items={}, combined_json={}",
|
||||||
|
prev_state.response_id,
|
||||||
|
prev_count,
|
||||||
|
current_count,
|
||||||
|
combined_input.len(),
|
||||||
|
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
combined_input
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
/// Storage backend type enum
|
||||||
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
||||||
|
pub enum StorageBackend {
|
||||||
|
Memory,
|
||||||
|
Supabase,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl StorageBackend {
|
||||||
|
pub fn from_str(s: &str) -> Option<Self> {
|
||||||
|
match s.to_lowercase().as_str() {
|
||||||
|
"memory" => Some(StorageBackend::Memory),
|
||||||
|
"supabase" => Some(StorageBackend::Supabase),
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// === Utility functions for state management ===
|
||||||
|
|
||||||
|
/// Extract input items from InputParam, converting text to structured format
|
||||||
|
pub fn extract_input_items(input: &InputParam) -> Vec<InputItem> {
|
||||||
|
match input {
|
||||||
|
InputParam::Text(text) => {
|
||||||
|
vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: text.clone(),
|
||||||
|
}]),
|
||||||
|
})]
|
||||||
|
}
|
||||||
|
InputParam::Items(items) => items.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Retrieve previous conversation state and combine with current input
|
||||||
|
/// Returns combined input if previous state found, or original input if not found/error
|
||||||
|
pub async fn retrieve_and_combine_input(
|
||||||
|
storage: Arc<dyn StateStorage>,
|
||||||
|
previous_response_id: &str,
|
||||||
|
current_input: Vec<InputItem>,
|
||||||
|
) -> Result<Vec<InputItem>, StateStorageError> {
|
||||||
|
|
||||||
|
// First get the previous state
|
||||||
|
let prev_state = storage.get(previous_response_id).await?;
|
||||||
|
let combined_input = storage.merge(&prev_state, current_input);
|
||||||
|
Ok(combined_input)
|
||||||
|
}
|
||||||
432
crates/brightstaff/src/state/postgresql.rs
Normal file
432
crates/brightstaff/src/state/postgresql.rs
Normal file
|
|
@ -0,0 +1,432 @@
|
||||||
|
use super::{OpenAIConversationState, StateStorage, StateStorageError};
|
||||||
|
use async_trait::async_trait;
|
||||||
|
use serde_json;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tokio::sync::OnceCell;
|
||||||
|
use tokio_postgres::{Client, NoTls};
|
||||||
|
use tracing::{debug, info, warn};
|
||||||
|
|
||||||
|
/// Supabase/PostgreSQL storage backend for conversation state
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct PostgreSQLConversationStorage {
|
||||||
|
client: Arc<Client>,
|
||||||
|
table_verified: Arc<OnceCell<()>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl PostgreSQLConversationStorage {
|
||||||
|
/// Creates a new Supabase storage instance with the given connection string
|
||||||
|
pub async fn new(connection_string: String) -> Result<Self, StateStorageError> {
|
||||||
|
let (client, connection) = tokio_postgres::connect(&connection_string, NoTls)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!("Failed to connect to database: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Spawn the connection to run in the background
|
||||||
|
tokio::spawn(async move {
|
||||||
|
if let Err(e) = connection.await {
|
||||||
|
warn!("Database connection error: {}", e);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
|
||||||
|
Ok(Self {
|
||||||
|
client: Arc::new(client),
|
||||||
|
table_verified: Arc::new(OnceCell::new()),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Ensures the conversation_states table exists (checks once, caches result)
|
||||||
|
async fn ensure_ready(&self) -> Result<(), StateStorageError> {
|
||||||
|
self.table_verified
|
||||||
|
.get_or_try_init(|| async {
|
||||||
|
let row = self
|
||||||
|
.client
|
||||||
|
.query_one(
|
||||||
|
"SELECT EXISTS (
|
||||||
|
SELECT FROM pg_tables
|
||||||
|
WHERE tablename = 'conversation_states'
|
||||||
|
)",
|
||||||
|
&[],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!(
|
||||||
|
"Failed to verify table existence: {}",
|
||||||
|
e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let exists: bool = row.get(0);
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return Err(StateStorageError::StorageError(
|
||||||
|
"Table 'conversation_states' does not exist. \
|
||||||
|
Please run the setup SQL from docs/db_setup/conversation_states.sql"
|
||||||
|
.to_string(),
|
||||||
|
));
|
||||||
|
}
|
||||||
|
|
||||||
|
info!("Conversation state storage table verified");
|
||||||
|
Ok(())
|
||||||
|
})
|
||||||
|
.await?;
|
||||||
|
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[async_trait]
|
||||||
|
impl StateStorage for PostgreSQLConversationStorage {
|
||||||
|
async fn put(&self, state: OpenAIConversationState) -> Result<(), StateStorageError> {
|
||||||
|
self.ensure_ready().await?;
|
||||||
|
|
||||||
|
// Serialize input_items to JSONB
|
||||||
|
let input_items_json = serde_json::to_value(&state.input_items).map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!("Failed to serialize input_items: {}", e))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
// Upsert the conversation state
|
||||||
|
self.client
|
||||||
|
.execute(
|
||||||
|
r#"
|
||||||
|
INSERT INTO conversation_states
|
||||||
|
(response_id, input_items, created_at, model, provider, updated_at)
|
||||||
|
VALUES ($1, $2, $3, $4, $5, NOW())
|
||||||
|
ON CONFLICT (response_id)
|
||||||
|
DO UPDATE SET
|
||||||
|
input_items = EXCLUDED.input_items,
|
||||||
|
model = EXCLUDED.model,
|
||||||
|
provider = EXCLUDED.provider,
|
||||||
|
updated_at = NOW()
|
||||||
|
"#,
|
||||||
|
&[
|
||||||
|
&state.response_id,
|
||||||
|
&input_items_json,
|
||||||
|
&state.created_at,
|
||||||
|
&state.model,
|
||||||
|
&state.provider,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!(
|
||||||
|
"Failed to store conversation state for {}: {}",
|
||||||
|
state.response_id, e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
debug!("Stored conversation state for {}", state.response_id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn get(&self, response_id: &str) -> Result<OpenAIConversationState, StateStorageError> {
|
||||||
|
self.ensure_ready().await?;
|
||||||
|
|
||||||
|
let row = self
|
||||||
|
.client
|
||||||
|
.query_opt(
|
||||||
|
r#"
|
||||||
|
SELECT response_id, input_items, created_at, model, provider
|
||||||
|
FROM conversation_states
|
||||||
|
WHERE response_id = $1
|
||||||
|
"#,
|
||||||
|
&[&response_id],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!(
|
||||||
|
"Failed to fetch conversation state for {}: {}",
|
||||||
|
response_id, e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
match row {
|
||||||
|
Some(row) => {
|
||||||
|
let response_id: String = row.get("response_id");
|
||||||
|
let input_items_json: serde_json::Value = row.get("input_items");
|
||||||
|
let created_at: i64 = row.get("created_at");
|
||||||
|
let model: String = row.get("model");
|
||||||
|
let provider: String = row.get("provider");
|
||||||
|
|
||||||
|
// Deserialize input_items from JSONB
|
||||||
|
let input_items =
|
||||||
|
serde_json::from_value(input_items_json).map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!(
|
||||||
|
"Failed to deserialize input_items: {}",
|
||||||
|
e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
Ok(OpenAIConversationState {
|
||||||
|
response_id,
|
||||||
|
input_items,
|
||||||
|
created_at,
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
None => Err(StateStorageError::NotFound(format!(
|
||||||
|
"Conversation state not found for response_id: {}",
|
||||||
|
response_id
|
||||||
|
))),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn exists(&self, response_id: &str) -> Result<bool, StateStorageError> {
|
||||||
|
self.ensure_ready().await?;
|
||||||
|
|
||||||
|
let row = self
|
||||||
|
.client
|
||||||
|
.query_one(
|
||||||
|
"SELECT EXISTS(SELECT 1 FROM conversation_states WHERE response_id = $1)",
|
||||||
|
&[&response_id],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!(
|
||||||
|
"Failed to check existence for {}: {}",
|
||||||
|
response_id, e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
let exists: bool = row.get(0);
|
||||||
|
Ok(exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
async fn delete(&self, response_id: &str) -> Result<(), StateStorageError> {
|
||||||
|
self.ensure_ready().await?;
|
||||||
|
|
||||||
|
let rows_affected = self
|
||||||
|
.client
|
||||||
|
.execute(
|
||||||
|
"DELETE FROM conversation_states WHERE response_id = $1",
|
||||||
|
&[&response_id],
|
||||||
|
)
|
||||||
|
.await
|
||||||
|
.map_err(|e| {
|
||||||
|
StateStorageError::StorageError(format!(
|
||||||
|
"Failed to delete conversation state for {}: {}",
|
||||||
|
response_id, e
|
||||||
|
))
|
||||||
|
})?;
|
||||||
|
|
||||||
|
if rows_affected == 0 {
|
||||||
|
return Err(StateStorageError::NotFound(format!(
|
||||||
|
"Conversation state not found for response_id: {}",
|
||||||
|
response_id
|
||||||
|
)));
|
||||||
|
}
|
||||||
|
|
||||||
|
debug!("Deleted conversation state for {}", response_id);
|
||||||
|
Ok(())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/*
|
||||||
|
PostgreSQL schema is maintained in docs/db_setup/conversation_states.sql
|
||||||
|
Run that SQL file against your database before using this storage backend.
|
||||||
|
*/
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use hermesllm::apis::openai_responses::{InputContent, InputItem, InputMessage, MessageContent, MessageRole};
|
||||||
|
|
||||||
|
fn create_test_state(response_id: &str) -> OpenAIConversationState {
|
||||||
|
OpenAIConversationState {
|
||||||
|
response_id: response_id.to_string(),
|
||||||
|
input_items: vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Test message".to_string(),
|
||||||
|
}]),
|
||||||
|
})],
|
||||||
|
created_at: 1234567890,
|
||||||
|
model: "gpt-4".to_string(),
|
||||||
|
provider: "openai".to_string(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Note: These tests require a running PostgreSQL database
|
||||||
|
// Set TEST_DATABASE_URL environment variable to run integration tests
|
||||||
|
// Example: TEST_DATABASE_URL=postgresql://user:pass@localhost/test_db
|
||||||
|
|
||||||
|
async fn get_test_storage() -> Option<PostgreSQLConversationStorage> {
|
||||||
|
if let Ok(db_url) = std::env::var("TEST_DATABASE_URL") {
|
||||||
|
match PostgreSQLConversationStorage::new(db_url).await {
|
||||||
|
Ok(storage) => Some(storage),
|
||||||
|
Err(e) => {
|
||||||
|
eprintln!("Failed to create test storage: {}", e);
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
eprintln!("TEST_DATABASE_URL not set, skipping Supabase integration tests");
|
||||||
|
None
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_put_and_get_success() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let state = create_test_state("test_resp_001");
|
||||||
|
storage.put(state.clone()).await.unwrap();
|
||||||
|
|
||||||
|
let retrieved = storage.get("test_resp_001").await.unwrap();
|
||||||
|
assert_eq!(retrieved.response_id, "test_resp_001");
|
||||||
|
assert_eq!(retrieved.input_items.len(), 1);
|
||||||
|
assert_eq!(retrieved.model, "gpt-4");
|
||||||
|
assert_eq!(retrieved.provider, "openai");
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = storage.delete("test_resp_001").await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_put_overwrites_existing() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let state1 = create_test_state("test_resp_002");
|
||||||
|
storage.put(state1).await.unwrap();
|
||||||
|
|
||||||
|
let mut state2 = create_test_state("test_resp_002");
|
||||||
|
state2.model = "gpt-4-turbo".to_string();
|
||||||
|
state2.input_items.push(InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "Response".to_string(),
|
||||||
|
}]),
|
||||||
|
}));
|
||||||
|
storage.put(state2).await.unwrap();
|
||||||
|
|
||||||
|
let retrieved = storage.get("test_resp_002").await.unwrap();
|
||||||
|
assert_eq!(retrieved.model, "gpt-4-turbo");
|
||||||
|
assert_eq!(retrieved.input_items.len(), 2);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = storage.delete("test_resp_002").await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_get_not_found() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = storage.get("nonexistent_id").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_exists_returns_false() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let exists = storage.exists("nonexistent_id").await.unwrap();
|
||||||
|
assert!(!exists);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_exists_returns_true_after_put() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let state = create_test_state("test_resp_003");
|
||||||
|
storage.put(state).await.unwrap();
|
||||||
|
|
||||||
|
let exists = storage.exists("test_resp_003").await.unwrap();
|
||||||
|
assert!(exists);
|
||||||
|
|
||||||
|
// Cleanup
|
||||||
|
let _ = storage.delete("test_resp_003").await;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_delete_success() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let state = create_test_state("test_resp_004");
|
||||||
|
storage.put(state).await.unwrap();
|
||||||
|
|
||||||
|
storage.delete("test_resp_004").await.unwrap();
|
||||||
|
|
||||||
|
let exists = storage.exists("test_resp_004").await.unwrap();
|
||||||
|
assert!(!exists);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_delete_not_found() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let result = storage.delete("nonexistent_id").await;
|
||||||
|
assert!(result.is_err());
|
||||||
|
assert!(matches!(result.unwrap_err(), StateStorageError::NotFound(_)));
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_merge_works() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
let prev_state = create_test_state("test_resp_005");
|
||||||
|
let current_input = vec![InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::User,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: "New message".to_string(),
|
||||||
|
}]),
|
||||||
|
})];
|
||||||
|
|
||||||
|
let merged = storage.merge(&prev_state, current_input);
|
||||||
|
|
||||||
|
// Should have 2 messages (1 from prev + 1 current)
|
||||||
|
assert_eq!(merged.len(), 2);
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
async fn test_supabase_table_verification() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// This should trigger table verification
|
||||||
|
let result = storage.ensure_ready().await;
|
||||||
|
assert!(result.is_ok(), "Table verification should succeed");
|
||||||
|
|
||||||
|
// Second call should use cached result
|
||||||
|
let result2 = storage.ensure_ready().await;
|
||||||
|
assert!(result2.is_ok(), "Cached verification should succeed");
|
||||||
|
}
|
||||||
|
|
||||||
|
#[tokio::test]
|
||||||
|
#[ignore] // Run manually with: cargo test test_verify_data_in_supabase -- --ignored
|
||||||
|
async fn test_verify_data_in_supabase() {
|
||||||
|
let Some(storage) = get_test_storage().await else {
|
||||||
|
return;
|
||||||
|
};
|
||||||
|
|
||||||
|
// Create a test record that persists
|
||||||
|
let state = create_test_state("manual_test_verification");
|
||||||
|
storage.put(state).await.unwrap();
|
||||||
|
|
||||||
|
println!("✅ Data written to Supabase!");
|
||||||
|
println!("Check your Supabase dashboard:");
|
||||||
|
println!(" SELECT * FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||||
|
println!("\nTo cleanup, run:");
|
||||||
|
println!(" DELETE FROM conversation_states WHERE response_id = 'manual_test_verification';");
|
||||||
|
|
||||||
|
// DON'T cleanup - leave it for manual verification
|
||||||
|
}
|
||||||
|
}
|
||||||
302
crates/brightstaff/src/state/response_state_processor.rs
Normal file
302
crates/brightstaff/src/state/response_state_processor.rs
Normal file
|
|
@ -0,0 +1,302 @@
|
||||||
|
use bytes::Bytes;
|
||||||
|
use flate2::read::GzDecoder;
|
||||||
|
use hermesllm::apis::openai_responses::{
|
||||||
|
InputItem, OutputItem, ResponsesAPIStreamEvent,
|
||||||
|
};
|
||||||
|
use hermesllm::apis::streaming_shapes::sse::SseStreamIter;
|
||||||
|
use hermesllm::transforms::response::output_to_input::outputs_to_inputs;
|
||||||
|
use std::io::Read;
|
||||||
|
use std::sync::Arc;
|
||||||
|
use tracing::{info, debug, warn};
|
||||||
|
|
||||||
|
use crate::handlers::utils::StreamProcessor;
|
||||||
|
use crate::state::{OpenAIConversationState, StateStorage};
|
||||||
|
|
||||||
|
/// Processor that wraps another processor and handles v1/responses state management
|
||||||
|
/// Captures response_id and output from streaming responses, stores state after completion
|
||||||
|
pub struct ResponsesStateProcessor<P: StreamProcessor> {
|
||||||
|
/// The underlying processor (e.g., ObservableStreamProcessor for metrics)
|
||||||
|
inner: P,
|
||||||
|
|
||||||
|
/// State storage backend
|
||||||
|
storage: Arc<dyn StateStorage>,
|
||||||
|
|
||||||
|
/// Original input items from the request
|
||||||
|
original_input: Vec<InputItem>,
|
||||||
|
|
||||||
|
/// Model name
|
||||||
|
model: String,
|
||||||
|
|
||||||
|
/// Provider name
|
||||||
|
provider: String,
|
||||||
|
|
||||||
|
/// Whether this is a streaming request
|
||||||
|
is_streaming: bool,
|
||||||
|
|
||||||
|
/// Whether upstream is OpenAI (skip storage if true)
|
||||||
|
is_openai_upstream: bool,
|
||||||
|
|
||||||
|
/// Content-Encoding header value (e.g., "gzip", "br", None)
|
||||||
|
content_encoding: Option<String>,
|
||||||
|
|
||||||
|
/// Request ID for logging
|
||||||
|
request_id: String,
|
||||||
|
|
||||||
|
/// Buffer for accumulating chunks (needed for non-streaming compressed responses)
|
||||||
|
chunk_buffer: Vec<u8>,
|
||||||
|
|
||||||
|
/// Captured response_id from response.completed event
|
||||||
|
response_id: Option<String>,
|
||||||
|
|
||||||
|
/// Captured output items from response.completed event
|
||||||
|
output_items: Option<Vec<OutputItem>>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P: StreamProcessor> ResponsesStateProcessor<P> {
|
||||||
|
pub fn new(
|
||||||
|
inner: P,
|
||||||
|
storage: Arc<dyn StateStorage>,
|
||||||
|
original_input: Vec<InputItem>,
|
||||||
|
model: String,
|
||||||
|
provider: String,
|
||||||
|
is_streaming: bool,
|
||||||
|
is_openai_upstream: bool,
|
||||||
|
content_encoding: Option<String>,
|
||||||
|
request_id: String,
|
||||||
|
) -> Self {
|
||||||
|
Self {
|
||||||
|
inner,
|
||||||
|
storage,
|
||||||
|
original_input,
|
||||||
|
model,
|
||||||
|
provider,
|
||||||
|
is_streaming,
|
||||||
|
is_openai_upstream,
|
||||||
|
content_encoding,
|
||||||
|
request_id,
|
||||||
|
chunk_buffer: Vec::new(),
|
||||||
|
response_id: None,
|
||||||
|
output_items: None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Decompress accumulated buffer based on Content-Encoding header
|
||||||
|
fn decompress_buffer(&self) -> Vec<u8> {
|
||||||
|
if self.chunk_buffer.is_empty() {
|
||||||
|
return Vec::new();
|
||||||
|
}
|
||||||
|
|
||||||
|
match self.content_encoding.as_deref() {
|
||||||
|
Some("gzip") => {
|
||||||
|
let mut decoder = GzDecoder::new(self.chunk_buffer.as_slice());
|
||||||
|
let mut decompressed = Vec::new();
|
||||||
|
match decoder.read_to_end(&mut decompressed) {
|
||||||
|
Ok(_) => {
|
||||||
|
debug!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully decompressed {} bytes to {} bytes",
|
||||||
|
self.request_id,
|
||||||
|
self.chunk_buffer.len(),
|
||||||
|
decompressed.len()
|
||||||
|
);
|
||||||
|
decompressed
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to decompress gzip buffer: {}",
|
||||||
|
self.request_id,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
self.chunk_buffer.clone()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
Some(encoding) => {
|
||||||
|
warn!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Unsupported Content-Encoding: {}. Only gzip is currently supported.",
|
||||||
|
self.request_id,
|
||||||
|
encoding
|
||||||
|
);
|
||||||
|
self.chunk_buffer.clone()
|
||||||
|
}
|
||||||
|
None => self.chunk_buffer.clone(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse response to extract response_id and output
|
||||||
|
/// For streaming: parse SSE events looking for response.completed (per chunk)
|
||||||
|
/// For non-streaming: buffer all chunks, then decompress and parse on completion
|
||||||
|
fn try_parse_response_chunk(&mut self, chunk: &[u8]) {
|
||||||
|
if self.is_streaming {
|
||||||
|
// Streaming: Try to parse SSE events from this chunk
|
||||||
|
// Note: For compressed streaming, we'd need to buffer and decompress first
|
||||||
|
// but most streaming responses aren't compressed since SSE needs to be readable
|
||||||
|
let sse_iter = match SseStreamIter::try_from(chunk) {
|
||||||
|
Ok(iter) => iter,
|
||||||
|
Err(_) => return, // Not valid SSE format, skip
|
||||||
|
};
|
||||||
|
|
||||||
|
// Process each SSE event in the chunk, looking for data lines with response.completed
|
||||||
|
for event in sse_iter {
|
||||||
|
// Only process data lines (skip event-only lines)
|
||||||
|
if let Some(data_str) = &event.data {
|
||||||
|
// Try to parse as ResponsesAPIStreamEvent
|
||||||
|
if let Ok(stream_event) = serde_json::from_str::<ResponsesAPIStreamEvent>(data_str) {
|
||||||
|
// Check if this is a ResponseCompleted event
|
||||||
|
if let ResponsesAPIStreamEvent::ResponseCompleted { response, .. } = stream_event {
|
||||||
|
info!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured streaming response.completed: response_id={}, output_items={}",
|
||||||
|
self.request_id,
|
||||||
|
response.id,
|
||||||
|
response.output.len()
|
||||||
|
);
|
||||||
|
self.response_id = Some(response.id.clone());
|
||||||
|
self.output_items = Some(response.output.clone());
|
||||||
|
return; // Found what we need, exit early
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// Non-streaming: Buffer chunks, will decompress and parse on completion
|
||||||
|
self.chunk_buffer.extend_from_slice(chunk);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Parse buffered non-streaming response (called on completion)
|
||||||
|
fn try_parse_buffered_response(&mut self) {
|
||||||
|
if self.is_streaming || self.chunk_buffer.is_empty() {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Decompress if needed
|
||||||
|
let decompressed = self.decompress_buffer();
|
||||||
|
|
||||||
|
// Parse complete JSON response
|
||||||
|
match serde_json::from_slice::<hermesllm::apis::openai_responses::ResponsesAPIResponse>(&decompressed) {
|
||||||
|
Ok(response) => {
|
||||||
|
info!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Captured non-streaming response: response_id={}, output_items={}",
|
||||||
|
self.request_id,
|
||||||
|
response.id,
|
||||||
|
response.output.len()
|
||||||
|
);
|
||||||
|
self.response_id = Some(response.id.clone());
|
||||||
|
self.output_items = Some(response.output.clone());
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
// Log parse error with chunk preview for debugging
|
||||||
|
let chunk_preview = String::from_utf8_lossy(&decompressed);
|
||||||
|
let preview_len = chunk_preview.len().min(200);
|
||||||
|
warn!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to parse non-streaming ResponsesAPIResponse: {}. Decompressed preview (first {} bytes): {}",
|
||||||
|
self.request_id,
|
||||||
|
e,
|
||||||
|
preview_len,
|
||||||
|
&chunk_preview[..preview_len]
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
impl<P: StreamProcessor> StreamProcessor for ResponsesStateProcessor<P> {
|
||||||
|
fn process_chunk(&mut self, chunk: Bytes) -> Result<Option<Bytes>, String> {
|
||||||
|
// Buffer/parse chunk for response extraction
|
||||||
|
self.try_parse_response_chunk(&chunk);
|
||||||
|
|
||||||
|
// Forward to inner processor
|
||||||
|
self.inner.process_chunk(chunk)
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_first_bytes(&mut self) {
|
||||||
|
self.inner.on_first_bytes();
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_complete(&mut self) {
|
||||||
|
// For non-streaming, decompress and parse buffered response
|
||||||
|
self.try_parse_buffered_response();
|
||||||
|
|
||||||
|
// First, let the inner processor complete
|
||||||
|
self.inner.on_complete();
|
||||||
|
|
||||||
|
// Skip storage for OpenAI upstream
|
||||||
|
if self.is_openai_upstream {
|
||||||
|
debug!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Skipping state storage for OpenAI upstream provider",
|
||||||
|
self.request_id
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Store state if we captured response_id and output
|
||||||
|
if let (Some(response_id), Some(output_items)) = (&self.response_id, &self.output_items) {
|
||||||
|
// Convert output items to input items for next request
|
||||||
|
let output_as_inputs = outputs_to_inputs(output_items);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Converting outputs to inputs: output_items_count={}, converted_input_items_count={}",
|
||||||
|
self.request_id, output_items.len(), output_as_inputs.len()
|
||||||
|
);
|
||||||
|
|
||||||
|
// Combine original input + output as new input history
|
||||||
|
let mut combined_input = self.original_input.clone();
|
||||||
|
combined_input.extend(output_as_inputs);
|
||||||
|
|
||||||
|
debug!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Storing state: original_input_count={}, combined_input_count={}, combined_json={}",
|
||||||
|
self.request_id,
|
||||||
|
self.original_input.len(),
|
||||||
|
combined_input.len(),
|
||||||
|
serde_json::to_string(&combined_input).unwrap_or_else(|_| "serialization_error".to_string())
|
||||||
|
);
|
||||||
|
|
||||||
|
let state = OpenAIConversationState {
|
||||||
|
response_id: response_id.clone(),
|
||||||
|
input_items: combined_input,
|
||||||
|
created_at: std::time::SystemTime::now()
|
||||||
|
.duration_since(std::time::UNIX_EPOCH)
|
||||||
|
.unwrap_or_default()
|
||||||
|
.as_secs() as i64,
|
||||||
|
model: self.model.clone(),
|
||||||
|
provider: self.provider.clone(),
|
||||||
|
};
|
||||||
|
|
||||||
|
// Store asynchronously (fire and forget with logging)
|
||||||
|
let storage = self.storage.clone();
|
||||||
|
let response_id_clone = response_id.clone();
|
||||||
|
let request_id = self.request_id.clone();
|
||||||
|
let items_count = state.input_items.len();
|
||||||
|
tokio::spawn(async move {
|
||||||
|
match storage.put(state).await {
|
||||||
|
Ok(()) => {
|
||||||
|
info!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Successfully stored conversation state for response_id: {}, items_count={}",
|
||||||
|
request_id,
|
||||||
|
response_id_clone,
|
||||||
|
items_count
|
||||||
|
);
|
||||||
|
}
|
||||||
|
Err(e) => {
|
||||||
|
warn!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | Failed to store conversation state for response_id {}: {}",
|
||||||
|
request_id,
|
||||||
|
response_id_clone,
|
||||||
|
e
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
});
|
||||||
|
} else {
|
||||||
|
warn!(
|
||||||
|
"[PLANO_REQ_ID:{}] | STATE_PROCESSOR | No response_id captured from upstream response - cannot store conversation state. response_id present: {}, output present: {}",
|
||||||
|
self.request_id,
|
||||||
|
self.response_id.is_some(),
|
||||||
|
self.output_items.is_some()
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn on_error(&mut self, error: &str) {
|
||||||
|
self.inner.on_error(error);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -44,6 +44,20 @@ pub struct Listener {
|
||||||
pub port: u16,
|
pub port: u16,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
|
pub struct StateStorageConfig {
|
||||||
|
#[serde(rename = "type")]
|
||||||
|
pub storage_type: StateStorageType,
|
||||||
|
pub connection_string: Option<String>,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
|
||||||
|
#[serde(rename_all = "lowercase")]
|
||||||
|
pub enum StateStorageType {
|
||||||
|
Memory,
|
||||||
|
Postgres,
|
||||||
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize)]
|
#[derive(Debug, Clone, Serialize, Deserialize)]
|
||||||
pub struct Configuration {
|
pub struct Configuration {
|
||||||
pub version: String,
|
pub version: String,
|
||||||
|
|
@ -62,6 +76,7 @@ pub struct Configuration {
|
||||||
pub agents: Option<Vec<Agent>>,
|
pub agents: Option<Vec<Agent>>,
|
||||||
pub filters: Option<Vec<Agent>>,
|
pub filters: Option<Vec<Agent>>,
|
||||||
pub listeners: Vec<Listener>,
|
pub listeners: Vec<Listener>,
|
||||||
|
pub state_storage: Option<StateStorageConfig>,
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
|
||||||
|
|
|
||||||
|
|
@ -59,6 +59,11 @@ pub struct ResponsesAPIStreamBuffer {
|
||||||
model: Option<String>,
|
model: Option<String>,
|
||||||
created_at: Option<i64>,
|
created_at: Option<i64>,
|
||||||
|
|
||||||
|
/// Full response metadata from upstream (tools, temperature, etc.)
|
||||||
|
/// This is extracted from the first upstream event and used to build
|
||||||
|
/// complete response.created and response.in_progress events
|
||||||
|
upstream_response_metadata: Option<ResponsesAPIResponse>,
|
||||||
|
|
||||||
/// Lifecycle state flags
|
/// Lifecycle state flags
|
||||||
created_emitted: bool,
|
created_emitted: bool,
|
||||||
in_progress_emitted: bool,
|
in_progress_emitted: bool,
|
||||||
|
|
@ -88,6 +93,7 @@ impl ResponsesAPIStreamBuffer {
|
||||||
response_id: None,
|
response_id: None,
|
||||||
model: None,
|
model: None,
|
||||||
created_at: None,
|
created_at: None,
|
||||||
|
upstream_response_metadata: None,
|
||||||
created_emitted: false,
|
created_emitted: false,
|
||||||
in_progress_emitted: false,
|
in_progress_emitted: false,
|
||||||
output_items_added: HashMap::new(),
|
output_items_added: HashMap::new(),
|
||||||
|
|
@ -171,6 +177,15 @@ impl ResponsesAPIStreamBuffer {
|
||||||
|
|
||||||
/// Build the base response object with current state
|
/// Build the base response object with current state
|
||||||
fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse {
|
fn build_response(&self, status: ResponseStatus) -> ResponsesAPIResponse {
|
||||||
|
// If we have upstream metadata, use it as a base and update status/output
|
||||||
|
if let Some(upstream) = &self.upstream_response_metadata {
|
||||||
|
let mut response = upstream.clone();
|
||||||
|
response.status = status;
|
||||||
|
// Don't update output here - will be set in finalize()
|
||||||
|
return response;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fallback: build a minimal response from local state
|
||||||
ResponsesAPIResponse {
|
ResponsesAPIResponse {
|
||||||
id: self.response_id.clone().unwrap_or_default(),
|
id: self.response_id.clone().unwrap_or_default(),
|
||||||
object: "response".to_string(),
|
object: "response".to_string(),
|
||||||
|
|
@ -293,24 +308,40 @@ impl ResponsesAPIStreamBuffer {
|
||||||
// Build final response
|
// Build final response
|
||||||
let mut output_items = Vec::new();
|
let mut output_items = Vec::new();
|
||||||
|
|
||||||
// Add tool calls to output
|
// Build complete output array by iterating through all output indices in order
|
||||||
for (item_id, arguments) in &self.function_arguments {
|
let max_output_index = self.output_items_added.keys().max().copied().unwrap_or(-1);
|
||||||
let output_index = self.output_items_added.iter()
|
|
||||||
.find(|(_, id)| *id == item_id)
|
|
||||||
.map(|(idx, _)| *idx)
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
for output_index in 0..=max_output_index {
|
||||||
.cloned()
|
if let Some(item_id) = self.output_items_added.get(&output_index) {
|
||||||
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
// Check if this is a function call
|
||||||
|
if let Some(arguments) = self.function_arguments.get(item_id) {
|
||||||
|
let (call_id, name) = self.tool_call_metadata.get(&output_index)
|
||||||
|
.cloned()
|
||||||
|
.unwrap_or_else(|| (format!("call_{}", uuid::Uuid::new_v4()), "unknown".to_string()));
|
||||||
|
|
||||||
output_items.push(OutputItem::FunctionCall {
|
output_items.push(OutputItem::FunctionCall {
|
||||||
id: item_id.clone(),
|
id: item_id.clone(),
|
||||||
status: OutputItemStatus::Completed,
|
status: OutputItemStatus::Completed,
|
||||||
call_id,
|
call_id,
|
||||||
name: Some(name),
|
name: Some(name),
|
||||||
arguments: Some(arguments.clone()),
|
arguments: Some(arguments.clone()),
|
||||||
});
|
});
|
||||||
|
}
|
||||||
|
// Check if this is a text message
|
||||||
|
else if let Some(text) = self.text_content.get(item_id) {
|
||||||
|
use crate::apis::openai_responses::OutputContent;
|
||||||
|
output_items.push(OutputItem::Message {
|
||||||
|
id: item_id.clone(),
|
||||||
|
status: OutputItemStatus::Completed,
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![OutputContent::OutputText {
|
||||||
|
text: text.clone(),
|
||||||
|
annotations: vec![],
|
||||||
|
logprobs: None,
|
||||||
|
}],
|
||||||
|
});
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut final_response = self.build_response(ResponseStatus::Completed);
|
let mut final_response = self.build_response(ResponseStatus::Completed);
|
||||||
|
|
@ -365,6 +396,24 @@ impl SseStreamBufferTrait for ResponsesAPIStreamBuffer {
|
||||||
|
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
|
|
||||||
|
// Capture upstream metadata from ResponseCreated or ResponseInProgress if present
|
||||||
|
match stream_event {
|
||||||
|
ResponsesAPIStreamEvent::ResponseCreated { response, .. } |
|
||||||
|
ResponsesAPIStreamEvent::ResponseInProgress { response, .. } => {
|
||||||
|
if self.upstream_response_metadata.is_none() {
|
||||||
|
// Store the full upstream response as our metadata template
|
||||||
|
self.upstream_response_metadata = Some(response.clone());
|
||||||
|
// Also extract basic fields
|
||||||
|
self.response_id = Some(response.id.clone());
|
||||||
|
self.model = Some(response.model.clone());
|
||||||
|
self.created_at = Some(response.created_at);
|
||||||
|
}
|
||||||
|
// Don't emit these - we'll generate our own lifecycle events
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
_ => {}
|
||||||
|
}
|
||||||
|
|
||||||
// Emit lifecycle events if not yet emitted
|
// Emit lifecycle events if not yet emitted
|
||||||
if !self.created_emitted {
|
if !self.created_emitted {
|
||||||
// Initialize metadata from first event if needed
|
// Initialize metadata from first event if needed
|
||||||
|
|
|
||||||
|
|
@ -193,6 +193,40 @@ impl SupportedAPIsFromClient {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
impl SupportedUpstreamAPIs {
|
||||||
|
/// Create a SupportedUpstreamApi from an endpoint path
|
||||||
|
pub fn from_endpoint(endpoint: &str) -> Option<Self> {
|
||||||
|
if let Some(openai_api) = OpenAIApi::from_endpoint(endpoint) {
|
||||||
|
// Check if this is the Responses API endpoint
|
||||||
|
if openai_api == OpenAIApi::Responses {
|
||||||
|
return Some(SupportedUpstreamAPIs::OpenAIResponsesAPI(openai_api));
|
||||||
|
}
|
||||||
|
// Otherwise it's ChatCompletions
|
||||||
|
return Some(SupportedUpstreamAPIs::OpenAIChatCompletions(openai_api));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(anthropic_api) = AnthropicApi::from_endpoint(endpoint) {
|
||||||
|
return Some(SupportedUpstreamAPIs::AnthropicMessagesAPI(anthropic_api));
|
||||||
|
}
|
||||||
|
|
||||||
|
if let Some(bedrock_api) = AmazonBedrockApi::from_endpoint(endpoint) {
|
||||||
|
match bedrock_api {
|
||||||
|
AmazonBedrockApi::Converse => {
|
||||||
|
return Some(SupportedUpstreamAPIs::AmazonBedrockConverse(bedrock_api))
|
||||||
|
}
|
||||||
|
AmazonBedrockApi::ConverseStream => {
|
||||||
|
return Some(SupportedUpstreamAPIs::AmazonBedrockConverseStream(bedrock_api))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
None
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
/// Get all supported endpoint paths
|
/// Get all supported endpoint paths
|
||||||
pub fn supported_endpoints() -> Vec<&'static str> {
|
pub fn supported_endpoints() -> Vec<&'static str> {
|
||||||
let mut endpoints = Vec::new();
|
let mut endpoints = Vec::new();
|
||||||
|
|
|
||||||
|
|
@ -1,3 +1,4 @@
|
||||||
//! Response transformation modules
|
//! Response transformation modules
|
||||||
|
pub mod output_to_input;
|
||||||
pub mod to_anthropic;
|
pub mod to_anthropic;
|
||||||
pub mod to_openai;
|
pub mod to_openai;
|
||||||
|
|
|
||||||
178
crates/hermesllm/src/transforms/response/output_to_input.rs
Normal file
178
crates/hermesllm/src/transforms/response/output_to_input.rs
Normal file
|
|
@ -0,0 +1,178 @@
|
||||||
|
//! Conversions from response outputs to request inputs for conversation continuation
|
||||||
|
//!
|
||||||
|
//! This module provides utilities for converting OutputItem types from API responses
|
||||||
|
//! into InputItem types that can be used in subsequent requests. This is primarily used
|
||||||
|
//! for maintaining conversation history in the v1/responses API.
|
||||||
|
|
||||||
|
use crate::apis::openai_responses::{
|
||||||
|
InputContent, InputItem, InputMessage, MessageContent, MessageRole, OutputContent, OutputItem,
|
||||||
|
};
|
||||||
|
|
||||||
|
/// Converts an OutputItem from a response into an InputItem for the next request
|
||||||
|
/// This is used to build conversation history from previous responses
|
||||||
|
pub fn convert_responses_output_to_input_items(output: &OutputItem) -> Option<InputItem> {
|
||||||
|
match output {
|
||||||
|
// Convert output messages to input messages
|
||||||
|
OutputItem::Message {
|
||||||
|
role, content, ..
|
||||||
|
} => {
|
||||||
|
let input_content: Vec<InputContent> = content
|
||||||
|
.iter()
|
||||||
|
.filter_map(|c| match c {
|
||||||
|
OutputContent::OutputText { text, .. } => Some(InputContent::InputText {
|
||||||
|
text: text.clone(),
|
||||||
|
}),
|
||||||
|
OutputContent::OutputAudio {
|
||||||
|
data, ..
|
||||||
|
} => Some(InputContent::InputAudio {
|
||||||
|
data: data.clone(),
|
||||||
|
format: None, // Format not preserved in output
|
||||||
|
}),
|
||||||
|
OutputContent::Refusal { .. } => None, // Skip refusals
|
||||||
|
})
|
||||||
|
.collect();
|
||||||
|
|
||||||
|
if input_content.is_empty() {
|
||||||
|
return None;
|
||||||
|
}
|
||||||
|
|
||||||
|
// Map role string to MessageRole enum
|
||||||
|
let message_role = match role.as_str() {
|
||||||
|
"user" => MessageRole::User,
|
||||||
|
"assistant" => MessageRole::Assistant,
|
||||||
|
"system" => MessageRole::System,
|
||||||
|
"developer" => MessageRole::Developer,
|
||||||
|
_ => MessageRole::Assistant, // Default to assistant
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(InputItem::Message(InputMessage {
|
||||||
|
role: message_role,
|
||||||
|
content: MessageContent::Items(input_content),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
// For function calls, we'll create an assistant message with the tool call info
|
||||||
|
// This matches how conversation history is typically built
|
||||||
|
OutputItem::FunctionCall {
|
||||||
|
name, arguments, ..
|
||||||
|
} => {
|
||||||
|
let tool_call_text = if let (Some(n), Some(args)) = (name, arguments) {
|
||||||
|
format!("Called function: {} with arguments: {}", n, args)
|
||||||
|
} else {
|
||||||
|
"Called a function".to_string()
|
||||||
|
};
|
||||||
|
|
||||||
|
Some(InputItem::Message(InputMessage {
|
||||||
|
role: MessageRole::Assistant,
|
||||||
|
content: MessageContent::Items(vec![InputContent::InputText {
|
||||||
|
text: tool_call_text,
|
||||||
|
}]),
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
// Skip other output types (tool outputs, etc.) as they don't convert to input
|
||||||
|
_ => None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Converts a Vec of OutputItems into InputItems for conversation continuation
|
||||||
|
pub fn outputs_to_inputs(outputs: &[OutputItem]) -> Vec<InputItem> {
|
||||||
|
outputs
|
||||||
|
.iter()
|
||||||
|
.filter_map(convert_responses_output_to_input_items)
|
||||||
|
.collect()
|
||||||
|
}
|
||||||
|
|
||||||
|
#[cfg(test)]
|
||||||
|
mod tests {
|
||||||
|
use super::*;
|
||||||
|
use crate::apis::openai_responses::{OutputItemStatus};
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_output_message_to_input() {
|
||||||
|
let output = OutputItem::Message {
|
||||||
|
id: "msg_123".to_string(),
|
||||||
|
status: OutputItemStatus::Completed,
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![OutputContent::OutputText {
|
||||||
|
text: "Hello!".to_string(),
|
||||||
|
annotations: vec![],
|
||||||
|
logprobs: None,
|
||||||
|
}],
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = convert_responses_output_to_input_items(&output).unwrap();
|
||||||
|
|
||||||
|
match input {
|
||||||
|
InputItem::Message(msg) => {
|
||||||
|
assert!(matches!(msg.role, MessageRole::Assistant));
|
||||||
|
match &msg.content {
|
||||||
|
MessageContent::Items(items) => {
|
||||||
|
assert_eq!(items.len(), 1);
|
||||||
|
match &items[0] {
|
||||||
|
InputContent::InputText { text } => assert_eq!(text, "Hello!"),
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Expected Message variant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_function_call_to_input() {
|
||||||
|
let output = OutputItem::FunctionCall {
|
||||||
|
id: "fc_123".to_string(),
|
||||||
|
status: OutputItemStatus::Completed,
|
||||||
|
call_id: "call_123".to_string(),
|
||||||
|
name: Some("get_weather".to_string()),
|
||||||
|
arguments: Some(r#"{"location":"SF"}"#.to_string()),
|
||||||
|
};
|
||||||
|
|
||||||
|
let input = convert_responses_output_to_input_items(&output).unwrap();
|
||||||
|
|
||||||
|
match input {
|
||||||
|
InputItem::Message(msg) => {
|
||||||
|
assert!(matches!(msg.role, MessageRole::Assistant));
|
||||||
|
match &msg.content {
|
||||||
|
MessageContent::Items(items) => {
|
||||||
|
match &items[0] {
|
||||||
|
InputContent::InputText { text } => {
|
||||||
|
assert!(text.contains("get_weather"));
|
||||||
|
}
|
||||||
|
_ => panic!("Expected InputText"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Expected MessageContent::Items"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
_ => panic!("Expected Message variant"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[test]
|
||||||
|
fn test_outputs_to_inputs() {
|
||||||
|
let outputs = vec![
|
||||||
|
OutputItem::Message {
|
||||||
|
id: "msg_1".to_string(),
|
||||||
|
status: OutputItemStatus::Completed,
|
||||||
|
role: "assistant".to_string(),
|
||||||
|
content: vec![OutputContent::OutputText {
|
||||||
|
text: "Hello".to_string(),
|
||||||
|
annotations: vec![],
|
||||||
|
logprobs: None,
|
||||||
|
}],
|
||||||
|
},
|
||||||
|
OutputItem::FunctionCall {
|
||||||
|
id: "fc_1".to_string(),
|
||||||
|
status: OutputItemStatus::Completed,
|
||||||
|
call_id: "call_1".to_string(),
|
||||||
|
name: Some("test".to_string()),
|
||||||
|
arguments: Some("{}".to_string()),
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
let inputs = outputs_to_inputs(&outputs);
|
||||||
|
assert_eq!(inputs.len(), 2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -80,8 +80,19 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
|
||||||
// Only add the message item if there's actual content (text, audio, or refusal)
|
// Only add the message item if there's actual content (text, audio, or refusal)
|
||||||
// Don't add empty message items when there are only tool calls
|
// Don't add empty message items when there are only tool calls
|
||||||
if !content.is_empty() {
|
if !content.is_empty() {
|
||||||
|
// Generate message ID: strip common prefixes to avoid double-prefixing
|
||||||
|
let message_id = if resp.id.starts_with("msg_") {
|
||||||
|
resp.id.clone()
|
||||||
|
} else if resp.id.starts_with("resp_") {
|
||||||
|
format!("msg_{}", &resp.id[5..]) // Strip "resp_" prefix
|
||||||
|
} else if resp.id.starts_with("chatcmpl-") {
|
||||||
|
format!("msg_{}", &resp.id[9..]) // Strip "chatcmpl-" prefix
|
||||||
|
} else {
|
||||||
|
format!("msg_{}", resp.id)
|
||||||
|
};
|
||||||
|
|
||||||
items.push(OutputItem::Message {
|
items.push(OutputItem::Message {
|
||||||
id: format!("msg_{}", resp.id),
|
id: message_id,
|
||||||
status: OutputItemStatus::Completed,
|
status: OutputItemStatus::Completed,
|
||||||
role: match choice.message.role {
|
role: match choice.message.role {
|
||||||
Role::User => "user".to_string(),
|
Role::User => "user".to_string(),
|
||||||
|
|
@ -151,7 +162,12 @@ impl TryFrom<ChatCompletionsResponse> for ResponsesAPIResponse {
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(ResponsesAPIResponse {
|
Ok(ResponsesAPIResponse {
|
||||||
id: resp.id,
|
// Generate proper resp_ prefixed ID if not already present
|
||||||
|
id: if resp.id.starts_with("resp_") {
|
||||||
|
resp.id
|
||||||
|
} else {
|
||||||
|
format!("resp_{}", uuid::Uuid::new_v4().to_string().replace("-", ""))
|
||||||
|
},
|
||||||
object: "response".to_string(),
|
object: "response".to_string(),
|
||||||
created_at: resp.created as i64,
|
created_at: resp.created as i64,
|
||||||
status,
|
status,
|
||||||
|
|
@ -942,7 +958,7 @@ mod tests {
|
||||||
use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse};
|
use crate::apis::openai_responses::{OutputContent, OutputItem, ResponsesAPIResponse};
|
||||||
|
|
||||||
let chat_response = ChatCompletionsResponse {
|
let chat_response = ChatCompletionsResponse {
|
||||||
id: "chatcmpl-123".to_string(),
|
id: "resp_6de5512800cf4375a329a473a4f02879".to_string(),
|
||||||
object: Some("chat.completion".to_string()),
|
object: Some("chat.completion".to_string()),
|
||||||
created: 1677652288,
|
created: 1677652288,
|
||||||
model: "gpt-4".to_string(),
|
model: "gpt-4".to_string(),
|
||||||
|
|
@ -974,7 +990,9 @@ mod tests {
|
||||||
|
|
||||||
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
let responses_api: ResponsesAPIResponse = chat_response.try_into().unwrap();
|
||||||
|
|
||||||
assert_eq!(responses_api.id, "chatcmpl-123");
|
// Response ID should be generated with resp_ prefix
|
||||||
|
assert!(responses_api.id.starts_with("resp_"), "Response ID should start with 'resp_'");
|
||||||
|
assert_eq!(responses_api.id.len(), 37, "Response ID should be resp_ + 32 char UUID");
|
||||||
assert_eq!(responses_api.object, "response");
|
assert_eq!(responses_api.object, "response");
|
||||||
assert_eq!(responses_api.model, "gpt-4");
|
assert_eq!(responses_api.model, "gpt-4");
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -58,11 +58,11 @@ impl TryFrom<MessagesStreamEvent> for ChatCompletionsStreamResponse {
|
||||||
None,
|
None,
|
||||||
)),
|
)),
|
||||||
|
|
||||||
MessagesStreamEvent::ContentBlockStart { content_block, .. } => {
|
MessagesStreamEvent::ContentBlockStart { content_block, index } => {
|
||||||
convert_content_block_start(content_block)
|
convert_content_block_start(content_block, index)
|
||||||
}
|
}
|
||||||
|
|
||||||
MessagesStreamEvent::ContentBlockDelta { delta, .. } => convert_content_delta(delta),
|
MessagesStreamEvent::ContentBlockDelta { delta, index } => convert_content_delta(delta, index),
|
||||||
|
|
||||||
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
MessagesStreamEvent::ContentBlockStop { .. } => Ok(create_empty_openai_chunk()),
|
||||||
|
|
||||||
|
|
@ -272,6 +272,7 @@ impl TryFrom<ConverseStreamEvent> for ChatCompletionsStreamResponse {
|
||||||
/// Convert content block start to OpenAI chunk
|
/// Convert content block start to OpenAI chunk
|
||||||
fn convert_content_block_start(
|
fn convert_content_block_start(
|
||||||
content_block: MessagesContentBlock,
|
content_block: MessagesContentBlock,
|
||||||
|
index: u32,
|
||||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||||
match content_block {
|
match content_block {
|
||||||
MessagesContentBlock::Text { .. } => {
|
MessagesContentBlock::Text { .. } => {
|
||||||
|
|
@ -291,7 +292,7 @@ fn convert_content_block_start(
|
||||||
refusal: None,
|
refusal: None,
|
||||||
function_call: None,
|
function_call: None,
|
||||||
tool_calls: Some(vec![ToolCallDelta {
|
tool_calls: Some(vec![ToolCallDelta {
|
||||||
index: 0,
|
index,
|
||||||
id: Some(id),
|
id: Some(id),
|
||||||
call_type: Some("function".to_string()),
|
call_type: Some("function".to_string()),
|
||||||
function: Some(FunctionCallDelta {
|
function: Some(FunctionCallDelta {
|
||||||
|
|
@ -313,6 +314,7 @@ fn convert_content_block_start(
|
||||||
/// Convert content delta to OpenAI chunk
|
/// Convert content delta to OpenAI chunk
|
||||||
fn convert_content_delta(
|
fn convert_content_delta(
|
||||||
delta: MessagesContentDelta,
|
delta: MessagesContentDelta,
|
||||||
|
index: u32,
|
||||||
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
) -> Result<ChatCompletionsStreamResponse, TransformError> {
|
||||||
match delta {
|
match delta {
|
||||||
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
MessagesContentDelta::TextDelta { text } => Ok(create_openai_chunk(
|
||||||
|
|
@ -350,7 +352,7 @@ fn convert_content_delta(
|
||||||
refusal: None,
|
refusal: None,
|
||||||
function_call: None,
|
function_call: None,
|
||||||
tool_calls: Some(vec![ToolCallDelta {
|
tool_calls: Some(vec![ToolCallDelta {
|
||||||
index: 0,
|
index,
|
||||||
id: None,
|
id: None,
|
||||||
call_type: None,
|
call_type: None,
|
||||||
function: Some(FunctionCallDelta {
|
function: Some(FunctionCallDelta {
|
||||||
|
|
|
||||||
109
docs/db_setup/README.md
Normal file
109
docs/db_setup/README.md
Normal file
|
|
@ -0,0 +1,109 @@
|
||||||
|
# Database Setup for Conversation State Storage
|
||||||
|
|
||||||
|
This directory contains SQL scripts needed to set up database tables for storing conversation state when using the OpenAI Responses API.
|
||||||
|
|
||||||
|
## Prerequisites
|
||||||
|
|
||||||
|
- PostgreSQL database (Supabase or self-hosted)
|
||||||
|
- Database connection credentials
|
||||||
|
- `psql` CLI tool or database admin access
|
||||||
|
|
||||||
|
## Setup Instructions
|
||||||
|
|
||||||
|
### Option 1: Using psql
|
||||||
|
|
||||||
|
```bash
|
||||||
|
psql $DATABASE_URL -f docs/db_setup/conversation_states.sql
|
||||||
|
```
|
||||||
|
|
||||||
|
### Option 2: Using Supabase Dashboard
|
||||||
|
|
||||||
|
1. Log in to your Supabase project dashboard
|
||||||
|
2. Navigate to the SQL Editor
|
||||||
|
3. Copy and paste the contents of `conversation_states.sql`
|
||||||
|
4. Run the query
|
||||||
|
|
||||||
|
### Option 3: Direct Database Connection
|
||||||
|
|
||||||
|
Connect to your PostgreSQL database using your preferred client and execute the SQL from `conversation_states.sql`.
|
||||||
|
|
||||||
|
## Verification
|
||||||
|
|
||||||
|
After running the setup, verify the table was created:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
SELECT tablename FROM pg_tables WHERE tablename = 'conversation_states';
|
||||||
|
```
|
||||||
|
|
||||||
|
You should see `conversation_states` in the results.
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
After setting up the database table, configure your application to use Supabase storage by setting the appropriate environment variable or configuration parameter with your database connection string.
|
||||||
|
|
||||||
|
### Supabase Connection String
|
||||||
|
|
||||||
|
**Important:** Supabase requires different connection strings depending on your network:
|
||||||
|
|
||||||
|
- **IPv4 Networks (Most Common)**: Use the **Session Pooler** connection string (port 5432):
|
||||||
|
```
|
||||||
|
postgresql://postgres.[PROJECT-REF]:[PASSWORD]@aws-0-[REGION].pooler.supabase.com:5432/postgres
|
||||||
|
```
|
||||||
|
|
||||||
|
- **IPv6 Networks**: Use the direct connection (port 5432):
|
||||||
|
```
|
||||||
|
postgresql://postgres:[PASSWORD]@db.[PROJECT-REF].supabase.co:5432/postgres
|
||||||
|
```
|
||||||
|
|
||||||
|
**How to get your connection string:**
|
||||||
|
1. Go to your Supabase project dashboard
|
||||||
|
2. Settings → Database → Connection Pooling
|
||||||
|
3. Copy the **Session mode** connection string
|
||||||
|
4. Replace `[YOUR-PASSWORD]` with your actual database password
|
||||||
|
5. URL-encode special characters in the password (e.g., `#` becomes `%23`)
|
||||||
|
|
||||||
|
**Example:**
|
||||||
|
```bash
|
||||||
|
# If your password is "MyPass#123", encode it as "MyPass%23123"
|
||||||
|
export DATABASE_URL="postgresql://postgres.myproject:MyPass%23123@aws-0-us-west-2.pooler.supabase.com:5432/postgres"
|
||||||
|
```
|
||||||
|
|
||||||
|
### Testing the Connection
|
||||||
|
|
||||||
|
To test your connection string works:
|
||||||
|
```bash
|
||||||
|
export TEST_DATABASE_URL="your-connection-string-here"
|
||||||
|
cd crates/brightstaff
|
||||||
|
cargo test supabase -- --nocapture
|
||||||
|
```
|
||||||
|
|
||||||
|
## Table Schema
|
||||||
|
|
||||||
|
The `conversation_states` table stores:
|
||||||
|
- `response_id` (TEXT, PRIMARY KEY): Unique identifier for each conversation
|
||||||
|
- `input_items` (JSONB): Array of conversation messages and context
|
||||||
|
- `created_at` (BIGINT): Unix timestamp when conversation started
|
||||||
|
- `model` (TEXT): Model name used for the conversation
|
||||||
|
- `provider` (TEXT): LLM provider name
|
||||||
|
- `updated_at` (TIMESTAMP): Last update time (auto-managed)
|
||||||
|
|
||||||
|
## Maintenance
|
||||||
|
|
||||||
|
### Cleanup Old Conversations
|
||||||
|
|
||||||
|
To prevent unbounded growth, consider periodically cleaning up old conversation states:
|
||||||
|
|
||||||
|
```sql
|
||||||
|
-- Delete conversations older than 7 days
|
||||||
|
DELETE FROM conversation_states
|
||||||
|
WHERE updated_at < NOW() - INTERVAL '7 days';
|
||||||
|
```
|
||||||
|
|
||||||
|
You can automate this with a cron job or database trigger.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
If you encounter errors on first use:
|
||||||
|
- **"Table 'conversation_states' does not exist"**: Run the setup SQL
|
||||||
|
- **Connection errors**: Verify your DATABASE_URL is correct
|
||||||
|
- **Permission errors**: Ensure your database user has CREATE TABLE privileges
|
||||||
31
docs/db_setup/conversation_states.sql
Normal file
31
docs/db_setup/conversation_states.sql
Normal file
|
|
@ -0,0 +1,31 @@
|
||||||
|
-- Conversation State Storage Table
|
||||||
|
-- This table stores conversational context for the OpenAI Responses API
|
||||||
|
-- Run this SQL against your PostgreSQL/Supabase database before enabling conversation state storage
|
||||||
|
|
||||||
|
CREATE TABLE IF NOT EXISTS conversation_states (
|
||||||
|
response_id TEXT PRIMARY KEY,
|
||||||
|
input_items JSONB NOT NULL,
|
||||||
|
created_at BIGINT NOT NULL,
|
||||||
|
model TEXT NOT NULL,
|
||||||
|
provider TEXT NOT NULL,
|
||||||
|
updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP
|
||||||
|
);
|
||||||
|
|
||||||
|
-- Indexes for common query patterns
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_conversation_states_created_at
|
||||||
|
ON conversation_states(created_at);
|
||||||
|
|
||||||
|
CREATE INDEX IF NOT EXISTS idx_conversation_states_provider
|
||||||
|
ON conversation_states(provider);
|
||||||
|
|
||||||
|
-- Optional: Add a policy for automatic cleanup of old conversations
|
||||||
|
-- Uncomment and adjust the retention period as needed
|
||||||
|
-- CREATE INDEX IF NOT EXISTS idx_conversation_states_updated_at
|
||||||
|
-- ON conversation_states(updated_at);
|
||||||
|
|
||||||
|
COMMENT ON TABLE conversation_states IS 'Stores conversation history for OpenAI Responses API continuity';
|
||||||
|
COMMENT ON COLUMN conversation_states.response_id IS 'Unique identifier for the conversation state';
|
||||||
|
COMMENT ON COLUMN conversation_states.input_items IS 'JSONB array of conversation messages and context';
|
||||||
|
COMMENT ON COLUMN conversation_states.created_at IS 'Unix timestamp (seconds) when the conversation started';
|
||||||
|
COMMENT ON COLUMN conversation_states.model IS 'Model name used for this conversation';
|
||||||
|
COMMENT ON COLUMN conversation_states.provider IS 'LLM provider (e.g., openai, anthropic, bedrock)';
|
||||||
|
|
@ -0,0 +1,32 @@
|
||||||
|
version: v0.1
|
||||||
|
|
||||||
|
listeners:
|
||||||
|
egress_traffic:
|
||||||
|
address: 0.0.0.0
|
||||||
|
port: 12000
|
||||||
|
message_format: openai
|
||||||
|
timeout: 30s
|
||||||
|
|
||||||
|
llm_providers:
|
||||||
|
|
||||||
|
# OpenAI Models
|
||||||
|
- model: openai/gpt-5-mini-2025-08-07
|
||||||
|
access_key: $OPENAI_API_KEY
|
||||||
|
default: true
|
||||||
|
|
||||||
|
# Anthropic Models
|
||||||
|
- model: anthropic/claude-sonnet-4-20250514
|
||||||
|
access_key: $ANTHROPIC_API_KEY
|
||||||
|
|
||||||
|
# State storage configuration for v1/responses API
|
||||||
|
# Manages conversation state for multi-turn conversations
|
||||||
|
state_storage:
|
||||||
|
# Type: memory | postgres
|
||||||
|
type: postgres
|
||||||
|
|
||||||
|
# Connection string for postgres type
|
||||||
|
# Environment variables are supported using $VAR_NAME or ${VAR_NAME} syntax
|
||||||
|
# Replace [USER] and [HOST] with your actual database credentials
|
||||||
|
# Variables like $DB_PASSWORD MUST be set before running config validation/rendering
|
||||||
|
# Example: Replace [USER] with 'myuser' and [HOST] with 'db.example.com:5432'
|
||||||
|
connection_string: "postgresql://[USER]:$DB_PASSWORD@[HOST]:5432/postgres"
|
||||||
25
tests/e2e/arch_config_memory_state_v1_responses.yaml
Normal file
25
tests/e2e/arch_config_memory_state_v1_responses.yaml
Normal file
|
|
@ -0,0 +1,25 @@
|
||||||
|
version: v0.1
|
||||||
|
|
||||||
|
listeners:
|
||||||
|
egress_traffic:
|
||||||
|
address: 0.0.0.0
|
||||||
|
port: 12000
|
||||||
|
message_format: openai
|
||||||
|
timeout: 30s
|
||||||
|
|
||||||
|
llm_providers:
|
||||||
|
|
||||||
|
# OpenAI Models
|
||||||
|
- model: openai/gpt-5-mini-2025-08-07
|
||||||
|
access_key: $OPENAI_API_KEY
|
||||||
|
default: true
|
||||||
|
|
||||||
|
# Anthropic Models
|
||||||
|
- model: anthropic/claude-sonnet-4-20250514
|
||||||
|
access_key: $ANTHROPIC_API_KEY
|
||||||
|
|
||||||
|
# State storage configuration for v1/responses API
|
||||||
|
# Manages conversation state for multi-turn conversations
|
||||||
|
state_storage:
|
||||||
|
# Type: memory | postgres
|
||||||
|
type: memory
|
||||||
|
|
@ -69,6 +69,14 @@ log running e2e tests for openai responses api client
|
||||||
log ========================================
|
log ========================================
|
||||||
poetry run pytest test_openai_responses_api_client.py
|
poetry run pytest test_openai_responses_api_client.py
|
||||||
|
|
||||||
|
log startup arch gateway with state storage for openai responses api client demo
|
||||||
|
archgw down
|
||||||
|
archgw up arch_config_memory_state_v1_responses.yaml
|
||||||
|
|
||||||
|
log running e2e tests for openai responses api client
|
||||||
|
log ========================================
|
||||||
|
poetry run pytest test_openai_responses_api_client_with_state.py
|
||||||
|
|
||||||
log shutting down the weather_forecast demo
|
log shutting down the weather_forecast demo
|
||||||
log =======================================
|
log =======================================
|
||||||
cd ../../demos/samples_python/weather_forecast
|
cd ../../demos/samples_python/weather_forecast
|
||||||
|
|
|
||||||
218
tests/e2e/test_openai_responses_api_client_with_state.py
Normal file
218
tests/e2e/test_openai_responses_api_client_with_state.py
Normal file
|
|
@ -0,0 +1,218 @@
|
||||||
|
import openai
|
||||||
|
import pytest
|
||||||
|
import os
|
||||||
|
import logging
|
||||||
|
import sys
|
||||||
|
|
||||||
|
# Set up logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.INFO,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
handlers=[logging.StreamHandler(sys.stdout)],
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
LLM_GATEWAY_ENDPOINT = os.getenv(
|
||||||
|
"LLM_GATEWAY_ENDPOINT", "http://localhost:12000/v1/chat/completions"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_state_management_two_turn():
|
||||||
|
"""
|
||||||
|
Test conversation state management across two turns:
|
||||||
|
1. Send initial message to non-OpenAI model via v1/responses
|
||||||
|
2. Capture response_id from first response
|
||||||
|
3. Send second message with previous_response_id
|
||||||
|
4. Verify model receives both messages in correct order
|
||||||
|
"""
|
||||||
|
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||||
|
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 80)
|
||||||
|
logger.info("TEST: Conversation State Management - Two Turn Flow")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
# Turn 1: Send initial message to Anthropic (non-OpenAI model)
|
||||||
|
logger.info("\n[TURN 1] Sending initial message...")
|
||||||
|
resp1 = client.responses.create(
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
input="My name is Alice and I like pizza.",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Extract response_id from first response
|
||||||
|
response_id_1 = resp1.id
|
||||||
|
logger.info(f"[TURN 1] Received response_id: {response_id_1}")
|
||||||
|
logger.info(f"[TURN 1] Model response: {resp1.output_text}")
|
||||||
|
|
||||||
|
assert response_id_1 is not None, "First response should have an id"
|
||||||
|
assert len(resp1.output_text) > 0, "First response should have content"
|
||||||
|
|
||||||
|
# Turn 2: Send follow-up message with previous_response_id
|
||||||
|
# Ask the model to list all messages to verify state was combined
|
||||||
|
logger.info(
|
||||||
|
f"\n[TURN 2] Sending follow-up with previous_response_id={response_id_1}"
|
||||||
|
)
|
||||||
|
resp2 = client.responses.create(
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
input="Please list all the messages you have received in our conversation, numbering each one.",
|
||||||
|
previous_response_id=response_id_1,
|
||||||
|
)
|
||||||
|
|
||||||
|
response_id_2 = resp2.id
|
||||||
|
logger.info(f"[TURN 2] Received response_id: {response_id_2}")
|
||||||
|
logger.info(f"[TURN 2] Model response: {resp2.output_text}")
|
||||||
|
|
||||||
|
assert response_id_2 is not None, "Second response should have an id"
|
||||||
|
assert response_id_2 != response_id_1, "Second response should have different id"
|
||||||
|
|
||||||
|
# Verify the model received the conversation history
|
||||||
|
# The response should reference both the initial message and the follow-up
|
||||||
|
response_lower = resp2.output_text.lower()
|
||||||
|
|
||||||
|
# Check if the model acknowledges receiving multiple messages
|
||||||
|
# Different models might format this differently, so we check for various indicators
|
||||||
|
has_conversation_context = (
|
||||||
|
"alice" in response_lower
|
||||||
|
or "pizza" in response_lower # References the name from turn 1
|
||||||
|
or "two" in response_lower # References the preference from turn 1
|
||||||
|
or "2" in response_lower # Mentions number of messages
|
||||||
|
or "first" in response_lower # Numeric indicator
|
||||||
|
or "second" # References first message
|
||||||
|
in response_lower # References second message
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[VALIDATION] Response contains conversation markers: {has_conversation_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print("Conversation State Test Results:")
|
||||||
|
print(f"Turn 1 Response ID: {response_id_1}")
|
||||||
|
print(f"Turn 2 Response ID: {response_id_2}")
|
||||||
|
print(f"Turn 1 Output: {resp1.output_text[:100]}...")
|
||||||
|
print(f"Turn 2 Output: {resp2.output_text}")
|
||||||
|
print(f"Conversation Context Preserved: {has_conversation_context}")
|
||||||
|
print(f"{'='*80}\n")
|
||||||
|
|
||||||
|
assert has_conversation_context, (
|
||||||
|
f"Model should have received conversation history. "
|
||||||
|
f"Response: {resp2.output_text}"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_conversation_state_management_two_turn_streaming():
|
||||||
|
"""
|
||||||
|
Test conversation state management across two turns with streaming:
|
||||||
|
1. Send initial streaming message to non-OpenAI model via v1/responses
|
||||||
|
2. Capture response_id from first response
|
||||||
|
3. Send second streaming message with previous_response_id
|
||||||
|
4. Verify model receives both messages in correct order
|
||||||
|
"""
|
||||||
|
base_url = LLM_GATEWAY_ENDPOINT.replace("/v1/chat/completions", "")
|
||||||
|
client = openai.OpenAI(api_key="test-key", base_url=f"{base_url}/v1")
|
||||||
|
|
||||||
|
logger.info("\n" + "=" * 80)
|
||||||
|
logger.info("TEST: Conversation State Management - Two Turn Streaming Flow")
|
||||||
|
logger.info("=" * 80)
|
||||||
|
|
||||||
|
# Turn 1: Send initial streaming message to Anthropic (non-OpenAI model)
|
||||||
|
logger.info("\n[TURN 1] Sending initial streaming message...")
|
||||||
|
stream1 = client.responses.create(
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
input="My name is Alice and I like pizza.",
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect streamed content and capture response_id
|
||||||
|
text_chunks_1 = []
|
||||||
|
response_id_1 = None
|
||||||
|
|
||||||
|
for event in stream1:
|
||||||
|
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
|
||||||
|
event, "delta", None
|
||||||
|
):
|
||||||
|
text_chunks_1.append(event.delta)
|
||||||
|
|
||||||
|
# Capture response_id from response.completed event
|
||||||
|
if getattr(event, "type", None) == "response.completed" and getattr(
|
||||||
|
event, "response", None
|
||||||
|
):
|
||||||
|
response_id_1 = event.response.id
|
||||||
|
|
||||||
|
output_1 = "".join(text_chunks_1)
|
||||||
|
logger.info(f"[TURN 1] Received response_id: {response_id_1}")
|
||||||
|
logger.info(f"[TURN 1] Model response: {output_1}")
|
||||||
|
|
||||||
|
assert response_id_1 is not None, "First response should have an id"
|
||||||
|
assert len(output_1) > 0, "First response should have content"
|
||||||
|
|
||||||
|
# Turn 2: Send follow-up streaming message with previous_response_id
|
||||||
|
logger.info(
|
||||||
|
f"\n[TURN 2] Sending follow-up streaming request with previous_response_id={response_id_1}"
|
||||||
|
)
|
||||||
|
stream2 = client.responses.create(
|
||||||
|
model="claude-sonnet-4-20250514",
|
||||||
|
input="Please list all the messages you have received in our conversation, numbering each one.",
|
||||||
|
previous_response_id=response_id_1,
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Collect streamed content from second response
|
||||||
|
text_chunks_2 = []
|
||||||
|
response_id_2 = None
|
||||||
|
|
||||||
|
for event in stream2:
|
||||||
|
if getattr(event, "type", None) == "response.output_text.delta" and getattr(
|
||||||
|
event, "delta", None
|
||||||
|
):
|
||||||
|
text_chunks_2.append(event.delta)
|
||||||
|
|
||||||
|
# Capture response_id from response.completed event
|
||||||
|
if getattr(event, "type", None) == "response.completed" and getattr(
|
||||||
|
event, "response", None
|
||||||
|
):
|
||||||
|
response_id_2 = event.response.id
|
||||||
|
|
||||||
|
output_2 = "".join(text_chunks_2)
|
||||||
|
logger.info(f"[TURN 2] Received response_id: {response_id_2}")
|
||||||
|
logger.info(f"[TURN 2] Model response: {output_2}")
|
||||||
|
|
||||||
|
assert response_id_2 is not None, "Second response should have an id"
|
||||||
|
assert response_id_2 != response_id_1, "Second response should have different id"
|
||||||
|
|
||||||
|
# Verify the model received the conversation history
|
||||||
|
response_lower = output_2.lower()
|
||||||
|
|
||||||
|
# Check if the model acknowledges receiving multiple messages
|
||||||
|
has_conversation_context = (
|
||||||
|
"alice" in response_lower
|
||||||
|
or "pizza" in response_lower # References the name from turn 1
|
||||||
|
or "two" in response_lower # References the preference from turn 1
|
||||||
|
or "2" in response_lower # Mentions number of messages
|
||||||
|
or "first" in response_lower # Numeric indicator
|
||||||
|
or "second" # References first message
|
||||||
|
in response_lower # References second message
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f"\n[VALIDATION] Conversation context preserved: {has_conversation_context}"
|
||||||
|
)
|
||||||
|
logger.info(
|
||||||
|
f"[VALIDATION] Response contains conversation markers: {has_conversation_context}"
|
||||||
|
)
|
||||||
|
|
||||||
|
print(f"\n{'='*80}")
|
||||||
|
print("Streaming Conversation State Test Results:")
|
||||||
|
print(f"Turn 1 Response ID: {response_id_1}")
|
||||||
|
print(f"Turn 2 Response ID: {response_id_2}")
|
||||||
|
print(f"Turn 1 Output: {output_1[:100]}...")
|
||||||
|
print(f"Turn 2 Output: {output_2}")
|
||||||
|
print(f"Conversation Context Preserved: {has_conversation_context}")
|
||||||
|
print(f"{'='*80}\n")
|
||||||
|
|
||||||
|
assert has_conversation_context, (
|
||||||
|
f"Model should have received conversation history. " f"Response: {output_2}"
|
||||||
|
)
|
||||||
Loading…
Add table
Add a link
Reference in a new issue