big update

This commit is contained in:
stardrophere
2026-03-11 20:52:58 +08:00
parent 8ed819a580
commit 966bcfbba4
44 changed files with 7124 additions and 650 deletions
+45 -2
View File
@@ -1,5 +1,12 @@
# app/api/dependencies.py
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.core.security import decode_access_token
from app.database import SessionLocal
from app.models.models import AppUser
bearer_scheme = HTTPBearer(auto_error=False)
def get_db():
"""
@@ -10,4 +17,40 @@ def get_db():
try:
yield db
finally:
db.close()
db.close()
def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
db: Session = Depends(get_db),
) -> AppUser:
"""
从 Bearer Token 中解析并返回当前登录用户。
要求:
1. 必须携带 Authorization: Bearer <token>
2. token 验签通过且未过期
3. 用户在数据库中存在
"""
if credentials is None or credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication credentials were not provided",
)
token = credentials.credentials
try:
user_id, email = decode_access_token(token)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
user = db.query(AppUser).filter(AppUser.id == user_id).first()
if not user or user.email != email:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token user",
)
return user
+51 -3
View File
@@ -1,5 +1,6 @@
import math
import os
from datetime import timedelta
from datetime import timedelta, timezone
from typing import Tuple
from fastapi import APIRouter, Depends, HTTPException, status
@@ -30,8 +31,18 @@ from app.utils.email_utils import send_html_email
router = APIRouter()
REGISTER_CODE_EXPIRE_MINUTES = int(os.getenv("REGISTER_CODE_EXPIRE_MINUTES", "10"))
LOGIN_CODE_EXPIRE_MINUTES = int(os.getenv("LOGIN_CODE_EXPIRE_MINUTES", "10"))
DEFAULT_REGISTER_CODE_EXPIRE_MINUTES = 10
DEFAULT_LOGIN_CODE_EXPIRE_MINUTES = 10
DEFAULT_CODE_SEND_COOLDOWN_SECONDS = 60
REGISTER_CODE_EXPIRE_MINUTES = int(
os.getenv("REGISTER_CODE_EXPIRE_MINUTES", str(DEFAULT_REGISTER_CODE_EXPIRE_MINUTES))
)
LOGIN_CODE_EXPIRE_MINUTES = int(
os.getenv("LOGIN_CODE_EXPIRE_MINUTES", str(DEFAULT_LOGIN_CODE_EXPIRE_MINUTES))
)
CODE_SEND_COOLDOWN_SECONDS = int(
os.getenv("CODE_SEND_COOLDOWN_SECONDS", str(DEFAULT_CODE_SEND_COOLDOWN_SECONDS))
)
def _normalize_email(email: str) -> str:
@@ -78,6 +89,41 @@ def _create_code_record(
return code_record, code
def _enforce_code_send_cooldown(db: Session, email: str, purpose: VerificationPurpose) -> None:
"""
防抖:限制同一邮箱同一用途验证码的发送频率,避免用户短时间连续点击。
"""
if CODE_SEND_COOLDOWN_SECONDS <= 0:
return
latest_record = (
db.query(EmailVerificationCode)
.filter(
EmailVerificationCode.email == email,
EmailVerificationCode.purpose == purpose,
)
.order_by(EmailVerificationCode.created_at.desc())
.first()
)
if not latest_record:
return
now = utcnow()
record_time = latest_record.created_at
if record_time.tzinfo is None:
record_time = record_time.replace(tzinfo=timezone.utc)
elapsed_seconds = (now - record_time).total_seconds()
if elapsed_seconds >= CODE_SEND_COOLDOWN_SECONDS:
return
retry_after_seconds = max(1, math.ceil(CODE_SEND_COOLDOWN_SECONDS - elapsed_seconds))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Please wait {retry_after_seconds}s before requesting another verification code",
headers={"Retry-After": str(retry_after_seconds)},
)
def _build_auth_response(user: AppUser) -> AuthTokenResponse:
token, expires_in = create_access_token(user_id=user.id, email=user.email)
return AuthTokenResponse(
@@ -95,6 +141,7 @@ async def send_register_code(payload: RegisterCodeSendRequest, db: Session = Dep
if existing_user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email is already registered")
_enforce_code_send_cooldown(db, email, VerificationPurpose.REGISTER)
_invalidate_unused_codes(db, email, VerificationPurpose.REGISTER)
code_record, code = _create_code_record(
db,
@@ -128,6 +175,7 @@ async def send_login_code(payload: LoginCodeSendRequest, db: Session = Depends(g
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Email is not registered")
_enforce_code_send_cooldown(db, email, VerificationPurpose.LOGIN)
_invalidate_unused_codes(db, email, VerificationPurpose.LOGIN)
code_record, code = _create_code_record(
db,
+353
View File
@@ -0,0 +1,353 @@
# 推送设置 API:管理用户的推送时间表和推送渠道
from datetime import time as dt_time
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.api.dependencies import get_current_user, get_db
from app.models.models import AppUser, UserDeliverySchedule, UserPushEndpoint
from app.schemas.delivery_schema import (
DeliveryScheduleCreate,
DeliveryScheduleResponse,
DeliveryScheduleUpdate,
PushEndpointCreate,
PushEndpointResponse,
PushEndpointUpdate,
UserDeliveryConfigResponse,
)
router = APIRouter()
# 两条推送时间之间的最小间隔(分钟)
MIN_SCHEDULE_GAP_MINUTES = 30
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
"""校验路径 user_id 是否为当前登录用户本人。"""
if path_user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only operate your own resources",
)
def _parse_time(time_str: str) -> dt_time:
"""将 HH:MM 字符串解析为 time 对象"""
try:
parts = time_str.split(":")
return dt_time(hour=int(parts[0]), minute=int(parts[1]))
except (ValueError, IndexError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid time format, expected HH:MM",
)
def _format_time(t: dt_time) -> str:
"""将 time 对象格式化为 HH:MM 字符串"""
return t.strftime("%H:%M")
def _time_to_minutes(t: dt_time) -> int:
return t.hour * 60 + t.minute
def _check_min_gap(
db: Session,
user_id: int,
new_time: dt_time,
exclude_id: int | None = None,
) -> None:
"""
校验新时间与用户已有的所有推送时间之间是否满足最小间隔要求(30 分钟)。
不满足时直接抛出 400 异常。
"""
query = db.query(UserDeliverySchedule).filter(
UserDeliverySchedule.user_id == user_id
)
if exclude_id is not None:
query = query.filter(UserDeliverySchedule.id != exclude_id)
existing = query.all()
new_minutes = _time_to_minutes(new_time)
for s in existing:
old_minutes = _time_to_minutes(s.delivery_time)
diff = abs(new_minutes - old_minutes)
circular_diff = min(diff, 1440 - diff)
if circular_diff < MIN_SCHEDULE_GAP_MINUTES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"推送时间间隔不能少于 {MIN_SCHEDULE_GAP_MINUTES} 分钟,"
f"与已有的 {_format_time(s.delivery_time)} 冲突",
)
# ==========================================
# 聚合查询:一次性返回用户全部推送配置
# ==========================================
@router.get(
"/users/{user_id}/delivery-config",
response_model=UserDeliveryConfigResponse,
)
def get_delivery_config(
user_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""获取用户的完整推送配置(时间表 + 渠道)。"""
_ensure_self_access(user_id, current_user)
schedules = (
db.query(UserDeliverySchedule)
.filter(UserDeliverySchedule.user_id == user_id)
.order_by(UserDeliverySchedule.delivery_time.asc())
.all()
)
endpoints = (
db.query(UserPushEndpoint)
.filter(UserPushEndpoint.user_id == user_id)
.order_by(UserPushEndpoint.priority_level.asc())
.all()
)
# 手动转换 time 字段为字符串
schedule_list = [
DeliveryScheduleResponse(
id=s.id,
user_id=s.user_id,
delivery_time=_format_time(s.delivery_time),
is_active=s.is_active,
created_at=s.created_at,
)
for s in schedules
]
return UserDeliveryConfigResponse(schedules=schedule_list, endpoints=endpoints)
# ==========================================
# 推送时间表 CRUD
# ==========================================
@router.post(
"/users/{user_id}/delivery-schedules",
response_model=DeliveryScheduleResponse,
status_code=status.HTTP_201_CREATED,
)
def create_delivery_schedule(
user_id: int,
payload: DeliveryScheduleCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一条推送时间。"""
_ensure_self_access(user_id, current_user)
parsed_time = _parse_time(payload.delivery_time)
_check_min_gap(db, user_id, parsed_time)
db_obj = UserDeliverySchedule(
user_id=user_id,
delivery_time=parsed_time,
is_active=payload.is_active,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This delivery time already exists",
)
db.refresh(db_obj)
return DeliveryScheduleResponse(
id=db_obj.id,
user_id=db_obj.user_id,
delivery_time=_format_time(db_obj.delivery_time),
is_active=db_obj.is_active,
created_at=db_obj.created_at,
)
@router.patch(
"/users/{user_id}/delivery-schedules/{schedule_id}",
response_model=DeliveryScheduleResponse,
)
def update_delivery_schedule(
user_id: int,
schedule_id: int,
payload: DeliveryScheduleUpdate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""更新一条推送时间。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserDeliverySchedule)
.filter(
UserDeliverySchedule.id == schedule_id,
UserDeliverySchedule.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
if payload.delivery_time is not None:
new_time = _parse_time(payload.delivery_time)
_check_min_gap(db, user_id, new_time, exclude_id=schedule_id)
db_obj.delivery_time = new_time
if payload.is_active is not None:
db_obj.is_active = payload.is_active
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This delivery time already exists",
)
db.refresh(db_obj)
return DeliveryScheduleResponse(
id=db_obj.id,
user_id=db_obj.user_id,
delivery_time=_format_time(db_obj.delivery_time),
is_active=db_obj.is_active,
created_at=db_obj.created_at,
)
@router.delete(
"/users/{user_id}/delivery-schedules/{schedule_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_delivery_schedule(
user_id: int,
schedule_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一条推送时间。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserDeliverySchedule)
.filter(
UserDeliverySchedule.id == schedule_id,
UserDeliverySchedule.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
db.delete(db_obj)
db.commit()
return None
# ==========================================
# 推送渠道 CRUD
# ==========================================
@router.post(
"/users/{user_id}/push-endpoints",
response_model=PushEndpointResponse,
status_code=status.HTTP_201_CREATED,
)
def create_push_endpoint(
user_id: int,
payload: PushEndpointCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一个推送渠道。"""
_ensure_self_access(user_id, current_user)
db_obj = UserPushEndpoint(
user_id=user_id,
channel_type=payload.channel_type.upper().strip(),
channel_account=payload.channel_account.strip(),
is_active=payload.is_active,
priority_level=payload.priority_level,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This channel type already exists for the user",
)
db.refresh(db_obj)
return db_obj
@router.patch(
"/users/{user_id}/push-endpoints/{endpoint_id}",
response_model=PushEndpointResponse,
)
def update_push_endpoint(
user_id: int,
endpoint_id: int,
payload: PushEndpointUpdate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""更新一个推送渠道配置。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserPushEndpoint)
.filter(
UserPushEndpoint.id == endpoint_id,
UserPushEndpoint.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
if payload.channel_account is not None:
db_obj.channel_account = payload.channel_account.strip()
if payload.is_active is not None:
db_obj.is_active = payload.is_active
if payload.priority_level is not None:
db_obj.priority_level = payload.priority_level
db.commit()
db.refresh(db_obj)
return db_obj
@router.delete(
"/users/{user_id}/push-endpoints/{endpoint_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_push_endpoint(
user_id: int,
endpoint_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一个推送渠道。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserPushEndpoint)
.filter(
UserPushEndpoint.id == endpoint_id,
UserPushEndpoint.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
db.delete(db_obj)
db.commit()
return None
+192 -46
View File
@@ -1,69 +1,215 @@
# app/api/endpoints/events.py
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from datetime import timedelta
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.api.dependencies import get_db
from app.models.models import UnifiedEvent, TrendingEvent, InfoSource, RankingLog, utcnow
# 导入你上传的 Schema
from app.schemas.event_schema import UnifiedEventResponse, PlatformTrendResponse
from app.models.models import (
ExtractedTopic,
InfoSource,
RankingLog,
TargetType,
TrendingEvent,
UnifiedEvent,
utcnow,
)
from app.schemas.event_schema import (
PaginatedUnifiedEventResponse,
PlatformTrendResponse,
UnifiedEventResponse,
)
router = APIRouter()
# 排名轨迹最多返回多少个点,避免长时间跨度下数据过大
MAX_RANKING_POINTS = 30
@router.get("/unified", response_model=List[UnifiedEventResponse])
@router.get("/unified", response_model=PaginatedUnifiedEventResponse)
def list_unified_events(
min_hot: int = Query(5, description="热度过滤阈值"),
hours: int = Query(24, description="查询过去 X 小时的数据"),
db: Session = Depends(get_db)
min_hot: int = Query(5, ge=0, description="热度阈值,仅返回 hot_score >= 此值的事件"),
hours: int = Query(24, ge=1, le=720, description="查询最近多少小时的数据"),
skip: int = Query(0, ge=0, description="分页偏移量"),
limit: int = Query(10, ge=1, le=50, description="每页返回条数"),
db: Session = Depends(get_db),
):
"""
获取聚合大事件列表,完全适配前端 template.html 所需的数据结构
"""
# 计算时间水位线
"""分页返回统一事件,附带各平台热搜、排名轨迹和标签。"""
time_limit = utcnow() - timedelta(hours=hours)
# 1. 查询大事件(按热度降序,且满足时间范围)
events = db.query(UnifiedEvent).filter(
# 先查总数,用于前端判断是否还有更多
base_query = db.query(UnifiedEvent).filter(
UnifiedEvent.hot_score >= min_hot,
UnifiedEvent.created_at >= time_limit
).order_by(UnifiedEvent.hot_score.desc()).all()
UnifiedEvent.created_at >= time_limit,
)
total = base_query.count()
results = []
# 分页查询
events = (
base_query
.order_by(UnifiedEvent.hot_score.desc())
.offset(skip)
.limit(limit)
.all()
)
if not events:
return PaginatedUnifiedEventResponse(total=total, has_more=False, data=[])
event_ids = [ev.id for ev in events]
# 批量查询所有相关的热搜条目(避免 N+1)
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id.in_(event_ids))
.all()
)
# 按 unified_event_id 分组
trend_map: dict[int, list[tuple]] = {}
trend_ids: list[int] = []
for trend, source_name in trend_rows:
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
trend_ids.append(trend.id)
# 批量查询排名日志(避免逐条查询)
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(
RankingLog.event_id,
RankingLog.ranking_position,
)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for event_id, position in ranking_rows:
ranking_map.setdefault(event_id, []).append(position)
# 批量查询标签
tag_map: dict[int, list[str]] = {}
tag_rows = (
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id.in_(event_ids),
)
.order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc())
.all()
)
for target_id, keyword in tag_rows:
tag_map.setdefault(target_id, []).append(keyword)
# 组装响应
results: list[UnifiedEventResponse] = []
for ev in events:
# 2. 联表查询:获取该大事件下关联的所有平台及其具体热搜信息
trends = db.query(TrendingEvent, InfoSource.source_name).join(
InfoSource, TrendingEvent.source_id == InfoSource.id
).filter(TrendingEvent.unified_event_id == ev.id).all()
platform_list: list[PlatformTrendResponse] = []
for trend, source_name in trend_map.get(ev.id, []):
history = ranking_map.get(trend.id, [])
# 截取尾部,只保留最近的点
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list = []
for trend, s_name in trends:
# 3. 获取排名历史轨迹 (用于前端渲染)
# 这里的排序顺序 asc 保证了数组从旧到新
logs = db.query(RankingLog.ranking_position).filter(
RankingLog.event_id == trend.id,
RankingLog.observed_at >= time_limit
).order_by(RankingLog.observed_at.asc()).all()
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=history,
)
)
# 组装符合 PlatformTrendResponse 结构的字典
platform_list.append(PlatformTrendResponse(
results.append(
UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
platforms=platform_list,
tags=tag_map.get(ev.id, []),
)
)
has_more = (skip + limit) < total
return PaginatedUnifiedEventResponse(total=total, has_more=has_more, data=results)
@router.get("/unified/{event_id}", response_model=UnifiedEventResponse)
def get_unified_event(
event_id: int,
db: Session = Depends(get_db),
):
"""按 ID 查询单个统一事件,用于推荐跳转时的聚光灯展示。"""
ev = db.query(UnifiedEvent).filter(UnifiedEvent.id == event_id).first()
if not ev:
raise HTTPException(status_code=404, detail="Event not found")
time_limit = utcnow() - timedelta(hours=720)
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id == event_id)
.all()
)
trend_ids = [t.id for t, _ in trend_rows]
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(RankingLog.event_id, RankingLog.ranking_position)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for eid, pos in ranking_rows:
ranking_map.setdefault(eid, []).append(pos)
tag_rows = (
db.query(ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id == event_id,
)
.order_by(ExtractedTopic.relevance_score.desc())
.all()
)
tags = [row[0] for row in tag_rows]
platform_list: list[PlatformTrendResponse] = []
for trend, source_name in trend_rows:
history = ranking_map.get(trend.id, [])
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=s_name,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=[log[0] for log in logs]
))
ranking_history=history,
)
)
# 4. 组装符合 UnifiedEventResponse 结构的字典
results.append(UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
platforms=platform_list
))
return results
return UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
platforms=platform_list,
tags=tags,
)
+158
View File
@@ -0,0 +1,158 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.api.dependencies import get_current_user, get_db
from app.models.models import AppUser, UserTopicPreference
from app.schemas.preference_schema import (
MatchedEventResponse,
UserPreferenceRecommendationResponse,
UserTopicPreferenceCreate,
UserTopicPreferenceResponse,
)
from app.services.matching_service import recommend_events_for_user
router = APIRouter()
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
"""校验路径 user_id 是否为当前登录用户本人。"""
if path_user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only operate your own resources",
)
@router.get(
"/users/{user_id}/preferences",
response_model=List[UserTopicPreferenceResponse],
)
def list_user_preferences(
user_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""获取用户已设置的兴趣关键词。"""
_ensure_self_access(user_id, current_user)
preferences = (
db.query(UserTopicPreference)
.filter(UserTopicPreference.user_id == user_id)
.order_by(UserTopicPreference.created_at.desc())
.all()
)
return preferences
@router.post(
"/users/{user_id}/preferences",
response_model=UserTopicPreferenceResponse,
status_code=status.HTTP_201_CREATED,
)
def create_user_preference(
user_id: int,
payload: UserTopicPreferenceCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一个用户兴趣关键词。"""
_ensure_self_access(user_id, current_user)
keyword = payload.interested_keyword.strip()
if not keyword:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Keyword cannot be empty")
db_obj = UserTopicPreference(
user_id=user_id,
interested_keyword=keyword,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Preference keyword already exists for this user",
)
db.refresh(db_obj)
return db_obj
@router.delete(
"/users/{user_id}/preferences/{preference_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_user_preference(
user_id: int,
preference_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一个用户兴趣关键词。"""
_ensure_self_access(user_id, current_user)
preference = (
db.query(UserTopicPreference)
.filter(
UserTopicPreference.id == preference_id,
UserTopicPreference.user_id == user_id,
)
.first()
)
if not preference:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Preference not found")
db.delete(preference)
db.commit()
return None
@router.get(
"/users/{user_id}/recommended-events",
response_model=UserPreferenceRecommendationResponse,
)
def recommend_events(
user_id: int,
min_hot: int = Query(3, ge=1, description="最小热度阈值"),
hours: int = Query(72, ge=1, le=24 * 30, description="仅匹配最近多少小时的事件"),
limit: int = Query(20, ge=1, le=50, description="最多返回多少条推荐"),
semantic_threshold: float = Query(0.78, ge=0.0, le=1.0, description="语义匹配相似度阈值"),
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""基于用户兴趣词推荐事件(精确匹配 + 语义匹配)。"""
_ensure_self_access(user_id, current_user)
matched = recommend_events_for_user(
db,
user_id=user_id,
min_hot=min_hot,
hours=hours,
limit=limit,
semantic_threshold=semantic_threshold,
)
result_data: list[MatchedEventResponse] = []
for item in matched:
result_data.append(
MatchedEventResponse(
event_id=item.event.id,
unified_title=item.event.unified_title,
summary=item.event.ai_comprehensive_summary,
hot_score=item.event.hot_score,
created_at=item.event.created_at,
tags=item.tags,
match_score=item.match_score,
exact_hits=item.exact_hits,
semantic_hits=item.semantic_hits,
)
)
return UserPreferenceRecommendationResponse(
user_id=user_id,
total=len(result_data),
data=result_data,
)
+75
View File
@@ -0,0 +1,75 @@
# 公关修改追踪 API:查询热搜标题被偷偷修改的历史记录
from datetime import timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.api.dependencies import get_db
from app.models.models import HeadlineRevision, InfoSource, TrendingEvent, utcnow
from pydantic import BaseModel, ConfigDict
from datetime import datetime
router = APIRouter()
class HeadlineRevisionResponse(BaseModel):
"""标题修改记录响应体"""
id: int
event_id: int
previous_headline: str
revised_headline: str
source_name: Optional[str] = None
platform_icon: Optional[str] = None
created_at: datetime
model_config = ConfigDict(from_attributes=True)
@router.get("/headline-revisions", response_model=List[HeadlineRevisionResponse])
def list_headline_revisions(
hours: int = Query(48, ge=1, le=720, description="查询最近多少小时内的修改记录"),
limit: int = Query(50, ge=1, le=500, description="最多返回条数"),
db: Session = Depends(get_db),
):
"""
获取最近的标题修改记录列表。
用于公关监测:发现哪些平台偷偷改了热搜标题。
"""
time_limit = utcnow() - timedelta(hours=hours)
rows = (
db.query(HeadlineRevision, InfoSource.source_name)
.join(TrendingEvent, HeadlineRevision.event_id == TrendingEvent.id)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(HeadlineRevision.created_at >= time_limit)
.order_by(HeadlineRevision.created_at.desc())
.limit(limit)
.all()
)
# 平台名到图标的简单映射
icon_map = {
"微博热搜": "weibo",
"知乎热榜": "zhihu",
"百度热搜": "baidu",
"今日头条": "toutiao",
"抖音热榜": "douyin",
"B站热搜": "bilibili",
}
results: list[HeadlineRevisionResponse] = []
for revision, source_name in rows:
results.append(
HeadlineRevisionResponse(
id=revision.id,
event_id=revision.event_id,
previous_headline=revision.previous_headline,
revised_headline=revision.revised_headline,
source_name=source_name,
platform_icon=icon_map.get(source_name, "newspaper"),
created_at=revision.created_at,
)
)
return results
+65
View File
@@ -0,0 +1,65 @@
# 系统状态监控 API:返回爬虫集群运行概况
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.api.dependencies import get_db
from app.models.models import DataSyncTask, InfoSource, TaskStatus, utcnow
router = APIRouter()
class SystemStatsResponse(BaseModel):
"""系统运行状态汇总"""
active_sources: int
total_sources: int
items_today: int
success_tasks_today: int
error_tasks_today: int
last_sync_at: Optional[datetime] = None
@router.get("/system/stats", response_model=SystemStatsResponse)
def get_system_stats(db: Session = Depends(get_db)):
"""获取爬虫集群的当日运行状态。"""
today_start = utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
# 信息源统计
total_sources = db.query(func.count(InfoSource.id)).scalar() or 0
active_sources = (
db.query(func.count(InfoSource.id))
.filter(InfoSource.is_enabled.is_(True))
.scalar() or 0
)
# 今日任务统计
today_tasks = (
db.query(DataSyncTask)
.filter(DataSyncTask.created_at >= today_start)
.all()
)
items_today = sum(t.items_fetched for t in today_tasks)
success_count = sum(1 for t in today_tasks if t.task_status == TaskStatus.SUCCESS)
error_count = sum(1 for t in today_tasks if t.task_status == TaskStatus.ERROR)
# 最后一次同步时间
last_task = (
db.query(DataSyncTask)
.filter(DataSyncTask.task_status == TaskStatus.SUCCESS)
.order_by(DataSyncTask.created_at.desc())
.first()
)
return SystemStatsResponse(
active_sources=active_sources,
total_sources=total_sources,
items_today=items_today,
success_tasks_today=success_count,
error_tasks_today=error_count,
last_sync_at=last_task.created_at if last_task else None,
)
+15 -1
View File
@@ -1,6 +1,6 @@
# app/api/router.py
from fastapi import APIRouter
from app.api.endpoints import auth, sources, events
from app.api.endpoints import auth, delivery, events, preferences, revisions, sources, stats
api_router = APIRouter()
@@ -9,4 +9,18 @@ api_router.include_router(sources.router, prefix="/sources", tags=["信息源管
# 注册大事件相关的路由
api_router.include_router(events.router, prefix="/events", tags=["Unified Events"])
# 认证
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
# 用户偏好(关键词订阅)
api_router.include_router(preferences.router, tags=["User Preferences"])
# 推送设置(时间表 + 渠道)
api_router.include_router(delivery.router, tags=["Delivery Settings"])
# 公关修改追踪
api_router.include_router(revisions.router, prefix="/events", tags=["Headline Revisions"])
# 系统状态监控
api_router.include_router(stats.router, tags=["System Stats"])