134 lines
4.8 KiB
Python
134 lines
4.8 KiB
Python
"""
|
|
Aaron AI — Graphiti Sidecar Service
|
|
Wraps graphiti-core in a FastAPI service to avoid asyncio event loop conflicts.
|
|
Port 8001 (internal only). No OpenAI dependency.
|
|
"""
|
|
|
|
import os, logging, sys
|
|
from contextlib import asynccontextmanager
|
|
from datetime import datetime
|
|
from pathlib import Path
|
|
|
|
from dotenv import load_dotenv
|
|
from fastapi import FastAPI, HTTPException
|
|
from pydantic import BaseModel
|
|
|
|
load_dotenv(Path.home() / "aaronai" / ".env")
|
|
|
|
logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
|
log = logging.getLogger("graphiti-sidecar")
|
|
|
|
GROUP_ID = os.getenv("GRAPHITI_GROUP_ID", "aaron")
|
|
FALKORDB_HOST = os.getenv("FALKORDB_HOST", "localhost")
|
|
FALKORDB_PORT = int(os.getenv("FALKORDB_PORT", "6379"))
|
|
LLM_PROVIDER = os.getenv("LLM_PROVIDER", "anthropic")
|
|
LLM_MODEL = os.getenv("LLM_MODEL", "claude-sonnet-4-6")
|
|
LLM_API_KEY = os.getenv("LLM_API_KEY") or os.getenv("ANTHROPIC_API_KEY")
|
|
os.environ["EMBEDDING_DIM"] = "384"
|
|
|
|
def get_llm_client():
|
|
from graphiti_core.llm_client.config import LLMConfig
|
|
config = LLMConfig(api_key=LLM_API_KEY, model=LLM_MODEL)
|
|
if LLM_PROVIDER == "anthropic":
|
|
from graphiti_core.llm_client.anthropic_client import AnthropicClient
|
|
return AnthropicClient(config)
|
|
elif LLM_PROVIDER == "openai":
|
|
from graphiti_core.llm_client.openai_client import OpenAIClient
|
|
return OpenAIClient(config)
|
|
elif LLM_PROVIDER == "gemini":
|
|
from graphiti_core.llm_client.gemini_client import GeminiClient
|
|
return GeminiClient(config)
|
|
elif LLM_PROVIDER == "groq":
|
|
from graphiti_core.llm_client.groq_client import GroqClient
|
|
return GroqClient(config)
|
|
raise ValueError(f"Unsupported LLM provider: {LLM_PROVIDER}")
|
|
|
|
graphiti_instance = None
|
|
|
|
async def get_graphiti():
|
|
if graphiti_instance is None:
|
|
raise HTTPException(status_code=503, detail="Graphiti not initialized")
|
|
return graphiti_instance
|
|
|
|
@asynccontextmanager
|
|
async def lifespan(app: FastAPI):
|
|
global graphiti_instance
|
|
sys.path.insert(0, str(Path.home() / "aaronai" / "scripts"))
|
|
log.info("Loading embedding and reranker models...")
|
|
from st_embedder import SentenceTransformerEmbedder
|
|
from graphiti_core.cross_encoder.bge_reranker_client import BGERerankerClient
|
|
from graphiti_core.driver.falkordb_driver import FalkorDriver
|
|
from graphiti_core import Graphiti
|
|
log.info(f"Connecting to FalkorDB at {FALKORDB_HOST}:{FALKORDB_PORT}...")
|
|
graphiti_instance = Graphiti(
|
|
llm_client=get_llm_client(),
|
|
embedder=SentenceTransformerEmbedder(),
|
|
cross_encoder=BGERerankerClient(),
|
|
graph_driver=FalkorDriver(host=FALKORDB_HOST, port=FALKORDB_PORT),
|
|
)
|
|
await graphiti_instance.build_indices_and_constraints()
|
|
log.info(f"Graphiti ready — provider: {LLM_PROVIDER}, group: {GROUP_ID}")
|
|
yield
|
|
await graphiti_instance.close()
|
|
|
|
app = FastAPI(title="Aaron AI Graphiti Sidecar", lifespan=lifespan)
|
|
|
|
class EpisodeRequest(BaseModel):
|
|
name: str
|
|
content: str
|
|
source_description: str = ""
|
|
timestamp: str | None = None
|
|
group_id: str | None = None
|
|
|
|
@app.get("/health")
|
|
async def health():
|
|
return {"ok": True, "provider": LLM_PROVIDER, "group": GROUP_ID}
|
|
|
|
@app.post("/episodes")
|
|
async def add_episode(req: EpisodeRequest):
|
|
g = await get_graphiti()
|
|
from graphiti_core.nodes import EpisodeType
|
|
try:
|
|
ref_time = datetime.fromisoformat(req.timestamp) if req.timestamp else datetime.now()
|
|
await g.add_episode(
|
|
name=req.name,
|
|
episode_body=req.content,
|
|
source=EpisodeType.text,
|
|
reference_time=ref_time,
|
|
source_description=req.source_description,
|
|
group_id=req.group_id or GROUP_ID,
|
|
)
|
|
return {"ok": True}
|
|
except Exception as e:
|
|
log.error(f"Episode ingestion failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
@app.get("/search")
|
|
async def search(query: str, limit: int = 8, group_id: str | None = None):
|
|
g = await get_graphiti()
|
|
try:
|
|
results = await g.search(
|
|
query=query,
|
|
num_results=limit,
|
|
group_ids=[group_id or GROUP_ID],
|
|
)
|
|
return {
|
|
"results": [
|
|
{
|
|
"fact": r.fact,
|
|
"source": getattr(r, "source_node_uuid", ""),
|
|
"score": getattr(r, "score", 0),
|
|
"valid_at": str(getattr(r, "valid_at", "")),
|
|
"invalid_at": str(getattr(r, "invalid_at", "")),
|
|
}
|
|
for r in results
|
|
]
|
|
}
|
|
except Exception as e:
|
|
log.error(f"Search failed: {e}")
|
|
raise HTTPException(status_code=500, detail=str(e))
|
|
|
|
if __name__ == "__main__":
|
|
import uvicorn
|
|
uvicorn.run(app, host="127.0.0.1", port=8001, log_level="info")
|