mirror of
https://github.com/dograh-hq/dograh.git
synced 2026-06-10 08:05:22 +02:00
chore: remove looptalk (#299)
* chore: remove looptalk Remove looptalk in the current version. We will be rethinking looptalk in a fresh way. * chore: formatting fix
This commit is contained in:
parent
0523dcb079
commit
45b00cd5d0
34 changed files with 214 additions and 4634 deletions
204
api/alembic/versions/4c1f1e3e8ef2_drop_looptalk_tables.py
Normal file
204
api/alembic/versions/4c1f1e3e8ef2_drop_looptalk_tables.py
Normal file
|
|
@ -0,0 +1,204 @@
|
|||
"""drop_looptalk_tables
|
||||
|
||||
Revision ID: 4c1f1e3e8ef2
|
||||
Revises: 6499c608d0f6
|
||||
Create Date: 2026-05-16 14:46:18.296517
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "4c1f1e3e8ef2"
|
||||
down_revision: Union[str, None] = "6499c608d0f6"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop child table first so its FK to looptalk_test_sessions is removed before the parent is dropped.
|
||||
op.drop_index(
|
||||
op.f("ix_looptalk_conversations_session_id"),
|
||||
table_name="looptalk_conversations",
|
||||
)
|
||||
op.drop_table("looptalk_conversations")
|
||||
op.drop_index(
|
||||
op.f("ix_looptalk_test_sessions_group_id"), table_name="looptalk_test_sessions"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_looptalk_test_sessions_load_test_group_id"),
|
||||
table_name="looptalk_test_sessions",
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_looptalk_test_sessions_org_id"), table_name="looptalk_test_sessions"
|
||||
)
|
||||
op.drop_index(
|
||||
op.f("ix_looptalk_test_sessions_status"), table_name="looptalk_test_sessions"
|
||||
)
|
||||
op.drop_table("looptalk_test_sessions")
|
||||
sa.Enum(
|
||||
"pending", "running", "completed", "failed", name="test_session_status"
|
||||
).drop(op.get_bind())
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
sa.Enum(
|
||||
"pending", "running", "completed", "failed", name="test_session_status"
|
||||
).create(op.get_bind())
|
||||
op.create_table(
|
||||
"looptalk_conversations",
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("test_session_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("duration_seconds", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
sa.Column(
|
||||
"actor_recording_url", sa.VARCHAR(), autoincrement=False, nullable=True
|
||||
),
|
||||
sa.Column(
|
||||
"adversary_recording_url", sa.VARCHAR(), autoincrement=False, nullable=True
|
||||
),
|
||||
sa.Column(
|
||||
"combined_recording_url", sa.VARCHAR(), autoincrement=False, nullable=True
|
||||
),
|
||||
sa.Column(
|
||||
"transcript",
|
||||
postgresql.JSON(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"metrics",
|
||||
postgresql.JSON(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
postgresql.TIMESTAMP(timezone=True),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"ended_at",
|
||||
postgresql.TIMESTAMP(timezone=True),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["test_session_id"],
|
||||
["looptalk_test_sessions.id"],
|
||||
name=op.f("looptalk_conversations_test_session_id_fkey"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("looptalk_conversations_pkey")),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_looptalk_conversations_session_id"),
|
||||
"looptalk_conversations",
|
||||
["test_session_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_table(
|
||||
"looptalk_test_sessions",
|
||||
sa.Column("id", sa.INTEGER(), autoincrement=True, nullable=False),
|
||||
sa.Column("organization_id", sa.INTEGER(), autoincrement=False, nullable=False),
|
||||
sa.Column("name", sa.VARCHAR(), autoincrement=False, nullable=False),
|
||||
sa.Column(
|
||||
"status",
|
||||
postgresql.ENUM(
|
||||
"pending",
|
||||
"running",
|
||||
"completed",
|
||||
"failed",
|
||||
name="test_session_status",
|
||||
create_type=False,
|
||||
),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"actor_workflow_id", sa.INTEGER(), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"adversary_workflow_id", sa.INTEGER(), autoincrement=False, nullable=False
|
||||
),
|
||||
sa.Column(
|
||||
"load_test_group_id", sa.VARCHAR(), autoincrement=False, nullable=True
|
||||
),
|
||||
sa.Column("test_index", sa.INTEGER(), autoincrement=False, nullable=True),
|
||||
sa.Column(
|
||||
"config",
|
||||
postgresql.JSON(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"results",
|
||||
postgresql.JSON(astext_type=sa.Text()),
|
||||
autoincrement=False,
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("error", sa.VARCHAR(), autoincrement=False, nullable=True),
|
||||
sa.Column(
|
||||
"created_at",
|
||||
postgresql.TIMESTAMP(timezone=True),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"started_at",
|
||||
postgresql.TIMESTAMP(timezone=True),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"completed_at",
|
||||
postgresql.TIMESTAMP(timezone=True),
|
||||
autoincrement=False,
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["actor_workflow_id"],
|
||||
["workflows.id"],
|
||||
name=op.f("looptalk_test_sessions_actor_workflow_id_fkey"),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["adversary_workflow_id"],
|
||||
["workflows.id"],
|
||||
name=op.f("looptalk_test_sessions_adversary_workflow_id_fkey"),
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
name=op.f("looptalk_test_sessions_organization_id_fkey"),
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name=op.f("looptalk_test_sessions_pkey")),
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_looptalk_test_sessions_status"),
|
||||
"looptalk_test_sessions",
|
||||
["status"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_looptalk_test_sessions_org_id"),
|
||||
"looptalk_test_sessions",
|
||||
["organization_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_looptalk_test_sessions_load_test_group_id"),
|
||||
"looptalk_test_sessions",
|
||||
["load_test_group_id"],
|
||||
unique=False,
|
||||
)
|
||||
op.create_index(
|
||||
op.f("ix_looptalk_test_sessions_group_id"),
|
||||
"looptalk_test_sessions",
|
||||
["load_test_group_id"],
|
||||
unique=False,
|
||||
)
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -4,7 +4,6 @@ from api.db.campaign_client import CampaignClient
|
|||
from api.db.embed_token_client import EmbedTokenClient
|
||||
from api.db.integration_client import IntegrationClient
|
||||
from api.db.knowledge_base_client import KnowledgeBaseClient
|
||||
from api.db.looptalk_client import LoopTalkClient
|
||||
from api.db.organization_client import OrganizationClient
|
||||
from api.db.organization_configuration_client import OrganizationConfigurationClient
|
||||
from api.db.organization_usage_client import OrganizationUsageClient
|
||||
|
|
@ -29,7 +28,6 @@ class DBClient(
|
|||
OrganizationUsageClient,
|
||||
IntegrationClient,
|
||||
WorkflowTemplateClient,
|
||||
LoopTalkClient,
|
||||
CampaignClient,
|
||||
ReportsClient,
|
||||
APIKeyClient,
|
||||
|
|
@ -54,7 +52,6 @@ class DBClient(
|
|||
- OrganizationUsageClient: handles organization usage and quota operations
|
||||
- IntegrationClient: handles integration operations
|
||||
- WorkflowTemplateClient: handles workflow template operations
|
||||
- LoopTalkClient: handles LoopTalk testing operations
|
||||
- CampaignClient: handles campaign operations
|
||||
- ReportsClient: handles reports and analytics operations
|
||||
- APIKeyClient: handles API key operations
|
||||
|
|
|
|||
|
|
@ -1,297 +0,0 @@
|
|||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from sqlalchemy.future import select
|
||||
from sqlalchemy.orm import selectinload
|
||||
|
||||
from api.db.base_client import BaseDBClient
|
||||
from api.db.models import (
|
||||
LoopTalkConversation,
|
||||
LoopTalkTestSession,
|
||||
WorkflowModel,
|
||||
)
|
||||
|
||||
|
||||
class LoopTalkClient(BaseDBClient):
|
||||
"""Database client for LoopTalk testing operations."""
|
||||
|
||||
async def create_test_session(
|
||||
self,
|
||||
organization_id: int,
|
||||
name: str,
|
||||
actor_workflow_id: int,
|
||||
adversary_workflow_id: int,
|
||||
config: Dict[str, Any],
|
||||
load_test_group_id: Optional[str] = None,
|
||||
test_index: Optional[int] = None,
|
||||
) -> LoopTalkTestSession:
|
||||
"""Create a new LoopTalk test session."""
|
||||
async with self.async_session() as session:
|
||||
test_session = LoopTalkTestSession(
|
||||
organization_id=organization_id,
|
||||
name=name,
|
||||
actor_workflow_id=actor_workflow_id,
|
||||
adversary_workflow_id=adversary_workflow_id,
|
||||
config=config,
|
||||
load_test_group_id=load_test_group_id,
|
||||
test_index=test_index,
|
||||
status="pending",
|
||||
)
|
||||
session.add(test_session)
|
||||
await session.commit()
|
||||
await session.refresh(test_session)
|
||||
return test_session
|
||||
|
||||
async def get_test_session(
|
||||
self, test_session_id: int, organization_id: int
|
||||
) -> Optional[LoopTalkTestSession]:
|
||||
"""Get a test session by ID."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(LoopTalkTestSession)
|
||||
.options(
|
||||
selectinload(LoopTalkTestSession.actor_workflow).selectinload(
|
||||
WorkflowModel.released_definition
|
||||
),
|
||||
selectinload(LoopTalkTestSession.adversary_workflow).selectinload(
|
||||
WorkflowModel.released_definition
|
||||
),
|
||||
selectinload(LoopTalkTestSession.conversations),
|
||||
)
|
||||
.where(
|
||||
LoopTalkTestSession.id == test_session_id,
|
||||
LoopTalkTestSession.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
return result.scalar_one_or_none()
|
||||
|
||||
async def list_test_sessions(
|
||||
self,
|
||||
organization_id: int,
|
||||
status: Optional[str] = None,
|
||||
load_test_group_id: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
) -> List[LoopTalkTestSession]:
|
||||
"""List test sessions with optional filtering."""
|
||||
async with self.async_session() as session:
|
||||
query = select(LoopTalkTestSession).where(
|
||||
LoopTalkTestSession.organization_id == organization_id
|
||||
)
|
||||
|
||||
if status:
|
||||
# "active" is a virtual status used by the UI to represent
|
||||
# both "pending" and "running" sessions. Translate it into
|
||||
# the real enum values stored in the database to avoid
|
||||
# invalid enum casting errors (e.g. asyncpg InvalidTextRepresentationError).
|
||||
if status == "active":
|
||||
query = query.where(
|
||||
LoopTalkTestSession.status.in_(["pending", "running"])
|
||||
)
|
||||
else:
|
||||
query = query.where(LoopTalkTestSession.status == status)
|
||||
|
||||
if load_test_group_id:
|
||||
query = query.where(
|
||||
LoopTalkTestSession.load_test_group_id == load_test_group_id
|
||||
)
|
||||
|
||||
query = (
|
||||
query.order_by(LoopTalkTestSession.created_at.desc())
|
||||
.limit(limit)
|
||||
.offset(offset)
|
||||
)
|
||||
|
||||
result = await session.execute(query)
|
||||
return result.scalars().all()
|
||||
|
||||
async def update_test_session_status(
|
||||
self,
|
||||
test_session_id: int,
|
||||
status: str,
|
||||
error: Optional[str] = None,
|
||||
results: Optional[Dict[str, Any]] = None,
|
||||
) -> LoopTalkTestSession:
|
||||
"""Update test session status and related fields."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(LoopTalkTestSession).where(
|
||||
LoopTalkTestSession.id == test_session_id
|
||||
)
|
||||
)
|
||||
test_session = result.scalar_one()
|
||||
|
||||
test_session.status = status
|
||||
|
||||
if status == "running":
|
||||
test_session.started_at = datetime.now(UTC)
|
||||
elif status in ["completed", "failed"]:
|
||||
test_session.completed_at = datetime.now(UTC)
|
||||
|
||||
if error:
|
||||
test_session.error = error
|
||||
|
||||
if results:
|
||||
test_session.results = results
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(test_session)
|
||||
return test_session
|
||||
|
||||
async def create_conversation(self, test_session_id: int) -> LoopTalkConversation:
|
||||
"""Create a new conversation for a test session."""
|
||||
async with self.async_session() as session:
|
||||
conversation = LoopTalkConversation(test_session_id=test_session_id)
|
||||
session.add(conversation)
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
async def update_conversation(
|
||||
self,
|
||||
conversation_id: int,
|
||||
duration_seconds: Optional[int] = None,
|
||||
actor_recording_url: Optional[str] = None,
|
||||
adversary_recording_url: Optional[str] = None,
|
||||
combined_recording_url: Optional[str] = None,
|
||||
transcript: Optional[Dict[str, Any]] = None,
|
||||
metrics: Optional[Dict[str, Any]] = None,
|
||||
ended_at: Optional[datetime] = None,
|
||||
) -> LoopTalkConversation:
|
||||
"""Update conversation details."""
|
||||
async with self.async_session() as session:
|
||||
result = await session.execute(
|
||||
select(LoopTalkConversation).where(
|
||||
LoopTalkConversation.id == conversation_id
|
||||
)
|
||||
)
|
||||
conversation = result.scalar_one()
|
||||
|
||||
if duration_seconds is not None:
|
||||
conversation.duration_seconds = duration_seconds
|
||||
if actor_recording_url:
|
||||
conversation.actor_recording_url = actor_recording_url
|
||||
if adversary_recording_url:
|
||||
conversation.adversary_recording_url = adversary_recording_url
|
||||
if combined_recording_url:
|
||||
conversation.combined_recording_url = combined_recording_url
|
||||
if transcript:
|
||||
conversation.transcript = transcript
|
||||
if metrics:
|
||||
conversation.metrics = metrics
|
||||
if ended_at:
|
||||
conversation.ended_at = ended_at
|
||||
|
||||
await session.commit()
|
||||
await session.refresh(conversation)
|
||||
return conversation
|
||||
|
||||
# Note: Turn tracking is handled by Langfuse, not stored in our database
|
||||
|
||||
async def create_load_test_group(
|
||||
self,
|
||||
organization_id: int,
|
||||
name_prefix: str,
|
||||
actor_workflow_id: int,
|
||||
adversary_workflow_id: int,
|
||||
config: Dict[str, Any],
|
||||
test_count: int,
|
||||
) -> List[LoopTalkTestSession]:
|
||||
"""Create multiple test sessions for load testing."""
|
||||
load_test_group_id = str(uuid4())
|
||||
test_sessions = []
|
||||
|
||||
async with self.async_session() as session:
|
||||
for i in range(test_count):
|
||||
test_session = LoopTalkTestSession(
|
||||
organization_id=organization_id,
|
||||
name=f"{name_prefix} - Test {i + 1}",
|
||||
actor_workflow_id=actor_workflow_id,
|
||||
adversary_workflow_id=adversary_workflow_id,
|
||||
config=config,
|
||||
load_test_group_id=load_test_group_id,
|
||||
test_index=i,
|
||||
status="pending",
|
||||
)
|
||||
session.add(test_session)
|
||||
test_sessions.append(test_session)
|
||||
|
||||
await session.commit()
|
||||
|
||||
# Refresh all sessions
|
||||
for test_session in test_sessions:
|
||||
await session.refresh(test_session)
|
||||
|
||||
return test_sessions
|
||||
|
||||
async def get_load_test_group_stats(
|
||||
self, load_test_group_id: str, organization_id: int
|
||||
) -> Dict[str, Any]:
|
||||
"""Get statistics for a load test group."""
|
||||
from sqlalchemy import case, func
|
||||
|
||||
async with self.async_session() as session:
|
||||
# Get status counts using SQL aggregation
|
||||
counts_result = await session.execute(
|
||||
select(
|
||||
func.count().label("total"),
|
||||
func.sum(
|
||||
case((LoopTalkTestSession.status == "pending", 1), else_=0)
|
||||
).label("pending"),
|
||||
func.sum(
|
||||
case((LoopTalkTestSession.status == "running", 1), else_=0)
|
||||
).label("running"),
|
||||
func.sum(
|
||||
case((LoopTalkTestSession.status == "completed", 1), else_=0)
|
||||
).label("completed"),
|
||||
func.sum(
|
||||
case((LoopTalkTestSession.status == "failed", 1), else_=0)
|
||||
).label("failed"),
|
||||
).where(
|
||||
LoopTalkTestSession.load_test_group_id == load_test_group_id,
|
||||
LoopTalkTestSession.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
counts = counts_result.one()
|
||||
|
||||
# Get session details (still needed for the sessions list)
|
||||
sessions_result = await session.execute(
|
||||
select(
|
||||
LoopTalkTestSession.id,
|
||||
LoopTalkTestSession.name,
|
||||
LoopTalkTestSession.status,
|
||||
LoopTalkTestSession.test_index,
|
||||
LoopTalkTestSession.created_at,
|
||||
LoopTalkTestSession.started_at,
|
||||
LoopTalkTestSession.completed_at,
|
||||
LoopTalkTestSession.error,
|
||||
).where(
|
||||
LoopTalkTestSession.load_test_group_id == load_test_group_id,
|
||||
LoopTalkTestSession.organization_id == organization_id,
|
||||
)
|
||||
)
|
||||
sessions = sessions_result.all()
|
||||
|
||||
stats = {
|
||||
"total": counts.total or 0,
|
||||
"pending": counts.pending or 0,
|
||||
"running": counts.running or 0,
|
||||
"completed": counts.completed or 0,
|
||||
"failed": counts.failed or 0,
|
||||
"sessions": [
|
||||
{
|
||||
"id": s.id,
|
||||
"name": s.name,
|
||||
"status": s.status,
|
||||
"test_index": s.test_index,
|
||||
"created_at": s.created_at,
|
||||
"started_at": s.started_at,
|
||||
"completed_at": s.completed_at,
|
||||
"error": s.error,
|
||||
}
|
||||
for s in sessions
|
||||
],
|
||||
}
|
||||
|
||||
return stats
|
||||
|
|
@ -501,87 +501,6 @@ class WorkflowRunModel(Base):
|
|||
)
|
||||
|
||||
|
||||
# LoopTalk Testing Models
|
||||
class LoopTalkTestSession(Base):
|
||||
__tablename__ = "looptalk_test_sessions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
organization_id = Column(Integer, ForeignKey("organizations.id"), nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
status = Column(
|
||||
Enum("pending", "running", "completed", "failed", name="test_session_status"),
|
||||
nullable=False,
|
||||
default="pending",
|
||||
)
|
||||
|
||||
# Workflow configuration
|
||||
actor_workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
|
||||
adversary_workflow_id = Column(Integer, ForeignKey("workflows.id"), nullable=False)
|
||||
|
||||
# Load testing configuration
|
||||
load_test_group_id = Column(String, nullable=True, index=True)
|
||||
test_index = Column(Integer, nullable=True)
|
||||
|
||||
# Test metadata
|
||||
config = Column(JSON, nullable=False, default=dict)
|
||||
results = Column(JSON, nullable=False, default=dict)
|
||||
error = Column(String, nullable=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
started_at = Column(DateTime(timezone=True), nullable=True)
|
||||
completed_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
organization = relationship("OrganizationModel")
|
||||
actor_workflow = relationship("WorkflowModel", foreign_keys=[actor_workflow_id])
|
||||
adversary_workflow = relationship(
|
||||
"WorkflowModel", foreign_keys=[adversary_workflow_id]
|
||||
)
|
||||
conversations = relationship("LoopTalkConversation", back_populates="test_session")
|
||||
|
||||
# Indexes for performance
|
||||
__table_args__ = (
|
||||
Index("ix_looptalk_test_sessions_org_id", "organization_id"),
|
||||
Index("ix_looptalk_test_sessions_group_id", "load_test_group_id"),
|
||||
Index("ix_looptalk_test_sessions_status", "status"),
|
||||
)
|
||||
|
||||
|
||||
class LoopTalkConversation(Base):
|
||||
__tablename__ = "looptalk_conversations"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
test_session_id = Column(
|
||||
Integer, ForeignKey("looptalk_test_sessions.id"), nullable=False
|
||||
)
|
||||
|
||||
# Conversation metadata
|
||||
duration_seconds = Column(Integer, nullable=True)
|
||||
# Note: Turn tracking is handled by Langfuse, not stored here
|
||||
|
||||
# Audio recording URLs
|
||||
actor_recording_url = Column(String, nullable=True)
|
||||
adversary_recording_url = Column(String, nullable=True)
|
||||
combined_recording_url = Column(String, nullable=True)
|
||||
|
||||
# Transcripts (if needed for quick access)
|
||||
transcript = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Metrics
|
||||
metrics = Column(JSON, nullable=False, default=dict)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), default=lambda: datetime.now(UTC))
|
||||
ended_at = Column(DateTime(timezone=True), nullable=True)
|
||||
|
||||
# Relationships
|
||||
test_session = relationship("LoopTalkTestSession", back_populates="conversations")
|
||||
|
||||
# Indexes
|
||||
__table_args__ = (Index("ix_looptalk_conversations_session_id", "test_session_id"),)
|
||||
|
||||
|
||||
class OrganizationUsageCycleModel(Base):
|
||||
"""
|
||||
This model is used to track the usage of Dograh tokens for an organization for a given usage
|
||||
|
|
|
|||
|
|
@ -1,316 +0,0 @@
|
|||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
BackgroundTasks,
|
||||
Depends,
|
||||
HTTPException,
|
||||
WebSocket,
|
||||
)
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from api.db import db_client
|
||||
from api.db.models import UserModel
|
||||
from api.services.auth.depends import get_user
|
||||
from api.services.looptalk.orchestrator import LoopTalkTestOrchestrator
|
||||
|
||||
router = APIRouter(prefix="/looptalk")
|
||||
|
||||
|
||||
# Request/Response Models
|
||||
class CreateTestSessionRequest(BaseModel):
|
||||
name: str
|
||||
actor_workflow_id: int
|
||||
adversary_workflow_id: int
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class StartTestSessionRequest(BaseModel):
|
||||
test_session_id: int
|
||||
|
||||
|
||||
class CreateLoadTestRequest(BaseModel):
|
||||
name_prefix: str
|
||||
actor_workflow_id: int
|
||||
adversary_workflow_id: int
|
||||
test_count: int = Field(ge=1, le=10)
|
||||
config: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class TestSessionResponse(BaseModel):
|
||||
id: int
|
||||
name: str
|
||||
status: str
|
||||
actor_workflow_id: int
|
||||
adversary_workflow_id: int
|
||||
load_test_group_id: Optional[str]
|
||||
test_index: Optional[int]
|
||||
config: Dict[str, Any]
|
||||
results: Optional[Dict[str, Any]]
|
||||
error: Optional[str]
|
||||
created_at: datetime
|
||||
started_at: Optional[datetime]
|
||||
completed_at: Optional[datetime]
|
||||
|
||||
|
||||
class ConversationResponse(BaseModel):
|
||||
id: int
|
||||
test_session_id: int
|
||||
duration_seconds: Optional[int]
|
||||
actor_recording_url: Optional[str]
|
||||
adversary_recording_url: Optional[str]
|
||||
combined_recording_url: Optional[str]
|
||||
transcript: Optional[Dict[str, Any]]
|
||||
metrics: Optional[Dict[str, Any]]
|
||||
created_at: datetime
|
||||
ended_at: Optional[datetime]
|
||||
|
||||
|
||||
# Note: Turn tracking is handled by Langfuse, not exposed via API
|
||||
|
||||
|
||||
class LoadTestStatsResponse(BaseModel):
|
||||
total: int
|
||||
pending: int
|
||||
running: int
|
||||
completed: int
|
||||
failed: int
|
||||
sessions: List[Dict[str, Any]]
|
||||
|
||||
|
||||
# Singleton orchestrator instance
|
||||
_orchestrator: Optional[LoopTalkTestOrchestrator] = None
|
||||
|
||||
|
||||
def get_orchestrator() -> LoopTalkTestOrchestrator:
|
||||
"""Get or create the LoopTalk orchestrator instance."""
|
||||
global _orchestrator
|
||||
if _orchestrator is None:
|
||||
_orchestrator = LoopTalkTestOrchestrator(db_client=db_client)
|
||||
return _orchestrator
|
||||
|
||||
|
||||
@router.post("/test-sessions", response_model=TestSessionResponse)
|
||||
async def create_test_session(
|
||||
request: CreateTestSessionRequest, user: UserModel = Depends(get_user)
|
||||
):
|
||||
"""Create a new LoopTalk test session."""
|
||||
|
||||
# Verify user has access to both workflows
|
||||
actor_workflow = await db_client.get_workflow(request.actor_workflow_id, user.id)
|
||||
if not actor_workflow:
|
||||
raise HTTPException(status_code=404, detail="Actor workflow not found")
|
||||
|
||||
adversary_workflow = await db_client.get_workflow(
|
||||
request.adversary_workflow_id, user.id
|
||||
)
|
||||
if not adversary_workflow:
|
||||
raise HTTPException(status_code=404, detail="Adversary workflow not found")
|
||||
|
||||
# Create test session
|
||||
test_session = await db_client.create_test_session(
|
||||
organization_id=user.selected_organization_id,
|
||||
name=request.name,
|
||||
actor_workflow_id=request.actor_workflow_id,
|
||||
adversary_workflow_id=request.adversary_workflow_id,
|
||||
config=request.config,
|
||||
)
|
||||
|
||||
return test_session
|
||||
|
||||
|
||||
@router.get("/test-sessions", response_model=List[TestSessionResponse])
|
||||
async def list_test_sessions(
|
||||
status: Optional[str] = None,
|
||||
load_test_group_id: Optional[str] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""List LoopTalk test sessions."""
|
||||
|
||||
test_sessions = await db_client.list_test_sessions(
|
||||
organization_id=user.selected_organization_id,
|
||||
status=status,
|
||||
load_test_group_id=load_test_group_id,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
)
|
||||
|
||||
return test_sessions
|
||||
|
||||
|
||||
@router.get("/test-sessions/{test_session_id}", response_model=TestSessionResponse)
|
||||
async def get_test_session(test_session_id: int, user: UserModel = Depends(get_user)):
|
||||
"""Get a specific test session."""
|
||||
|
||||
test_session = await db_client.get_test_session(
|
||||
test_session_id=test_session_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
||||
if not test_session:
|
||||
raise HTTPException(status_code=404, detail="Test session not found")
|
||||
|
||||
return test_session
|
||||
|
||||
|
||||
@router.post("/test-sessions/{test_session_id}/start")
|
||||
async def start_test_session(
|
||||
test_session_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: UserModel = Depends(get_user),
|
||||
orchestrator: LoopTalkTestOrchestrator = Depends(get_orchestrator),
|
||||
):
|
||||
"""Start a LoopTalk test session."""
|
||||
|
||||
# Verify test session exists and user has access
|
||||
test_session = await db_client.get_test_session(
|
||||
test_session_id=test_session_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
||||
if not test_session:
|
||||
raise HTTPException(status_code=404, detail="Test session not found")
|
||||
|
||||
if test_session.status != "pending":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Test session is {test_session.status}, not pending",
|
||||
)
|
||||
|
||||
# Start test session in background
|
||||
background_tasks.add_task(
|
||||
orchestrator.start_test_session,
|
||||
test_session_id=test_session_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
return {"message": "Test session starting", "test_session_id": test_session_id}
|
||||
|
||||
|
||||
@router.post("/test-sessions/{test_session_id}/stop")
|
||||
async def stop_test_session(
|
||||
test_session_id: int,
|
||||
user: UserModel = Depends(get_user),
|
||||
orchestrator: LoopTalkTestOrchestrator = Depends(get_orchestrator),
|
||||
):
|
||||
"""Stop a running test session."""
|
||||
|
||||
# Verify test session exists and user has access
|
||||
test_session = await db_client.get_test_session(
|
||||
test_session_id=test_session_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
||||
if not test_session:
|
||||
raise HTTPException(status_code=404, detail="Test session not found")
|
||||
|
||||
if test_session.status != "running":
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=f"Test session is {test_session.status}, not running",
|
||||
)
|
||||
|
||||
# Stop test session
|
||||
result = await orchestrator.stop_test_session(test_session_id=test_session_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get("/test-sessions/{test_session_id}/conversation")
|
||||
async def get_test_session_conversation(
|
||||
test_session_id: int, user: UserModel = Depends(get_user)
|
||||
):
|
||||
"""Get conversation details for a test session."""
|
||||
|
||||
# Verify test session exists and user has access
|
||||
test_session = await db_client.get_test_session(
|
||||
test_session_id=test_session_id, organization_id=user.selected_organization_id
|
||||
)
|
||||
|
||||
if not test_session:
|
||||
raise HTTPException(status_code=404, detail="Test session not found")
|
||||
|
||||
# Get conversation
|
||||
if test_session.conversations:
|
||||
conversation = test_session.conversations[
|
||||
0
|
||||
] # For now, one conversation per session
|
||||
|
||||
# Note: Turn details are available in Langfuse, not here
|
||||
return {
|
||||
"conversation": conversation,
|
||||
"message": "Turn details are tracked in Langfuse",
|
||||
}
|
||||
|
||||
return {"conversation": None}
|
||||
|
||||
|
||||
@router.post("/load-tests", response_model=Dict[str, Any])
|
||||
async def create_load_test(
|
||||
request: CreateLoadTestRequest,
|
||||
background_tasks: BackgroundTasks,
|
||||
user: UserModel = Depends(get_user),
|
||||
orchestrator: LoopTalkTestOrchestrator = Depends(get_orchestrator),
|
||||
):
|
||||
"""Create and start a load test."""
|
||||
|
||||
# Verify user has access to both workflows
|
||||
actor_workflow = await db_client.get_workflow(request.actor_workflow_id, user.id)
|
||||
if not actor_workflow:
|
||||
raise HTTPException(status_code=404, detail="Actor workflow not found")
|
||||
|
||||
adversary_workflow = await db_client.get_workflow(
|
||||
request.adversary_workflow_id, user.id
|
||||
)
|
||||
if not adversary_workflow:
|
||||
raise HTTPException(status_code=404, detail="Adversary workflow not found")
|
||||
|
||||
# Start load test in background
|
||||
result = await orchestrator.start_load_test(
|
||||
organization_id=user.selected_organization_id,
|
||||
name_prefix=request.name_prefix,
|
||||
actor_workflow_id=request.actor_workflow_id,
|
||||
adversary_workflow_id=request.adversary_workflow_id,
|
||||
config=request.config,
|
||||
test_count=request.test_count,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@router.get(
|
||||
"/load-tests/{load_test_group_id}/stats", response_model=LoadTestStatsResponse
|
||||
)
|
||||
async def get_load_test_stats(
|
||||
load_test_group_id: str, user: UserModel = Depends(get_user)
|
||||
):
|
||||
"""Get statistics for a load test group."""
|
||||
|
||||
stats = await db_client.get_load_test_group_stats(
|
||||
load_test_group_id=load_test_group_id,
|
||||
organization_id=user.selected_organization_id,
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/active-tests")
|
||||
async def get_active_tests(
|
||||
orchestrator: LoopTalkTestOrchestrator = Depends(get_orchestrator),
|
||||
user: UserModel = Depends(get_user),
|
||||
):
|
||||
"""Get information about currently active test sessions."""
|
||||
|
||||
return orchestrator.get_active_test_info()
|
||||
|
||||
|
||||
@router.websocket("/test-sessions/{test_session_id}/audio-stream")
|
||||
async def audio_stream_websocket(
|
||||
websocket: WebSocket,
|
||||
test_session_id: int,
|
||||
role: str = "mixed", # "actor", "adversary", or "mixed"
|
||||
token: Optional[str] = None,
|
||||
):
|
||||
"""WebSocket endpoint for real-time audio streaming from LoopTalk test sessions."""
|
||||
# TODO: to be implemented
|
||||
pass
|
||||
|
|
@ -8,7 +8,6 @@ from api.routes.campaign import router as campaign_router
|
|||
from api.routes.credentials import router as credentials_router
|
||||
from api.routes.integration import router as integration_router
|
||||
from api.routes.knowledge_base import router as knowledge_base_router
|
||||
from api.routes.looptalk import router as looptalk_router
|
||||
from api.routes.node_types import router as node_types_router
|
||||
from api.routes.organization import router as organization_router
|
||||
from api.routes.organization_usage import router as organization_usage_router
|
||||
|
|
@ -44,7 +43,6 @@ router.include_router(integration_router)
|
|||
router.include_router(organization_router)
|
||||
router.include_router(s3_router)
|
||||
router.include_router(service_keys_router)
|
||||
router.include_router(looptalk_router)
|
||||
router.include_router(organization_usage_router)
|
||||
router.include_router(reports_router)
|
||||
router.include_router(webrtc_signaling_router)
|
||||
|
|
|
|||
|
|
@ -79,7 +79,7 @@ async def _validate_and_extract_workflow_run_id(
|
|||
|
||||
Args:
|
||||
key: S3 object key
|
||||
allow_special_paths: If True, allows looptalk/voicemail paths
|
||||
allow_special_paths: If True, allows voicemail paths
|
||||
|
||||
Returns:
|
||||
workflow_run_id if found, None for special paths (when allowed)
|
||||
|
|
@ -91,10 +91,7 @@ async def _validate_and_extract_workflow_run_id(
|
|||
run_id_str = key[len("transcripts/") : -4] # strip prefix & suffix
|
||||
elif key.startswith("recordings/") and key.endswith(".wav"):
|
||||
run_id_str = key[len("recordings/") : -4]
|
||||
elif allow_special_paths and (
|
||||
key.startswith("looptalk/") or key.startswith("voicemail_detections/")
|
||||
):
|
||||
# Allow looptalk and voicemail paths for debugging (only if explicitly allowed)
|
||||
elif allow_special_paths and key.startswith("voicemail_detections/"):
|
||||
return None # Skip validation for these paths
|
||||
else:
|
||||
raise HTTPException(status_code=400, detail="Invalid key format")
|
||||
|
|
@ -258,7 +255,7 @@ async def get_file_metadata(
|
|||
f"METADATA: Using stored {backend} for metadata request - key: {key}"
|
||||
)
|
||||
else:
|
||||
# Fallback to current storage for legacy records or looptalk/voicemail files
|
||||
# Fallback to current storage for legacy records or voicemail files
|
||||
storage = storage_fs
|
||||
current_backend = StorageBackend.get_current_backend()
|
||||
logger.warning(
|
||||
|
|
|
|||
|
|
@ -1,3 +0,0 @@
|
|||
from .orchestrator import LoopTalkTestOrchestrator
|
||||
|
||||
__all__ = ["LoopTalkTestOrchestrator"]
|
||||
|
|
@ -1,220 +0,0 @@
|
|||
"""
|
||||
Audio streaming processor for LoopTalk real-time audio monitoring.
|
||||
|
||||
This processor captures audio from both actor and adversary agents and streams
|
||||
it to connected WebRTC clients for real-time monitoring.
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from typing import Dict, Set
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.audio.utils import mix_audio
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection, FrameProcessor
|
||||
|
||||
|
||||
class LoopTalkAudioStreamer(FrameProcessor):
|
||||
"""
|
||||
Processes audio frames from LoopTalk conversations and streams to WebRTC clients.
|
||||
|
||||
This processor sits in the pipeline and captures all audio frames, then
|
||||
forwards them to connected WebRTC clients for real-time monitoring.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
test_session_id: str,
|
||||
role: str, # "actor" or "adversary"
|
||||
**kwargs,
|
||||
):
|
||||
super().__init__(**kwargs)
|
||||
self._test_session_id = test_session_id
|
||||
self._role = role
|
||||
self._listeners: Set[asyncio.Queue] = set()
|
||||
self._sample_rate = 16000 # Default sample rate
|
||||
self._num_channels = 1
|
||||
|
||||
def add_listener(self, queue: asyncio.Queue):
|
||||
"""Add a listener queue for streaming audio."""
|
||||
self._listeners.add(queue)
|
||||
|
||||
def remove_listener(self, queue: asyncio.Queue):
|
||||
"""Remove a listener queue."""
|
||||
self._listeners.discard(queue)
|
||||
|
||||
async def process_frame(self, frame: Frame, direction: FrameDirection):
|
||||
"""Process audio frames and stream to listeners."""
|
||||
await super().process_frame(frame, direction)
|
||||
|
||||
# Capture both input and output audio
|
||||
if isinstance(frame, (InputAudioRawFrame, OutputAudioRawFrame)):
|
||||
# Extract audio data
|
||||
audio_data = frame.audio
|
||||
sample_rate = frame.sample_rate
|
||||
num_channels = frame.num_channels
|
||||
|
||||
# Store sample rate for reference
|
||||
if sample_rate:
|
||||
self._sample_rate = sample_rate
|
||||
if num_channels:
|
||||
self._num_channels = num_channels
|
||||
|
||||
# Stream to all listeners
|
||||
if self._listeners and audio_data:
|
||||
# Create a packet with metadata
|
||||
packet = {
|
||||
"test_session_id": self._test_session_id,
|
||||
"role": self._role,
|
||||
"audio": audio_data,
|
||||
"sample_rate": sample_rate,
|
||||
"num_channels": num_channels,
|
||||
"is_input": isinstance(frame, InputAudioRawFrame),
|
||||
}
|
||||
|
||||
# Send to all listeners without blocking
|
||||
for queue in list(self._listeners):
|
||||
try:
|
||||
queue.put_nowait(packet)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Audio queue full for session {self._test_session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming audio: {e}")
|
||||
self._listeners.discard(queue)
|
||||
elif self._listeners and not audio_data:
|
||||
logger.warning(
|
||||
f"Audio streamer {self._role} received frame with no audio data"
|
||||
)
|
||||
elif audio_data and not self._listeners:
|
||||
# This is expected early in the session before WebSocket connects
|
||||
pass
|
||||
|
||||
# Always forward the frame
|
||||
await self.push_frame(frame, direction)
|
||||
|
||||
|
||||
class LoopTalkAudioMixer:
|
||||
"""
|
||||
Mixes audio from actor and adversary streams for combined playback.
|
||||
|
||||
This class manages the mixing of two audio streams (actor and adversary)
|
||||
to create a combined audio stream for monitoring.
|
||||
"""
|
||||
|
||||
def __init__(self, test_session_id: str):
|
||||
self._test_session_id = test_session_id
|
||||
self._actor_buffer = bytearray()
|
||||
self._adversary_buffer = bytearray()
|
||||
self._listeners: Set[asyncio.Queue] = set()
|
||||
self._sample_rate = 16000
|
||||
self._num_channels = 1
|
||||
self._buffer_size = 8000 # 0.5 seconds at 16kHz
|
||||
|
||||
def add_listener(self, queue: asyncio.Queue):
|
||||
"""Add a listener for mixed audio."""
|
||||
self._listeners.add(queue)
|
||||
|
||||
def remove_listener(self, queue: asyncio.Queue):
|
||||
"""Remove a listener."""
|
||||
self._listeners.discard(queue)
|
||||
|
||||
async def add_audio(
|
||||
self, role: str, audio_data: bytes, sample_rate: int, num_channels: int
|
||||
):
|
||||
"""Add audio data from actor or adversary."""
|
||||
if role == "actor":
|
||||
self._actor_buffer.extend(audio_data)
|
||||
elif role == "adversary":
|
||||
self._adversary_buffer.extend(audio_data)
|
||||
|
||||
# Update audio parameters
|
||||
self._sample_rate = sample_rate
|
||||
self._num_channels = num_channels
|
||||
|
||||
# Check if we have enough data to mix
|
||||
await self._check_and_mix()
|
||||
|
||||
async def _check_and_mix(self):
|
||||
"""Check buffers and mix audio when enough data is available."""
|
||||
# Mix when we have at least buffer_size in both buffers
|
||||
while (
|
||||
len(self._actor_buffer) >= self._buffer_size
|
||||
and len(self._adversary_buffer) >= self._buffer_size
|
||||
):
|
||||
# Extract chunks
|
||||
actor_chunk = bytes(self._actor_buffer[: self._buffer_size])
|
||||
adversary_chunk = bytes(self._adversary_buffer[: self._buffer_size])
|
||||
|
||||
# Remove from buffers
|
||||
del self._actor_buffer[: self._buffer_size]
|
||||
del self._adversary_buffer[: self._buffer_size]
|
||||
|
||||
# Mix audio
|
||||
mixed_audio = mix_audio(actor_chunk, adversary_chunk)
|
||||
|
||||
# Stream to listeners
|
||||
if self._listeners and mixed_audio:
|
||||
packet = {
|
||||
"test_session_id": self._test_session_id,
|
||||
"role": "mixed",
|
||||
"audio": mixed_audio,
|
||||
"sample_rate": self._sample_rate,
|
||||
"num_channels": self._num_channels,
|
||||
"is_input": False,
|
||||
}
|
||||
|
||||
for queue in list(self._listeners):
|
||||
try:
|
||||
queue.put_nowait(packet)
|
||||
except asyncio.QueueFull:
|
||||
logger.warning(
|
||||
f"Mixed audio queue full for session {self._test_session_id}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Error streaming mixed audio: {e}")
|
||||
self._listeners.discard(queue)
|
||||
|
||||
|
||||
# Global registry for audio streamers and mixers
|
||||
_audio_streamers: Dict[str, Dict[str, LoopTalkAudioStreamer]] = {}
|
||||
_audio_mixers: Dict[str, LoopTalkAudioMixer] = {}
|
||||
|
||||
|
||||
def get_or_create_audio_streamer(
|
||||
test_session_id: str, role: str
|
||||
) -> LoopTalkAudioStreamer:
|
||||
"""Get or create an audio streamer for a test session and role."""
|
||||
if test_session_id not in _audio_streamers:
|
||||
_audio_streamers[test_session_id] = {}
|
||||
|
||||
if role not in _audio_streamers[test_session_id]:
|
||||
_audio_streamers[test_session_id][role] = LoopTalkAudioStreamer(
|
||||
test_session_id=test_session_id, role=role
|
||||
)
|
||||
|
||||
return _audio_streamers[test_session_id][role]
|
||||
|
||||
|
||||
def get_or_create_audio_mixer(test_session_id: str) -> LoopTalkAudioMixer:
|
||||
"""Get or create an audio mixer for a test session."""
|
||||
if test_session_id not in _audio_mixers:
|
||||
_audio_mixers[test_session_id] = LoopTalkAudioMixer(test_session_id)
|
||||
|
||||
return _audio_mixers[test_session_id]
|
||||
|
||||
|
||||
def cleanup_audio_streamers(test_session_id: str):
|
||||
"""Clean up audio streamers and mixers for a test session."""
|
||||
if test_session_id in _audio_streamers:
|
||||
del _audio_streamers[test_session_id]
|
||||
|
||||
if test_session_id in _audio_mixers:
|
||||
del _audio_mixers[test_session_id]
|
||||
|
||||
logger.info(f"Cleaned up audio streamers for test session {test_session_id}")
|
||||
|
|
@ -1 +0,0 @@
|
|||
"""Core modules for LoopTalk orchestration."""
|
||||
|
|
@ -1,170 +0,0 @@
|
|||
"""Pipeline building logic for LoopTalk agents."""
|
||||
|
||||
from typing import Any, Dict
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.pipeline.pipeline import Pipeline
|
||||
from pipecat.processors.aggregators.llm_response_universal import (
|
||||
LLMContextAggregatorPair,
|
||||
)
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.services.looptalk.audio_streamer import get_or_create_audio_streamer
|
||||
from api.services.looptalk.internal_transport import InternalTransport
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
from api.services.pipecat.pipeline_builder import (
|
||||
create_pipeline_components,
|
||||
create_pipeline_task,
|
||||
)
|
||||
from api.services.pipecat.pipeline_engine_callbacks_processor import (
|
||||
PipelineEngineCallbacksProcessor,
|
||||
)
|
||||
from api.services.pipecat.service_factory import (
|
||||
create_llm_service,
|
||||
create_stt_service,
|
||||
create_tts_service,
|
||||
)
|
||||
from api.services.workflow.dto import ReactFlowDTO
|
||||
from api.services.workflow.pipecat_engine import PipecatEngine
|
||||
from api.services.workflow.workflow_graph import WorkflowGraph
|
||||
|
||||
|
||||
class LoopTalkPipelineBuilder:
|
||||
"""Builds pipelines for LoopTalk agents."""
|
||||
|
||||
def __init__(self, db_client: DBClient):
|
||||
"""Initialize the pipeline builder.
|
||||
|
||||
Args:
|
||||
db_client: Database client for fetching user configurations
|
||||
"""
|
||||
self.db_client = db_client
|
||||
|
||||
async def create_agent_pipeline(
|
||||
self,
|
||||
transport: InternalTransport,
|
||||
workflow: Any,
|
||||
test_session_id: int,
|
||||
agent_id: str,
|
||||
role: str,
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a pipeline for an agent (actor or adversary).
|
||||
|
||||
Args:
|
||||
transport: Internal transport for the agent
|
||||
workflow: Workflow model from database
|
||||
test_session_id: ID of the test session
|
||||
agent_id: Unique identifier for the agent
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Dictionary containing pipeline task, engine, and components
|
||||
"""
|
||||
# Get user configuration from database
|
||||
user_config = await self.db_client.get_user_configurations(workflow.user_id)
|
||||
|
||||
# Create pipeline components
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
vad_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Use published definition for graph + configs
|
||||
released_def = workflow.released_definition
|
||||
wf_json = released_def.workflow_json
|
||||
wf_configs = released_def.workflow_configurations or {}
|
||||
|
||||
# Extract keyterms from workflow configurations
|
||||
keyterms = None
|
||||
if wf_configs and "dictionary" in wf_configs:
|
||||
dictionary = wf_configs["dictionary"]
|
||||
if dictionary and isinstance(dictionary, str):
|
||||
keyterms = [
|
||||
term.strip() for term in dictionary.split(",") if term.strip()
|
||||
]
|
||||
if keyterms:
|
||||
logger.info(f"Using {len(keyterms)} keyterms for STT: {keyterms}")
|
||||
|
||||
# Resolve model overrides from the version onto global user config
|
||||
from api.services.configuration.resolve import resolve_effective_config
|
||||
|
||||
model_overrides = wf_configs.get("model_overrides")
|
||||
user_config = resolve_effective_config(user_config, model_overrides)
|
||||
|
||||
# Create services
|
||||
stt = create_stt_service(user_config, audio_config, keyterms=keyterms)
|
||||
llm = create_llm_service(user_config)
|
||||
tts = create_tts_service(user_config, audio_config)
|
||||
|
||||
logger.debug(f"Created services for {role}: STT={stt}, LLM={llm}, TTS={tts}")
|
||||
|
||||
# Get workflow graph
|
||||
workflow_graph = WorkflowGraph(ReactFlowDTO.model_validate(wf_json))
|
||||
|
||||
# Create engine first (needed for create_pipeline_components)
|
||||
engine = PipecatEngine(
|
||||
llm=llm,
|
||||
workflow=workflow_graph,
|
||||
call_context_vars={},
|
||||
workflow_run_id=None, # LoopTalk doesn't have workflow runs
|
||||
)
|
||||
|
||||
# Create pipeline components with audio configuration and engine
|
||||
audio_buffer, transcript, context = create_pipeline_components(
|
||||
audio_config, engine
|
||||
)
|
||||
|
||||
# Set the context and audio_buffer after creation
|
||||
engine.set_context(context)
|
||||
|
||||
context_aggregator = LLMContextAggregatorPair(context)
|
||||
|
||||
# Create pipeline engine callback processor
|
||||
pipeline_engine_callback_processor = PipelineEngineCallbacksProcessor(
|
||||
max_call_duration_seconds=300,
|
||||
max_duration_end_task_callback=engine.create_max_duration_callback(),
|
||||
generation_started_callback=engine.create_generation_started_callback(),
|
||||
)
|
||||
|
||||
# Get aggregators
|
||||
user_context_aggregator = context_aggregator.user()
|
||||
assistant_context_aggregator = context_aggregator.assistant()
|
||||
|
||||
# Get audio streamer for real-time streaming
|
||||
audio_streamer = get_or_create_audio_streamer(str(test_session_id), role)
|
||||
|
||||
# Create pipeline with AudioBufferProcessor after transport.output()
|
||||
pipeline = Pipeline(
|
||||
[
|
||||
transport.input(),
|
||||
audio_streamer, # Stream audio to connected clients
|
||||
stt,
|
||||
transcript.user(),
|
||||
user_context_aggregator,
|
||||
llm,
|
||||
pipeline_engine_callback_processor,
|
||||
tts,
|
||||
transport.output(),
|
||||
audio_buffer, # AudioBufferProcessor - records both input and output audio
|
||||
transcript.assistant(),
|
||||
assistant_context_aggregator,
|
||||
]
|
||||
)
|
||||
|
||||
# Create pipeline task with unique conversation ID for tracing
|
||||
conversation_id = f"{test_session_id}-{role}-{agent_id}"
|
||||
task = create_pipeline_task(pipeline, conversation_id, audio_config)
|
||||
|
||||
# Set the task on the engine
|
||||
engine.set_task(task)
|
||||
|
||||
return {
|
||||
"task": task,
|
||||
"engine": engine,
|
||||
"audio_buffer": audio_buffer,
|
||||
"transcript": transcript,
|
||||
"assistant_context_aggregator": assistant_context_aggregator,
|
||||
"audio_streamer": audio_streamer,
|
||||
}
|
||||
|
|
@ -1,216 +0,0 @@
|
|||
"""Recording management for LoopTalk sessions."""
|
||||
|
||||
import wave
|
||||
from pathlib import Path
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
|
||||
from api.enums import StorageBackend
|
||||
from api.services.storage import storage_fs
|
||||
|
||||
|
||||
class RecordingManager:
|
||||
"""Manages audio recording and transcript files for LoopTalk sessions."""
|
||||
|
||||
def __init__(self, base_dir: Path):
|
||||
"""Initialize the recording manager.
|
||||
|
||||
Args:
|
||||
base_dir: Base directory for temporary recordings
|
||||
"""
|
||||
self.base_dir = base_dir
|
||||
self.base_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
def get_recording_paths(self, test_session_id: int, role: str) -> Dict[str, Path]:
|
||||
"""Get file paths for recordings.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Dictionary with paths for audio, transcript, and temp audio files
|
||||
"""
|
||||
session_dir = self.base_dir / f"session_{test_session_id}"
|
||||
session_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
return {
|
||||
"audio": session_dir / f"{role}_audio.wav",
|
||||
"transcript": session_dir / f"{role}_transcript.txt",
|
||||
"temp_audio": session_dir / f"{role}_audio_temp.pcm",
|
||||
}
|
||||
|
||||
def convert_pcm_to_wav(
|
||||
self,
|
||||
test_session_id: int,
|
||||
role: str,
|
||||
sample_rate: int = 16000,
|
||||
num_channels: int = 1,
|
||||
) -> Optional[Path]:
|
||||
"""Convert PCM audio file to WAV format.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
sample_rate: Sample rate of the audio
|
||||
num_channels: Number of audio channels
|
||||
|
||||
Returns:
|
||||
Path to the WAV file if successful, None otherwise
|
||||
"""
|
||||
paths = self.get_recording_paths(test_session_id, role)
|
||||
|
||||
# Check if PCM file exists
|
||||
if not paths["temp_audio"].exists():
|
||||
logger.warning(f"No audio recorded for {role} in session {test_session_id}")
|
||||
return None
|
||||
|
||||
try:
|
||||
# Read PCM data
|
||||
with open(paths["temp_audio"], "rb") as f:
|
||||
pcm_data = f.read()
|
||||
|
||||
# Write WAV file
|
||||
with wave.open(str(paths["audio"]), "wb") as wav_file:
|
||||
wav_file.setnchannels(num_channels)
|
||||
wav_file.setsampwidth(2) # 16-bit audio
|
||||
wav_file.setframerate(sample_rate)
|
||||
wav_file.writeframes(pcm_data)
|
||||
|
||||
# Remove temporary PCM file
|
||||
paths["temp_audio"].unlink()
|
||||
|
||||
logger.info(
|
||||
f"Converted audio to WAV for {role} in session {test_session_id}: {paths['audio']}"
|
||||
)
|
||||
return paths["audio"]
|
||||
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to convert audio to WAV for {role} in session {test_session_id}: {e}"
|
||||
)
|
||||
return None
|
||||
|
||||
async def upload_recording_to_s3(
|
||||
self, test_session_id: int, role: str
|
||||
) -> Tuple[Optional[str], Optional[str]]:
|
||||
"""Upload recording and transcript to S3.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Tuple of (audio_url, transcript_url) or (None, None) if failed
|
||||
"""
|
||||
paths = self.get_recording_paths(test_session_id, role)
|
||||
audio_url = None
|
||||
transcript_url = None
|
||||
|
||||
# Import here to avoid circular imports
|
||||
|
||||
current_backend = StorageBackend.get_current_backend()
|
||||
logger.info(
|
||||
f"LOOPTALK UPLOAD: Using {current_backend.label} (code: {current_backend.code}) for session {test_session_id}, role: {role}"
|
||||
)
|
||||
|
||||
# Upload audio if exists
|
||||
if paths["audio"].exists():
|
||||
audio_key = f"looptalk/recordings/{test_session_id}/{role}_audio.wav"
|
||||
try:
|
||||
success = await storage_fs.aupload_file(str(paths["audio"]), audio_key)
|
||||
if success:
|
||||
audio_url = audio_key
|
||||
logger.info(
|
||||
f"Uploaded {role} audio to {current_backend.label}: {audio_key}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to upload {role} audio to {current_backend.label}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading {role} audio to {current_backend.label}: {e}"
|
||||
)
|
||||
|
||||
# Upload transcript if exists
|
||||
if paths["transcript"].exists():
|
||||
transcript_key = (
|
||||
f"looptalk/transcripts/{test_session_id}/{role}_transcript.txt"
|
||||
)
|
||||
try:
|
||||
success = await storage_fs.aupload_file(
|
||||
str(paths["transcript"]), transcript_key
|
||||
)
|
||||
if success:
|
||||
transcript_url = transcript_key
|
||||
logger.info(
|
||||
f"Uploaded {role} transcript to {current_backend.label}: {transcript_key}"
|
||||
)
|
||||
else:
|
||||
logger.error(
|
||||
f"Failed to upload {role} transcript to {current_backend.label}"
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error uploading {role} transcript to {current_backend.label}: {e}"
|
||||
)
|
||||
|
||||
return audio_url, transcript_url
|
||||
|
||||
def cleanup_session_files(self, test_session_id: int):
|
||||
"""Clean up local files for a session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
"""
|
||||
session_dir = self.base_dir / f"session_{test_session_id}"
|
||||
if session_dir.exists():
|
||||
try:
|
||||
import shutil
|
||||
|
||||
shutil.rmtree(session_dir)
|
||||
logger.debug(f"Cleaned up local files for session {test_session_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to clean up session files: {e}")
|
||||
|
||||
def get_recording_info(self, test_session_id: int) -> Dict[str, any]:
|
||||
"""Get information about recordings for a test session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
|
||||
Returns:
|
||||
Dictionary with recording information
|
||||
"""
|
||||
session_dir = self.base_dir / f"session_{test_session_id}"
|
||||
|
||||
info = {
|
||||
"test_session_id": test_session_id,
|
||||
"recording_dir": str(session_dir),
|
||||
"files": {},
|
||||
}
|
||||
|
||||
for role in ["actor", "adversary"]:
|
||||
paths = self.get_recording_paths(test_session_id, role)
|
||||
role_info = {}
|
||||
|
||||
# Check audio file
|
||||
if paths["audio"].exists():
|
||||
role_info["audio"] = {
|
||||
"path": str(paths["audio"]),
|
||||
"size_bytes": paths["audio"].stat().st_size,
|
||||
}
|
||||
|
||||
# Check transcript file
|
||||
if paths["transcript"].exists():
|
||||
role_info["transcript"] = {
|
||||
"path": str(paths["transcript"]),
|
||||
"size_bytes": paths["transcript"].stat().st_size,
|
||||
}
|
||||
|
||||
if role_info:
|
||||
info["files"][role] = role_info
|
||||
|
||||
return info
|
||||
|
|
@ -1,184 +0,0 @@
|
|||
"""Session management for LoopTalk test sessions."""
|
||||
|
||||
import asyncio
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
|
||||
|
||||
class SessionManager:
|
||||
"""Manages running LoopTalk test sessions."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the session manager."""
|
||||
self._running_sessions: Dict[int, Dict[str, Any]] = {}
|
||||
self._disconnect_handlers: Dict[int, asyncio.Task] = {}
|
||||
|
||||
def add_session(self, test_session_id: int, session_info: Dict[str, Any]):
|
||||
"""Add a new session to the manager.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
session_info: Dictionary containing session information
|
||||
"""
|
||||
self._running_sessions[test_session_id] = session_info
|
||||
|
||||
def get_session(self, test_session_id: int) -> Optional[Dict[str, Any]]:
|
||||
"""Get session information.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
|
||||
Returns:
|
||||
Session information dictionary or None if not found
|
||||
"""
|
||||
return self._running_sessions.get(test_session_id)
|
||||
|
||||
def remove_session(self, test_session_id: int):
|
||||
"""Remove a session from the manager.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
"""
|
||||
if test_session_id in self._running_sessions:
|
||||
del self._running_sessions[test_session_id]
|
||||
|
||||
# Cancel any disconnect handler for this session
|
||||
if test_session_id in self._disconnect_handlers:
|
||||
handler = self._disconnect_handlers.pop(test_session_id)
|
||||
if not handler.done():
|
||||
handler.cancel()
|
||||
|
||||
def get_active_count(self) -> int:
|
||||
"""Get the number of currently active sessions."""
|
||||
return len(self._running_sessions)
|
||||
|
||||
def get_active_info(self) -> Dict[str, Any]:
|
||||
"""Get information about all active sessions."""
|
||||
return {
|
||||
"count": len(self._running_sessions),
|
||||
"sessions": [
|
||||
{
|
||||
"test_session_id": session_id,
|
||||
"conversation_id": info["conversation"].id,
|
||||
"start_time": info["start_time"],
|
||||
"duration_seconds": int(
|
||||
(datetime.now(UTC) - info["start_time"]).total_seconds()
|
||||
),
|
||||
}
|
||||
for session_id, info in self._running_sessions.items()
|
||||
],
|
||||
}
|
||||
|
||||
async def handle_agent_disconnect(
|
||||
self, test_session_id: int, disconnected_role: str, stop_callback: callable
|
||||
):
|
||||
"""Handle when one agent disconnects.
|
||||
|
||||
This will cancel the other agent as well to ensure clean shutdown.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
disconnected_role: Role that disconnected ("actor" or "adversary")
|
||||
stop_callback: Callback to stop the session
|
||||
"""
|
||||
logger.info(
|
||||
f"Handling {disconnected_role} disconnect for session {test_session_id}"
|
||||
)
|
||||
|
||||
# Check if we already have a disconnect handler running
|
||||
if test_session_id in self._disconnect_handlers:
|
||||
logger.debug(
|
||||
f"Disconnect handler already running for session {test_session_id}"
|
||||
)
|
||||
return
|
||||
|
||||
# Create a task to handle the disconnect
|
||||
async def _handle_disconnect():
|
||||
try:
|
||||
# Wait a short time to avoid race conditions
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
# Check if session still exists
|
||||
session_info = self.get_session(test_session_id)
|
||||
if not session_info:
|
||||
logger.debug(f"Session {test_session_id} already stopped")
|
||||
return
|
||||
|
||||
# Stop the session (which will cancel both agents)
|
||||
logger.info(
|
||||
f"Stopping session {test_session_id} due to {disconnected_role} disconnect"
|
||||
)
|
||||
await stop_callback(test_session_id)
|
||||
|
||||
except asyncio.CancelledError:
|
||||
logger.debug(
|
||||
f"Disconnect handler cancelled for session {test_session_id}"
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Error handling disconnect for session {test_session_id}: {e}"
|
||||
)
|
||||
|
||||
# Store the task so we can cancel it if needed
|
||||
self._disconnect_handlers[test_session_id] = asyncio.create_task(
|
||||
_handle_disconnect()
|
||||
)
|
||||
|
||||
def update_audio_metadata(
|
||||
self,
|
||||
test_session_id: int,
|
||||
role: str,
|
||||
sample_rate: Optional[int] = None,
|
||||
num_channels: Optional[int] = None,
|
||||
):
|
||||
"""Update audio metadata for a role in a session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
sample_rate: Sample rate of the audio
|
||||
num_channels: Number of audio channels
|
||||
"""
|
||||
if test_session_id not in self._running_sessions:
|
||||
return
|
||||
|
||||
if "audio_metadata" not in self._running_sessions[test_session_id]:
|
||||
self._running_sessions[test_session_id]["audio_metadata"] = {}
|
||||
|
||||
if role not in self._running_sessions[test_session_id]["audio_metadata"]:
|
||||
self._running_sessions[test_session_id]["audio_metadata"][role] = {}
|
||||
|
||||
metadata = self._running_sessions[test_session_id]["audio_metadata"][role]
|
||||
if sample_rate is not None:
|
||||
metadata["sample_rate"] = sample_rate
|
||||
if num_channels is not None:
|
||||
metadata["num_channels"] = num_channels
|
||||
|
||||
def get_audio_metadata(self, test_session_id: int, role: str) -> Dict[str, Any]:
|
||||
"""Get audio metadata for a role in a session.
|
||||
|
||||
Args:
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
|
||||
Returns:
|
||||
Dictionary with sample_rate and num_channels
|
||||
"""
|
||||
default = {"sample_rate": 16000, "num_channels": 1}
|
||||
|
||||
if test_session_id not in self._running_sessions:
|
||||
return default
|
||||
|
||||
metadata = (
|
||||
self._running_sessions.get(test_session_id, {})
|
||||
.get("audio_metadata", {})
|
||||
.get(role, {})
|
||||
)
|
||||
|
||||
return {
|
||||
"sample_rate": metadata.get("sample_rate", 16000),
|
||||
"num_channels": metadata.get("num_channels", 1),
|
||||
}
|
||||
|
|
@ -1,75 +0,0 @@
|
|||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Internal frame serializer for agent-to-agent communication."""
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
Frame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
)
|
||||
from pipecat.serializers.base_serializer import FrameSerializer
|
||||
|
||||
|
||||
class InternalFrameSerializer(FrameSerializer):
|
||||
"""Serializer for InternalTransport that filters frames between agents.
|
||||
|
||||
This serializer ensures only audio frames are passed between agents,
|
||||
preventing control frames from creating infinite loops.
|
||||
"""
|
||||
|
||||
async def serialize(self, frame: Frame) -> bytes | None:
|
||||
"""Only serialize audio frames for transmission between agents."""
|
||||
# Only pass audio frames between agents
|
||||
if isinstance(frame, OutputAudioRawFrame):
|
||||
# Use a fixed-size header to avoid parsing issues with binary data
|
||||
# Format: "AUDIO" (5 bytes) + sample_rate (4 bytes) + num_channels (2 bytes) + audio data
|
||||
header = b"AUDIO"
|
||||
sample_rate_bytes = frame.sample_rate.to_bytes(4, byteorder="big")
|
||||
num_channels_bytes = frame.num_channels.to_bytes(2, byteorder="big")
|
||||
|
||||
serialized = header + sample_rate_bytes + num_channels_bytes + frame.audio
|
||||
return serialized
|
||||
|
||||
# Don't pass control frames between agents
|
||||
return None
|
||||
|
||||
async def deserialize(self, data: bytes) -> Frame | None:
|
||||
"""Deserialize audio frames from partner agent."""
|
||||
if data.startswith(b"AUDIO"):
|
||||
try:
|
||||
# Fixed-size header parsing
|
||||
# Header: "AUDIO" (5 bytes) + sample_rate (4 bytes) + num_channels (2 bytes)
|
||||
if len(data) < 11: # Minimum size for header
|
||||
logger.error(
|
||||
f"InternalSerializer: Data too short for header: {len(data)} bytes"
|
||||
)
|
||||
return None
|
||||
|
||||
# Extract fixed-size fields
|
||||
# Skip header validation - we already checked startswith(b"AUDIO")
|
||||
sample_rate = int.from_bytes(data[5:9], byteorder="big")
|
||||
num_channels = int.from_bytes(data[9:11], byteorder="big")
|
||||
|
||||
# Extract audio data - everything after the header
|
||||
audio_data = data[11:]
|
||||
|
||||
# Check if audio data length is valid
|
||||
if len(audio_data) % 2 != 0:
|
||||
logger.warning(
|
||||
f"InternalSerializer: Audio data has odd length: {len(audio_data)}"
|
||||
)
|
||||
|
||||
# Convert to InputAudioRawFrame for the receiving agent
|
||||
return InputAudioRawFrame(
|
||||
audio=audio_data, num_channels=num_channels, sample_rate=sample_rate
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to deserialize audio frame: {e}")
|
||||
return None
|
||||
|
||||
return None
|
||||
|
|
@ -1,405 +0,0 @@
|
|||
#
|
||||
# Copyright (c) 2024–2025, Daily
|
||||
#
|
||||
# SPDX-License-Identifier: BSD 2-Clause License
|
||||
#
|
||||
|
||||
"""Internal transport for in-memory agent-to-agent communication."""
|
||||
|
||||
import asyncio
|
||||
import time
|
||||
from typing import Dict, Optional, Tuple
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.frames.frames import (
|
||||
CancelFrame,
|
||||
EndFrame,
|
||||
InputAudioRawFrame,
|
||||
OutputAudioRawFrame,
|
||||
OutputDTMFFrame,
|
||||
OutputDTMFUrgentFrame,
|
||||
OutputImageRawFrame,
|
||||
StartFrame,
|
||||
StopFrame,
|
||||
)
|
||||
from pipecat.processors.frame_processor import FrameDirection
|
||||
from pipecat.transports.base_input import BaseInputTransport
|
||||
from pipecat.transports.base_output import BaseOutputTransport
|
||||
from pipecat.transports.base_transport import BaseTransport, TransportParams
|
||||
|
||||
from api.services.looptalk.internal_serializer import InternalFrameSerializer
|
||||
|
||||
|
||||
class InternalInputTransport(BaseInputTransport):
|
||||
"""Input side of internal transport for agent-to-agent communication."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
transport: Optional["InternalTransport"],
|
||||
params: TransportParams,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize internal input transport.
|
||||
|
||||
Args:
|
||||
transport: The parent InternalTransport instance.
|
||||
params: Transport parameters for configuration.
|
||||
**kwargs: Additional keyword arguments including latency_seconds.
|
||||
"""
|
||||
# Extract latency configuration before passing to parent
|
||||
self._latency_seconds = kwargs.pop("latency_seconds", 0.0)
|
||||
|
||||
super().__init__(params, **kwargs)
|
||||
self._transport = transport
|
||||
self._queue: asyncio.Queue[bytes] = asyncio.Queue()
|
||||
self._partner: Optional["InternalOutputTransport"] = None
|
||||
self._running = False
|
||||
self._connected = False
|
||||
self._serializer = InternalFrameSerializer()
|
||||
# Queue for delayed packets (timestamp, data)
|
||||
self._delayed_queue: asyncio.Queue[Tuple[float, bytes]] = asyncio.Queue()
|
||||
self._latency_task: Optional[asyncio.Task] = None
|
||||
|
||||
def set_partner(self, partner: "InternalOutputTransport"):
|
||||
"""Connect this input transport to an output transport."""
|
||||
self._partner = partner
|
||||
|
||||
async def receive_data(self, data: bytes):
|
||||
"""Receive serialized data from the partner output transport."""
|
||||
# logger.debug("received data in input transport")
|
||||
if self._latency_seconds > 0:
|
||||
# Add to delayed queue with delivery timestamp
|
||||
delivery_time = time.monotonic() + self._latency_seconds
|
||||
await self._delayed_queue.put((delivery_time, data))
|
||||
else:
|
||||
# No latency, put directly in the main queue
|
||||
await self._queue.put(data)
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the input transport."""
|
||||
self._running = True
|
||||
await super().start(frame)
|
||||
await self._serializer.setup(frame)
|
||||
|
||||
# Set transport ready to initialize audio task for VAD processing
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
# Trigger on_client_connected event for InternalTransport (only once)
|
||||
if hasattr(self, "_transport") and self._transport and not self._connected:
|
||||
self._connected = True
|
||||
await self._transport._call_event_handler(
|
||||
"on_client_connected", self._transport
|
||||
)
|
||||
|
||||
# Start latency processor if latency is configured
|
||||
if self._latency_seconds > 0:
|
||||
self._latency_task = asyncio.create_task(self._latency_processor())
|
||||
|
||||
asyncio.create_task(self._run())
|
||||
|
||||
async def stop(self, frame: EndFrame | StopFrame | None = None):
|
||||
"""Stop the input transport."""
|
||||
self._running = False
|
||||
|
||||
# Stop latency processor
|
||||
if self._latency_task:
|
||||
self._latency_task.cancel()
|
||||
try:
|
||||
await self._latency_task
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
self._latency_task = None
|
||||
|
||||
await super().stop(frame)
|
||||
|
||||
# Trigger on_client_disconnected event for InternalTransport
|
||||
if hasattr(self, "_transport") and self._transport:
|
||||
await self._transport._call_event_handler(
|
||||
"on_client_disconnected", self._transport
|
||||
)
|
||||
|
||||
async def _run(self):
|
||||
"""Main loop to process incoming data."""
|
||||
while self._running:
|
||||
try:
|
||||
data = await asyncio.wait_for(self._queue.get(), timeout=0.1)
|
||||
|
||||
# Deserialize the data
|
||||
frame = await self._serializer.deserialize(data)
|
||||
if frame:
|
||||
if isinstance(frame, InputAudioRawFrame):
|
||||
# Debug received audio
|
||||
try:
|
||||
import numpy as np
|
||||
|
||||
# Check if audio length is valid for int16
|
||||
if len(frame.audio) % 2 != 0:
|
||||
logger.error(
|
||||
f"InternalInput: Audio buffer has odd length: {len(frame.audio)}"
|
||||
)
|
||||
else:
|
||||
audio_array = np.frombuffer(frame.audio, dtype=np.int16)
|
||||
# logger.debug(f"InternalInput: Received audio - size: {len(frame.audio)} bytes, "
|
||||
# f"samples: {len(audio_array)}, min: {audio_array.min()}, max: {audio_array.max()}, "
|
||||
# f"sample_rate: {frame.sample_rate}")
|
||||
except Exception as e:
|
||||
logger.error(f"InternalInput: Error analyzing audio: {e}")
|
||||
|
||||
# Use the base class's audio processing which includes VAD
|
||||
await self.push_audio_frame(frame)
|
||||
else:
|
||||
# For non-audio frames, push directly
|
||||
await self.push_frame(frame, FrameDirection.DOWNSTREAM)
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
continue
|
||||
except Exception as e:
|
||||
logger.error(f"Error in internal input transport: {e}")
|
||||
|
||||
async def _latency_processor(self):
|
||||
"""Process delayed packets and deliver them after the configured latency."""
|
||||
logger.info(
|
||||
f"InternalInput: Started latency processor with {self._latency_seconds}s delay"
|
||||
)
|
||||
|
||||
# Use a list to maintain order (we'll process in FIFO order)
|
||||
pending_packets = []
|
||||
|
||||
while self._running:
|
||||
try:
|
||||
# Get all new packets from the delayed queue (non-blocking)
|
||||
while True:
|
||||
try:
|
||||
packet = self._delayed_queue.get_nowait()
|
||||
pending_packets.append(packet)
|
||||
except asyncio.QueueEmpty:
|
||||
break
|
||||
|
||||
# Process packets that are ready
|
||||
current_time = time.monotonic()
|
||||
delivered = []
|
||||
|
||||
for i, (delivery_time, data) in enumerate(pending_packets):
|
||||
if current_time >= delivery_time:
|
||||
# Time to deliver this packet
|
||||
await self._queue.put(data)
|
||||
delivered.append(i)
|
||||
|
||||
# Remove delivered packets (in reverse order to maintain indices)
|
||||
for i in reversed(delivered):
|
||||
pending_packets.pop(i)
|
||||
|
||||
# Sleep briefly before next check
|
||||
await asyncio.sleep(0.005) # 5ms for more responsive delivery
|
||||
|
||||
except asyncio.CancelledError:
|
||||
# Deliver any remaining packets immediately on shutdown
|
||||
for _, data in pending_packets:
|
||||
await self._queue.put(data)
|
||||
break
|
||||
except Exception as e:
|
||||
logger.error(f"Error in latency processor: {e}")
|
||||
await asyncio.sleep(0.01)
|
||||
|
||||
logger.info("InternalInput: Stopped latency processor")
|
||||
|
||||
|
||||
class InternalOutputTransport(BaseOutputTransport):
|
||||
"""Output side of internal transport for agent-to-agent communication."""
|
||||
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
"""Initialize internal output transport.
|
||||
|
||||
Args:
|
||||
params: Transport parameters for configuration.
|
||||
**kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(params, **kwargs)
|
||||
self._partner: Optional[InternalInputTransport] = None
|
||||
self._serializer = InternalFrameSerializer()
|
||||
|
||||
# Audio timing synchronization (similar to WebsocketServerOutputTransport)
|
||||
# _send_interval is the time interval between audio chunks in seconds
|
||||
self._send_interval = 0
|
||||
self._next_send_time = 0
|
||||
|
||||
def set_partner(self, partner: InternalInputTransport):
|
||||
"""Connect this output transport to an input transport."""
|
||||
self._partner = partner
|
||||
|
||||
async def start(self, frame: StartFrame):
|
||||
"""Start the output transport."""
|
||||
await super().start(frame)
|
||||
await self._serializer.setup(frame)
|
||||
# Calculate the send interval based on audio chunk size (like WebsocketServerOutputTransport)
|
||||
self._send_interval = (
|
||||
self._params.audio_out_10ms_chunks * 10 / 1000
|
||||
) # Convert ms to seconds
|
||||
await self.set_transport_ready(frame)
|
||||
|
||||
async def write_audio_frame(self, frame: OutputAudioRawFrame):
|
||||
"""Write audio frame to partner through serializer with proper timing."""
|
||||
# Debug audio characteristics
|
||||
# import numpy as np
|
||||
# audio_array = np.frombuffer(frame.audio, dtype=np.int16)
|
||||
# logger.debug(f"InternalOutput: Sending audio - type: {type(frame).__name__}, size: {len(frame.audio)} bytes, "
|
||||
# f"samples: {len(audio_array)}, min: {audio_array.min()}, max: {audio_array.max()}, "
|
||||
# f"sample_rate: {frame.sample_rate}")
|
||||
|
||||
# Serialize and send the audio first
|
||||
data = await self._serializer.serialize(frame)
|
||||
if data and self._partner:
|
||||
await self._partner.receive_data(data)
|
||||
|
||||
# logger.debug(f"InternalOutput: Sent audio frame to partner")
|
||||
|
||||
# Then simulate audio playback timing (following WebsocketServerOutputTransport pattern)
|
||||
await self._write_audio_sleep()
|
||||
|
||||
async def write_video_frame(self, _frame: OutputImageRawFrame):
|
||||
"""Internal transport doesn't support video."""
|
||||
pass
|
||||
|
||||
async def write_dtmf(self, _frame: OutputDTMFFrame | OutputDTMFUrgentFrame):
|
||||
"""Internal transport doesn't support DTMF."""
|
||||
pass
|
||||
|
||||
async def stop(self, frame: EndFrame):
|
||||
"""Stop the output transport and reset timing."""
|
||||
await super().stop(frame)
|
||||
self._next_send_time = 0
|
||||
|
||||
async def cancel(self, frame: CancelFrame):
|
||||
"""Cancel the output transport and reset timing."""
|
||||
await super().cancel(frame)
|
||||
self._next_send_time = 0
|
||||
|
||||
async def _write_audio_sleep(self):
|
||||
"""Simulate audio playback timing (following WebsocketServerOutputTransport pattern)."""
|
||||
# Simulate a clock to ensure audio is sent at real-time pace
|
||||
current_time = time.monotonic()
|
||||
sleep_duration = max(0, self._next_send_time - current_time)
|
||||
await asyncio.sleep(sleep_duration)
|
||||
if sleep_duration == 0:
|
||||
self._next_send_time = time.monotonic() + self._send_interval
|
||||
else:
|
||||
self._next_send_time += self._send_interval
|
||||
|
||||
|
||||
class InternalTransport(BaseTransport):
|
||||
"""Internal transport for in-memory agent-to-agent communication."""
|
||||
|
||||
def __init__(self, params: TransportParams, **kwargs):
|
||||
"""Initialize internal transport.
|
||||
|
||||
Args:
|
||||
params: Transport parameters for configuration.
|
||||
**kwargs: Additional keyword arguments including latency_seconds.
|
||||
"""
|
||||
# Extract latency configuration before passing to parent
|
||||
self._latency_seconds = kwargs.pop("latency_seconds", 0.0)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
self._params = params
|
||||
|
||||
# Create input and output transports
|
||||
self._input = InternalInputTransport(
|
||||
self,
|
||||
params,
|
||||
name=self._input_name or f"{self.name}#input",
|
||||
latency_seconds=self._latency_seconds,
|
||||
)
|
||||
self._output = InternalOutputTransport(
|
||||
params, name=self._output_name or f"{self.name}#output"
|
||||
)
|
||||
|
||||
# Register supported event handlers
|
||||
self._register_event_handler("on_client_connected")
|
||||
self._register_event_handler("on_client_disconnected")
|
||||
|
||||
def input(self) -> InternalInputTransport:
|
||||
"""Get the input transport."""
|
||||
return self._input
|
||||
|
||||
def output(self) -> InternalOutputTransport:
|
||||
"""Get the output transport."""
|
||||
return self._output
|
||||
|
||||
def connect_partner(self, partner: "InternalTransport"):
|
||||
"""Connect this transport to another internal transport."""
|
||||
# Connect output of this transport to input of partner
|
||||
self._output.set_partner(partner._input)
|
||||
# Connect output of partner to input of this transport
|
||||
partner._output.set_partner(self._input)
|
||||
|
||||
|
||||
class InternalTransportManager:
|
||||
"""Manages multiple internal transport pairs for load testing."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize internal transport manager."""
|
||||
self._transport_pairs: Dict[
|
||||
str, Tuple[InternalTransport, InternalTransport]
|
||||
] = {}
|
||||
|
||||
def create_transport_pair(
|
||||
self,
|
||||
test_session_id: str,
|
||||
actor_params: TransportParams,
|
||||
adversary_params: TransportParams,
|
||||
latency_seconds: float = 0.0,
|
||||
) -> Tuple[InternalTransport, InternalTransport]:
|
||||
"""Create a connected pair of internal transports.
|
||||
|
||||
Args:
|
||||
test_session_id: Unique identifier for the test session.
|
||||
actor_params: Transport parameters for the actor.
|
||||
adversary_params: Transport parameters for the adversary.
|
||||
latency_seconds: Simulated network latency in seconds (default: 0.0).
|
||||
|
||||
Returns:
|
||||
Tuple of (actor_transport, adversary_transport).
|
||||
"""
|
||||
# Create actor transport with latency
|
||||
actor_transport = InternalTransport(
|
||||
params=actor_params,
|
||||
name=f"actor-{test_session_id}",
|
||||
latency_seconds=latency_seconds,
|
||||
)
|
||||
|
||||
# Create adversary transport with latency
|
||||
adversary_transport = InternalTransport(
|
||||
params=adversary_params,
|
||||
name=f"adversary-{test_session_id}",
|
||||
latency_seconds=latency_seconds,
|
||||
)
|
||||
|
||||
# Connect them
|
||||
actor_transport.connect_partner(adversary_transport)
|
||||
|
||||
# Store the pair
|
||||
self._transport_pairs[test_session_id] = (actor_transport, adversary_transport)
|
||||
|
||||
logger.info(
|
||||
f"Created internal transport pair for test session: {test_session_id} with {latency_seconds}s latency"
|
||||
)
|
||||
|
||||
return actor_transport, adversary_transport
|
||||
|
||||
def get_transport_pair(
|
||||
self, test_session_id: str
|
||||
) -> Optional[Tuple[InternalTransport, InternalTransport]]:
|
||||
"""Get an existing transport pair."""
|
||||
return self._transport_pairs.get(test_session_id)
|
||||
|
||||
def remove_transport_pair(self, test_session_id: str):
|
||||
"""Remove a transport pair."""
|
||||
if test_session_id in self._transport_pairs:
|
||||
del self._transport_pairs[test_session_id]
|
||||
logger.info(
|
||||
f"Removed internal transport pair for test session: {test_session_id}"
|
||||
)
|
||||
|
||||
def get_active_test_count(self) -> int:
|
||||
"""Get the number of active test sessions."""
|
||||
return len(self._transport_pairs)
|
||||
|
|
@ -1,542 +0,0 @@
|
|||
import asyncio
|
||||
import os
|
||||
import uuid
|
||||
from datetime import UTC, datetime
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from loguru import logger
|
||||
from pipecat.pipeline.task import PipelineTask
|
||||
from pipecat.utils.run_context import set_current_run_id
|
||||
|
||||
from api.db.db_client import DBClient
|
||||
from api.services.looptalk.internal_transport import (
|
||||
InternalTransport,
|
||||
InternalTransportManager,
|
||||
)
|
||||
from api.services.pipecat.transport_setup import create_internal_transport
|
||||
|
||||
from .core.pipeline_builder import LoopTalkPipelineBuilder
|
||||
from .core.recording_manager import RecordingManager
|
||||
from .core.session_manager import SessionManager
|
||||
|
||||
|
||||
class LoopTalkTestOrchestrator:
|
||||
"""Orchestrates LoopTalk testing sessions with agent-to-agent conversations."""
|
||||
|
||||
def __init__(
|
||||
self, db_client: DBClient, network_latency_seconds: Optional[float] = None
|
||||
):
|
||||
self.db_client = db_client
|
||||
self.transport_manager = InternalTransportManager()
|
||||
self.session_manager = SessionManager()
|
||||
self.pipeline_builder = LoopTalkPipelineBuilder(db_client)
|
||||
self.recording_manager = RecordingManager(Path("/tmp/looptalk_recordings"))
|
||||
|
||||
# Default network latency (can be overridden per session)
|
||||
# Priority: constructor param > env var > default (100ms)
|
||||
if network_latency_seconds is not None:
|
||||
self._default_network_latency = network_latency_seconds
|
||||
else:
|
||||
env_latency = os.environ.get("LOOPTALK_NETWORK_LATENCY_MS")
|
||||
if env_latency:
|
||||
try:
|
||||
self._default_network_latency = (
|
||||
float(env_latency) / 1000.0
|
||||
) # Convert ms to seconds
|
||||
except ValueError:
|
||||
logger.warning(
|
||||
f"Invalid LOOPTALK_NETWORK_LATENCY_MS value: {env_latency}, using default 100ms"
|
||||
)
|
||||
self._default_network_latency = 0.1
|
||||
else:
|
||||
self._default_network_latency = 0.1 # 100ms default
|
||||
|
||||
async def start_test_session(
|
||||
self,
|
||||
test_session_id: int,
|
||||
organization_id: int,
|
||||
network_latency_seconds: Optional[float] = None,
|
||||
) -> Dict[str, Any]:
|
||||
"""Start a LoopTalk test session."""
|
||||
|
||||
# Get test session details
|
||||
test_session = await self.db_client.get_test_session(
|
||||
test_session_id=test_session_id, organization_id=organization_id
|
||||
)
|
||||
|
||||
if not test_session:
|
||||
raise ValueError(f"Test session {test_session_id} not found")
|
||||
|
||||
if test_session.status != "pending":
|
||||
raise ValueError(f"Test session {test_session_id} is not in pending state")
|
||||
|
||||
try:
|
||||
# Update status to running
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id, status="running"
|
||||
)
|
||||
|
||||
# Create conversation record
|
||||
conversation = await self.db_client.create_conversation(
|
||||
test_session_id=test_session_id
|
||||
)
|
||||
|
||||
# Create audio configuration for LoopTalk
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
||||
audio_config = AudioConfig(
|
||||
transport_in_sample_rate=16000,
|
||||
transport_out_sample_rate=16000,
|
||||
pipeline_sample_rate=16000,
|
||||
)
|
||||
|
||||
# Use provided latency or fall back to default
|
||||
latency = (
|
||||
network_latency_seconds
|
||||
if network_latency_seconds is not None
|
||||
else self._default_network_latency
|
||||
)
|
||||
logger.info(
|
||||
f"Using network latency of {latency}s for test session {test_session_id}"
|
||||
)
|
||||
|
||||
# Generate unique workflow run IDs for each agent
|
||||
actor_workflow_run_id = int(str(test_session_id) + "1")
|
||||
adversary_workflow_run_id = int(str(test_session_id) + "2")
|
||||
|
||||
# Create transports using the new method with turn analyzer
|
||||
actor_transport = create_internal_transport(
|
||||
workflow_run_id=actor_workflow_run_id,
|
||||
audio_config=audio_config,
|
||||
latency_seconds=latency,
|
||||
)
|
||||
adversary_transport = create_internal_transport(
|
||||
workflow_run_id=adversary_workflow_run_id,
|
||||
audio_config=audio_config,
|
||||
latency_seconds=latency,
|
||||
)
|
||||
|
||||
# Connect the transports
|
||||
actor_transport.connect_partner(adversary_transport)
|
||||
|
||||
# Store the transport pair in the manager
|
||||
self.transport_manager._transport_pairs[str(test_session_id)] = (
|
||||
actor_transport,
|
||||
adversary_transport,
|
||||
)
|
||||
|
||||
# Generate unique identifiers for actor and adversary
|
||||
actor_id = f"actor_{test_session_id}_{str(uuid.uuid4())[:8]}"
|
||||
adversary_id = f"adversary_{test_session_id}_{str(uuid.uuid4())[:8]}"
|
||||
|
||||
# Create pipelines for both agents
|
||||
actor_pipeline_info = await self.pipeline_builder.create_agent_pipeline(
|
||||
transport=actor_transport,
|
||||
workflow=test_session.actor_workflow,
|
||||
test_session_id=test_session_id,
|
||||
agent_id=actor_id,
|
||||
role="actor",
|
||||
)
|
||||
actor_pipeline_task = actor_pipeline_info["task"]
|
||||
|
||||
adversary_pipeline_info = await self.pipeline_builder.create_agent_pipeline(
|
||||
transport=adversary_transport,
|
||||
workflow=test_session.adversary_workflow,
|
||||
test_session_id=test_session_id,
|
||||
agent_id=adversary_id,
|
||||
role="adversary",
|
||||
)
|
||||
|
||||
adversary_pipeline_task = adversary_pipeline_info["task"]
|
||||
|
||||
# Register event handlers for both pipelines
|
||||
await self._register_transport_handlers(
|
||||
actor_transport, actor_pipeline_info, test_session_id, "actor"
|
||||
)
|
||||
await self._register_transport_handlers(
|
||||
adversary_transport,
|
||||
adversary_pipeline_info,
|
||||
test_session_id,
|
||||
"adversary",
|
||||
)
|
||||
|
||||
# Store session info
|
||||
session_info = {
|
||||
"test_session": test_session,
|
||||
"conversation": conversation,
|
||||
"actor_task": actor_pipeline_task,
|
||||
"adversary_task": adversary_pipeline_task,
|
||||
"actor_transport": actor_transport,
|
||||
"adversary_transport": adversary_transport,
|
||||
"start_time": datetime.now(UTC),
|
||||
}
|
||||
self.session_manager.add_session(test_session_id, session_info)
|
||||
|
||||
# Start both pipelines in background tasks
|
||||
from pipecat.pipeline.base_task import PipelineTaskParams
|
||||
|
||||
params = PipelineTaskParams(loop=asyncio.get_event_loop())
|
||||
|
||||
# Start the pipelines - this will trigger initialization through the normal pipeline start process
|
||||
# The workflow engines will be initialized when the pipeline starts
|
||||
|
||||
# Create conversation IDs for tracing
|
||||
actor_conversation_id = f"{test_session_id}-actor-{actor_id}"
|
||||
adversary_conversation_id = f"{test_session_id}-adversary-{adversary_id}"
|
||||
|
||||
# Create tasks but don't await them - they'll run in the background
|
||||
logger.debug(f"Running actor task with ID: {actor_id}")
|
||||
actor_task_future = asyncio.create_task(
|
||||
self._run_pipeline_with_context(
|
||||
actor_pipeline_task,
|
||||
params,
|
||||
actor_id,
|
||||
actor_conversation_id,
|
||||
"actor",
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"Running adversary task with ID: {adversary_id}")
|
||||
adversary_task_future = asyncio.create_task(
|
||||
self._run_pipeline_with_context(
|
||||
adversary_pipeline_task,
|
||||
params,
|
||||
adversary_id,
|
||||
adversary_conversation_id,
|
||||
"adversary",
|
||||
)
|
||||
)
|
||||
|
||||
# Store the futures so we can monitor them
|
||||
session_info["actor_task_future"] = actor_task_future
|
||||
session_info["adversary_task_future"] = adversary_task_future
|
||||
|
||||
logger.info(f"Started LoopTalk test session {test_session_id}")
|
||||
|
||||
return {
|
||||
"test_session_id": test_session_id,
|
||||
"conversation_id": conversation.id,
|
||||
"status": "running",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to start test session {test_session_id}: {e}")
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id, status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def _register_transport_handlers(
|
||||
self,
|
||||
transport: InternalTransport,
|
||||
pipeline_info: Dict[str, Any],
|
||||
test_session_id: int,
|
||||
role: str,
|
||||
):
|
||||
"""Register transport event handlers for a pipeline.
|
||||
|
||||
Args:
|
||||
transport: The transport to register handlers on
|
||||
pipeline_info: Dictionary containing pipeline components
|
||||
test_session_id: ID of the test session
|
||||
role: Either "actor" or "adversary"
|
||||
"""
|
||||
engine = pipeline_info["engine"]
|
||||
task = pipeline_info["task"]
|
||||
audio_buffer = pipeline_info["audio_buffer"]
|
||||
transcript = pipeline_info["transcript"]
|
||||
assistant_context_aggregator = pipeline_info["assistant_context_aggregator"]
|
||||
|
||||
# Register transport event handlers
|
||||
@transport.event_handler("on_client_connected")
|
||||
async def on_client_connected(transport, participant):
|
||||
logger.debug(f"LoopTalk {role} client connected - initializing workflow")
|
||||
# Start audio recording
|
||||
await audio_buffer.start_recording()
|
||||
await engine.initialize()
|
||||
|
||||
@transport.event_handler("on_client_disconnected")
|
||||
async def on_client_disconnected(transport, participant):
|
||||
logger.debug(f"LoopTalk {role} client disconnected")
|
||||
# Stop audio recording
|
||||
await audio_buffer.stop_recording()
|
||||
|
||||
# Handle disconnect propagation - stop the other agent too
|
||||
await self.session_manager.handle_agent_disconnect(
|
||||
test_session_id, role, self.stop_test_session
|
||||
)
|
||||
|
||||
await task.cancel()
|
||||
|
||||
# Register custom audio and transcript handlers for LoopTalk
|
||||
await self._register_looptalk_handlers(
|
||||
audio_buffer, transcript, test_session_id, role
|
||||
)
|
||||
|
||||
async def _register_looptalk_handlers(
|
||||
self, audio_buffer, transcript, test_session_id: int, role: str
|
||||
):
|
||||
"""Register LoopTalk-specific handlers for audio and transcript recording"""
|
||||
|
||||
paths = self.recording_manager.get_recording_paths(test_session_id, role)
|
||||
|
||||
# Store audio metadata for later WAV conversion
|
||||
audio_metadata = {"sample_rate": None, "num_channels": None}
|
||||
|
||||
# Audio handler - writes directly to PCM file
|
||||
@audio_buffer.event_handler("on_audio_data")
|
||||
async def on_audio_data(buffer, audio, sample_rate, num_channels):
|
||||
if not audio:
|
||||
return
|
||||
|
||||
# Store metadata on first write
|
||||
if audio_metadata["sample_rate"] is None:
|
||||
audio_metadata["sample_rate"] = sample_rate
|
||||
audio_metadata["num_channels"] = num_channels
|
||||
|
||||
# Append PCM data to temporary file
|
||||
try:
|
||||
with open(paths["temp_audio"], "ab") as f:
|
||||
f.write(audio)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to write audio for {role} in session {test_session_id}: {e}"
|
||||
)
|
||||
|
||||
# Transcript handler - writes directly to text file
|
||||
@transcript.event_handler("on_transcript_update")
|
||||
async def on_transcript_update(processor, frame):
|
||||
transcript_text = ""
|
||||
for msg in frame.messages:
|
||||
timestamp = f"[{msg.timestamp}] " if msg.timestamp else ""
|
||||
line = f"{timestamp}{msg.role}: {msg.content}\n"
|
||||
transcript_text += line
|
||||
|
||||
# Append transcript to file
|
||||
try:
|
||||
with open(paths["transcript"], "a") as f:
|
||||
f.write(transcript_text)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
f"Failed to write transcript for {role} in session {test_session_id}: {e}"
|
||||
)
|
||||
|
||||
# Store metadata in session info for later WAV conversion
|
||||
# Set default values if not yet captured
|
||||
if audio_metadata["sample_rate"] is None:
|
||||
audio_metadata["sample_rate"] = 16000 # Default sample rate
|
||||
audio_metadata["num_channels"] = 1 # Default channels
|
||||
|
||||
self.session_manager.update_audio_metadata(
|
||||
test_session_id,
|
||||
role,
|
||||
sample_rate=audio_metadata["sample_rate"],
|
||||
num_channels=audio_metadata["num_channels"],
|
||||
)
|
||||
|
||||
async def _run_pipeline_with_context(
|
||||
self,
|
||||
pipeline_task: PipelineTask,
|
||||
params,
|
||||
agent_id: str,
|
||||
conversation_id: str,
|
||||
role: str,
|
||||
):
|
||||
"""Run a pipeline task with the agent_id set in context"""
|
||||
set_current_run_id(agent_id)
|
||||
return await pipeline_task.run(params)
|
||||
|
||||
async def stop_test_session(self, test_session_id: int) -> Dict[str, Any]:
|
||||
"""Stop a running test session."""
|
||||
|
||||
session_info = self.session_manager.get_session(test_session_id)
|
||||
if not session_info:
|
||||
raise ValueError(f"Test session {test_session_id} is not running")
|
||||
|
||||
try:
|
||||
# Cancel both pipeline tasks
|
||||
await session_info["actor_task"].cancel()
|
||||
await session_info["adversary_task"].cancel()
|
||||
|
||||
# Also cancel the task futures if they exist
|
||||
if "actor_task_future" in session_info:
|
||||
session_info["actor_task_future"].cancel()
|
||||
if "adversary_task_future" in session_info:
|
||||
session_info["adversary_task_future"].cancel()
|
||||
|
||||
# Calculate duration
|
||||
duration_seconds = int(
|
||||
(datetime.now(UTC) - session_info["start_time"]).total_seconds()
|
||||
)
|
||||
|
||||
# Update conversation
|
||||
await self.db_client.update_conversation(
|
||||
conversation_id=session_info["conversation"].id,
|
||||
duration_seconds=duration_seconds,
|
||||
ended_at=datetime.now(UTC),
|
||||
)
|
||||
|
||||
# Update test session status
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id,
|
||||
status="completed",
|
||||
results={
|
||||
"duration_seconds": duration_seconds,
|
||||
"conversation_id": session_info["conversation"].id,
|
||||
},
|
||||
)
|
||||
|
||||
# Finalize recordings for both actor and adversary
|
||||
# Convert PCM files to WAV
|
||||
actor_metadata = self.session_manager.get_audio_metadata(
|
||||
test_session_id, "actor"
|
||||
)
|
||||
adversary_metadata = self.session_manager.get_audio_metadata(
|
||||
test_session_id, "adversary"
|
||||
)
|
||||
|
||||
self.recording_manager.convert_pcm_to_wav(
|
||||
test_session_id,
|
||||
"actor",
|
||||
sample_rate=actor_metadata["sample_rate"],
|
||||
num_channels=actor_metadata["num_channels"],
|
||||
)
|
||||
self.recording_manager.convert_pcm_to_wav(
|
||||
test_session_id,
|
||||
"adversary",
|
||||
sample_rate=adversary_metadata["sample_rate"],
|
||||
num_channels=adversary_metadata["num_channels"],
|
||||
)
|
||||
|
||||
# Upload recordings to S3 (synchronously for load testing)
|
||||
(
|
||||
actor_audio_url,
|
||||
actor_transcript_url,
|
||||
) = await self.recording_manager.upload_recording_to_s3(
|
||||
test_session_id, "actor"
|
||||
)
|
||||
(
|
||||
adversary_audio_url,
|
||||
adversary_transcript_url,
|
||||
) = await self.recording_manager.upload_recording_to_s3(
|
||||
test_session_id, "adversary"
|
||||
)
|
||||
|
||||
# Update conversation with recording URLs
|
||||
await self.db_client.update_conversation(
|
||||
conversation_id=session_info["conversation"].id,
|
||||
actor_recording_url=actor_audio_url,
|
||||
adversary_recording_url=adversary_audio_url,
|
||||
transcript={
|
||||
"actor_transcript_url": actor_transcript_url,
|
||||
"adversary_transcript_url": adversary_transcript_url,
|
||||
},
|
||||
)
|
||||
|
||||
# Log recording locations
|
||||
logger.info(f"LoopTalk recordings uploaded to S3:")
|
||||
if actor_audio_url:
|
||||
logger.info(f" - Actor audio: {actor_audio_url}")
|
||||
if actor_transcript_url:
|
||||
logger.info(f" - Actor transcript: {actor_transcript_url}")
|
||||
if adversary_audio_url:
|
||||
logger.info(f" - Adversary audio: {adversary_audio_url}")
|
||||
if adversary_transcript_url:
|
||||
logger.info(f" - Adversary transcript: {adversary_transcript_url}")
|
||||
|
||||
# Clean up local files after successful upload
|
||||
self.recording_manager.cleanup_session_files(test_session_id)
|
||||
|
||||
# Clean up
|
||||
self.transport_manager.remove_transport_pair(str(test_session_id))
|
||||
self.session_manager.remove_session(test_session_id)
|
||||
|
||||
# Clean up audio streamers
|
||||
from api.services.looptalk.audio_streamer import cleanup_audio_streamers
|
||||
|
||||
cleanup_audio_streamers(str(test_session_id))
|
||||
|
||||
logger.info(f"Stopped LoopTalk test session {test_session_id}")
|
||||
|
||||
return {
|
||||
"test_session_id": test_session_id,
|
||||
"status": "completed",
|
||||
"duration_seconds": duration_seconds,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to stop test session {test_session_id}: {e}")
|
||||
await self.db_client.update_test_session_status(
|
||||
test_session_id=test_session_id, status="failed", error=str(e)
|
||||
)
|
||||
raise
|
||||
|
||||
async def start_load_test(
|
||||
self,
|
||||
organization_id: int,
|
||||
name_prefix: str,
|
||||
actor_workflow_id: int,
|
||||
adversary_workflow_id: int,
|
||||
config: Dict[str, Any],
|
||||
test_count: int,
|
||||
) -> Dict[str, Any]:
|
||||
"""Start a load test with multiple concurrent test sessions."""
|
||||
|
||||
# Validate test count
|
||||
if test_count < 1 or test_count > 10:
|
||||
raise ValueError("Test count must be between 1 and 10")
|
||||
|
||||
# Create test sessions
|
||||
test_sessions = await self.db_client.create_load_test_group(
|
||||
organization_id=organization_id,
|
||||
name_prefix=name_prefix,
|
||||
actor_workflow_id=actor_workflow_id,
|
||||
adversary_workflow_id=adversary_workflow_id,
|
||||
config=config,
|
||||
test_count=test_count,
|
||||
)
|
||||
|
||||
# Start all test sessions concurrently
|
||||
tasks = []
|
||||
for test_session in test_sessions:
|
||||
task = asyncio.create_task(
|
||||
self.start_test_session(
|
||||
test_session_id=test_session.id, organization_id=organization_id
|
||||
)
|
||||
)
|
||||
tasks.append(task)
|
||||
|
||||
# Wait for all to start
|
||||
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Count successes and failures
|
||||
started = sum(1 for r in results if not isinstance(r, Exception))
|
||||
failed = sum(1 for r in results if isinstance(r, Exception))
|
||||
|
||||
load_test_group_id = test_sessions[0].load_test_group_id
|
||||
|
||||
logger.info(
|
||||
f"Started load test {load_test_group_id}: "
|
||||
f"{started} started, {failed} failed out of {test_count}"
|
||||
)
|
||||
|
||||
return {
|
||||
"load_test_group_id": load_test_group_id,
|
||||
"total": test_count,
|
||||
"started": started,
|
||||
"failed": failed,
|
||||
"test_session_ids": [ts.id for ts in test_sessions],
|
||||
}
|
||||
|
||||
def get_active_test_count(self) -> int:
|
||||
"""Get the number of currently active test sessions."""
|
||||
return self.session_manager.get_active_count()
|
||||
|
||||
def get_active_test_info(self) -> Dict[str, Any]:
|
||||
"""Get information about all active test sessions."""
|
||||
return self.session_manager.get_active_info()
|
||||
|
||||
def get_recording_info(self, test_session_id: int) -> Dict[str, Any]:
|
||||
"""Get information about recordings for a test session"""
|
||||
return self.recording_manager.get_recording_info(test_session_id)
|
||||
|
|
@ -1,7 +1,7 @@
|
|||
"""Transport factories for non-telephony pipelines.
|
||||
|
||||
Telephony transports live in their respective ``api.services.telephony.providers/<name>/transport.py``.
|
||||
This module hosts only the shared, non-telephony transports (WebRTC, internal/LoopTalk).
|
||||
This module hosts only the shared, non-telephony transports (WebRTC).
|
||||
"""
|
||||
|
||||
from api.services.pipecat.audio_config import AudioConfig
|
||||
|
|
@ -32,24 +32,3 @@ async def create_webrtc_transport(
|
|||
audio_out_mixer=mixer,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def create_internal_transport(
|
||||
workflow_run_id: int,
|
||||
audio_config: AudioConfig,
|
||||
latency_seconds: float = 0.0,
|
||||
ambient_noise_config: dict | None = None,
|
||||
):
|
||||
"""Create an internal transport for agent-to-agent connections (LoopTalk).
|
||||
|
||||
Args:
|
||||
workflow_run_id: ID of the workflow run for turn analyzer context
|
||||
audio_config: Audio configuration for the transport
|
||||
latency_seconds: Network latency to simulate
|
||||
|
||||
Returns:
|
||||
InternalTransport instance configured with turn analyzer
|
||||
"""
|
||||
pass
|
||||
# Commented out because looptalk coming in the regular import flow
|
||||
# was causing issue. May be move this to looptalk/orchestrator.py
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue