summaryrefslogtreecommitdiff
path: root/backend/tests/test_watchlist.py
blob: 4d63428350c6c754be10e2b94f574fd355a5935f (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from pathlib import Path

import pytest

from app.db import watchlist


@pytest.fixture()
def db_path(tmp_path: Path) -> Path:
    return tmp_path / "prism.db"


def test_seed_default_profile(db_path: Path) -> None:
    watchlist.init_db(db_path)
    with watchlist.connect(db_path) as conn:
        row = conn.execute("SELECT name FROM profiles WHERE name = 'default'").fetchone()
    assert row["name"] == "default"


def test_add_remove_and_uppercase(db_path: Path) -> None:
    added = watchlist.add_symbol("aapl", db_path)
    assert added["symbol"] == "AAPL"
    assert watchlist.list_symbols(db_path)[0]["symbol"] == "AAPL"
    assert watchlist.remove_symbol("AAPL", db_path) is True
    assert watchlist.list_symbols(db_path) == []


def test_duplicate_prevention(db_path: Path) -> None:
    first = watchlist.add_symbol("msft", db_path)
    second = watchlist.add_symbol("MSFT", db_path)
    rows = watchlist.list_symbols(db_path)
    assert first == second
    assert [row["symbol"] for row in rows] == ["MSFT"]


def test_ten_symbol_cap(db_path: Path) -> None:
    for idx in range(10):
        watchlist.add_symbol(f"T{idx}", db_path)
    with pytest.raises(watchlist.WatchlistFullError):
        watchlist.add_symbol("OVER", db_path)