summaryrefslogtreecommitdiff
path: root/backend/app/main.py
blob: 1cc127e5d62b364554a6f5e247ab35c457018cc6 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
"""FastAPI entrypoint for Prism v2."""
from __future__ import annotations

from contextlib import asynccontextmanager
from pathlib import Path

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Query, status
from fastapi.middleware.cors import CORSMiddleware

from app.db import watchlist
from app.schemas import FinancialsResponse, HistoryPoint, MarketIndex, SearchResult, TickerOverview, ValuationResponse, WatchlistResponse
from app.services import data_service

load_dotenv()

@asynccontextmanager
async def lifespan(_: FastAPI):
    watchlist.init_db(DB_PATH)
    yield


app = FastAPI(title="Prism v2 API", version="0.1.0", lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:3000",
        "http://127.0.0.1:3000",
        "http://localhost:3001",
        "http://127.0.0.1:3001",
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

DB_PATH = Path(__file__).resolve().parents[1] / "data" / "prism.db"


@app.get("/health")
def health() -> dict[str, str]:
    return {"status": "ok"}


@app.get("/api/search", response_model=list[SearchResult])
def search(q: str = Query(default="", min_length=0)) -> list[dict]:
    return data_service.search_tickers(q)


@app.get("/api/market/indices", response_model=list[MarketIndex])
def market_indices() -> list[dict]:
    return data_service.get_market_indices()


@app.get("/api/tickers/{symbol}/overview", response_model=TickerOverview)
def ticker_overview(symbol: str) -> dict:
    overview = data_service.get_ticker_overview(symbol)
    if overview is None:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="ticker data unavailable")
    return overview


@app.get("/api/tickers/{symbol}/history", response_model=list[HistoryPoint])
def ticker_history(symbol: str, period: str = Query(default="1y", pattern="^(1m|3m|6m|1y|5y)$")) -> list[dict]:
    return data_service.get_price_history(symbol, period=period)


@app.get("/api/tickers/{symbol}/financials", response_model=FinancialsResponse)
def ticker_financials(symbol: str, period: str = Query(default="annual", pattern="^(annual|quarterly)$")) -> dict:
    return data_service.get_financials(symbol, period=period)


@app.get("/api/tickers/{symbol}/valuation", response_model=ValuationResponse)
def ticker_valuation(symbol: str) -> dict:
    return data_service.get_valuation(symbol)


@app.get("/api/watchlist", response_model=WatchlistResponse)
def get_watchlist() -> dict:
    items = []
    for row in watchlist.list_symbols(DB_PATH):
        info = data_service.get_company_info(row["symbol"])
        items.append({**row, "quote": data_service.build_quote(info, row["symbol"])})
    return {"items": items, "limit": watchlist.WATCHLIST_LIMIT}


@app.post("/api/watchlist/{symbol}", response_model=WatchlistResponse, status_code=status.HTTP_201_CREATED)
def add_watchlist_symbol(symbol: str) -> dict:
    try:
        watchlist.add_symbol(symbol, DB_PATH)
    except watchlist.WatchlistFullError as exc:
        raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
    except ValueError as exc:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
    return get_watchlist()


@app.delete("/api/watchlist/{symbol}", response_model=WatchlistResponse)
def delete_watchlist_symbol(symbol: str) -> dict:
    try:
        watchlist.remove_symbol(symbol, DB_PATH)
    except ValueError as exc:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
    return get_watchlist()