mirror of
https://github.com/trustgraph-ai/trustgraph.git
synced 2026-04-25 00:16:23 +02:00
Batch embeddings (#668)
Base Service (trustgraph-base/trustgraph/base/embeddings_service.py): - Changed on_request to use request.texts FastEmbed Processor (trustgraph-flow/trustgraph/embeddings/fastembed/processor.py): - on_embeddings(texts, model=None) now processes full batch efficiently - Returns [[v.tolist()] for v in vecs] - list of vector sets Ollama Processor (trustgraph-flow/trustgraph/embeddings/ollama/processor.py): - on_embeddings(texts, model=None) passes list directly to Ollama - Returns [[embedding] for embedding in embeds.embeddings] EmbeddingsClient (trustgraph-base/trustgraph/base/embeddings_client.py): - embed(texts, timeout=300) accepts list of texts Tests Updated: - test_fastembed_dynamic_model.py - 4 tests updated for new interface - test_ollama_dynamic_model.py - 4 tests updated for new interface Updated CLI, SDK and APIs
This commit is contained in:
parent
3bf8a65409
commit
0a2ce47a88
16 changed files with 785 additions and 79 deletions
|
|
@ -544,30 +544,29 @@ class FlowInstance:
|
|||
input
|
||||
)["response"]
|
||||
|
||||
def embeddings(self, text):
|
||||
def embeddings(self, texts):
|
||||
"""
|
||||
Generate vector embeddings for text.
|
||||
Generate vector embeddings for one or more texts.
|
||||
|
||||
Converts text into dense vector representations suitable for semantic
|
||||
Converts texts into dense vector representations suitable for semantic
|
||||
search and similarity comparison.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
texts: List of input texts to embed
|
||||
|
||||
Returns:
|
||||
list[float]: Vector embedding
|
||||
list[list[list[float]]]: Vector embeddings, one set per input text
|
||||
|
||||
Example:
|
||||
```python
|
||||
flow = api.flow().id("default")
|
||||
vectors = flow.embeddings("quantum computing")
|
||||
print(f"Embedding dimension: {len(vectors)}")
|
||||
vectors = flow.embeddings(["quantum computing"])
|
||||
print(f"Embedding dimension: {len(vectors[0][0])}")
|
||||
```
|
||||
"""
|
||||
|
||||
# The input consists of a text block
|
||||
input = {
|
||||
"text": text
|
||||
"texts": texts
|
||||
}
|
||||
|
||||
return self.request(
|
||||
|
|
|
|||
|
|
@ -712,27 +712,27 @@ class SocketFlowInstance:
|
|||
|
||||
return self.client._send_request_sync("document-embeddings", self.flow_id, request, False)
|
||||
|
||||
def embeddings(self, text: str, **kwargs: Any) -> Dict[str, Any]:
|
||||
def embeddings(self, texts: list, **kwargs: Any) -> Dict[str, Any]:
|
||||
"""
|
||||
Generate vector embeddings for text.
|
||||
Generate vector embeddings for one or more texts.
|
||||
|
||||
Args:
|
||||
text: Input text to embed
|
||||
texts: List of input texts to embed
|
||||
**kwargs: Additional parameters passed to the service
|
||||
|
||||
Returns:
|
||||
dict: Response containing vectors
|
||||
dict: Response containing vectors (one set per input text)
|
||||
|
||||
Example:
|
||||
```python
|
||||
socket = api.socket()
|
||||
flow = socket.flow("default")
|
||||
|
||||
result = flow.embeddings("quantum computing")
|
||||
result = flow.embeddings(["quantum computing"])
|
||||
vectors = result.get("vectors", [])
|
||||
```
|
||||
"""
|
||||
request = {"text": text}
|
||||
request = {"texts": texts}
|
||||
request.update(kwargs)
|
||||
|
||||
return self.client._send_request_sync("embeddings", self.flow_id, request, False)
|
||||
|
|
|
|||
|
|
@ -3,11 +3,11 @@ from . request_response_spec import RequestResponse, RequestResponseSpec
|
|||
from .. schema import EmbeddingsRequest, EmbeddingsResponse
|
||||
|
||||
class EmbeddingsClient(RequestResponse):
|
||||
async def embed(self, text, timeout=30):
|
||||
async def embed(self, texts, timeout=300):
|
||||
|
||||
resp = await self.request(
|
||||
EmbeddingsRequest(
|
||||
text = text
|
||||
texts = texts
|
||||
),
|
||||
timeout=timeout
|
||||
)
|
||||
|
|
|
|||
|
|
@ -65,7 +65,7 @@ class EmbeddingsService(FlowProcessor):
|
|||
|
||||
# Pass model from request if specified (non-empty), otherwise use default
|
||||
model = flow("model")
|
||||
vectors = await self.on_embeddings(request.text, model=model)
|
||||
vectors = await self.on_embeddings(request.texts, model=model)
|
||||
|
||||
await flow("response").send(
|
||||
EmbeddingsResponse(
|
||||
|
|
@ -94,7 +94,7 @@ class EmbeddingsService(FlowProcessor):
|
|||
type = "embeddings-error",
|
||||
message = str(e),
|
||||
),
|
||||
vectors=None,
|
||||
vectors=[],
|
||||
),
|
||||
properties={"id": id}
|
||||
)
|
||||
|
|
|
|||
|
|
@ -5,15 +5,15 @@ from .base import MessageTranslator
|
|||
|
||||
class EmbeddingsRequestTranslator(MessageTranslator):
|
||||
"""Translator for EmbeddingsRequest schema objects"""
|
||||
|
||||
|
||||
def to_pulsar(self, data: Dict[str, Any]) -> EmbeddingsRequest:
|
||||
return EmbeddingsRequest(
|
||||
text=data["text"]
|
||||
texts=data["texts"]
|
||||
)
|
||||
|
||||
|
||||
def from_pulsar(self, obj: EmbeddingsRequest) -> Dict[str, Any]:
|
||||
return {
|
||||
"text": obj.text
|
||||
"texts": obj.texts
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -29,12 +29,12 @@ class TextCompletionResponse:
|
|||
|
||||
@dataclass
|
||||
class EmbeddingsRequest:
|
||||
text: str = ""
|
||||
texts: list[str] = field(default_factory=list)
|
||||
|
||||
@dataclass
|
||||
class EmbeddingsResponse:
|
||||
error: Error | None = None
|
||||
vectors: list[list[float]] = field(default_factory=list)
|
||||
vectors: list[list[list[float]]] = field(default_factory=list)
|
||||
|
||||
############################################################################
|
||||
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue