summaryrefslogtreecommitdiff
path: root/backend/app/main.py
blob: 7e02cbe04aaa593bec78327b549f35d12323315c (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
"""FastAPI entrypoint for Prism v2."""
from __future__ import annotations

from contextlib import asynccontextmanager
from pathlib import Path

from dotenv import load_dotenv
from fastapi import FastAPI, HTTPException, Query, status
from fastapi.middleware.cors import CORSMiddleware

from app.db import watchlist
from app.schemas import FinancialsResponse, HistoryPoint, MarketIndex, RatiosResponse, SearchResult, TickerOverview, ValuationResponse, WatchlistResponse
from app.services import data_service

load_dotenv()

@asynccontextmanager
async def lifespan(_: FastAPI):
    watchlist.init_db(DB_PATH)
    yield


app = FastAPI(title="Prism v2 API", version="0.1.0", lifespan=lifespan)
app.add_middleware(
    CORSMiddleware,
    allow_origins=[
        "http://localhost:3000",
        "http://127.0.0.1:3000",
        "http://localhost:3001",
        "http://127.0.0.1:3001",
    ],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

DB_PATH = Path(__file__).resolve().parents[1] / "data" / "prism.db"


@app.get("/health")
def health() -> dict[str, str]:
    return {"status": "ok"}


@app.get("/api/search", response_model=list[SearchResult])
def search(q: str = Query(default="", min_length=0)) -> list[dict]:
    return data_service.search_tickers(q)


@app.get("/api/market/indices", response_model=list[MarketIndex])
def market_indices() -> list[dict]:
    return data_service.get_market_indices()


@app.get("/api/tickers/{symbol}/overview", response_model=TickerOverview)
def ticker_overview(symbol: str) -> dict:
    overview = data_service.get_ticker_overview(symbol)
    if overview is None:
        raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="ticker data unavailable")
    return overview


@app.get("/api/tickers/{symbol}/history", response_model=list[HistoryPoint])
def ticker_history(symbol: str, period: str = Query(default="1y", pattern="^(1m|3m|6m|1y|5y)$")) -> list[dict]:
    return data_service.get_price_history(symbol, period=period)


@app.get("/api/tickers/{symbol}/financials", response_model=FinancialsResponse)
def ticker_financials(symbol: str, period: str = Query(default="annual", pattern="^(annual|quarterly)$")) -> dict:
    return data_service.get_financials(symbol, period=period)


@app.get("/api/tickers/{symbol}/valuation", response_model=ValuationResponse)
def ticker_valuation(symbol: str) -> dict:
    return data_service.get_valuation(symbol)


@app.get("/api/tickers/{symbol}/ratios", response_model=RatiosResponse)
def ticker_ratios(symbol: str) -> dict:
    return data_service.get_ratios(symbol)


@app.get("/api/watchlist", response_model=WatchlistResponse)
def get_watchlist() -> dict:
    items = []
    for row in watchlist.list_symbols(DB_PATH):
        info = data_service.get_company_info(row["symbol"])
        items.append({**row, "quote": data_service.build_quote(info, row["symbol"])})
    return {"items": items, "limit": watchlist.WATCHLIST_LIMIT}


@app.post("/api/watchlist/{symbol}", response_model=WatchlistResponse, status_code=status.HTTP_201_CREATED)
def add_watchlist_symbol(symbol: str) -> dict:
    try:
        watchlist.add_symbol(symbol, DB_PATH)
    except watchlist.WatchlistFullError as exc:
        raise HTTPException(status_code=status.HTTP_409_CONFLICT, detail=str(exc)) from exc
    except ValueError as exc:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
    return get_watchlist()


@app.delete("/api/watchlist/{symbol}", response_model=WatchlistResponse)
def delete_watchlist_symbol(symbol: str) -> dict:
    try:
        watchlist.remove_symbol(symbol, DB_PATH)
    except ValueError as exc:
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail=str(exc)) from exc
    return get_watchlist()