mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
fix: repair deferred imports to preserve module-level names for test patching (#831)
A previous commit moved SDK imports into __init__/methods and stashed them on self, which broke @patch targets in 24 unit tests. This fixes the approach: chunker and pdf_decoder use module-level sentinels with global/if-None guards so imports are still deferred but patchable. Google AI Studio reverts to standard module-level imports since the module is only loaded when communicating with Gemini. Keeps lazy loading on other imports.
This commit is contained in:
parent
d7745baab4
commit
cce3acd84f
4 changed files with 48 additions and 36 deletions
|
|
@ -10,6 +10,8 @@ from prometheus_client import Histogram
|
|||
from ... schema import TextDocument, Chunk, Metadata, Triples
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
RecursiveCharacterTextSplitter = None
|
||||
|
||||
from ... provenance import (
|
||||
chunk_uri as make_chunk_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
|
|
@ -41,8 +43,12 @@ class Processor(ChunkingService):
|
|||
self.default_chunk_size = chunk_size
|
||||
self.default_chunk_overlap = chunk_overlap
|
||||
|
||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||
self.RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter
|
||||
global RecursiveCharacterTextSplitter
|
||||
if RecursiveCharacterTextSplitter is None:
|
||||
from langchain_text_splitters import (
|
||||
RecursiveCharacterTextSplitter as _cls,
|
||||
)
|
||||
RecursiveCharacterTextSplitter = _cls
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
|
|
@ -52,7 +58,7 @@ class Processor(ChunkingService):
|
|||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = self.RecursiveCharacterTextSplitter(
|
||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
|
|
@ -105,7 +111,7 @@ class Processor(ChunkingService):
|
|||
chunk_overlap = int(chunk_overlap)
|
||||
|
||||
# Create text splitter with effective parameters
|
||||
text_splitter = self.RecursiveCharacterTextSplitter(
|
||||
text_splitter = RecursiveCharacterTextSplitter(
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
length_function=len,
|
||||
|
|
|
|||
|
|
@ -10,6 +10,8 @@ from prometheus_client import Histogram
|
|||
from ... schema import TextDocument, Chunk, Metadata, Triples
|
||||
from ... base import ChunkingService, ConsumerSpec, ProducerSpec
|
||||
|
||||
TokenTextSplitter = None
|
||||
|
||||
from ... provenance import (
|
||||
chunk_uri as make_chunk_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
|
|
@ -41,8 +43,10 @@ class Processor(ChunkingService):
|
|||
self.default_chunk_size = chunk_size
|
||||
self.default_chunk_overlap = chunk_overlap
|
||||
|
||||
from langchain_text_splitters import TokenTextSplitter
|
||||
self.TokenTextSplitter = TokenTextSplitter
|
||||
global TokenTextSplitter
|
||||
if TokenTextSplitter is None:
|
||||
from langchain_text_splitters import TokenTextSplitter as _cls
|
||||
TokenTextSplitter = _cls
|
||||
|
||||
if not hasattr(__class__, "chunk_metric"):
|
||||
__class__.chunk_metric = Histogram(
|
||||
|
|
@ -52,7 +56,7 @@ class Processor(ChunkingService):
|
|||
2500, 4000, 6400, 10000, 16000]
|
||||
)
|
||||
|
||||
self.text_splitter = self.TokenTextSplitter(
|
||||
self.text_splitter = TokenTextSplitter(
|
||||
encoding_name="cl100k_base",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
|
|
@ -104,7 +108,7 @@ class Processor(ChunkingService):
|
|||
chunk_overlap = int(chunk_overlap)
|
||||
|
||||
# Create text splitter with effective parameters
|
||||
text_splitter = self.TokenTextSplitter(
|
||||
text_splitter = TokenTextSplitter(
|
||||
encoding_name="cl100k_base",
|
||||
chunk_size=chunk_size,
|
||||
chunk_overlap=chunk_overlap,
|
||||
|
|
|
|||
|
|
@ -15,6 +15,9 @@ from ... schema import Document, TextDocument, Metadata
|
|||
from ... schema import librarian_request_queue, librarian_response_queue
|
||||
from ... schema import Triples
|
||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
||||
|
||||
PyPDFLoader = None
|
||||
|
||||
from ... provenance import (
|
||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||
set_graph, GRAPH_SOURCE,
|
||||
|
|
@ -128,7 +131,12 @@ class Processor(FlowProcessor):
|
|||
fp.write(base64.b64decode(v.data))
|
||||
fp.close()
|
||||
|
||||
from langchain_community.document_loaders import PyPDFLoader
|
||||
global PyPDFLoader
|
||||
if PyPDFLoader is None:
|
||||
from langchain_community.document_loaders import (
|
||||
PyPDFLoader as _cls,
|
||||
)
|
||||
PyPDFLoader = _cls
|
||||
loader = PyPDFLoader(temp_path)
|
||||
pages = loader.load()
|
||||
|
||||
|
|
|
|||
|
|
@ -18,6 +18,12 @@ import logging
|
|||
# Module logger
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.types import HarmCategory, HarmBlockThreshold
|
||||
from google.genai.errors import ClientError
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
|
||||
from .... exceptions import TooManyRequests
|
||||
from .... base import LlmService, LlmResult, LlmChunk
|
||||
|
||||
|
|
@ -37,18 +43,6 @@ class Processor(LlmService):
|
|||
temperature = params.get("temperature", default_temperature)
|
||||
max_output = params.get("max_output", default_max_output)
|
||||
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
from google.genai.types import HarmCategory, HarmBlockThreshold
|
||||
from google.genai.errors import ClientError
|
||||
from google.api_core.exceptions import ResourceExhausted
|
||||
self.genai = genai
|
||||
self.types = types
|
||||
self.HarmCategory = HarmCategory
|
||||
self.HarmBlockThreshold = HarmBlockThreshold
|
||||
self.ClientError = ClientError
|
||||
self.ResourceExhausted = ResourceExhausted
|
||||
|
||||
if api_key is None:
|
||||
raise RuntimeError("Google AI Studio API key not specified")
|
||||
|
||||
|
|
@ -60,7 +54,7 @@ class Processor(LlmService):
|
|||
}
|
||||
)
|
||||
|
||||
self.client = self.genai.Client(api_key=api_key, vertexai=False)
|
||||
self.client = genai.Client(api_key=api_key, vertexai=False)
|
||||
self.default_model = model
|
||||
self.temperature = temperature
|
||||
self.max_output = max_output
|
||||
|
|
@ -68,23 +62,23 @@ class Processor(LlmService):
|
|||
# Cache for generation configs per model
|
||||
self.generation_configs = {}
|
||||
|
||||
block_level = self.HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||
|
||||
self.safety_settings = [
|
||||
self.types.SafetySetting(
|
||||
category = self.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||
threshold = block_level,
|
||||
),
|
||||
self.types.SafetySetting(
|
||||
category = self.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
self.types.SafetySetting(
|
||||
category = self.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||
threshold = block_level,
|
||||
),
|
||||
self.types.SafetySetting(
|
||||
category = self.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
types.SafetySetting(
|
||||
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||
threshold = block_level,
|
||||
),
|
||||
# There is a documentation conflict on whether or not
|
||||
|
|
@ -104,7 +98,7 @@ class Processor(LlmService):
|
|||
|
||||
if cache_key not in self.generation_configs:
|
||||
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
|
||||
self.generation_configs[cache_key] = self.types.GenerateContentConfig(
|
||||
self.generation_configs[cache_key] = types.GenerateContentConfig(
|
||||
temperature = effective_temperature,
|
||||
top_p = 1,
|
||||
top_k = 40,
|
||||
|
|
@ -153,14 +147,14 @@ class Processor(LlmService):
|
|||
|
||||
return resp
|
||||
|
||||
except self.ResourceExhausted as e:
|
||||
except ResourceExhausted as e:
|
||||
|
||||
logger.warning("Rate limit exceeded")
|
||||
|
||||
# Leave rate limit retries to the default handler
|
||||
raise TooManyRequests()
|
||||
|
||||
except self.ClientError as e:
|
||||
except ClientError as e:
|
||||
# google-genai SDK throws ClientError for 4xx errors
|
||||
if e.code == 429:
|
||||
logger.warning(f"Rate limit exceeded (ClientError 429): {e}")
|
||||
|
|
@ -229,11 +223,11 @@ class Processor(LlmService):
|
|||
|
||||
logger.debug("Streaming complete")
|
||||
|
||||
except self.ResourceExhausted:
|
||||
except ResourceExhausted:
|
||||
logger.warning("Rate limit exceeded during streaming")
|
||||
raise TooManyRequests()
|
||||
|
||||
except self.ClientError as e:
|
||||
except ClientError as e:
|
||||
# google-genai SDK throws ClientError for 4xx errors
|
||||
if e.code == 429:
|
||||
logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}")
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue