diff --git a/scripts/st_embedder.py b/scripts/st_embedder.py new file mode 100644 index 0000000..f5db73a --- /dev/null +++ b/scripts/st_embedder.py @@ -0,0 +1,29 @@ +""" +Custom SentenceTransformer embedder for Graphiti. +Implements EmbedderClient interface using local all-MiniLM-L6-v2. +No API cost, no external dependency, consistent with existing ingest pipeline. +""" +from collections.abc import Iterable +from graphiti_core.embedder.client import EmbedderClient, EmbedderConfig + +class SentenceTransformerEmbedderConfig(EmbedderConfig): + model_name: str = "all-MiniLM-L6-v2" + embedding_dim: int = 384 + +class SentenceTransformerEmbedder(EmbedderClient): + def __init__(self, config: SentenceTransformerEmbedderConfig | None = None): + self.config = config or SentenceTransformerEmbedderConfig() + from sentence_transformers import SentenceTransformer + self._model = SentenceTransformer(self.config.model_name) + + async def create( + self, input_data: str | list[str] | Iterable[int] | Iterable[Iterable[int]] + ) -> list[float]: + if isinstance(input_data, str): + input_data = [input_data] + embeddings = self._model.encode(list(input_data)) + return embeddings[0].tolist() + + async def create_batch(self, input_data_list: list[str]) -> list[list[float]]: + embeddings = self._model.encode(input_data_list) + return [e.tolist() for e in embeddings]