mirror of
https://github.com/stardrophere/InsightRadar.git
synced 2026-06-06 01:57:51 +08:00
356 lines
10 KiB
Python
356 lines
10 KiB
Python
# 推送设置 API:管理用户的推送时间表和推送渠道
|
|
# 关键约束:同一用户两条推送时间间隔至少 30 分钟
|
|
from datetime import time as dt_time
|
|
from typing import List
|
|
|
|
from fastapi import APIRouter, Depends, HTTPException, status
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.api.dependencies import get_current_user, get_db
|
|
from app.models.models import AppUser, UserDeliverySchedule, UserPushEndpoint
|
|
from app.schemas.delivery_schema import (
|
|
DeliveryScheduleCreate,
|
|
DeliveryScheduleResponse,
|
|
DeliveryScheduleUpdate,
|
|
PushEndpointCreate,
|
|
PushEndpointResponse,
|
|
PushEndpointUpdate,
|
|
UserDeliveryConfigResponse,
|
|
)
|
|
|
|
router = APIRouter()
|
|
|
|
# 两条推送时间之间的最小间隔(分钟)
|
|
MIN_SCHEDULE_GAP_MINUTES = 30
|
|
|
|
|
|
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
|
|
"""校验路径 user_id 是否为当前登录用户本人。"""
|
|
if path_user_id != current_user.id:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_403_FORBIDDEN,
|
|
detail="You can only operate your own resources",
|
|
)
|
|
|
|
|
|
def _parse_time(time_str: str) -> dt_time:
|
|
"""将 HH:MM 字符串解析为 time 对象"""
|
|
try:
|
|
parts = time_str.split(":")
|
|
return dt_time(hour=int(parts[0]), minute=int(parts[1]))
|
|
except (ValueError, IndexError):
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail="Invalid time format, expected HH:MM",
|
|
)
|
|
|
|
|
|
def _format_time(t: dt_time) -> str:
|
|
"""将 time 对象格式化为 HH:MM 字符串"""
|
|
return t.strftime("%H:%M")
|
|
|
|
|
|
def _time_to_minutes(t: dt_time) -> int:
|
|
return t.hour * 60 + t.minute
|
|
|
|
|
|
def _check_min_gap(
|
|
db: Session,
|
|
user_id: int,
|
|
new_time: dt_time,
|
|
exclude_id: int | None = None,
|
|
) -> None:
|
|
"""
|
|
校验新时间与用户已有的所有推送时间之间是否满足最小间隔要求(30 分钟)。
|
|
不满足时直接抛出 400 异常。
|
|
"""
|
|
query = db.query(UserDeliverySchedule).filter(
|
|
UserDeliverySchedule.user_id == user_id
|
|
)
|
|
if exclude_id is not None:
|
|
query = query.filter(UserDeliverySchedule.id != exclude_id)
|
|
|
|
existing = query.all()
|
|
new_minutes = _time_to_minutes(new_time)
|
|
|
|
# 考虑跨午夜情况:如 23:50 与 00:10 实际只差 20 分钟
|
|
for s in existing:
|
|
old_minutes = _time_to_minutes(s.delivery_time)
|
|
diff = abs(new_minutes - old_minutes)
|
|
circular_diff = min(diff, 1440 - diff)
|
|
if circular_diff < MIN_SCHEDULE_GAP_MINUTES:
|
|
raise HTTPException(
|
|
status_code=status.HTTP_400_BAD_REQUEST,
|
|
detail=f"推送时间间隔不能少于 {MIN_SCHEDULE_GAP_MINUTES} 分钟,"
|
|
f"与已有的 {_format_time(s.delivery_time)} 冲突",
|
|
)
|
|
|
|
|
|
# ==========================================
|
|
# 聚合查询:一次性返回用户全部推送配置
|
|
# ==========================================
|
|
@router.get(
|
|
"/users/{user_id}/delivery-config",
|
|
response_model=UserDeliveryConfigResponse,
|
|
)
|
|
def get_delivery_config(
|
|
user_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: AppUser = Depends(get_current_user),
|
|
):
|
|
"""获取用户的完整推送配置(时间表 + 渠道)。"""
|
|
_ensure_self_access(user_id, current_user)
|
|
|
|
schedules = (
|
|
db.query(UserDeliverySchedule)
|
|
.filter(UserDeliverySchedule.user_id == user_id)
|
|
.order_by(UserDeliverySchedule.delivery_time.asc())
|
|
.all()
|
|
)
|
|
endpoints = (
|
|
db.query(UserPushEndpoint)
|
|
.filter(UserPushEndpoint.user_id == user_id)
|
|
.order_by(UserPushEndpoint.priority_level.asc())
|
|
.all()
|
|
)
|
|
|
|
# 手动转换 time 字段为字符串
|
|
schedule_list = [
|
|
DeliveryScheduleResponse(
|
|
id=s.id,
|
|
user_id=s.user_id,
|
|
delivery_time=_format_time(s.delivery_time),
|
|
is_active=s.is_active,
|
|
created_at=s.created_at,
|
|
)
|
|
for s in schedules
|
|
]
|
|
|
|
return UserDeliveryConfigResponse(schedules=schedule_list, endpoints=endpoints)
|
|
|
|
|
|
# ==========================================
|
|
# 推送时间表 CRUD
|
|
# ==========================================
|
|
@router.post(
|
|
"/users/{user_id}/delivery-schedules",
|
|
response_model=DeliveryScheduleResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
def create_delivery_schedule(
|
|
user_id: int,
|
|
payload: DeliveryScheduleCreate,
|
|
db: Session = Depends(get_db),
|
|
current_user: AppUser = Depends(get_current_user),
|
|
):
|
|
"""新增一条推送时间。"""
|
|
_ensure_self_access(user_id, current_user)
|
|
|
|
parsed_time = _parse_time(payload.delivery_time)
|
|
_check_min_gap(db, user_id, parsed_time) # 校验与已有时间间隔
|
|
db_obj = UserDeliverySchedule(
|
|
user_id=user_id,
|
|
delivery_time=parsed_time,
|
|
is_active=payload.is_active,
|
|
)
|
|
db.add(db_obj)
|
|
try:
|
|
db.commit()
|
|
except IntegrityError:
|
|
db.rollback()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="This delivery time already exists",
|
|
)
|
|
db.refresh(db_obj)
|
|
return DeliveryScheduleResponse(
|
|
id=db_obj.id,
|
|
user_id=db_obj.user_id,
|
|
delivery_time=_format_time(db_obj.delivery_time),
|
|
is_active=db_obj.is_active,
|
|
created_at=db_obj.created_at,
|
|
)
|
|
|
|
|
|
@router.patch(
|
|
"/users/{user_id}/delivery-schedules/{schedule_id}",
|
|
response_model=DeliveryScheduleResponse,
|
|
)
|
|
def update_delivery_schedule(
|
|
user_id: int,
|
|
schedule_id: int,
|
|
payload: DeliveryScheduleUpdate,
|
|
db: Session = Depends(get_db),
|
|
current_user: AppUser = Depends(get_current_user),
|
|
):
|
|
"""更新一条推送时间。"""
|
|
_ensure_self_access(user_id, current_user)
|
|
|
|
db_obj = (
|
|
db.query(UserDeliverySchedule)
|
|
.filter(
|
|
UserDeliverySchedule.id == schedule_id,
|
|
UserDeliverySchedule.user_id == user_id,
|
|
)
|
|
.first()
|
|
)
|
|
if not db_obj:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
|
|
|
|
if payload.delivery_time is not None:
|
|
new_time = _parse_time(payload.delivery_time)
|
|
_check_min_gap(db, user_id, new_time, exclude_id=schedule_id)
|
|
db_obj.delivery_time = new_time
|
|
if payload.is_active is not None:
|
|
db_obj.is_active = payload.is_active
|
|
|
|
try:
|
|
db.commit()
|
|
except IntegrityError:
|
|
db.rollback()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="This delivery time already exists",
|
|
)
|
|
db.refresh(db_obj)
|
|
return DeliveryScheduleResponse(
|
|
id=db_obj.id,
|
|
user_id=db_obj.user_id,
|
|
delivery_time=_format_time(db_obj.delivery_time),
|
|
is_active=db_obj.is_active,
|
|
created_at=db_obj.created_at,
|
|
)
|
|
|
|
|
|
@router.delete(
|
|
"/users/{user_id}/delivery-schedules/{schedule_id}",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
)
|
|
def delete_delivery_schedule(
|
|
user_id: int,
|
|
schedule_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: AppUser = Depends(get_current_user),
|
|
):
|
|
"""删除一条推送时间。"""
|
|
_ensure_self_access(user_id, current_user)
|
|
|
|
db_obj = (
|
|
db.query(UserDeliverySchedule)
|
|
.filter(
|
|
UserDeliverySchedule.id == schedule_id,
|
|
UserDeliverySchedule.user_id == user_id,
|
|
)
|
|
.first()
|
|
)
|
|
if not db_obj:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
|
|
|
|
db.delete(db_obj)
|
|
db.commit()
|
|
return None
|
|
|
|
|
|
# ==========================================
|
|
# 推送渠道 CRUD
|
|
# ==========================================
|
|
@router.post(
|
|
"/users/{user_id}/push-endpoints",
|
|
response_model=PushEndpointResponse,
|
|
status_code=status.HTTP_201_CREATED,
|
|
)
|
|
def create_push_endpoint(
|
|
user_id: int,
|
|
payload: PushEndpointCreate,
|
|
db: Session = Depends(get_db),
|
|
current_user: AppUser = Depends(get_current_user),
|
|
):
|
|
"""新增一个推送渠道。"""
|
|
_ensure_self_access(user_id, current_user)
|
|
|
|
db_obj = UserPushEndpoint(
|
|
user_id=user_id,
|
|
channel_type=payload.channel_type.upper().strip(),
|
|
channel_account=payload.channel_account.strip(),
|
|
is_active=payload.is_active,
|
|
priority_level=payload.priority_level,
|
|
)
|
|
db.add(db_obj)
|
|
try:
|
|
db.commit()
|
|
except IntegrityError:
|
|
db.rollback()
|
|
raise HTTPException(
|
|
status_code=status.HTTP_409_CONFLICT,
|
|
detail="This channel type already exists for the user",
|
|
)
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
|
|
@router.patch(
|
|
"/users/{user_id}/push-endpoints/{endpoint_id}",
|
|
response_model=PushEndpointResponse,
|
|
)
|
|
def update_push_endpoint(
|
|
user_id: int,
|
|
endpoint_id: int,
|
|
payload: PushEndpointUpdate,
|
|
db: Session = Depends(get_db),
|
|
current_user: AppUser = Depends(get_current_user),
|
|
):
|
|
"""更新一个推送渠道配置。"""
|
|
_ensure_self_access(user_id, current_user)
|
|
|
|
db_obj = (
|
|
db.query(UserPushEndpoint)
|
|
.filter(
|
|
UserPushEndpoint.id == endpoint_id,
|
|
UserPushEndpoint.user_id == user_id,
|
|
)
|
|
.first()
|
|
)
|
|
if not db_obj:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
|
|
|
|
if payload.channel_account is not None:
|
|
db_obj.channel_account = payload.channel_account.strip()
|
|
if payload.is_active is not None:
|
|
db_obj.is_active = payload.is_active
|
|
if payload.priority_level is not None:
|
|
db_obj.priority_level = payload.priority_level
|
|
|
|
db.commit()
|
|
db.refresh(db_obj)
|
|
return db_obj
|
|
|
|
|
|
@router.delete(
|
|
"/users/{user_id}/push-endpoints/{endpoint_id}",
|
|
status_code=status.HTTP_204_NO_CONTENT,
|
|
)
|
|
def delete_push_endpoint(
|
|
user_id: int,
|
|
endpoint_id: int,
|
|
db: Session = Depends(get_db),
|
|
current_user: AppUser = Depends(get_current_user),
|
|
):
|
|
"""删除一个推送渠道。"""
|
|
_ensure_self_access(user_id, current_user)
|
|
|
|
db_obj = (
|
|
db.query(UserPushEndpoint)
|
|
.filter(
|
|
UserPushEndpoint.id == endpoint_id,
|
|
UserPushEndpoint.user_id == user_id,
|
|
)
|
|
.first()
|
|
)
|
|
if not db_obj:
|
|
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
|
|
|
|
db.delete(db_obj)
|
|
db.commit()
|
|
return None
|