summaryrefslogtreecommitdiff
path: root/backend/app/db/watchlist.py
diff options
context:
space:
mode:
Diffstat (limited to 'backend/app/db/watchlist.py')
-rw-r--r--backend/app/db/watchlist.py104
1 files changed, 104 insertions, 0 deletions
diff --git a/backend/app/db/watchlist.py b/backend/app/db/watchlist.py
new file mode 100644
index 0000000..238cf35
--- /dev/null
+++ b/backend/app/db/watchlist.py
@@ -0,0 +1,104 @@
+"""SQLite watchlist persistence for one local default profile."""
+from __future__ import annotations
+
+import sqlite3
+from datetime import datetime, timezone
+from pathlib import Path
+
+DEFAULT_DB_PATH = Path(__file__).resolve().parents[2] / "data" / "prism.db"
+WATCHLIST_LIMIT = 10
+
+
+class WatchlistFullError(ValueError):
+ """Raised when the local watchlist already has the maximum number of symbols."""
+
+
+def normalize_symbol(symbol: str) -> str:
+ cleaned = str(symbol or "").strip().upper()
+ if not cleaned:
+ raise ValueError("symbol is required")
+ if len(cleaned) > 16:
+ raise ValueError("symbol is too long")
+ return cleaned
+
+
+def connect(db_path: Path | str = DEFAULT_DB_PATH) -> sqlite3.Connection:
+ path = Path(db_path)
+ path.parent.mkdir(parents=True, exist_ok=True)
+ conn = sqlite3.connect(path)
+ conn.row_factory = sqlite3.Row
+ return conn
+
+
+def init_db(db_path: Path | str = DEFAULT_DB_PATH) -> None:
+ with connect(db_path) as conn:
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS profiles (
+ id INTEGER PRIMARY KEY,
+ name TEXT UNIQUE NOT NULL
+ )
+ """
+ )
+ conn.execute(
+ """
+ CREATE TABLE IF NOT EXISTS watchlist (
+ profile_id INTEGER NOT NULL,
+ symbol TEXT NOT NULL,
+ created_at TEXT NOT NULL,
+ UNIQUE(profile_id, symbol),
+ FOREIGN KEY(profile_id) REFERENCES profiles(id)
+ )
+ """
+ )
+ conn.execute("INSERT OR IGNORE INTO profiles(name) VALUES (?)", ("default",))
+
+
+def get_default_profile_id(conn: sqlite3.Connection) -> int:
+ row = conn.execute("SELECT id FROM profiles WHERE name = ?", ("default",)).fetchone()
+ if row is None:
+ conn.execute("INSERT INTO profiles(name) VALUES (?)", ("default",))
+ row = conn.execute("SELECT id FROM profiles WHERE name = ?", ("default",)).fetchone()
+ return int(row["id"])
+
+
+def list_symbols(db_path: Path | str = DEFAULT_DB_PATH) -> list[dict[str, str]]:
+ init_db(db_path)
+ with connect(db_path) as conn:
+ profile_id = get_default_profile_id(conn)
+ rows = conn.execute(
+ "SELECT symbol, created_at FROM watchlist WHERE profile_id = ? ORDER BY created_at ASC",
+ (profile_id,),
+ ).fetchall()
+ return [{"symbol": row["symbol"], "created_at": row["created_at"]} for row in rows]
+
+
+def add_symbol(symbol: str, db_path: Path | str = DEFAULT_DB_PATH) -> dict[str, str]:
+ sym = normalize_symbol(symbol)
+ init_db(db_path)
+ with connect(db_path) as conn:
+ profile_id = get_default_profile_id(conn)
+ existing = conn.execute(
+ "SELECT symbol, created_at FROM watchlist WHERE profile_id = ? AND symbol = ?",
+ (profile_id, sym),
+ ).fetchone()
+ if existing:
+ return {"symbol": existing["symbol"], "created_at": existing["created_at"]}
+ count = conn.execute("SELECT COUNT(*) AS c FROM watchlist WHERE profile_id = ?", (profile_id,)).fetchone()["c"]
+ if int(count) >= WATCHLIST_LIMIT:
+ raise WatchlistFullError("watchlist limit reached")
+ created_at = datetime.now(timezone.utc).isoformat()
+ conn.execute(
+ "INSERT INTO watchlist(profile_id, symbol, created_at) VALUES (?, ?, ?)",
+ (profile_id, sym, created_at),
+ )
+ return {"symbol": sym, "created_at": created_at}
+
+
+def remove_symbol(symbol: str, db_path: Path | str = DEFAULT_DB_PATH) -> bool:
+ sym = normalize_symbol(symbol)
+ init_db(db_path)
+ with connect(db_path) as conn:
+ profile_id = get_default_profile_id(conn)
+ cur = conn.execute("DELETE FROM watchlist WHERE profile_id = ? AND symbol = ?", (profile_id, sym))
+ return cur.rowcount > 0