diff options
Diffstat (limited to 'backend/app/db')
| -rw-r--r-- | backend/app/db/__init__.py | 1 | ||||
| -rw-r--r-- | backend/app/db/watchlist.py | 104 |
2 files changed, 105 insertions, 0 deletions
diff --git a/backend/app/db/__init__.py b/backend/app/db/__init__.py new file mode 100644 index 0000000..a6e0ed0 --- /dev/null +++ b/backend/app/db/__init__.py @@ -0,0 +1 @@ +"""SQLite persistence package.""" diff --git a/backend/app/db/watchlist.py b/backend/app/db/watchlist.py new file mode 100644 index 0000000..238cf35 --- /dev/null +++ b/backend/app/db/watchlist.py @@ -0,0 +1,104 @@ +"""SQLite watchlist persistence for one local default profile.""" +from __future__ import annotations + +import sqlite3 +from datetime import datetime, timezone +from pathlib import Path + +DEFAULT_DB_PATH = Path(__file__).resolve().parents[2] / "data" / "prism.db" +WATCHLIST_LIMIT = 10 + + +class WatchlistFullError(ValueError): + """Raised when the local watchlist already has the maximum number of symbols.""" + + +def normalize_symbol(symbol: str) -> str: + cleaned = str(symbol or "").strip().upper() + if not cleaned: + raise ValueError("symbol is required") + if len(cleaned) > 16: + raise ValueError("symbol is too long") + return cleaned + + +def connect(db_path: Path | str = DEFAULT_DB_PATH) -> sqlite3.Connection: + path = Path(db_path) + path.parent.mkdir(parents=True, exist_ok=True) + conn = sqlite3.connect(path) + conn.row_factory = sqlite3.Row + return conn + + +def init_db(db_path: Path | str = DEFAULT_DB_PATH) -> None: + with connect(db_path) as conn: + conn.execute( + """ + CREATE TABLE IF NOT EXISTS profiles ( + id INTEGER PRIMARY KEY, + name TEXT UNIQUE NOT NULL + ) + """ + ) + conn.execute( + """ + CREATE TABLE IF NOT EXISTS watchlist ( + profile_id INTEGER NOT NULL, + symbol TEXT NOT NULL, + created_at TEXT NOT NULL, + UNIQUE(profile_id, symbol), + FOREIGN KEY(profile_id) REFERENCES profiles(id) + ) + """ + ) + conn.execute("INSERT OR IGNORE INTO profiles(name) VALUES (?)", ("default",)) + + +def get_default_profile_id(conn: sqlite3.Connection) -> int: + row = conn.execute("SELECT id FROM profiles WHERE name = ?", ("default",)).fetchone() + if row is None: + conn.execute("INSERT INTO profiles(name) VALUES (?)", ("default",)) + row = conn.execute("SELECT id FROM profiles WHERE name = ?", ("default",)).fetchone() + return int(row["id"]) + + +def list_symbols(db_path: Path | str = DEFAULT_DB_PATH) -> list[dict[str, str]]: + init_db(db_path) + with connect(db_path) as conn: + profile_id = get_default_profile_id(conn) + rows = conn.execute( + "SELECT symbol, created_at FROM watchlist WHERE profile_id = ? ORDER BY created_at ASC", + (profile_id,), + ).fetchall() + return [{"symbol": row["symbol"], "created_at": row["created_at"]} for row in rows] + + +def add_symbol(symbol: str, db_path: Path | str = DEFAULT_DB_PATH) -> dict[str, str]: + sym = normalize_symbol(symbol) + init_db(db_path) + with connect(db_path) as conn: + profile_id = get_default_profile_id(conn) + existing = conn.execute( + "SELECT symbol, created_at FROM watchlist WHERE profile_id = ? AND symbol = ?", + (profile_id, sym), + ).fetchone() + if existing: + return {"symbol": existing["symbol"], "created_at": existing["created_at"]} + count = conn.execute("SELECT COUNT(*) AS c FROM watchlist WHERE profile_id = ?", (profile_id,)).fetchone()["c"] + if int(count) >= WATCHLIST_LIMIT: + raise WatchlistFullError("watchlist limit reached") + created_at = datetime.now(timezone.utc).isoformat() + conn.execute( + "INSERT INTO watchlist(profile_id, symbol, created_at) VALUES (?, ?, ?)", + (profile_id, sym, created_at), + ) + return {"symbol": sym, "created_at": created_at} + + +def remove_symbol(symbol: str, db_path: Path | str = DEFAULT_DB_PATH) -> bool: + sym = normalize_symbol(symbol) + init_db(db_path) + with connect(db_path) as conn: + profile_id = get_default_profile_id(conn) + cur = conn.execute("DELETE FROM watchlist WHERE profile_id = ? AND symbol = ?", (profile_id, sym)) + return cur.rowcount > 0 |
