"""SQLite watchlist persistence for one local default profile.""" from __future__ import annotations import sqlite3 from datetime import datetime, timezone from pathlib import Path DEFAULT_DB_PATH = Path(__file__).resolve().parents[2] / "data" / "prism.db" WATCHLIST_LIMIT = 10 class WatchlistFullError(ValueError): """Raised when the local watchlist already has the maximum number of symbols.""" def normalize_symbol(symbol: str) -> str: cleaned = str(symbol or "").strip().upper() if not cleaned: raise ValueError("symbol is required") if len(cleaned) > 16: raise ValueError("symbol is too long") return cleaned def connect(db_path: Path | str = DEFAULT_DB_PATH) -> sqlite3.Connection: path = Path(db_path) path.parent.mkdir(parents=True, exist_ok=True) conn = sqlite3.connect(path) conn.row_factory = sqlite3.Row return conn def init_db(db_path: Path | str = DEFAULT_DB_PATH) -> None: with connect(db_path) as conn: conn.execute( """ CREATE TABLE IF NOT EXISTS profiles ( id INTEGER PRIMARY KEY, name TEXT UNIQUE NOT NULL ) """ ) conn.execute( """ CREATE TABLE IF NOT EXISTS watchlist ( profile_id INTEGER NOT NULL, symbol TEXT NOT NULL, created_at TEXT NOT NULL, UNIQUE(profile_id, symbol), FOREIGN KEY(profile_id) REFERENCES profiles(id) ) """ ) conn.execute("INSERT OR IGNORE INTO profiles(name) VALUES (?)", ("default",)) def get_default_profile_id(conn: sqlite3.Connection) -> int: row = conn.execute("SELECT id FROM profiles WHERE name = ?", ("default",)).fetchone() if row is None: conn.execute("INSERT INTO profiles(name) VALUES (?)", ("default",)) row = conn.execute("SELECT id FROM profiles WHERE name = ?", ("default",)).fetchone() return int(row["id"]) def list_symbols(db_path: Path | str = DEFAULT_DB_PATH) -> list[dict[str, str]]: init_db(db_path) with connect(db_path) as conn: profile_id = get_default_profile_id(conn) rows = conn.execute( "SELECT symbol, created_at FROM watchlist WHERE profile_id = ? ORDER BY created_at ASC", (profile_id,), ).fetchall() return [{"symbol": row["symbol"], "created_at": row["created_at"]} for row in rows] def add_symbol(symbol: str, db_path: Path | str = DEFAULT_DB_PATH) -> dict[str, str]: sym = normalize_symbol(symbol) init_db(db_path) with connect(db_path) as conn: profile_id = get_default_profile_id(conn) existing = conn.execute( "SELECT symbol, created_at FROM watchlist WHERE profile_id = ? AND symbol = ?", (profile_id, sym), ).fetchone() if existing: return {"symbol": existing["symbol"], "created_at": existing["created_at"]} count = conn.execute("SELECT COUNT(*) AS c FROM watchlist WHERE profile_id = ?", (profile_id,)).fetchone()["c"] if int(count) >= WATCHLIST_LIMIT: raise WatchlistFullError("watchlist limit reached") created_at = datetime.now(timezone.utc).isoformat() conn.execute( "INSERT INTO watchlist(profile_id, symbol, created_at) VALUES (?, ?, ?)", (profile_id, sym, created_at), ) return {"symbol": sym, "created_at": created_at} def remove_symbol(symbol: str, db_path: Path | str = DEFAULT_DB_PATH) -> bool: sym = normalize_symbol(symbol) init_db(db_path) with connect(db_path) as conn: profile_id = get_default_profile_id(conn) cur = conn.execute("DELETE FROM watchlist WHERE profile_id = ? AND symbol = ?", (profile_id, sym)) return cur.rowcount > 0