Persist sessions to SQLite — survive service restarts
This commit is contained in:
+30
-4
@@ -102,7 +102,15 @@ re-brief on context that's already in memory or documents."""
|
|||||||
# Auth configuration
|
# Auth configuration
|
||||||
import os
|
import os
|
||||||
SESSION_PASSWORD = os.getenv("AARON_AI_PASSWORD", "changeme")
|
SESSION_PASSWORD = os.getenv("AARON_AI_PASSWORD", "changeme")
|
||||||
SESSIONS: set = set() # In-memory session store
|
SESSIONS_DB = str(Path.home() / "aaronai" / "sessions.db")
|
||||||
|
|
||||||
|
def _init_sessions():
|
||||||
|
conn = sqlite3.connect(SESSIONS_DB)
|
||||||
|
conn.execute("CREATE TABLE IF NOT EXISTS sessions (token TEXT PRIMARY KEY, created_at TEXT)")
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
_init_sessions()
|
||||||
|
|
||||||
def make_session_token() -> str:
|
def make_session_token() -> str:
|
||||||
return secrets.token_urlsafe(32)
|
return secrets.token_urlsafe(32)
|
||||||
@@ -110,12 +118,30 @@ def make_session_token() -> str:
|
|||||||
def hash_password(password: str) -> str:
|
def hash_password(password: str) -> str:
|
||||||
return hashlib.sha256(password.encode()).hexdigest()
|
return hashlib.sha256(password.encode()).hexdigest()
|
||||||
|
|
||||||
|
def save_session(token: str):
|
||||||
|
conn = sqlite3.connect(SESSIONS_DB)
|
||||||
|
conn.execute("INSERT OR REPLACE INTO sessions VALUES (?, ?)", (token, datetime.now().isoformat()))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def delete_session(token: str):
|
||||||
|
conn = sqlite3.connect(SESSIONS_DB)
|
||||||
|
conn.execute("DELETE FROM sessions WHERE token = ?", (token,))
|
||||||
|
conn.commit()
|
||||||
|
conn.close()
|
||||||
|
|
||||||
|
def session_exists(token: str) -> bool:
|
||||||
|
conn = sqlite3.connect(SESSIONS_DB)
|
||||||
|
row = conn.execute("SELECT 1 FROM sessions WHERE token = ?", (token,)).fetchone()
|
||||||
|
conn.close()
|
||||||
|
return row is not None
|
||||||
|
|
||||||
def get_session(request: Request) -> str | None:
|
def get_session(request: Request) -> str | None:
|
||||||
return request.cookies.get("aaronai_session")
|
return request.cookies.get("aaronai_session")
|
||||||
|
|
||||||
def require_auth(request: Request):
|
def require_auth(request: Request):
|
||||||
token = get_session(request)
|
token = get_session(request)
|
||||||
if not token or token not in SESSIONS:
|
if not token or not session_exists(token):
|
||||||
raise HTTPException(status_code=401, detail="Not authenticated")
|
raise HTTPException(status_code=401, detail="Not authenticated")
|
||||||
return token
|
return token
|
||||||
|
|
||||||
@@ -321,7 +347,7 @@ async def login(request: Request, response: Response):
|
|||||||
if hash_password(password) != hash_password(SESSION_PASSWORD):
|
if hash_password(password) != hash_password(SESSION_PASSWORD):
|
||||||
raise HTTPException(status_code=401, detail="Invalid password")
|
raise HTTPException(status_code=401, detail="Invalid password")
|
||||||
token = make_session_token()
|
token = make_session_token()
|
||||||
SESSIONS.add(token)
|
save_session(token)
|
||||||
response.set_cookie(
|
response.set_cookie(
|
||||||
key="aaronai_session",
|
key="aaronai_session",
|
||||||
value=token,
|
value=token,
|
||||||
@@ -339,7 +365,7 @@ async def login(request: Request, response: Response):
|
|||||||
async def logout(request: Request, response: Response):
|
async def logout(request: Request, response: Response):
|
||||||
token = get_session(request)
|
token = get_session(request)
|
||||||
if token:
|
if token:
|
||||||
SESSIONS.discard(token)
|
delete_session(token)
|
||||||
response.delete_cookie("aaronai_session")
|
response.delete_cookie("aaronai_session")
|
||||||
return JSONResponse({"ok": True})
|
return JSONResponse({"ok": True})
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user