mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-05-12 16:52:37 +02:00
Fixed document-rag workspace problem (#866)
- Fixed document-rag workspace problem - OpenAI text-completion processor now puts 'not-set' in the token if no token is set (new OpenAI library requires it to be set to something. - Update tests
This commit is contained in:
parent
03cc5ac80f
commit
d282d72db1
7 changed files with 22 additions and 19 deletions
|
|
@ -54,7 +54,7 @@ class TestDocumentRagIntegration:
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_fetch_chunk(self):
|
def mock_fetch_chunk(self):
|
||||||
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
"""Mock fetch_chunk function that retrieves chunk content from librarian"""
|
||||||
async def fetch(chunk_id, user):
|
async def fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
return fetch
|
return fetch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -297,10 +297,10 @@ class TestTextCompletionIntegration:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_text_completion_authentication_patterns(self):
|
async def test_text_completion_authentication_patterns(self):
|
||||||
"""Test different authentication configurations"""
|
"""Test different authentication configurations"""
|
||||||
# Test missing API key first (this should fail early)
|
# Test missing API key - now uses placeholder instead of raising
|
||||||
with pytest.raises(RuntimeError) as exc_info:
|
# (newer openai package rejects empty string keys at validation)
|
||||||
Processor(id="test-no-key", api_key=None)
|
# Processor(id="test-no-key", api_key=None) would fail on
|
||||||
assert "OpenAI API key not specified" in str(exc_info.value)
|
# missing taskgroup, not on API key
|
||||||
|
|
||||||
# Test authentication pattern by examining the initialization logic
|
# Test authentication pattern by examining the initialization logic
|
||||||
# Since we can't fully instantiate due to taskgroup requirements,
|
# Since we can't fully instantiate due to taskgroup requirements,
|
||||||
|
|
|
||||||
|
|
@ -27,7 +27,7 @@ CHUNK_CONTENT = {
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def mock_fetch_chunk():
|
def mock_fetch_chunk():
|
||||||
"""Create a mock fetch_chunk function"""
|
"""Create a mock fetch_chunk function"""
|
||||||
async def fetch(chunk_id, user):
|
async def fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
return fetch
|
return fetch
|
||||||
|
|
||||||
|
|
@ -203,7 +203,7 @@ class TestQuery:
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
# Mock fetch_chunk function
|
# Mock fetch_chunk function
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
@ -361,7 +361,7 @@ class TestQuery:
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
# Mock fetch_chunk
|
# Mock fetch_chunk
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
@ -437,7 +437,7 @@ class TestQuery:
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return f"Content for {chunk_id}"
|
return f"Content for {chunk_id}"
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
@ -594,7 +594,7 @@ class TestQuery:
|
||||||
mock_rag.embeddings_client = mock_embeddings_client
|
mock_rag.embeddings_client = mock_embeddings_client
|
||||||
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
mock_rag.doc_embeddings_client = mock_doc_embeddings_client
|
||||||
|
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
return CHUNK_CONTENT.get(chunk_id, f"Content for {chunk_id}")
|
||||||
mock_rag.fetch_chunk = mock_fetch
|
mock_rag.fetch_chunk = mock_fetch
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -105,7 +105,7 @@ def build_mock_clients():
|
||||||
]
|
]
|
||||||
|
|
||||||
# 4. Chunk content
|
# 4. Chunk content
|
||||||
async def mock_fetch(chunk_id, user):
|
async def mock_fetch(chunk_id):
|
||||||
return {
|
return {
|
||||||
CHUNK_A: CHUNK_A_CONTENT,
|
CHUNK_A: CHUNK_A_CONTENT,
|
||||||
CHUNK_B: CHUNK_B_CONTENT,
|
CHUNK_B: CHUNK_B_CONTENT,
|
||||||
|
|
|
||||||
|
|
@ -171,14 +171,16 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
@patch('trustgraph.base.llm_service.LlmService.__init__')
|
||||||
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
|
async def test_processor_initialization_without_api_key(self, mock_llm_init, mock_async_init, mock_openai_class):
|
||||||
"""Test processor initialization without API key (should fail)"""
|
"""Test processor initialization without API key uses placeholder"""
|
||||||
# Arrange
|
# Arrange
|
||||||
|
mock_openai_client = MagicMock()
|
||||||
|
mock_openai_class.return_value = mock_openai_client
|
||||||
mock_async_init.return_value = None
|
mock_async_init.return_value = None
|
||||||
mock_llm_init.return_value = None
|
mock_llm_init.return_value = None
|
||||||
|
|
||||||
config = {
|
config = {
|
||||||
'model': 'gpt-3.5-turbo',
|
'model': 'gpt-3.5-turbo',
|
||||||
'api_key': None, # No API key provided
|
'api_key': None,
|
||||||
'url': 'https://api.openai.com/v1',
|
'url': 'https://api.openai.com/v1',
|
||||||
'temperature': 0.0,
|
'temperature': 0.0,
|
||||||
'max_output': 4096,
|
'max_output': 4096,
|
||||||
|
|
@ -187,9 +189,10 @@ class TestOpenAIProcessorSimple(IsolatedAsyncioTestCase):
|
||||||
'id': 'test-processor'
|
'id': 'test-processor'
|
||||||
}
|
}
|
||||||
|
|
||||||
# Act & Assert
|
processor = Processor(**config)
|
||||||
with pytest.raises(RuntimeError, match="OpenAI API key not specified"):
|
mock_openai_class.assert_called_once_with(
|
||||||
processor = Processor(**config)
|
base_url='https://api.openai.com/v1', api_key='not-set'
|
||||||
|
)
|
||||||
|
|
||||||
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
@patch('trustgraph.model.text_completion.openai.llm.OpenAI')
|
||||||
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
@patch('trustgraph.base.async_processor.AsyncProcessor.__init__')
|
||||||
|
|
|
||||||
|
|
@ -35,8 +35,8 @@ class Processor(LlmService):
|
||||||
temperature = params.get("temperature", default_temperature)
|
temperature = params.get("temperature", default_temperature)
|
||||||
max_output = params.get("max_output", default_max_output)
|
max_output = params.get("max_output", default_max_output)
|
||||||
|
|
||||||
if api_key is None:
|
if not api_key:
|
||||||
raise RuntimeError("OpenAI API key not specified")
|
api_key = "not-set"
|
||||||
|
|
||||||
super(Processor, self).__init__(
|
super(Processor, self).__init__(
|
||||||
**params | {
|
**params | {
|
||||||
|
|
|
||||||
|
|
@ -122,7 +122,7 @@ class Query:
|
||||||
for match in chunk_matches:
|
for match in chunk_matches:
|
||||||
if match.chunk_id:
|
if match.chunk_id:
|
||||||
try:
|
try:
|
||||||
content = await self.rag.fetch_chunk(match.chunk_id, self.workspace)
|
content = await self.rag.fetch_chunk(match.chunk_id)
|
||||||
docs.append(content)
|
docs.append(content)
|
||||||
chunk_ids.append(match.chunk_id)
|
chunk_ids.append(match.chunk_id)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue