diff --git a/scripts/ingest_conversations.py b/scripts/ingest_conversations.py index 3ad7100..694bd57 100644 --- a/scripts/ingest_conversations.py +++ b/scripts/ingest_conversations.py @@ -18,8 +18,14 @@ CONVERSATIONS_DB = str(Path.home() / "aaronai" / "conversations.db") PG_DSN = os.getenv("PG_DSN") MIN_EXCHANGES = 3 -print("Loading embedding model...") -embedder = SentenceTransformer("all-MiniLM-L6-v2") +_embedder = None + +def get_embedder(): + global _embedder + if _embedder is None: + print("Loading embedding model...") + _embedder = SentenceTransformer("all-MiniLM-L6-v2") + return _embedder def get_conversations(): conn = sqlite3.connect(CONVERSATIONS_DB) @@ -123,7 +129,7 @@ def run(): # Embed and insert texts = [c[1] for c in new_chunks] - embeddings = embedder.encode(texts, show_progress_bar=False).tolist() + embeddings = get_embedder().encode(texts, show_progress_bar=False).tolist() for (chunk_id, chunk_text, meta), embedding in zip(new_chunks, embeddings): if not meta.get("type"):