fix: repair deferred imports to preserve module-level names for test patching (#831)

A previous commit moved SDK imports into __init__/methods and
stashed them on self, which broke @patch targets in 24 unit tests.

This fixes the approach: chunker and pdf_decoder use module-level
sentinels with global/if-None guards so imports are still deferred but
patchable. Google AI Studio reverts to standard module-level imports
since the module is only loaded when communicating with Gemini.
Keeps lazy loading on other imports.
This commit is contained in:
cybermaggedon 2026-04-18 11:43:21 +01:00 committed by GitHub
parent d7745baab4
commit cce3acd84f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 48 additions and 36 deletions

View file

@ -10,6 +10,8 @@ from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata, Triples from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec from ... base import ChunkingService, ConsumerSpec, ProducerSpec
RecursiveCharacterTextSplitter = None
from ... provenance import ( from ... provenance import (
chunk_uri as make_chunk_uri, derived_entity_triples, chunk_uri as make_chunk_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE, set_graph, GRAPH_SOURCE,
@ -41,8 +43,12 @@ class Processor(ChunkingService):
self.default_chunk_size = chunk_size self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap self.default_chunk_overlap = chunk_overlap
from langchain_text_splitters import RecursiveCharacterTextSplitter global RecursiveCharacterTextSplitter
self.RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter if RecursiveCharacterTextSplitter is None:
from langchain_text_splitters import (
RecursiveCharacterTextSplitter as _cls,
)
RecursiveCharacterTextSplitter = _cls
if not hasattr(__class__, "chunk_metric"): if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram( __class__.chunk_metric = Histogram(
@ -52,7 +58,7 @@ class Processor(ChunkingService):
2500, 4000, 6400, 10000, 16000] 2500, 4000, 6400, 10000, 16000]
) )
self.text_splitter = self.RecursiveCharacterTextSplitter( self.text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
length_function=len, length_function=len,
@ -105,7 +111,7 @@ class Processor(ChunkingService):
chunk_overlap = int(chunk_overlap) chunk_overlap = int(chunk_overlap)
# Create text splitter with effective parameters # Create text splitter with effective parameters
text_splitter = self.RecursiveCharacterTextSplitter( text_splitter = RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
length_function=len, length_function=len,

View file

@ -10,6 +10,8 @@ from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata, Triples from ... schema import TextDocument, Chunk, Metadata, Triples
from ... base import ChunkingService, ConsumerSpec, ProducerSpec from ... base import ChunkingService, ConsumerSpec, ProducerSpec
TokenTextSplitter = None
from ... provenance import ( from ... provenance import (
chunk_uri as make_chunk_uri, derived_entity_triples, chunk_uri as make_chunk_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE, set_graph, GRAPH_SOURCE,
@ -41,8 +43,10 @@ class Processor(ChunkingService):
self.default_chunk_size = chunk_size self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap self.default_chunk_overlap = chunk_overlap
from langchain_text_splitters import TokenTextSplitter global TokenTextSplitter
self.TokenTextSplitter = TokenTextSplitter if TokenTextSplitter is None:
from langchain_text_splitters import TokenTextSplitter as _cls
TokenTextSplitter = _cls
if not hasattr(__class__, "chunk_metric"): if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram( __class__.chunk_metric = Histogram(
@ -52,7 +56,7 @@ class Processor(ChunkingService):
2500, 4000, 6400, 10000, 16000] 2500, 4000, 6400, 10000, 16000]
) )
self.text_splitter = self.TokenTextSplitter( self.text_splitter = TokenTextSplitter(
encoding_name="cl100k_base", encoding_name="cl100k_base",
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
@ -104,7 +108,7 @@ class Processor(ChunkingService):
chunk_overlap = int(chunk_overlap) chunk_overlap = int(chunk_overlap)
# Create text splitter with effective parameters # Create text splitter with effective parameters
text_splitter = self.TokenTextSplitter( text_splitter = TokenTextSplitter(
encoding_name="cl100k_base", encoding_name="cl100k_base",
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,

View file

@ -15,6 +15,9 @@ from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
PyPDFLoader = None
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE, set_graph, GRAPH_SOURCE,
@ -128,7 +131,12 @@ class Processor(FlowProcessor):
fp.write(base64.b64decode(v.data)) fp.write(base64.b64decode(v.data))
fp.close() fp.close()
from langchain_community.document_loaders import PyPDFLoader global PyPDFLoader
if PyPDFLoader is None:
from langchain_community.document_loaders import (
PyPDFLoader as _cls,
)
PyPDFLoader = _cls
loader = PyPDFLoader(temp_path) loader = PyPDFLoader(temp_path)
pages = loader.load() pages = loader.load()

View file

@ -18,6 +18,12 @@ import logging
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
from google import genai
from google.genai import types
from google.genai.types import HarmCategory, HarmBlockThreshold
from google.genai.errors import ClientError
from google.api_core.exceptions import ResourceExhausted
from .... exceptions import TooManyRequests from .... exceptions import TooManyRequests
from .... base import LlmService, LlmResult, LlmChunk from .... base import LlmService, LlmResult, LlmChunk
@ -37,18 +43,6 @@ class Processor(LlmService):
temperature = params.get("temperature", default_temperature) temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output) max_output = params.get("max_output", default_max_output)
from google import genai
from google.genai import types
from google.genai.types import HarmCategory, HarmBlockThreshold
from google.genai.errors import ClientError
from google.api_core.exceptions import ResourceExhausted
self.genai = genai
self.types = types
self.HarmCategory = HarmCategory
self.HarmBlockThreshold = HarmBlockThreshold
self.ClientError = ClientError
self.ResourceExhausted = ResourceExhausted
if api_key is None: if api_key is None:
raise RuntimeError("Google AI Studio API key not specified") raise RuntimeError("Google AI Studio API key not specified")
@ -60,7 +54,7 @@ class Processor(LlmService):
} }
) )
self.client = self.genai.Client(api_key=api_key, vertexai=False) self.client = genai.Client(api_key=api_key, vertexai=False)
self.default_model = model self.default_model = model
self.temperature = temperature self.temperature = temperature
self.max_output = max_output self.max_output = max_output
@ -68,23 +62,23 @@ class Processor(LlmService):
# Cache for generation configs per model # Cache for generation configs per model
self.generation_configs = {} self.generation_configs = {}
block_level = self.HarmBlockThreshold.BLOCK_ONLY_HIGH block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [ self.safety_settings = [
self.types.SafetySetting( types.SafetySetting(
category = self.HarmCategory.HARM_CATEGORY_HATE_SPEECH, category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level, threshold = block_level,
), ),
self.types.SafetySetting( types.SafetySetting(
category = self.HarmCategory.HARM_CATEGORY_HARASSMENT, category = HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level, threshold = block_level,
), ),
self.types.SafetySetting( types.SafetySetting(
category = self.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level, threshold = block_level,
), ),
self.types.SafetySetting( types.SafetySetting(
category = self.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level, threshold = block_level,
), ),
# There is a documentation conflict on whether or not # There is a documentation conflict on whether or not
@ -104,7 +98,7 @@ class Processor(LlmService):
if cache_key not in self.generation_configs: if cache_key not in self.generation_configs:
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}") logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
self.generation_configs[cache_key] = self.types.GenerateContentConfig( self.generation_configs[cache_key] = types.GenerateContentConfig(
temperature = effective_temperature, temperature = effective_temperature,
top_p = 1, top_p = 1,
top_k = 40, top_k = 40,
@ -153,14 +147,14 @@ class Processor(LlmService):
return resp return resp
except self.ResourceExhausted as e: except ResourceExhausted as e:
logger.warning("Rate limit exceeded") logger.warning("Rate limit exceeded")
# Leave rate limit retries to the default handler # Leave rate limit retries to the default handler
raise TooManyRequests() raise TooManyRequests()
except self.ClientError as e: except ClientError as e:
# google-genai SDK throws ClientError for 4xx errors # google-genai SDK throws ClientError for 4xx errors
if e.code == 429: if e.code == 429:
logger.warning(f"Rate limit exceeded (ClientError 429): {e}") logger.warning(f"Rate limit exceeded (ClientError 429): {e}")
@ -229,11 +223,11 @@ class Processor(LlmService):
logger.debug("Streaming complete") logger.debug("Streaming complete")
except self.ResourceExhausted: except ResourceExhausted:
logger.warning("Rate limit exceeded during streaming") logger.warning("Rate limit exceeded during streaming")
raise TooManyRequests() raise TooManyRequests()
except self.ClientError as e: except ClientError as e:
# google-genai SDK throws ClientError for 4xx errors # google-genai SDK throws ClientError for 4xx errors
if e.code == 429: if e.code == 429:
logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}") logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}")