fix for issue #821: deferring optional SDK imports to runtime for provider modules (#828)

This commit is contained in:
Syed Ishmum Ahnaf 2026-04-18 16:14:52 +06:00 committed by GitHub
parent 290922858f
commit b341bf5ea1
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 39 additions and 31 deletions

View file

@ -7,8 +7,6 @@ Input is text, output is embeddings vector.
import logging import logging
from ... base import EmbeddingsService from ... base import EmbeddingsService
from langchain_huggingface import HuggingFaceEmbeddings
# Module logger # Module logger
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
@ -38,6 +36,7 @@ class Processor(EmbeddingsService):
def _load_model(self, model_name): def _load_model(self, model_name):
"""Load a model, caching it for reuse""" """Load a model, caching it for reuse"""
if self.cached_model_name != model_name: if self.cached_model_name != model_name:
from langchain_huggingface import HuggingFaceEmbeddings
logger.info(f"Loading HuggingFace embeddings model: {model_name}") logger.info(f"Loading HuggingFace embeddings model: {model_name}")
self.embeddings = HuggingFaceEmbeddings(model_name=model_name) self.embeddings = HuggingFaceEmbeddings(model_name=model_name)
self.cached_model_name = model_name self.cached_model_name = model_name

View file

@ -5,7 +5,6 @@ as text as separate output objects.
""" """
import logging import logging
from langchain_text_splitters import RecursiveCharacterTextSplitter
from prometheus_client import Histogram from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata, Triples from ... schema import TextDocument, Chunk, Metadata, Triples
@ -42,6 +41,9 @@ class Processor(ChunkingService):
self.default_chunk_size = chunk_size self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap self.default_chunk_overlap = chunk_overlap
from langchain_text_splitters import RecursiveCharacterTextSplitter
self.RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter
if not hasattr(__class__, "chunk_metric"): if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram( __class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size', 'chunk_size', 'Chunk size',
@ -50,7 +52,7 @@ class Processor(ChunkingService):
2500, 4000, 6400, 10000, 16000] 2500, 4000, 6400, 10000, 16000]
) )
self.text_splitter = RecursiveCharacterTextSplitter( self.text_splitter = self.RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
length_function=len, length_function=len,
@ -103,7 +105,7 @@ class Processor(ChunkingService):
chunk_overlap = int(chunk_overlap) chunk_overlap = int(chunk_overlap)
# Create text splitter with effective parameters # Create text splitter with effective parameters
text_splitter = RecursiveCharacterTextSplitter( text_splitter = self.RecursiveCharacterTextSplitter(
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
length_function=len, length_function=len,

View file

@ -5,7 +5,6 @@ as text as separate output objects.
""" """
import logging import logging
from langchain_text_splitters import TokenTextSplitter
from prometheus_client import Histogram from prometheus_client import Histogram
from ... schema import TextDocument, Chunk, Metadata, Triples from ... schema import TextDocument, Chunk, Metadata, Triples
@ -42,6 +41,9 @@ class Processor(ChunkingService):
self.default_chunk_size = chunk_size self.default_chunk_size = chunk_size
self.default_chunk_overlap = chunk_overlap self.default_chunk_overlap = chunk_overlap
from langchain_text_splitters import TokenTextSplitter
self.TokenTextSplitter = TokenTextSplitter
if not hasattr(__class__, "chunk_metric"): if not hasattr(__class__, "chunk_metric"):
__class__.chunk_metric = Histogram( __class__.chunk_metric = Histogram(
'chunk_size', 'Chunk size', 'chunk_size', 'Chunk size',
@ -50,7 +52,7 @@ class Processor(ChunkingService):
2500, 4000, 6400, 10000, 16000] 2500, 4000, 6400, 10000, 16000]
) )
self.text_splitter = TokenTextSplitter( self.text_splitter = self.TokenTextSplitter(
encoding_name="cl100k_base", encoding_name="cl100k_base",
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,
@ -102,7 +104,7 @@ class Processor(ChunkingService):
chunk_overlap = int(chunk_overlap) chunk_overlap = int(chunk_overlap)
# Create text splitter with effective parameters # Create text splitter with effective parameters
text_splitter = TokenTextSplitter( text_splitter = self.TokenTextSplitter(
encoding_name="cl100k_base", encoding_name="cl100k_base",
chunk_size=chunk_size, chunk_size=chunk_size,
chunk_overlap=chunk_overlap, chunk_overlap=chunk_overlap,

View file

@ -11,13 +11,10 @@ import os
import tempfile import tempfile
import base64 import base64
import logging import logging
from langchain_community.document_loaders import PyPDFLoader
from ... schema import Document, TextDocument, Metadata from ... schema import Document, TextDocument, Metadata
from ... schema import librarian_request_queue, librarian_response_queue from ... schema import librarian_request_queue, librarian_response_queue
from ... schema import Triples from ... schema import Triples
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
from ... provenance import ( from ... provenance import (
document_uri, page_uri as make_page_uri, derived_entity_triples, document_uri, page_uri as make_page_uri, derived_entity_triples,
set_graph, GRAPH_SOURCE, set_graph, GRAPH_SOURCE,
@ -131,6 +128,7 @@ class Processor(FlowProcessor):
fp.write(base64.b64decode(v.data)) fp.write(base64.b64decode(v.data))
fp.close() fp.close()
from langchain_community.document_loaders import PyPDFLoader
loader = PyPDFLoader(temp_path) loader = PyPDFLoader(temp_path)
pages = loader.load() pages = loader.load()

View file

@ -12,11 +12,6 @@ Input is prompt, output is response.
# TrustGraph implements in the trustgraph-vertexai package. # TrustGraph implements in the trustgraph-vertexai package.
# #
from google import genai
from google.genai import types
from google.genai.types import HarmCategory, HarmBlockThreshold
from google.genai.errors import ClientError
from google.api_core.exceptions import ResourceExhausted
import os import os
import logging import logging
@ -42,6 +37,18 @@ class Processor(LlmService):
temperature = params.get("temperature", default_temperature) temperature = params.get("temperature", default_temperature)
max_output = params.get("max_output", default_max_output) max_output = params.get("max_output", default_max_output)
from google import genai
from google.genai import types
from google.genai.types import HarmCategory, HarmBlockThreshold
from google.genai.errors import ClientError
from google.api_core.exceptions import ResourceExhausted
self.genai = genai
self.types = types
self.HarmCategory = HarmCategory
self.HarmBlockThreshold = HarmBlockThreshold
self.ClientError = ClientError
self.ResourceExhausted = ResourceExhausted
if api_key is None: if api_key is None:
raise RuntimeError("Google AI Studio API key not specified") raise RuntimeError("Google AI Studio API key not specified")
@ -53,7 +60,7 @@ class Processor(LlmService):
} }
) )
self.client = genai.Client(api_key=api_key, vertexai=False) self.client = self.genai.Client(api_key=api_key, vertexai=False)
self.default_model = model self.default_model = model
self.temperature = temperature self.temperature = temperature
self.max_output = max_output self.max_output = max_output
@ -61,23 +68,23 @@ class Processor(LlmService):
# Cache for generation configs per model # Cache for generation configs per model
self.generation_configs = {} self.generation_configs = {}
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH block_level = self.HarmBlockThreshold.BLOCK_ONLY_HIGH
self.safety_settings = [ self.safety_settings = [
types.SafetySetting( self.types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH, category = self.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
threshold = block_level, threshold = block_level,
), ),
types.SafetySetting( self.types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_HARASSMENT, category = self.HarmCategory.HARM_CATEGORY_HARASSMENT,
threshold = block_level, threshold = block_level,
), ),
types.SafetySetting( self.types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT, category = self.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
threshold = block_level, threshold = block_level,
), ),
types.SafetySetting( self.types.SafetySetting(
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT, category = self.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
threshold = block_level, threshold = block_level,
), ),
# There is a documentation conflict on whether or not # There is a documentation conflict on whether or not
@ -97,7 +104,7 @@ class Processor(LlmService):
if cache_key not in self.generation_configs: if cache_key not in self.generation_configs:
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}") logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
self.generation_configs[cache_key] = types.GenerateContentConfig( self.generation_configs[cache_key] = self.types.GenerateContentConfig(
temperature = effective_temperature, temperature = effective_temperature,
top_p = 1, top_p = 1,
top_k = 40, top_k = 40,
@ -146,14 +153,14 @@ class Processor(LlmService):
return resp return resp
except ResourceExhausted as e: except self.ResourceExhausted as e:
logger.warning("Rate limit exceeded") logger.warning("Rate limit exceeded")
# Leave rate limit retries to the default handler # Leave rate limit retries to the default handler
raise TooManyRequests() raise TooManyRequests()
except ClientError as e: except self.ClientError as e:
# google-genai SDK throws ClientError for 4xx errors # google-genai SDK throws ClientError for 4xx errors
if e.code == 429: if e.code == 429:
logger.warning(f"Rate limit exceeded (ClientError 429): {e}") logger.warning(f"Rate limit exceeded (ClientError 429): {e}")
@ -222,11 +229,11 @@ class Processor(LlmService):
logger.debug("Streaming complete") logger.debug("Streaming complete")
except ResourceExhausted: except self.ResourceExhausted:
logger.warning("Rate limit exceeded during streaming") logger.warning("Rate limit exceeded during streaming")
raise TooManyRequests() raise TooManyRequests()
except ClientError as e: except self.ClientError as e:
# google-genai SDK throws ClientError for 4xx errors # google-genai SDK throws ClientError for 4xx errors
if e.code == 429: if e.code == 429:
logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}") logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}")