Add session-based auth — replace Basic Auth with httpOnly cookie, 30-day expiry
This commit is contained in:
+73
-14
@@ -9,7 +9,10 @@ from dotenv import load_dotenv
|
||||
import chromadb
|
||||
from sentence_transformers import SentenceTransformer
|
||||
import anthropic
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, Response, Depends, HTTPException
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
import secrets
|
||||
import hashlib
|
||||
from fastapi.responses import FileResponse, JSONResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
@@ -76,6 +79,26 @@ the memory file as ground truth for his present situation. Use web
|
||||
search automatically when current information is needed. Never
|
||||
re-brief on context that's already in memory or documents."""
|
||||
|
||||
# Auth configuration
|
||||
import os
|
||||
SESSION_PASSWORD = os.getenv("AARON_AI_PASSWORD", "changeme")
|
||||
SESSIONS: set = set() # In-memory session store
|
||||
|
||||
def make_session_token() -> str:
|
||||
return secrets.token_urlsafe(32)
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
return hashlib.sha256(password.encode()).hexdigest()
|
||||
|
||||
def get_session(request: Request) -> str | None:
|
||||
return request.cookies.get("aaronai_session")
|
||||
|
||||
def require_auth(request: Request):
|
||||
token = get_session(request)
|
||||
if not token or token not in SESSIONS:
|
||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||
return token
|
||||
|
||||
CV_SOURCES = ["Aaron Nelson CV 2024.pdf", "Aaron Nelson CV 2025.pdf", "Aaron Nelson - CV.docx"]
|
||||
|
||||
def init_conversations_db():
|
||||
@@ -271,16 +294,52 @@ def chat(user_message, conversation_id, settings):
|
||||
app = FastAPI()
|
||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||
|
||||
@app.post("/auth/login")
|
||||
async def login(request: Request, response: Response):
|
||||
data = await request.json()
|
||||
password = data.get("password", "")
|
||||
if hash_password(password) != hash_password(SESSION_PASSWORD):
|
||||
raise HTTPException(status_code=401, detail="Invalid password")
|
||||
token = make_session_token()
|
||||
SESSIONS.add(token)
|
||||
response.set_cookie(
|
||||
key="aaronai_session",
|
||||
value=token,
|
||||
httponly=True,
|
||||
secure=True,
|
||||
samesite="lax",
|
||||
max_age=60 * 60 * 24 * 30
|
||||
)
|
||||
response.body = b'{"ok": true}'
|
||||
response.status_code = 200
|
||||
response.media_type = "application/json"
|
||||
return response
|
||||
|
||||
@app.post("/auth/logout")
|
||||
async def logout(request: Request, response: Response):
|
||||
token = get_session(request)
|
||||
if token:
|
||||
SESSIONS.discard(token)
|
||||
response.delete_cookie("aaronai_session")
|
||||
return JSONResponse({"ok": True})
|
||||
|
||||
@app.get("/auth/check")
|
||||
async def check_auth(request: Request):
|
||||
token = get_session(request)
|
||||
if not token or token not in SESSIONS:
|
||||
return JSONResponse({"authenticated": False})
|
||||
return JSONResponse({"authenticated": True})
|
||||
|
||||
@app.get("/", response_class=FileResponse)
|
||||
async def index():
|
||||
return FileResponse("/home/aaron/aaronai/static/index.html")
|
||||
|
||||
@app.get("/api/settings")
|
||||
async def get_settings():
|
||||
async def get_settings(auth: str = Depends(require_auth)):
|
||||
return JSONResponse(load_settings())
|
||||
|
||||
@app.post("/api/settings")
|
||||
async def update_settings(request: Request):
|
||||
async def update_settings(request: Request, auth: str = Depends(require_auth)):
|
||||
data = await request.json()
|
||||
settings = load_settings()
|
||||
settings.update(data)
|
||||
@@ -288,7 +347,7 @@ async def update_settings(request: Request):
|
||||
return JSONResponse(settings)
|
||||
|
||||
@app.get("/api/conversations")
|
||||
async def list_conversations():
|
||||
async def list_conversations(auth: str = Depends(require_auth)):
|
||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||
c = conn.cursor()
|
||||
c.execute('''SELECT id, title, created_at, updated_at, message_count
|
||||
@@ -301,14 +360,14 @@ async def list_conversations():
|
||||
} for r in rows])
|
||||
|
||||
@app.post("/api/conversations")
|
||||
async def new_conversation(request: Request):
|
||||
async def new_conversation(request: Request, auth: str = Depends(require_auth)):
|
||||
data = await request.json()
|
||||
title = data.get("title", "New conversation")
|
||||
conv_id = create_conversation(title)
|
||||
return JSONResponse({"id": conv_id, "title": title})
|
||||
|
||||
@app.get("/api/conversations/{conv_id}/messages")
|
||||
async def get_messages(conv_id: str):
|
||||
async def get_messages(conv_id: str, auth: str = Depends(require_auth)):
|
||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||
c = conn.cursor()
|
||||
c.execute('''SELECT role, content, sources, timestamp FROM messages
|
||||
@@ -321,7 +380,7 @@ async def get_messages(conv_id: str):
|
||||
} for r in rows])
|
||||
|
||||
@app.patch("/api/conversations/{conv_id}")
|
||||
async def rename_conversation(conv_id: str, request: Request):
|
||||
async def rename_conversation(conv_id: str, request: Request, auth: str = Depends(require_auth)):
|
||||
data = await request.json()
|
||||
title = data.get("title", "")
|
||||
if not title:
|
||||
@@ -334,7 +393,7 @@ async def rename_conversation(conv_id: str, request: Request):
|
||||
return JSONResponse({"id": conv_id, "title": title})
|
||||
|
||||
@app.delete("/api/conversations/{conv_id}")
|
||||
async def delete_conversation(conv_id: str):
|
||||
async def delete_conversation(conv_id: str, auth: str = Depends(require_auth)):
|
||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||
c = conn.cursor()
|
||||
c.execute("DELETE FROM messages WHERE conversation_id = ?", (conv_id,))
|
||||
@@ -344,7 +403,7 @@ async def delete_conversation(conv_id: str):
|
||||
return JSONResponse({"deleted": conv_id})
|
||||
|
||||
@app.post("/api/chat")
|
||||
async def chat_endpoint(request: Request):
|
||||
async def chat_endpoint(request: Request, auth: str = Depends(require_auth)):
|
||||
data = await request.json()
|
||||
user_message = data.get("message", "").strip()
|
||||
conversation_id = data.get("conversation_id", "")
|
||||
@@ -402,18 +461,18 @@ async def chat_endpoint(request: Request):
|
||||
})
|
||||
|
||||
@app.get("/api/memory")
|
||||
async def get_memory():
|
||||
async def get_memory(auth: str = Depends(require_auth)):
|
||||
return JSONResponse({"content": load_memory()})
|
||||
|
||||
@app.post("/api/memory")
|
||||
async def update_memory(request: Request):
|
||||
async def update_memory(request: Request, auth: str = Depends(require_auth)):
|
||||
data = await request.json()
|
||||
content = data.get("content", "")
|
||||
save_memory(content)
|
||||
return JSONResponse({"saved": True})
|
||||
|
||||
@app.get("/api/status")
|
||||
async def get_status():
|
||||
async def get_status(auth: str = Depends(require_auth)):
|
||||
chunk_count = collection.count()
|
||||
|
||||
# Watcher status
|
||||
@@ -475,7 +534,7 @@ async def get_status():
|
||||
})
|
||||
|
||||
@app.post("/api/reindex")
|
||||
async def trigger_reindex():
|
||||
async def trigger_reindex(auth: str = Depends(require_auth)):
|
||||
try:
|
||||
subprocess.Popen([PYTHON, INGEST_SCRIPT, NEXTCLOUD_PATH])
|
||||
return JSONResponse({"started": True, "message": "Re-indexing started in background"})
|
||||
@@ -483,7 +542,7 @@ async def trigger_reindex():
|
||||
return JSONResponse({"started": False, "error": str(e)})
|
||||
|
||||
@app.delete("/api/conversations")
|
||||
async def clear_all_conversations():
|
||||
async def clear_all_conversations(auth: str = Depends(require_auth)):
|
||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||
c = conn.cursor()
|
||||
c.execute("DELETE FROM messages")
|
||||
|
||||
Reference in New Issue
Block a user