Add session-based auth — replace Basic Auth with httpOnly cookie, 30-day expiry
This commit is contained in:
+73
-14
@@ -9,7 +9,10 @@ from dotenv import load_dotenv
|
|||||||
import chromadb
|
import chromadb
|
||||||
from sentence_transformers import SentenceTransformer
|
from sentence_transformers import SentenceTransformer
|
||||||
import anthropic
|
import anthropic
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request, Response, Depends, HTTPException
|
||||||
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
|
import secrets
|
||||||
|
import hashlib
|
||||||
from fastapi.responses import FileResponse, JSONResponse
|
from fastapi.responses import FileResponse, JSONResponse
|
||||||
from fastapi.staticfiles import StaticFiles
|
from fastapi.staticfiles import StaticFiles
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
@@ -76,6 +79,26 @@ the memory file as ground truth for his present situation. Use web
|
|||||||
search automatically when current information is needed. Never
|
search automatically when current information is needed. Never
|
||||||
re-brief on context that's already in memory or documents."""
|
re-brief on context that's already in memory or documents."""
|
||||||
|
|
||||||
|
# Auth configuration
|
||||||
|
import os
|
||||||
|
SESSION_PASSWORD = os.getenv("AARON_AI_PASSWORD", "changeme")
|
||||||
|
SESSIONS: set = set() # In-memory session store
|
||||||
|
|
||||||
|
def make_session_token() -> str:
|
||||||
|
return secrets.token_urlsafe(32)
|
||||||
|
|
||||||
|
def hash_password(password: str) -> str:
|
||||||
|
return hashlib.sha256(password.encode()).hexdigest()
|
||||||
|
|
||||||
|
def get_session(request: Request) -> str | None:
|
||||||
|
return request.cookies.get("aaronai_session")
|
||||||
|
|
||||||
|
def require_auth(request: Request):
|
||||||
|
token = get_session(request)
|
||||||
|
if not token or token not in SESSIONS:
|
||||||
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
|
return token
|
||||||
|
|
||||||
CV_SOURCES = ["Aaron Nelson CV 2024.pdf", "Aaron Nelson CV 2025.pdf", "Aaron Nelson - CV.docx"]
|
CV_SOURCES = ["Aaron Nelson CV 2024.pdf", "Aaron Nelson CV 2025.pdf", "Aaron Nelson - CV.docx"]
|
||||||
|
|
||||||
def init_conversations_db():
|
def init_conversations_db():
|
||||||
@@ -271,16 +294,52 @@ def chat(user_message, conversation_id, settings):
|
|||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
app.add_middleware(CORSMiddleware, allow_origins=["*"], allow_methods=["*"], allow_headers=["*"])
|
||||||
|
|
||||||
|
@app.post("/auth/login")
|
||||||
|
async def login(request: Request, response: Response):
|
||||||
|
data = await request.json()
|
||||||
|
password = data.get("password", "")
|
||||||
|
if hash_password(password) != hash_password(SESSION_PASSWORD):
|
||||||
|
raise HTTPException(status_code=401, detail="Invalid password")
|
||||||
|
token = make_session_token()
|
||||||
|
SESSIONS.add(token)
|
||||||
|
response.set_cookie(
|
||||||
|
key="aaronai_session",
|
||||||
|
value=token,
|
||||||
|
httponly=True,
|
||||||
|
secure=True,
|
||||||
|
samesite="lax",
|
||||||
|
max_age=60 * 60 * 24 * 30
|
||||||
|
)
|
||||||
|
response.body = b'{"ok": true}'
|
||||||
|
response.status_code = 200
|
||||||
|
response.media_type = "application/json"
|
||||||
|
return response
|
||||||
|
|
||||||
|
@app.post("/auth/logout")
|
||||||
|
async def logout(request: Request, response: Response):
|
||||||
|
token = get_session(request)
|
||||||
|
if token:
|
||||||
|
SESSIONS.discard(token)
|
||||||
|
response.delete_cookie("aaronai_session")
|
||||||
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
@app.get("/auth/check")
|
||||||
|
async def check_auth(request: Request):
|
||||||
|
token = get_session(request)
|
||||||
|
if not token or token not in SESSIONS:
|
||||||
|
return JSONResponse({"authenticated": False})
|
||||||
|
return JSONResponse({"authenticated": True})
|
||||||
|
|
||||||
@app.get("/", response_class=FileResponse)
|
@app.get("/", response_class=FileResponse)
|
||||||
async def index():
|
async def index():
|
||||||
return FileResponse("/home/aaron/aaronai/static/index.html")
|
return FileResponse("/home/aaron/aaronai/static/index.html")
|
||||||
|
|
||||||
@app.get("/api/settings")
|
@app.get("/api/settings")
|
||||||
async def get_settings():
|
async def get_settings(auth: str = Depends(require_auth)):
|
||||||
return JSONResponse(load_settings())
|
return JSONResponse(load_settings())
|
||||||
|
|
||||||
@app.post("/api/settings")
|
@app.post("/api/settings")
|
||||||
async def update_settings(request: Request):
|
async def update_settings(request: Request, auth: str = Depends(require_auth)):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
settings = load_settings()
|
settings = load_settings()
|
||||||
settings.update(data)
|
settings.update(data)
|
||||||
@@ -288,7 +347,7 @@ async def update_settings(request: Request):
|
|||||||
return JSONResponse(settings)
|
return JSONResponse(settings)
|
||||||
|
|
||||||
@app.get("/api/conversations")
|
@app.get("/api/conversations")
|
||||||
async def list_conversations():
|
async def list_conversations(auth: str = Depends(require_auth)):
|
||||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''SELECT id, title, created_at, updated_at, message_count
|
c.execute('''SELECT id, title, created_at, updated_at, message_count
|
||||||
@@ -301,14 +360,14 @@ async def list_conversations():
|
|||||||
} for r in rows])
|
} for r in rows])
|
||||||
|
|
||||||
@app.post("/api/conversations")
|
@app.post("/api/conversations")
|
||||||
async def new_conversation(request: Request):
|
async def new_conversation(request: Request, auth: str = Depends(require_auth)):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
title = data.get("title", "New conversation")
|
title = data.get("title", "New conversation")
|
||||||
conv_id = create_conversation(title)
|
conv_id = create_conversation(title)
|
||||||
return JSONResponse({"id": conv_id, "title": title})
|
return JSONResponse({"id": conv_id, "title": title})
|
||||||
|
|
||||||
@app.get("/api/conversations/{conv_id}/messages")
|
@app.get("/api/conversations/{conv_id}/messages")
|
||||||
async def get_messages(conv_id: str):
|
async def get_messages(conv_id: str, auth: str = Depends(require_auth)):
|
||||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute('''SELECT role, content, sources, timestamp FROM messages
|
c.execute('''SELECT role, content, sources, timestamp FROM messages
|
||||||
@@ -321,7 +380,7 @@ async def get_messages(conv_id: str):
|
|||||||
} for r in rows])
|
} for r in rows])
|
||||||
|
|
||||||
@app.patch("/api/conversations/{conv_id}")
|
@app.patch("/api/conversations/{conv_id}")
|
||||||
async def rename_conversation(conv_id: str, request: Request):
|
async def rename_conversation(conv_id: str, request: Request, auth: str = Depends(require_auth)):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
title = data.get("title", "")
|
title = data.get("title", "")
|
||||||
if not title:
|
if not title:
|
||||||
@@ -334,7 +393,7 @@ async def rename_conversation(conv_id: str, request: Request):
|
|||||||
return JSONResponse({"id": conv_id, "title": title})
|
return JSONResponse({"id": conv_id, "title": title})
|
||||||
|
|
||||||
@app.delete("/api/conversations/{conv_id}")
|
@app.delete("/api/conversations/{conv_id}")
|
||||||
async def delete_conversation(conv_id: str):
|
async def delete_conversation(conv_id: str, auth: str = Depends(require_auth)):
|
||||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute("DELETE FROM messages WHERE conversation_id = ?", (conv_id,))
|
c.execute("DELETE FROM messages WHERE conversation_id = ?", (conv_id,))
|
||||||
@@ -344,7 +403,7 @@ async def delete_conversation(conv_id: str):
|
|||||||
return JSONResponse({"deleted": conv_id})
|
return JSONResponse({"deleted": conv_id})
|
||||||
|
|
||||||
@app.post("/api/chat")
|
@app.post("/api/chat")
|
||||||
async def chat_endpoint(request: Request):
|
async def chat_endpoint(request: Request, auth: str = Depends(require_auth)):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
user_message = data.get("message", "").strip()
|
user_message = data.get("message", "").strip()
|
||||||
conversation_id = data.get("conversation_id", "")
|
conversation_id = data.get("conversation_id", "")
|
||||||
@@ -402,18 +461,18 @@ async def chat_endpoint(request: Request):
|
|||||||
})
|
})
|
||||||
|
|
||||||
@app.get("/api/memory")
|
@app.get("/api/memory")
|
||||||
async def get_memory():
|
async def get_memory(auth: str = Depends(require_auth)):
|
||||||
return JSONResponse({"content": load_memory()})
|
return JSONResponse({"content": load_memory()})
|
||||||
|
|
||||||
@app.post("/api/memory")
|
@app.post("/api/memory")
|
||||||
async def update_memory(request: Request):
|
async def update_memory(request: Request, auth: str = Depends(require_auth)):
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
content = data.get("content", "")
|
content = data.get("content", "")
|
||||||
save_memory(content)
|
save_memory(content)
|
||||||
return JSONResponse({"saved": True})
|
return JSONResponse({"saved": True})
|
||||||
|
|
||||||
@app.get("/api/status")
|
@app.get("/api/status")
|
||||||
async def get_status():
|
async def get_status(auth: str = Depends(require_auth)):
|
||||||
chunk_count = collection.count()
|
chunk_count = collection.count()
|
||||||
|
|
||||||
# Watcher status
|
# Watcher status
|
||||||
@@ -475,7 +534,7 @@ async def get_status():
|
|||||||
})
|
})
|
||||||
|
|
||||||
@app.post("/api/reindex")
|
@app.post("/api/reindex")
|
||||||
async def trigger_reindex():
|
async def trigger_reindex(auth: str = Depends(require_auth)):
|
||||||
try:
|
try:
|
||||||
subprocess.Popen([PYTHON, INGEST_SCRIPT, NEXTCLOUD_PATH])
|
subprocess.Popen([PYTHON, INGEST_SCRIPT, NEXTCLOUD_PATH])
|
||||||
return JSONResponse({"started": True, "message": "Re-indexing started in background"})
|
return JSONResponse({"started": True, "message": "Re-indexing started in background"})
|
||||||
@@ -483,7 +542,7 @@ async def trigger_reindex():
|
|||||||
return JSONResponse({"started": False, "error": str(e)})
|
return JSONResponse({"started": False, "error": str(e)})
|
||||||
|
|
||||||
@app.delete("/api/conversations")
|
@app.delete("/api/conversations")
|
||||||
async def clear_all_conversations():
|
async def clear_all_conversations(auth: str = Depends(require_auth)):
|
||||||
conn = sqlite3.connect(CONVERSATIONS_DB)
|
conn = sqlite3.connect(CONVERSATIONS_DB)
|
||||||
c = conn.cursor()
|
c = conn.cursor()
|
||||||
c.execute("DELETE FROM messages")
|
c.execute("DELETE FROM messages")
|
||||||
|
|||||||
Reference in New Issue
Block a user