185 lines
6.6 KiB
Python
185 lines
6.6 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import secrets
|
|
from datetime import datetime, timedelta, timezone
|
|
from urllib.parse import urlencode
|
|
|
|
import httpx
|
|
from fastapi import Depends, HTTPException, Request, Response, status
|
|
from sqlalchemy import select
|
|
from sqlalchemy.orm import Session
|
|
|
|
from .config import settings
|
|
from .database import get_db
|
|
from .models import UserSession
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
def utcnow() -> datetime:
|
|
return datetime.now(timezone.utc)
|
|
|
|
|
|
def create_session_token() -> str:
|
|
return secrets.token_urlsafe(32)
|
|
|
|
|
|
def set_session_cookie(response: Response, session_token: str) -> None:
|
|
response.set_cookie(
|
|
settings.session_cookie_name,
|
|
session_token,
|
|
httponly=True,
|
|
secure=settings.session_cookie_secure,
|
|
samesite="lax",
|
|
max_age=60 * 60 * 24 * 14,
|
|
)
|
|
|
|
|
|
def clear_session_cookie(response: Response) -> None:
|
|
response.delete_cookie(settings.session_cookie_name)
|
|
|
|
|
|
def get_current_session(
|
|
request: Request,
|
|
db: Session = Depends(get_db),
|
|
required: bool = False,
|
|
) -> UserSession | None:
|
|
session_token = request.cookies.get(settings.session_cookie_name)
|
|
if not session_token:
|
|
if required:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
|
return None
|
|
|
|
session = db.scalar(select(UserSession).where(UserSession.session_token == session_token))
|
|
if session is None and required:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid session")
|
|
return session
|
|
|
|
|
|
def require_session(session: UserSession | None = Depends(get_current_session)) -> UserSession:
|
|
if session is None:
|
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Not authenticated")
|
|
return session
|
|
|
|
|
|
def require_admin(session: UserSession = Depends(require_session)) -> UserSession:
|
|
if not session.is_admin:
|
|
raise HTTPException(status_code=status.HTTP_403_FORBIDDEN, detail="Admin access required")
|
|
return session
|
|
|
|
|
|
def build_teamsnap_authorize_url(state: str) -> str:
|
|
params = urlencode(
|
|
{
|
|
"client_id": settings.teamsnap_client_id,
|
|
"redirect_uri": settings.teamsnap_redirect_uri,
|
|
"response_type": "code",
|
|
"scope": settings.teamsnap_scope,
|
|
"state": state,
|
|
}
|
|
)
|
|
return f"{settings.teamsnap_auth_url}?{params}"
|
|
|
|
|
|
def _extract_collection_item(payload: dict) -> dict[str, object] | None:
|
|
items = payload.get("collection", {}).get("items", [])
|
|
if not items:
|
|
return None
|
|
|
|
values: dict[str, object] = {}
|
|
for field in items[0].get("data", []):
|
|
name = field.get("name")
|
|
if isinstance(name, str):
|
|
values[name] = field.get("value")
|
|
return values
|
|
|
|
|
|
async def fetch_teamsnap_user_id(access_token: str) -> str | None:
|
|
headers = {
|
|
"Accept": "application/vnd.collection+json",
|
|
"Authorization": f"Bearer {access_token}",
|
|
}
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
root_response = await client.get(settings.teamsnap_api_root, headers=headers)
|
|
root_response.raise_for_status()
|
|
queries = root_response.json().get("collection", {}).get("queries", [])
|
|
me_href = next((query.get("href") for query in queries if query.get("rel") == "me"), None)
|
|
if not isinstance(me_href, str) or not me_href:
|
|
return None
|
|
|
|
me_response = await client.get(me_href, headers=headers)
|
|
me_response.raise_for_status()
|
|
except httpx.HTTPError:
|
|
return None
|
|
|
|
me_item = _extract_collection_item(me_response.json())
|
|
if not me_item:
|
|
return None
|
|
user_id = me_item.get("id")
|
|
return str(user_id) if user_id is not None else None
|
|
|
|
|
|
async def exchange_code_for_token(code: str) -> dict:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
response = await client.post(
|
|
settings.teamsnap_token_url,
|
|
data={
|
|
"grant_type": "authorization_code",
|
|
"code": code,
|
|
"redirect_uri": settings.teamsnap_redirect_uri,
|
|
"client_id": settings.teamsnap_client_id,
|
|
"client_secret": settings.teamsnap_client_secret,
|
|
},
|
|
headers={"Accept": "application/json"},
|
|
)
|
|
except httpx.HTTPError as exc:
|
|
logger.exception("TeamSnap token exchange request failed")
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="TeamSnap token exchange failed") from exc
|
|
if response.status_code >= 400:
|
|
logger.error(
|
|
"TeamSnap token exchange rejected: status=%s body=%s redirect_uri=%s client_id_suffix=%s",
|
|
response.status_code,
|
|
response.text,
|
|
settings.teamsnap_redirect_uri,
|
|
settings.teamsnap_client_id[-6:] if settings.teamsnap_client_id else "",
|
|
)
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="TeamSnap token exchange failed")
|
|
return response.json()
|
|
|
|
|
|
async def refresh_access_token(refresh_token: str) -> dict:
|
|
try:
|
|
async with httpx.AsyncClient(timeout=15.0) as client:
|
|
response = await client.post(
|
|
settings.teamsnap_token_url,
|
|
data={
|
|
"grant_type": "refresh_token",
|
|
"refresh_token": refresh_token,
|
|
"client_id": settings.teamsnap_client_id,
|
|
"client_secret": settings.teamsnap_client_secret,
|
|
},
|
|
headers={"Accept": "application/json"},
|
|
)
|
|
except httpx.HTTPError as exc:
|
|
logger.exception("TeamSnap token refresh request failed")
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="TeamSnap token refresh failed") from exc
|
|
if response.status_code >= 400:
|
|
logger.error(
|
|
"TeamSnap token refresh rejected: status=%s body=%s client_id_suffix=%s",
|
|
response.status_code,
|
|
response.text,
|
|
settings.teamsnap_client_id[-6:] if settings.teamsnap_client_id else "",
|
|
)
|
|
raise HTTPException(status_code=status.HTTP_502_BAD_GATEWAY, detail="TeamSnap token refresh failed")
|
|
return response.json()
|
|
|
|
|
|
def update_session_tokens(session: UserSession, token_payload: dict) -> None:
|
|
session.access_token = token_payload.get("access_token")
|
|
session.refresh_token = token_payload.get("refresh_token", session.refresh_token)
|
|
expires_in = token_payload.get("expires_in")
|
|
session.token_expires_at = utcnow() + timedelta(seconds=int(expires_in)) if expires_in else None
|