Files
InsightRadar/backend/app/api/endpoints/delivery.py
T
stardrophere 966bcfbba4 big update
2026-03-11 20:52:58 +08:00

354 lines
10 KiB
Python

# 推送设置 API:管理用户的推送时间表和推送渠道
from datetime import time as dt_time
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.api.dependencies import get_current_user, get_db
from app.models.models import AppUser, UserDeliverySchedule, UserPushEndpoint
from app.schemas.delivery_schema import (
DeliveryScheduleCreate,
DeliveryScheduleResponse,
DeliveryScheduleUpdate,
PushEndpointCreate,
PushEndpointResponse,
PushEndpointUpdate,
UserDeliveryConfigResponse,
)
router = APIRouter()
# 两条推送时间之间的最小间隔(分钟)
MIN_SCHEDULE_GAP_MINUTES = 30
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
"""校验路径 user_id 是否为当前登录用户本人。"""
if path_user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only operate your own resources",
)
def _parse_time(time_str: str) -> dt_time:
"""将 HH:MM 字符串解析为 time 对象"""
try:
parts = time_str.split(":")
return dt_time(hour=int(parts[0]), minute=int(parts[1]))
except (ValueError, IndexError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid time format, expected HH:MM",
)
def _format_time(t: dt_time) -> str:
"""将 time 对象格式化为 HH:MM 字符串"""
return t.strftime("%H:%M")
def _time_to_minutes(t: dt_time) -> int:
return t.hour * 60 + t.minute
def _check_min_gap(
db: Session,
user_id: int,
new_time: dt_time,
exclude_id: int | None = None,
) -> None:
"""
校验新时间与用户已有的所有推送时间之间是否满足最小间隔要求(30 分钟)。
不满足时直接抛出 400 异常。
"""
query = db.query(UserDeliverySchedule).filter(
UserDeliverySchedule.user_id == user_id
)
if exclude_id is not None:
query = query.filter(UserDeliverySchedule.id != exclude_id)
existing = query.all()
new_minutes = _time_to_minutes(new_time)
for s in existing:
old_minutes = _time_to_minutes(s.delivery_time)
diff = abs(new_minutes - old_minutes)
circular_diff = min(diff, 1440 - diff)
if circular_diff < MIN_SCHEDULE_GAP_MINUTES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"推送时间间隔不能少于 {MIN_SCHEDULE_GAP_MINUTES} 分钟,"
f"与已有的 {_format_time(s.delivery_time)} 冲突",
)
# ==========================================
# 聚合查询:一次性返回用户全部推送配置
# ==========================================
@router.get(
"/users/{user_id}/delivery-config",
response_model=UserDeliveryConfigResponse,
)
def get_delivery_config(
user_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""获取用户的完整推送配置(时间表 + 渠道)。"""
_ensure_self_access(user_id, current_user)
schedules = (
db.query(UserDeliverySchedule)
.filter(UserDeliverySchedule.user_id == user_id)
.order_by(UserDeliverySchedule.delivery_time.asc())
.all()
)
endpoints = (
db.query(UserPushEndpoint)
.filter(UserPushEndpoint.user_id == user_id)
.order_by(UserPushEndpoint.priority_level.asc())
.all()
)
# 手动转换 time 字段为字符串
schedule_list = [
DeliveryScheduleResponse(
id=s.id,
user_id=s.user_id,
delivery_time=_format_time(s.delivery_time),
is_active=s.is_active,
created_at=s.created_at,
)
for s in schedules
]
return UserDeliveryConfigResponse(schedules=schedule_list, endpoints=endpoints)
# ==========================================
# 推送时间表 CRUD
# ==========================================
@router.post(
"/users/{user_id}/delivery-schedules",
response_model=DeliveryScheduleResponse,
status_code=status.HTTP_201_CREATED,
)
def create_delivery_schedule(
user_id: int,
payload: DeliveryScheduleCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一条推送时间。"""
_ensure_self_access(user_id, current_user)
parsed_time = _parse_time(payload.delivery_time)
_check_min_gap(db, user_id, parsed_time)
db_obj = UserDeliverySchedule(
user_id=user_id,
delivery_time=parsed_time,
is_active=payload.is_active,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This delivery time already exists",
)
db.refresh(db_obj)
return DeliveryScheduleResponse(
id=db_obj.id,
user_id=db_obj.user_id,
delivery_time=_format_time(db_obj.delivery_time),
is_active=db_obj.is_active,
created_at=db_obj.created_at,
)
@router.patch(
"/users/{user_id}/delivery-schedules/{schedule_id}",
response_model=DeliveryScheduleResponse,
)
def update_delivery_schedule(
user_id: int,
schedule_id: int,
payload: DeliveryScheduleUpdate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""更新一条推送时间。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserDeliverySchedule)
.filter(
UserDeliverySchedule.id == schedule_id,
UserDeliverySchedule.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
if payload.delivery_time is not None:
new_time = _parse_time(payload.delivery_time)
_check_min_gap(db, user_id, new_time, exclude_id=schedule_id)
db_obj.delivery_time = new_time
if payload.is_active is not None:
db_obj.is_active = payload.is_active
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This delivery time already exists",
)
db.refresh(db_obj)
return DeliveryScheduleResponse(
id=db_obj.id,
user_id=db_obj.user_id,
delivery_time=_format_time(db_obj.delivery_time),
is_active=db_obj.is_active,
created_at=db_obj.created_at,
)
@router.delete(
"/users/{user_id}/delivery-schedules/{schedule_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_delivery_schedule(
user_id: int,
schedule_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一条推送时间。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserDeliverySchedule)
.filter(
UserDeliverySchedule.id == schedule_id,
UserDeliverySchedule.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
db.delete(db_obj)
db.commit()
return None
# ==========================================
# 推送渠道 CRUD
# ==========================================
@router.post(
"/users/{user_id}/push-endpoints",
response_model=PushEndpointResponse,
status_code=status.HTTP_201_CREATED,
)
def create_push_endpoint(
user_id: int,
payload: PushEndpointCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一个推送渠道。"""
_ensure_self_access(user_id, current_user)
db_obj = UserPushEndpoint(
user_id=user_id,
channel_type=payload.channel_type.upper().strip(),
channel_account=payload.channel_account.strip(),
is_active=payload.is_active,
priority_level=payload.priority_level,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This channel type already exists for the user",
)
db.refresh(db_obj)
return db_obj
@router.patch(
"/users/{user_id}/push-endpoints/{endpoint_id}",
response_model=PushEndpointResponse,
)
def update_push_endpoint(
user_id: int,
endpoint_id: int,
payload: PushEndpointUpdate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""更新一个推送渠道配置。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserPushEndpoint)
.filter(
UserPushEndpoint.id == endpoint_id,
UserPushEndpoint.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
if payload.channel_account is not None:
db_obj.channel_account = payload.channel_account.strip()
if payload.is_active is not None:
db_obj.is_active = payload.is_active
if payload.priority_level is not None:
db_obj.priority_level = payload.priority_level
db.commit()
db.refresh(db_obj)
return db_obj
@router.delete(
"/users/{user_id}/push-endpoints/{endpoint_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_push_endpoint(
user_id: int,
endpoint_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一个推送渠道。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserPushEndpoint)
.filter(
UserPushEndpoint.id == endpoint_id,
UserPushEndpoint.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
db.delete(db_obj)
db.commit()
return None