diff --git a/scripts/api.py b/scripts/api.py index a903fc4..abdba8a 100644 --- a/scripts/api.py +++ b/scripts/api.py @@ -9,7 +9,10 @@ from dotenv import load_dotenv import chromadb from sentence_transformers import SentenceTransformer import anthropic -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, Response, Depends, HTTPException +from fastapi.responses import FileResponse, JSONResponse +import secrets +import hashlib from fastapi.responses import FileResponse, JSONResponse from fastapi.staticfiles import StaticFiles from fastapi.middleware.cors import CORSMiddleware @@ -76,6 +79,26 @@ the memory file as ground truth for his present situation. Use web search automatically when current information is needed. Never re-brief on context that's already in memory or documents.""" +# Auth configuration +import os +SESSION_PASSWORD = os.getenv("AARON_AI_PASSWORD", "changeme") +SESSIONS: set = set() # In-memory session store + +def make_session_token() -> str: + return secrets.token_urlsafe(32) + +def hash_password(password: str) -> str: + return hashlib.sha256(password.encode()).hexdigest() + +def get_session(request: Request) -> str | None: + return request.cookies.get("aaronai_session") + +def require_auth(request: Request): + token = get_session(request) + if not token or token not in SESSIONS: + raise HTTPException(status_code=401, detail="Not authenticated") + return token + CV_SOURCES = ["Aaron Nelson CV 2024.pdf", "Aaron Nelson CV 2025.pdf", "Aaron Nelson - CV.docx"] def init_conversations_db(): @@ -271,16 +294,52 @@ def chat(user_message, conversation_id, settings): app = FastAPI() app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"]) +@app.post("/auth/login") +async def login(request: Request, response: Response): + data = await request.json() + password = data.get("password", "") + if hash_password(password) != hash_password(SESSION_PASSWORD): + raise HTTPException(status_code=401, detail="Invalid password") + token = make_session_token() + SESSIONS.add(token) + response.set_cookie( + key="aaronai_session", + value=token, + httponly=True, + secure=True, + samesite="lax", + max_age=60 * 60 * 24 * 30 + ) + response.body = b'{"ok": true}' + response.status_code = 200 + response.media_type = "application/json" + return response + +@app.post("/auth/logout") +async def logout(request: Request, response: Response): + token = get_session(request) + if token: + SESSIONS.discard(token) + response.delete_cookie("aaronai_session") + return JSONResponse({"ok": True}) + +@app.get("/auth/check") +async def check_auth(request: Request): + token = get_session(request) + if not token or token not in SESSIONS: + return JSONResponse({"authenticated": False}) + return JSONResponse({"authenticated": True}) + @app.get("/", response_class=FileResponse) async def index(): return FileResponse("/home/aaron/aaronai/static/index.html") @app.get("/api/settings") -async def get_settings(): +async def get_settings(auth: str = Depends(require_auth)): return JSONResponse(load_settings()) @app.post("/api/settings") -async def update_settings(request: Request): +async def update_settings(request: Request, auth: str = Depends(require_auth)): data = await request.json() settings = load_settings() settings.update(data) @@ -288,7 +347,7 @@ async def update_settings(request: Request): return JSONResponse(settings) @app.get("/api/conversations") -async def list_conversations(): +async def list_conversations(auth: str = Depends(require_auth)): conn = sqlite3.connect(CONVERSATIONS_DB) c = conn.cursor() c.execute('''SELECT id, title, created_at, updated_at, message_count @@ -301,14 +360,14 @@ async def list_conversations(): } for r in rows]) @app.post("/api/conversations") -async def new_conversation(request: Request): +async def new_conversation(request: Request, auth: str = Depends(require_auth)): data = await request.json() title = data.get("title", "New conversation") conv_id = create_conversation(title) return JSONResponse({"id": conv_id, "title": title}) @app.get("/api/conversations/{conv_id}/messages") -async def get_messages(conv_id: str): +async def get_messages(conv_id: str, auth: str = Depends(require_auth)): conn = sqlite3.connect(CONVERSATIONS_DB) c = conn.cursor() c.execute('''SELECT role, content, sources, timestamp FROM messages @@ -321,7 +380,7 @@ async def get_messages(conv_id: str): } for r in rows]) @app.patch("/api/conversations/{conv_id}") -async def rename_conversation(conv_id: str, request: Request): +async def rename_conversation(conv_id: str, request: Request, auth: str = Depends(require_auth)): data = await request.json() title = data.get("title", "") if not title: @@ -334,7 +393,7 @@ async def rename_conversation(conv_id: str, request: Request): return JSONResponse({"id": conv_id, "title": title}) @app.delete("/api/conversations/{conv_id}") -async def delete_conversation(conv_id: str): +async def delete_conversation(conv_id: str, auth: str = Depends(require_auth)): conn = sqlite3.connect(CONVERSATIONS_DB) c = conn.cursor() c.execute("DELETE FROM messages WHERE conversation_id = ?", (conv_id,)) @@ -344,7 +403,7 @@ async def delete_conversation(conv_id: str): return JSONResponse({"deleted": conv_id}) @app.post("/api/chat") -async def chat_endpoint(request: Request): +async def chat_endpoint(request: Request, auth: str = Depends(require_auth)): data = await request.json() user_message = data.get("message", "").strip() conversation_id = data.get("conversation_id", "") @@ -402,18 +461,18 @@ async def chat_endpoint(request: Request): }) @app.get("/api/memory") -async def get_memory(): +async def get_memory(auth: str = Depends(require_auth)): return JSONResponse({"content": load_memory()}) @app.post("/api/memory") -async def update_memory(request: Request): +async def update_memory(request: Request, auth: str = Depends(require_auth)): data = await request.json() content = data.get("content", "") save_memory(content) return JSONResponse({"saved": True}) @app.get("/api/status") -async def get_status(): +async def get_status(auth: str = Depends(require_auth)): chunk_count = collection.count() # Watcher status @@ -475,7 +534,7 @@ async def get_status(): }) @app.post("/api/reindex") -async def trigger_reindex(): +async def trigger_reindex(auth: str = Depends(require_auth)): try: subprocess.Popen([PYTHON, INGEST_SCRIPT, NEXTCLOUD_PATH]) return JSONResponse({"started": True, "message": "Re-indexing started in background"}) @@ -483,7 +542,7 @@ async def trigger_reindex(): return JSONResponse({"started": False, "error": str(e)}) @app.delete("/api/conversations") -async def clear_all_conversations(): +async def clear_all_conversations(auth: str = Depends(require_auth)): conn = sqlite3.connect(CONVERSATIONS_DB) c = conn.cursor() c.execute("DELETE FROM messages")