30 lines
1.2 KiB
Python
30 lines
1.2 KiB
Python
"""
|
|
Custom SentenceTransformer embedder for Graphiti.
|
|
Implements EmbedderClient interface using local all-MiniLM-L6-v2.
|
|
No API cost, no external dependency, consistent with existing ingest pipeline.
|
|
"""
|
|
from collections.abc import Iterable
|
|
from graphiti_core.embedder.client import EmbedderClient, EmbedderConfig
|
|
|
|
class SentenceTransformerEmbedderConfig(EmbedderConfig):
|
|
model_name: str = "all-MiniLM-L6-v2"
|
|
embedding_dim: int = 384
|
|
|
|
class SentenceTransformerEmbedder(EmbedderClient):
|
|
def __init__(self, config: SentenceTransformerEmbedderConfig | None = None):
|
|
self.config = config or SentenceTransformerEmbedderConfig()
|
|
from sentence_transformers import SentenceTransformer
|
|
self._model = SentenceTransformer(self.config.model_name)
|
|
|
|
async def create(
|
|
self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]]
|
|
) -> list[float]:
|
|
if isinstance(input_data, str):
|
|
input_data = [input_data]
|
|
embeddings = self._model.encode(list(input_data))
|
|
return embeddings[0].tolist()
|
|
|
|
async def create_batch(self, input_data_list: list[str]) -> list[list[float]]:
|
|
embeddings = self._model.encode(input_data_list)
|
|
return [e.tolist() for e in embeddings]
|