mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
This commit is contained in:
parent
290922858f
commit
b341bf5ea1
5 changed files with 39 additions and 31 deletions
|
|
@ -7,8 +7,6 @@ Input is text, output is embeddings vector.
|
||||||
import logging
|
import logging
|
||||||
from ... base import EmbeddingsService
|
from ... base import EmbeddingsService
|
||||||
|
|
||||||
from langchain_huggingface import HuggingFaceEmbeddings
|
|
||||||
|
|
||||||
# Module logger
|
# Module logger
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -38,6 +36,7 @@ class Processor(EmbeddingsService):
|
||||||
def _load_model(self, model_name):
|
def _load_model(self, model_name):
|
||||||
"""Load a model, caching it for reuse"""
|
"""Load a model, caching it for reuse"""
|
||||||
if self.cached_model_name != model_name:
|
if self.cached_model_name != model_name:
|
||||||
|
from langchain_huggingface import HuggingFaceEmbeddings
|
||||||
logger.info(f"Loading HuggingFace embeddings model: {model_name}")
|
logger.info(f"Loading HuggingFace embeddings model: {model_name}")
|
||||||
self.embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
self.embeddings = HuggingFaceEmbeddings(model_name=model_name)
|
||||||
self.cached_model_name = model_name
|
self.cached_model_name = model_name
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ as text as separate output objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
|
||||||
from prometheus_client import Histogram
|
from prometheus_client import Histogram
|
||||||
|
|
||||||
from ... schema import TextDocument, Chunk, Metadata, Triples
|
from ... schema import TextDocument, Chunk, Metadata, Triples
|
||||||
|
|
@ -42,6 +41,9 @@ class Processor(ChunkingService):
|
||||||
self.default_chunk_size = chunk_size
|
self.default_chunk_size = chunk_size
|
||||||
self.default_chunk_overlap = chunk_overlap
|
self.default_chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
from langchain_text_splitters import RecursiveCharacterTextSplitter
|
||||||
|
self.RecursiveCharacterTextSplitter = RecursiveCharacterTextSplitter
|
||||||
|
|
||||||
if not hasattr(__class__, "chunk_metric"):
|
if not hasattr(__class__, "chunk_metric"):
|
||||||
__class__.chunk_metric = Histogram(
|
__class__.chunk_metric = Histogram(
|
||||||
'chunk_size', 'Chunk size',
|
'chunk_size', 'Chunk size',
|
||||||
|
|
@ -50,7 +52,7 @@ class Processor(ChunkingService):
|
||||||
2500, 4000, 6400, 10000, 16000]
|
2500, 4000, 6400, 10000, 16000]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.text_splitter = RecursiveCharacterTextSplitter(
|
self.text_splitter = self.RecursiveCharacterTextSplitter(
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
length_function=len,
|
length_function=len,
|
||||||
|
|
@ -103,7 +105,7 @@ class Processor(ChunkingService):
|
||||||
chunk_overlap = int(chunk_overlap)
|
chunk_overlap = int(chunk_overlap)
|
||||||
|
|
||||||
# Create text splitter with effective parameters
|
# Create text splitter with effective parameters
|
||||||
text_splitter = RecursiveCharacterTextSplitter(
|
text_splitter = self.RecursiveCharacterTextSplitter(
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
length_function=len,
|
length_function=len,
|
||||||
|
|
|
||||||
|
|
@ -5,7 +5,6 @@ as text as separate output objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
from langchain_text_splitters import TokenTextSplitter
|
|
||||||
from prometheus_client import Histogram
|
from prometheus_client import Histogram
|
||||||
|
|
||||||
from ... schema import TextDocument, Chunk, Metadata, Triples
|
from ... schema import TextDocument, Chunk, Metadata, Triples
|
||||||
|
|
@ -42,6 +41,9 @@ class Processor(ChunkingService):
|
||||||
self.default_chunk_size = chunk_size
|
self.default_chunk_size = chunk_size
|
||||||
self.default_chunk_overlap = chunk_overlap
|
self.default_chunk_overlap = chunk_overlap
|
||||||
|
|
||||||
|
from langchain_text_splitters import TokenTextSplitter
|
||||||
|
self.TokenTextSplitter = TokenTextSplitter
|
||||||
|
|
||||||
if not hasattr(__class__, "chunk_metric"):
|
if not hasattr(__class__, "chunk_metric"):
|
||||||
__class__.chunk_metric = Histogram(
|
__class__.chunk_metric = Histogram(
|
||||||
'chunk_size', 'Chunk size',
|
'chunk_size', 'Chunk size',
|
||||||
|
|
@ -50,7 +52,7 @@ class Processor(ChunkingService):
|
||||||
2500, 4000, 6400, 10000, 16000]
|
2500, 4000, 6400, 10000, 16000]
|
||||||
)
|
)
|
||||||
|
|
||||||
self.text_splitter = TokenTextSplitter(
|
self.text_splitter = self.TokenTextSplitter(
|
||||||
encoding_name="cl100k_base",
|
encoding_name="cl100k_base",
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
|
|
@ -102,7 +104,7 @@ class Processor(ChunkingService):
|
||||||
chunk_overlap = int(chunk_overlap)
|
chunk_overlap = int(chunk_overlap)
|
||||||
|
|
||||||
# Create text splitter with effective parameters
|
# Create text splitter with effective parameters
|
||||||
text_splitter = TokenTextSplitter(
|
text_splitter = self.TokenTextSplitter(
|
||||||
encoding_name="cl100k_base",
|
encoding_name="cl100k_base",
|
||||||
chunk_size=chunk_size,
|
chunk_size=chunk_size,
|
||||||
chunk_overlap=chunk_overlap,
|
chunk_overlap=chunk_overlap,
|
||||||
|
|
|
||||||
|
|
@ -11,13 +11,10 @@ import os
|
||||||
import tempfile
|
import tempfile
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
from langchain_community.document_loaders import PyPDFLoader
|
|
||||||
|
|
||||||
from ... schema import Document, TextDocument, Metadata
|
from ... schema import Document, TextDocument, Metadata
|
||||||
from ... schema import librarian_request_queue, librarian_response_queue
|
from ... schema import librarian_request_queue, librarian_response_queue
|
||||||
from ... schema import Triples
|
from ... schema import Triples
|
||||||
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
from ... base import FlowProcessor, ConsumerSpec, ProducerSpec, LibrarianClient
|
||||||
|
|
||||||
from ... provenance import (
|
from ... provenance import (
|
||||||
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
document_uri, page_uri as make_page_uri, derived_entity_triples,
|
||||||
set_graph, GRAPH_SOURCE,
|
set_graph, GRAPH_SOURCE,
|
||||||
|
|
@ -131,6 +128,7 @@ class Processor(FlowProcessor):
|
||||||
fp.write(base64.b64decode(v.data))
|
fp.write(base64.b64decode(v.data))
|
||||||
fp.close()
|
fp.close()
|
||||||
|
|
||||||
|
from langchain_community.document_loaders import PyPDFLoader
|
||||||
loader = PyPDFLoader(temp_path)
|
loader = PyPDFLoader(temp_path)
|
||||||
pages = loader.load()
|
pages = loader.load()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -12,11 +12,6 @@ Input is prompt, output is response.
|
||||||
# TrustGraph implements in the trustgraph-vertexai package.
|
# TrustGraph implements in the trustgraph-vertexai package.
|
||||||
#
|
#
|
||||||
|
|
||||||
from google import genai
|
|
||||||
from google.genai import types
|
|
||||||
from google.genai.types import HarmCategory, HarmBlockThreshold
|
|
||||||
from google.genai.errors import ClientError
|
|
||||||
from google.api_core.exceptions import ResourceExhausted
|
|
||||||
import os
|
import os
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
|
@ -42,6 +37,18 @@ class Processor(LlmService):
|
||||||
temperature = params.get("temperature", default_temperature)
|
temperature = params.get("temperature", default_temperature)
|
||||||
max_output = params.get("max_output", default_max_output)
|
max_output = params.get("max_output", default_max_output)
|
||||||
|
|
||||||
|
from google import genai
|
||||||
|
from google.genai import types
|
||||||
|
from google.genai.types import HarmCategory, HarmBlockThreshold
|
||||||
|
from google.genai.errors import ClientError
|
||||||
|
from google.api_core.exceptions import ResourceExhausted
|
||||||
|
self.genai = genai
|
||||||
|
self.types = types
|
||||||
|
self.HarmCategory = HarmCategory
|
||||||
|
self.HarmBlockThreshold = HarmBlockThreshold
|
||||||
|
self.ClientError = ClientError
|
||||||
|
self.ResourceExhausted = ResourceExhausted
|
||||||
|
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise RuntimeError("Google AI Studio API key not specified")
|
raise RuntimeError("Google AI Studio API key not specified")
|
||||||
|
|
||||||
|
|
@ -53,7 +60,7 @@ class Processor(LlmService):
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
self.client = genai.Client(api_key=api_key, vertexai=False)
|
self.client = self.genai.Client(api_key=api_key, vertexai=False)
|
||||||
self.default_model = model
|
self.default_model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.max_output = max_output
|
self.max_output = max_output
|
||||||
|
|
@ -61,23 +68,23 @@ class Processor(LlmService):
|
||||||
# Cache for generation configs per model
|
# Cache for generation configs per model
|
||||||
self.generation_configs = {}
|
self.generation_configs = {}
|
||||||
|
|
||||||
block_level = HarmBlockThreshold.BLOCK_ONLY_HIGH
|
block_level = self.HarmBlockThreshold.BLOCK_ONLY_HIGH
|
||||||
|
|
||||||
self.safety_settings = [
|
self.safety_settings = [
|
||||||
types.SafetySetting(
|
self.types.SafetySetting(
|
||||||
category = HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
category = self.HarmCategory.HARM_CATEGORY_HATE_SPEECH,
|
||||||
threshold = block_level,
|
threshold = block_level,
|
||||||
),
|
),
|
||||||
types.SafetySetting(
|
self.types.SafetySetting(
|
||||||
category = HarmCategory.HARM_CATEGORY_HARASSMENT,
|
category = self.HarmCategory.HARM_CATEGORY_HARASSMENT,
|
||||||
threshold = block_level,
|
threshold = block_level,
|
||||||
),
|
),
|
||||||
types.SafetySetting(
|
self.types.SafetySetting(
|
||||||
category = HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
category = self.HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT,
|
||||||
threshold = block_level,
|
threshold = block_level,
|
||||||
),
|
),
|
||||||
types.SafetySetting(
|
self.types.SafetySetting(
|
||||||
category = HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
category = self.HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT,
|
||||||
threshold = block_level,
|
threshold = block_level,
|
||||||
),
|
),
|
||||||
# There is a documentation conflict on whether or not
|
# There is a documentation conflict on whether or not
|
||||||
|
|
@ -97,7 +104,7 @@ class Processor(LlmService):
|
||||||
|
|
||||||
if cache_key not in self.generation_configs:
|
if cache_key not in self.generation_configs:
|
||||||
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
|
logger.info(f"Creating generation config for '{model_name}' with temperature {effective_temperature}")
|
||||||
self.generation_configs[cache_key] = types.GenerateContentConfig(
|
self.generation_configs[cache_key] = self.types.GenerateContentConfig(
|
||||||
temperature = effective_temperature,
|
temperature = effective_temperature,
|
||||||
top_p = 1,
|
top_p = 1,
|
||||||
top_k = 40,
|
top_k = 40,
|
||||||
|
|
@ -146,14 +153,14 @@ class Processor(LlmService):
|
||||||
|
|
||||||
return resp
|
return resp
|
||||||
|
|
||||||
except ResourceExhausted as e:
|
except self.ResourceExhausted as e:
|
||||||
|
|
||||||
logger.warning("Rate limit exceeded")
|
logger.warning("Rate limit exceeded")
|
||||||
|
|
||||||
# Leave rate limit retries to the default handler
|
# Leave rate limit retries to the default handler
|
||||||
raise TooManyRequests()
|
raise TooManyRequests()
|
||||||
|
|
||||||
except ClientError as e:
|
except self.ClientError as e:
|
||||||
# google-genai SDK throws ClientError for 4xx errors
|
# google-genai SDK throws ClientError for 4xx errors
|
||||||
if e.code == 429:
|
if e.code == 429:
|
||||||
logger.warning(f"Rate limit exceeded (ClientError 429): {e}")
|
logger.warning(f"Rate limit exceeded (ClientError 429): {e}")
|
||||||
|
|
@ -222,11 +229,11 @@ class Processor(LlmService):
|
||||||
|
|
||||||
logger.debug("Streaming complete")
|
logger.debug("Streaming complete")
|
||||||
|
|
||||||
except ResourceExhausted:
|
except self.ResourceExhausted:
|
||||||
logger.warning("Rate limit exceeded during streaming")
|
logger.warning("Rate limit exceeded during streaming")
|
||||||
raise TooManyRequests()
|
raise TooManyRequests()
|
||||||
|
|
||||||
except ClientError as e:
|
except self.ClientError as e:
|
||||||
# google-genai SDK throws ClientError for 4xx errors
|
# google-genai SDK throws ClientError for 4xx errors
|
||||||
if e.code == 429:
|
if e.code == 429:
|
||||||
logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}")
|
logger.warning(f"Rate limit exceeded during streaming (ClientError 429): {e}")
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue