mirror of
https://github.com/stardrophere/InsightRadar.git
synced 2026-06-05 23:07:51 +08:00
big update
This commit is contained in:
@@ -1,5 +1,12 @@
|
||||
# app/api/dependencies.py
|
||||
from fastapi import Depends, HTTPException, status
|
||||
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.core.security import decode_access_token
|
||||
from app.database import SessionLocal
|
||||
from app.models.models import AppUser
|
||||
|
||||
bearer_scheme = HTTPBearer(auto_error=False)
|
||||
|
||||
def get_db():
|
||||
"""
|
||||
@@ -10,4 +17,40 @@ def get_db():
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
db.close()
|
||||
|
||||
|
||||
def get_current_user(
|
||||
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
|
||||
db: Session = Depends(get_db),
|
||||
) -> AppUser:
|
||||
"""
|
||||
从 Bearer Token 中解析并返回当前登录用户。
|
||||
要求:
|
||||
1. 必须携带 Authorization: Bearer <token>
|
||||
2. token 验签通过且未过期
|
||||
3. 用户在数据库中存在
|
||||
"""
|
||||
if credentials is None or credentials.scheme.lower() != "bearer":
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Authentication credentials were not provided",
|
||||
)
|
||||
|
||||
token = credentials.credentials
|
||||
try:
|
||||
user_id, email = decode_access_token(token)
|
||||
except ValueError:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid or expired token",
|
||||
)
|
||||
|
||||
user = db.query(AppUser).filter(AppUser.id == user_id).first()
|
||||
if not user or user.email != email:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail="Invalid token user",
|
||||
)
|
||||
|
||||
return user
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
import math
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from datetime import timedelta, timezone
|
||||
from typing import Tuple
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
@@ -30,8 +31,18 @@ from app.utils.email_utils import send_html_email
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
REGISTER_CODE_EXPIRE_MINUTES = int(os.getenv("REGISTER_CODE_EXPIRE_MINUTES", "10"))
|
||||
LOGIN_CODE_EXPIRE_MINUTES = int(os.getenv("LOGIN_CODE_EXPIRE_MINUTES", "10"))
|
||||
DEFAULT_REGISTER_CODE_EXPIRE_MINUTES = 10
|
||||
DEFAULT_LOGIN_CODE_EXPIRE_MINUTES = 10
|
||||
DEFAULT_CODE_SEND_COOLDOWN_SECONDS = 60
|
||||
REGISTER_CODE_EXPIRE_MINUTES = int(
|
||||
os.getenv("REGISTER_CODE_EXPIRE_MINUTES", str(DEFAULT_REGISTER_CODE_EXPIRE_MINUTES))
|
||||
)
|
||||
LOGIN_CODE_EXPIRE_MINUTES = int(
|
||||
os.getenv("LOGIN_CODE_EXPIRE_MINUTES", str(DEFAULT_LOGIN_CODE_EXPIRE_MINUTES))
|
||||
)
|
||||
CODE_SEND_COOLDOWN_SECONDS = int(
|
||||
os.getenv("CODE_SEND_COOLDOWN_SECONDS", str(DEFAULT_CODE_SEND_COOLDOWN_SECONDS))
|
||||
)
|
||||
|
||||
|
||||
def _normalize_email(email: str) -> str:
|
||||
@@ -78,6 +89,41 @@ def _create_code_record(
|
||||
return code_record, code
|
||||
|
||||
|
||||
def _enforce_code_send_cooldown(db: Session, email: str, purpose: VerificationPurpose) -> None:
|
||||
"""
|
||||
防抖:限制同一邮箱同一用途验证码的发送频率,避免用户短时间连续点击。
|
||||
"""
|
||||
if CODE_SEND_COOLDOWN_SECONDS <= 0:
|
||||
return
|
||||
|
||||
latest_record = (
|
||||
db.query(EmailVerificationCode)
|
||||
.filter(
|
||||
EmailVerificationCode.email == email,
|
||||
EmailVerificationCode.purpose == purpose,
|
||||
)
|
||||
.order_by(EmailVerificationCode.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if not latest_record:
|
||||
return
|
||||
|
||||
now = utcnow()
|
||||
record_time = latest_record.created_at
|
||||
if record_time.tzinfo is None:
|
||||
record_time = record_time.replace(tzinfo=timezone.utc)
|
||||
elapsed_seconds = (now - record_time).total_seconds()
|
||||
if elapsed_seconds >= CODE_SEND_COOLDOWN_SECONDS:
|
||||
return
|
||||
|
||||
retry_after_seconds = max(1, math.ceil(CODE_SEND_COOLDOWN_SECONDS - elapsed_seconds))
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
|
||||
detail=f"Please wait {retry_after_seconds}s before requesting another verification code",
|
||||
headers={"Retry-After": str(retry_after_seconds)},
|
||||
)
|
||||
|
||||
|
||||
def _build_auth_response(user: AppUser) -> AuthTokenResponse:
|
||||
token, expires_in = create_access_token(user_id=user.id, email=user.email)
|
||||
return AuthTokenResponse(
|
||||
@@ -95,6 +141,7 @@ async def send_register_code(payload: RegisterCodeSendRequest, db: Session = Dep
|
||||
if existing_user:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email is already registered")
|
||||
|
||||
_enforce_code_send_cooldown(db, email, VerificationPurpose.REGISTER)
|
||||
_invalidate_unused_codes(db, email, VerificationPurpose.REGISTER)
|
||||
code_record, code = _create_code_record(
|
||||
db,
|
||||
@@ -128,6 +175,7 @@ async def send_login_code(payload: LoginCodeSendRequest, db: Session = Depends(g
|
||||
if not user:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Email is not registered")
|
||||
|
||||
_enforce_code_send_cooldown(db, email, VerificationPurpose.LOGIN)
|
||||
_invalidate_unused_codes(db, email, VerificationPurpose.LOGIN)
|
||||
code_record, code = _create_code_record(
|
||||
db,
|
||||
|
||||
@@ -0,0 +1,353 @@
|
||||
# 推送设置 API:管理用户的推送时间表和推送渠道
|
||||
from datetime import time as dt_time
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_current_user, get_db
|
||||
from app.models.models import AppUser, UserDeliverySchedule, UserPushEndpoint
|
||||
from app.schemas.delivery_schema import (
|
||||
DeliveryScheduleCreate,
|
||||
DeliveryScheduleResponse,
|
||||
DeliveryScheduleUpdate,
|
||||
PushEndpointCreate,
|
||||
PushEndpointResponse,
|
||||
PushEndpointUpdate,
|
||||
UserDeliveryConfigResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 两条推送时间之间的最小间隔(分钟)
|
||||
MIN_SCHEDULE_GAP_MINUTES = 30
|
||||
|
||||
|
||||
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
|
||||
"""校验路径 user_id 是否为当前登录用户本人。"""
|
||||
if path_user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only operate your own resources",
|
||||
)
|
||||
|
||||
|
||||
def _parse_time(time_str: str) -> dt_time:
|
||||
"""将 HH:MM 字符串解析为 time 对象"""
|
||||
try:
|
||||
parts = time_str.split(":")
|
||||
return dt_time(hour=int(parts[0]), minute=int(parts[1]))
|
||||
except (ValueError, IndexError):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="Invalid time format, expected HH:MM",
|
||||
)
|
||||
|
||||
|
||||
def _format_time(t: dt_time) -> str:
|
||||
"""将 time 对象格式化为 HH:MM 字符串"""
|
||||
return t.strftime("%H:%M")
|
||||
|
||||
|
||||
def _time_to_minutes(t: dt_time) -> int:
|
||||
return t.hour * 60 + t.minute
|
||||
|
||||
|
||||
def _check_min_gap(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
new_time: dt_time,
|
||||
exclude_id: int | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
校验新时间与用户已有的所有推送时间之间是否满足最小间隔要求(30 分钟)。
|
||||
不满足时直接抛出 400 异常。
|
||||
"""
|
||||
query = db.query(UserDeliverySchedule).filter(
|
||||
UserDeliverySchedule.user_id == user_id
|
||||
)
|
||||
if exclude_id is not None:
|
||||
query = query.filter(UserDeliverySchedule.id != exclude_id)
|
||||
|
||||
existing = query.all()
|
||||
new_minutes = _time_to_minutes(new_time)
|
||||
|
||||
for s in existing:
|
||||
old_minutes = _time_to_minutes(s.delivery_time)
|
||||
diff = abs(new_minutes - old_minutes)
|
||||
circular_diff = min(diff, 1440 - diff)
|
||||
if circular_diff < MIN_SCHEDULE_GAP_MINUTES:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"推送时间间隔不能少于 {MIN_SCHEDULE_GAP_MINUTES} 分钟,"
|
||||
f"与已有的 {_format_time(s.delivery_time)} 冲突",
|
||||
)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 聚合查询:一次性返回用户全部推送配置
|
||||
# ==========================================
|
||||
@router.get(
|
||||
"/users/{user_id}/delivery-config",
|
||||
response_model=UserDeliveryConfigResponse,
|
||||
)
|
||||
def get_delivery_config(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""获取用户的完整推送配置(时间表 + 渠道)。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
schedules = (
|
||||
db.query(UserDeliverySchedule)
|
||||
.filter(UserDeliverySchedule.user_id == user_id)
|
||||
.order_by(UserDeliverySchedule.delivery_time.asc())
|
||||
.all()
|
||||
)
|
||||
endpoints = (
|
||||
db.query(UserPushEndpoint)
|
||||
.filter(UserPushEndpoint.user_id == user_id)
|
||||
.order_by(UserPushEndpoint.priority_level.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
# 手动转换 time 字段为字符串
|
||||
schedule_list = [
|
||||
DeliveryScheduleResponse(
|
||||
id=s.id,
|
||||
user_id=s.user_id,
|
||||
delivery_time=_format_time(s.delivery_time),
|
||||
is_active=s.is_active,
|
||||
created_at=s.created_at,
|
||||
)
|
||||
for s in schedules
|
||||
]
|
||||
|
||||
return UserDeliveryConfigResponse(schedules=schedule_list, endpoints=endpoints)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 推送时间表 CRUD
|
||||
# ==========================================
|
||||
@router.post(
|
||||
"/users/{user_id}/delivery-schedules",
|
||||
response_model=DeliveryScheduleResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
def create_delivery_schedule(
|
||||
user_id: int,
|
||||
payload: DeliveryScheduleCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""新增一条推送时间。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
parsed_time = _parse_time(payload.delivery_time)
|
||||
_check_min_gap(db, user_id, parsed_time)
|
||||
db_obj = UserDeliverySchedule(
|
||||
user_id=user_id,
|
||||
delivery_time=parsed_time,
|
||||
is_active=payload.is_active,
|
||||
)
|
||||
db.add(db_obj)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="This delivery time already exists",
|
||||
)
|
||||
db.refresh(db_obj)
|
||||
return DeliveryScheduleResponse(
|
||||
id=db_obj.id,
|
||||
user_id=db_obj.user_id,
|
||||
delivery_time=_format_time(db_obj.delivery_time),
|
||||
is_active=db_obj.is_active,
|
||||
created_at=db_obj.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/users/{user_id}/delivery-schedules/{schedule_id}",
|
||||
response_model=DeliveryScheduleResponse,
|
||||
)
|
||||
def update_delivery_schedule(
|
||||
user_id: int,
|
||||
schedule_id: int,
|
||||
payload: DeliveryScheduleUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""更新一条推送时间。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
db_obj = (
|
||||
db.query(UserDeliverySchedule)
|
||||
.filter(
|
||||
UserDeliverySchedule.id == schedule_id,
|
||||
UserDeliverySchedule.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not db_obj:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
|
||||
|
||||
if payload.delivery_time is not None:
|
||||
new_time = _parse_time(payload.delivery_time)
|
||||
_check_min_gap(db, user_id, new_time, exclude_id=schedule_id)
|
||||
db_obj.delivery_time = new_time
|
||||
if payload.is_active is not None:
|
||||
db_obj.is_active = payload.is_active
|
||||
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="This delivery time already exists",
|
||||
)
|
||||
db.refresh(db_obj)
|
||||
return DeliveryScheduleResponse(
|
||||
id=db_obj.id,
|
||||
user_id=db_obj.user_id,
|
||||
delivery_time=_format_time(db_obj.delivery_time),
|
||||
is_active=db_obj.is_active,
|
||||
created_at=db_obj.created_at,
|
||||
)
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/users/{user_id}/delivery-schedules/{schedule_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
def delete_delivery_schedule(
|
||||
user_id: int,
|
||||
schedule_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""删除一条推送时间。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
db_obj = (
|
||||
db.query(UserDeliverySchedule)
|
||||
.filter(
|
||||
UserDeliverySchedule.id == schedule_id,
|
||||
UserDeliverySchedule.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not db_obj:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
|
||||
|
||||
db.delete(db_obj)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 推送渠道 CRUD
|
||||
# ==========================================
|
||||
@router.post(
|
||||
"/users/{user_id}/push-endpoints",
|
||||
response_model=PushEndpointResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
def create_push_endpoint(
|
||||
user_id: int,
|
||||
payload: PushEndpointCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""新增一个推送渠道。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
db_obj = UserPushEndpoint(
|
||||
user_id=user_id,
|
||||
channel_type=payload.channel_type.upper().strip(),
|
||||
channel_account=payload.channel_account.strip(),
|
||||
is_active=payload.is_active,
|
||||
priority_level=payload.priority_level,
|
||||
)
|
||||
db.add(db_obj)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="This channel type already exists for the user",
|
||||
)
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
@router.patch(
|
||||
"/users/{user_id}/push-endpoints/{endpoint_id}",
|
||||
response_model=PushEndpointResponse,
|
||||
)
|
||||
def update_push_endpoint(
|
||||
user_id: int,
|
||||
endpoint_id: int,
|
||||
payload: PushEndpointUpdate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""更新一个推送渠道配置。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
db_obj = (
|
||||
db.query(UserPushEndpoint)
|
||||
.filter(
|
||||
UserPushEndpoint.id == endpoint_id,
|
||||
UserPushEndpoint.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not db_obj:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
|
||||
|
||||
if payload.channel_account is not None:
|
||||
db_obj.channel_account = payload.channel_account.strip()
|
||||
if payload.is_active is not None:
|
||||
db_obj.is_active = payload.is_active
|
||||
if payload.priority_level is not None:
|
||||
db_obj.priority_level = payload.priority_level
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/users/{user_id}/push-endpoints/{endpoint_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
def delete_push_endpoint(
|
||||
user_id: int,
|
||||
endpoint_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""删除一个推送渠道。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
db_obj = (
|
||||
db.query(UserPushEndpoint)
|
||||
.filter(
|
||||
UserPushEndpoint.id == endpoint_id,
|
||||
UserPushEndpoint.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not db_obj:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
|
||||
|
||||
db.delete(db_obj)
|
||||
db.commit()
|
||||
return None
|
||||
@@ -1,69 +1,215 @@
|
||||
# app/api/endpoints/events.py
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from datetime import timedelta
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_db
|
||||
from app.models.models import UnifiedEvent, TrendingEvent, InfoSource, RankingLog, utcnow
|
||||
# 导入你上传的 Schema
|
||||
from app.schemas.event_schema import UnifiedEventResponse, PlatformTrendResponse
|
||||
from app.models.models import (
|
||||
ExtractedTopic,
|
||||
InfoSource,
|
||||
RankingLog,
|
||||
TargetType,
|
||||
TrendingEvent,
|
||||
UnifiedEvent,
|
||||
utcnow,
|
||||
)
|
||||
from app.schemas.event_schema import (
|
||||
PaginatedUnifiedEventResponse,
|
||||
PlatformTrendResponse,
|
||||
UnifiedEventResponse,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 排名轨迹最多返回多少个点,避免长时间跨度下数据过大
|
||||
MAX_RANKING_POINTS = 30
|
||||
|
||||
@router.get("/unified", response_model=List[UnifiedEventResponse])
|
||||
|
||||
@router.get("/unified", response_model=PaginatedUnifiedEventResponse)
|
||||
def list_unified_events(
|
||||
min_hot: int = Query(5, description="热度过滤阈值"),
|
||||
hours: int = Query(24, description="查询过去 X 小时的数据"),
|
||||
db: Session = Depends(get_db)
|
||||
min_hot: int = Query(5, ge=0, description="热度阈值,仅返回 hot_score >= 此值的事件"),
|
||||
hours: int = Query(24, ge=1, le=720, description="查询最近多少小时的数据"),
|
||||
skip: int = Query(0, ge=0, description="分页偏移量"),
|
||||
limit: int = Query(10, ge=1, le=50, description="每页返回条数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取聚合大事件列表,完全适配前端 template.html 所需的数据结构
|
||||
"""
|
||||
# 计算时间水位线
|
||||
"""分页返回统一事件,附带各平台热搜、排名轨迹和标签。"""
|
||||
time_limit = utcnow() - timedelta(hours=hours)
|
||||
|
||||
# 1. 查询大事件(按热度降序,且满足时间范围)
|
||||
events = db.query(UnifiedEvent).filter(
|
||||
# 先查总数,用于前端判断是否还有更多
|
||||
base_query = db.query(UnifiedEvent).filter(
|
||||
UnifiedEvent.hot_score >= min_hot,
|
||||
UnifiedEvent.created_at >= time_limit
|
||||
).order_by(UnifiedEvent.hot_score.desc()).all()
|
||||
UnifiedEvent.created_at >= time_limit,
|
||||
)
|
||||
total = base_query.count()
|
||||
|
||||
results = []
|
||||
# 分页查询
|
||||
events = (
|
||||
base_query
|
||||
.order_by(UnifiedEvent.hot_score.desc())
|
||||
.offset(skip)
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not events:
|
||||
return PaginatedUnifiedEventResponse(total=total, has_more=False, data=[])
|
||||
|
||||
event_ids = [ev.id for ev in events]
|
||||
|
||||
# 批量查询所有相关的热搜条目(避免 N+1)
|
||||
trend_rows = (
|
||||
db.query(TrendingEvent, InfoSource.source_name)
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
|
||||
.filter(TrendingEvent.unified_event_id.in_(event_ids))
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 unified_event_id 分组
|
||||
trend_map: dict[int, list[tuple]] = {}
|
||||
trend_ids: list[int] = []
|
||||
for trend, source_name in trend_rows:
|
||||
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
|
||||
trend_ids.append(trend.id)
|
||||
|
||||
# 批量查询排名日志(避免逐条查询)
|
||||
ranking_map: dict[int, list[int]] = {}
|
||||
if trend_ids:
|
||||
ranking_rows = (
|
||||
db.query(
|
||||
RankingLog.event_id,
|
||||
RankingLog.ranking_position,
|
||||
)
|
||||
.filter(
|
||||
RankingLog.event_id.in_(trend_ids),
|
||||
RankingLog.observed_at >= time_limit,
|
||||
)
|
||||
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
|
||||
.all()
|
||||
)
|
||||
for event_id, position in ranking_rows:
|
||||
ranking_map.setdefault(event_id, []).append(position)
|
||||
|
||||
# 批量查询标签
|
||||
tag_map: dict[int, list[str]] = {}
|
||||
tag_rows = (
|
||||
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
|
||||
.filter(
|
||||
ExtractedTopic.target_type == TargetType.EVENT,
|
||||
ExtractedTopic.target_id.in_(event_ids),
|
||||
)
|
||||
.order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
for target_id, keyword in tag_rows:
|
||||
tag_map.setdefault(target_id, []).append(keyword)
|
||||
|
||||
# 组装响应
|
||||
results: list[UnifiedEventResponse] = []
|
||||
for ev in events:
|
||||
# 2. 联表查询:获取该大事件下关联的所有平台及其具体热搜信息
|
||||
trends = db.query(TrendingEvent, InfoSource.source_name).join(
|
||||
InfoSource, TrendingEvent.source_id == InfoSource.id
|
||||
).filter(TrendingEvent.unified_event_id == ev.id).all()
|
||||
platform_list: list[PlatformTrendResponse] = []
|
||||
for trend, source_name in trend_map.get(ev.id, []):
|
||||
history = ranking_map.get(trend.id, [])
|
||||
# 截取尾部,只保留最近的点
|
||||
if len(history) > MAX_RANKING_POINTS:
|
||||
history = history[-MAX_RANKING_POINTS:]
|
||||
|
||||
platform_list = []
|
||||
for trend, s_name in trends:
|
||||
# 3. 获取排名历史轨迹 (用于前端渲染)
|
||||
# 这里的排序顺序 asc 保证了数组从旧到新
|
||||
logs = db.query(RankingLog.ranking_position).filter(
|
||||
RankingLog.event_id == trend.id,
|
||||
RankingLog.observed_at >= time_limit
|
||||
).order_by(RankingLog.observed_at.asc()).all()
|
||||
platform_list.append(
|
||||
PlatformTrendResponse(
|
||||
source_id=trend.source_id,
|
||||
platform_name=source_name,
|
||||
headline=trend.current_headline,
|
||||
url=trend.event_url,
|
||||
current_ranking=trend.current_ranking,
|
||||
ranking_history=history,
|
||||
)
|
||||
)
|
||||
|
||||
# 组装符合 PlatformTrendResponse 结构的字典
|
||||
platform_list.append(PlatformTrendResponse(
|
||||
results.append(
|
||||
UnifiedEventResponse(
|
||||
event_id=ev.id,
|
||||
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
|
||||
summary=ev.ai_comprehensive_summary,
|
||||
hot_score=ev.hot_score,
|
||||
created_at=ev.created_at,
|
||||
platforms=platform_list,
|
||||
tags=tag_map.get(ev.id, []),
|
||||
)
|
||||
)
|
||||
|
||||
has_more = (skip + limit) < total
|
||||
return PaginatedUnifiedEventResponse(total=total, has_more=has_more, data=results)
|
||||
|
||||
|
||||
@router.get("/unified/{event_id}", response_model=UnifiedEventResponse)
|
||||
def get_unified_event(
|
||||
event_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""按 ID 查询单个统一事件,用于推荐跳转时的聚光灯展示。"""
|
||||
ev = db.query(UnifiedEvent).filter(UnifiedEvent.id == event_id).first()
|
||||
if not ev:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
|
||||
time_limit = utcnow() - timedelta(hours=720)
|
||||
|
||||
trend_rows = (
|
||||
db.query(TrendingEvent, InfoSource.source_name)
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
|
||||
.filter(TrendingEvent.unified_event_id == event_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
trend_ids = [t.id for t, _ in trend_rows]
|
||||
ranking_map: dict[int, list[int]] = {}
|
||||
if trend_ids:
|
||||
ranking_rows = (
|
||||
db.query(RankingLog.event_id, RankingLog.ranking_position)
|
||||
.filter(
|
||||
RankingLog.event_id.in_(trend_ids),
|
||||
RankingLog.observed_at >= time_limit,
|
||||
)
|
||||
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
|
||||
.all()
|
||||
)
|
||||
for eid, pos in ranking_rows:
|
||||
ranking_map.setdefault(eid, []).append(pos)
|
||||
|
||||
tag_rows = (
|
||||
db.query(ExtractedTopic.topic_keyword)
|
||||
.filter(
|
||||
ExtractedTopic.target_type == TargetType.EVENT,
|
||||
ExtractedTopic.target_id == event_id,
|
||||
)
|
||||
.order_by(ExtractedTopic.relevance_score.desc())
|
||||
.all()
|
||||
)
|
||||
tags = [row[0] for row in tag_rows]
|
||||
|
||||
platform_list: list[PlatformTrendResponse] = []
|
||||
for trend, source_name in trend_rows:
|
||||
history = ranking_map.get(trend.id, [])
|
||||
if len(history) > MAX_RANKING_POINTS:
|
||||
history = history[-MAX_RANKING_POINTS:]
|
||||
platform_list.append(
|
||||
PlatformTrendResponse(
|
||||
source_id=trend.source_id,
|
||||
platform_name=s_name,
|
||||
platform_name=source_name,
|
||||
headline=trend.current_headline,
|
||||
url=trend.event_url,
|
||||
current_ranking=trend.current_ranking,
|
||||
ranking_history=[log[0] for log in logs]
|
||||
))
|
||||
ranking_history=history,
|
||||
)
|
||||
)
|
||||
|
||||
# 4. 组装符合 UnifiedEventResponse 结构的字典
|
||||
results.append(UnifiedEventResponse(
|
||||
event_id=ev.id,
|
||||
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
|
||||
summary=ev.ai_comprehensive_summary,
|
||||
hot_score=ev.hot_score,
|
||||
created_at=ev.created_at,
|
||||
platforms=platform_list
|
||||
))
|
||||
|
||||
return results
|
||||
return UnifiedEventResponse(
|
||||
event_id=ev.id,
|
||||
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
|
||||
summary=ev.ai_comprehensive_summary,
|
||||
hot_score=ev.hot_score,
|
||||
created_at=ev.created_at,
|
||||
platforms=platform_list,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
@@ -0,0 +1,158 @@
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_current_user, get_db
|
||||
from app.models.models import AppUser, UserTopicPreference
|
||||
from app.schemas.preference_schema import (
|
||||
MatchedEventResponse,
|
||||
UserPreferenceRecommendationResponse,
|
||||
UserTopicPreferenceCreate,
|
||||
UserTopicPreferenceResponse,
|
||||
)
|
||||
from app.services.matching_service import recommend_events_for_user
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
|
||||
"""校验路径 user_id 是否为当前登录用户本人。"""
|
||||
if path_user_id != current_user.id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You can only operate your own resources",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/preferences",
|
||||
response_model=List[UserTopicPreferenceResponse],
|
||||
)
|
||||
def list_user_preferences(
|
||||
user_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""获取用户已设置的兴趣关键词。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
preferences = (
|
||||
db.query(UserTopicPreference)
|
||||
.filter(UserTopicPreference.user_id == user_id)
|
||||
.order_by(UserTopicPreference.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return preferences
|
||||
|
||||
|
||||
@router.post(
|
||||
"/users/{user_id}/preferences",
|
||||
response_model=UserTopicPreferenceResponse,
|
||||
status_code=status.HTTP_201_CREATED,
|
||||
)
|
||||
def create_user_preference(
|
||||
user_id: int,
|
||||
payload: UserTopicPreferenceCreate,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""新增一个用户兴趣关键词。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
keyword = payload.interested_keyword.strip()
|
||||
if not keyword:
|
||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Keyword cannot be empty")
|
||||
|
||||
db_obj = UserTopicPreference(
|
||||
user_id=user_id,
|
||||
interested_keyword=keyword,
|
||||
)
|
||||
db.add(db_obj)
|
||||
try:
|
||||
db.commit()
|
||||
except IntegrityError:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_409_CONFLICT,
|
||||
detail="Preference keyword already exists for this user",
|
||||
)
|
||||
|
||||
db.refresh(db_obj)
|
||||
return db_obj
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/users/{user_id}/preferences/{preference_id}",
|
||||
status_code=status.HTTP_204_NO_CONTENT,
|
||||
)
|
||||
def delete_user_preference(
|
||||
user_id: int,
|
||||
preference_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""删除一个用户兴趣关键词。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
preference = (
|
||||
db.query(UserTopicPreference)
|
||||
.filter(
|
||||
UserTopicPreference.id == preference_id,
|
||||
UserTopicPreference.user_id == user_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not preference:
|
||||
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Preference not found")
|
||||
|
||||
db.delete(preference)
|
||||
db.commit()
|
||||
return None
|
||||
|
||||
|
||||
@router.get(
|
||||
"/users/{user_id}/recommended-events",
|
||||
response_model=UserPreferenceRecommendationResponse,
|
||||
)
|
||||
def recommend_events(
|
||||
user_id: int,
|
||||
min_hot: int = Query(3, ge=1, description="最小热度阈值"),
|
||||
hours: int = Query(72, ge=1, le=24 * 30, description="仅匹配最近多少小时的事件"),
|
||||
limit: int = Query(20, ge=1, le=50, description="最多返回多少条推荐"),
|
||||
semantic_threshold: float = Query(0.78, ge=0.0, le=1.0, description="语义匹配相似度阈值"),
|
||||
db: Session = Depends(get_db),
|
||||
current_user: AppUser = Depends(get_current_user),
|
||||
):
|
||||
"""基于用户兴趣词推荐事件(精确匹配 + 语义匹配)。"""
|
||||
_ensure_self_access(user_id, current_user)
|
||||
|
||||
matched = recommend_events_for_user(
|
||||
db,
|
||||
user_id=user_id,
|
||||
min_hot=min_hot,
|
||||
hours=hours,
|
||||
limit=limit,
|
||||
semantic_threshold=semantic_threshold,
|
||||
)
|
||||
|
||||
result_data: list[MatchedEventResponse] = []
|
||||
for item in matched:
|
||||
result_data.append(
|
||||
MatchedEventResponse(
|
||||
event_id=item.event.id,
|
||||
unified_title=item.event.unified_title,
|
||||
summary=item.event.ai_comprehensive_summary,
|
||||
hot_score=item.event.hot_score,
|
||||
created_at=item.event.created_at,
|
||||
tags=item.tags,
|
||||
match_score=item.match_score,
|
||||
exact_hits=item.exact_hits,
|
||||
semantic_hits=item.semantic_hits,
|
||||
)
|
||||
)
|
||||
|
||||
return UserPreferenceRecommendationResponse(
|
||||
user_id=user_id,
|
||||
total=len(result_data),
|
||||
data=result_data,
|
||||
)
|
||||
@@ -0,0 +1,75 @@
|
||||
# 公关修改追踪 API:查询热搜标题被偷偷修改的历史记录
|
||||
from datetime import timedelta
|
||||
from typing import List, Optional
|
||||
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_db
|
||||
from app.models.models import HeadlineRevision, InfoSource, TrendingEvent, utcnow
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
from datetime import datetime
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class HeadlineRevisionResponse(BaseModel):
|
||||
"""标题修改记录响应体"""
|
||||
id: int
|
||||
event_id: int
|
||||
previous_headline: str
|
||||
revised_headline: str
|
||||
source_name: Optional[str] = None
|
||||
platform_icon: Optional[str] = None
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
@router.get("/headline-revisions", response_model=List[HeadlineRevisionResponse])
|
||||
def list_headline_revisions(
|
||||
hours: int = Query(48, ge=1, le=720, description="查询最近多少小时内的修改记录"),
|
||||
limit: int = Query(50, ge=1, le=500, description="最多返回条数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
获取最近的标题修改记录列表。
|
||||
用于公关监测:发现哪些平台偷偷改了热搜标题。
|
||||
"""
|
||||
time_limit = utcnow() - timedelta(hours=hours)
|
||||
|
||||
rows = (
|
||||
db.query(HeadlineRevision, InfoSource.source_name)
|
||||
.join(TrendingEvent, HeadlineRevision.event_id == TrendingEvent.id)
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
|
||||
.filter(HeadlineRevision.created_at >= time_limit)
|
||||
.order_by(HeadlineRevision.created_at.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
# 平台名到图标的简单映射
|
||||
icon_map = {
|
||||
"微博热搜": "weibo",
|
||||
"知乎热榜": "zhihu",
|
||||
"百度热搜": "baidu",
|
||||
"今日头条": "toutiao",
|
||||
"抖音热榜": "douyin",
|
||||
"B站热搜": "bilibili",
|
||||
}
|
||||
|
||||
results: list[HeadlineRevisionResponse] = []
|
||||
for revision, source_name in rows:
|
||||
results.append(
|
||||
HeadlineRevisionResponse(
|
||||
id=revision.id,
|
||||
event_id=revision.event_id,
|
||||
previous_headline=revision.previous_headline,
|
||||
revised_headline=revision.revised_headline,
|
||||
source_name=source_name,
|
||||
platform_icon=icon_map.get(source_name, "newspaper"),
|
||||
created_at=revision.created_at,
|
||||
)
|
||||
)
|
||||
|
||||
return results
|
||||
@@ -0,0 +1,65 @@
|
||||
# 系统状态监控 API:返回爬虫集群运行概况
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Optional
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import func
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.api.dependencies import get_db
|
||||
from app.models.models import DataSyncTask, InfoSource, TaskStatus, utcnow
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
class SystemStatsResponse(BaseModel):
|
||||
"""系统运行状态汇总"""
|
||||
active_sources: int
|
||||
total_sources: int
|
||||
items_today: int
|
||||
success_tasks_today: int
|
||||
error_tasks_today: int
|
||||
last_sync_at: Optional[datetime] = None
|
||||
|
||||
|
||||
@router.get("/system/stats", response_model=SystemStatsResponse)
|
||||
def get_system_stats(db: Session = Depends(get_db)):
|
||||
"""获取爬虫集群的当日运行状态。"""
|
||||
today_start = utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
|
||||
# 信息源统计
|
||||
total_sources = db.query(func.count(InfoSource.id)).scalar() or 0
|
||||
active_sources = (
|
||||
db.query(func.count(InfoSource.id))
|
||||
.filter(InfoSource.is_enabled.is_(True))
|
||||
.scalar() or 0
|
||||
)
|
||||
|
||||
# 今日任务统计
|
||||
today_tasks = (
|
||||
db.query(DataSyncTask)
|
||||
.filter(DataSyncTask.created_at >= today_start)
|
||||
.all()
|
||||
)
|
||||
|
||||
items_today = sum(t.items_fetched for t in today_tasks)
|
||||
success_count = sum(1 for t in today_tasks if t.task_status == TaskStatus.SUCCESS)
|
||||
error_count = sum(1 for t in today_tasks if t.task_status == TaskStatus.ERROR)
|
||||
|
||||
# 最后一次同步时间
|
||||
last_task = (
|
||||
db.query(DataSyncTask)
|
||||
.filter(DataSyncTask.task_status == TaskStatus.SUCCESS)
|
||||
.order_by(DataSyncTask.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
return SystemStatsResponse(
|
||||
active_sources=active_sources,
|
||||
total_sources=total_sources,
|
||||
items_today=items_today,
|
||||
success_tasks_today=success_count,
|
||||
error_tasks_today=error_count,
|
||||
last_sync_at=last_task.created_at if last_task else None,
|
||||
)
|
||||
@@ -1,6 +1,6 @@
|
||||
# app/api/router.py
|
||||
from fastapi import APIRouter
|
||||
from app.api.endpoints import auth, sources, events
|
||||
from app.api.endpoints import auth, delivery, events, preferences, revisions, sources, stats
|
||||
|
||||
api_router = APIRouter()
|
||||
|
||||
@@ -9,4 +9,18 @@ api_router.include_router(sources.router, prefix="/sources", tags=["信息源管
|
||||
|
||||
# 注册大事件相关的路由
|
||||
api_router.include_router(events.router, prefix="/events", tags=["Unified Events"])
|
||||
|
||||
# 认证
|
||||
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
|
||||
|
||||
# 用户偏好(关键词订阅)
|
||||
api_router.include_router(preferences.router, tags=["User Preferences"])
|
||||
|
||||
# 推送设置(时间表 + 渠道)
|
||||
api_router.include_router(delivery.router, tags=["Delivery Settings"])
|
||||
|
||||
# 公关修改追踪
|
||||
api_router.include_router(revisions.router, prefix="/events", tags=["Headline Revisions"])
|
||||
|
||||
# 系统状态监控
|
||||
api_router.include_router(stats.router, tags=["System Stats"])
|
||||
|
||||
@@ -8,9 +8,15 @@ import time
|
||||
from typing import Tuple
|
||||
|
||||
|
||||
PASSWORD_HASH_ITERATIONS = int(os.getenv("PASSWORD_HASH_ITERATIONS", "120000"))
|
||||
DEFAULT_PASSWORD_HASH_ITERATIONS = 120000
|
||||
PASSWORD_HASH_ITERATIONS = int(
|
||||
os.getenv("PASSWORD_HASH_ITERATIONS", str(DEFAULT_PASSWORD_HASH_ITERATIONS))
|
||||
)
|
||||
AUTH_SECRET_KEY = os.getenv("AUTH_SECRET_KEY", "change-this-secret-in-env")
|
||||
AUTH_TOKEN_EXPIRE_MINUTES = int(os.getenv("AUTH_TOKEN_EXPIRE_MINUTES", "10080"))
|
||||
DEFAULT_AUTH_TOKEN_EXPIRE_MINUTES = 10080
|
||||
AUTH_TOKEN_EXPIRE_MINUTES = int(
|
||||
os.getenv("AUTH_TOKEN_EXPIRE_MINUTES", str(DEFAULT_AUTH_TOKEN_EXPIRE_MINUTES))
|
||||
)
|
||||
|
||||
|
||||
def hash_password(password: str) -> str:
|
||||
@@ -61,6 +67,11 @@ def _urlsafe_b64encode(raw: bytes) -> str:
|
||||
return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=")
|
||||
|
||||
|
||||
def _urlsafe_b64decode(raw: str) -> bytes:
|
||||
padding = "=" * (-len(raw) % 4)
|
||||
return base64.urlsafe_b64decode(raw + padding)
|
||||
|
||||
|
||||
def create_access_token(user_id: int, email: str) -> Tuple[str, int]:
|
||||
expires_in = AUTH_TOKEN_EXPIRE_MINUTES * 60
|
||||
payload = {
|
||||
@@ -77,3 +88,51 @@ def create_access_token(user_id: int, email: str) -> Tuple[str, int]:
|
||||
).digest()
|
||||
token = f"{encoded_payload}.{_urlsafe_b64encode(signature)}"
|
||||
return token, expires_in
|
||||
|
||||
|
||||
def decode_access_token(token: str) -> Tuple[int, str]:
|
||||
"""
|
||||
解码并校验访问令牌,返回 (user_id, email)。
|
||||
校验项包括:结构、签名、过期时间、字段完整性。
|
||||
"""
|
||||
try:
|
||||
encoded_payload, encoded_signature = token.split(".", 1)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid token format") from exc
|
||||
|
||||
try:
|
||||
provided_signature = _urlsafe_b64decode(encoded_signature)
|
||||
except Exception as exc:
|
||||
raise ValueError("Invalid token signature encoding") from exc
|
||||
|
||||
expected_signature = hmac.new(
|
||||
AUTH_SECRET_KEY.encode("utf-8"),
|
||||
encoded_payload.encode("utf-8"),
|
||||
hashlib.sha256,
|
||||
).digest()
|
||||
if not hmac.compare_digest(provided_signature, expected_signature):
|
||||
raise ValueError("Invalid token signature")
|
||||
|
||||
try:
|
||||
payload_bytes = _urlsafe_b64decode(encoded_payload)
|
||||
payload = json.loads(payload_bytes.decode("utf-8"))
|
||||
except Exception as exc:
|
||||
raise ValueError("Invalid token payload") from exc
|
||||
|
||||
sub = payload.get("sub")
|
||||
email = payload.get("email")
|
||||
exp = payload.get("exp")
|
||||
|
||||
if not sub or not email or exp is None:
|
||||
raise ValueError("Token payload missing required fields")
|
||||
|
||||
try:
|
||||
user_id = int(sub)
|
||||
exp_ts = int(exp)
|
||||
except (TypeError, ValueError) as exc:
|
||||
raise ValueError("Invalid token payload types") from exc
|
||||
|
||||
if time.time() >= exp_ts:
|
||||
raise ValueError("Token expired")
|
||||
|
||||
return user_id, str(email)
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
import requests
|
||||
import json
|
||||
|
||||
# 请将此处的 URL 替换为您实际的 API 基础域名
|
||||
api_url = "http://10.252.130.135:8000/api/v1/sources/"
|
||||
|
||||
# 请求头
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
# "Authorization": "Bearer YOUR_TOKEN" # 如果接口需要鉴权,请取消注释并填入 Token
|
||||
}
|
||||
|
||||
# 解析后的数据源列表
|
||||
sources_data = [
|
||||
{"name": "今日头条", "url": "toutiao"},
|
||||
{"name": "百度热搜", "url": "baidu"},
|
||||
{"name": "华尔街见闻", "url": "wallstreetcn-hot"},
|
||||
{"name": "澎湃新闻", "url": "thepaper"},
|
||||
{"name": "bilibili 热搜", "url": "bilibili-hot-search"},
|
||||
{"name": "财联社热门", "url": "cls-hot"},
|
||||
{"name": "凤凰网", "url": "ifeng"},
|
||||
{"name": "贴吧", "url": "tieba"},
|
||||
{"name": "微博", "url": "weibo"},
|
||||
{"name": "抖音", "url": "douyin"},
|
||||
{"name": "知乎", "url": "zhihu"}
|
||||
]
|
||||
|
||||
# 遍历数据并发送 POST 请求
|
||||
for item in sources_data:
|
||||
payload = {
|
||||
"source_name": item["name"],
|
||||
"source_type": "HOT_TREND",
|
||||
"home_url": item["url"],
|
||||
"is_enabled": True
|
||||
}
|
||||
|
||||
try:
|
||||
response = requests.post(api_url, headers=headers, data=json.dumps(payload))
|
||||
if response.status_code in (200, 201):
|
||||
print(f"✅ 成功创建: {item['name']}")
|
||||
else:
|
||||
print(f"❌ 创建失败: {item['name']} - 状态码: {response.status_code} - 详情: {response.text}")
|
||||
except Exception as e:
|
||||
print(f"⚠️ 请求异常: {item['name']} - 错误: {e}")
|
||||
|
||||
print("执行完毕!")
|
||||
+37
-3
@@ -1,12 +1,24 @@
|
||||
# app/main.py
|
||||
import logging
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from dotenv import load_dotenv
|
||||
|
||||
# 统一配置日志格式和级别,确保 delivery_service 等的 INFO 日志可见
|
||||
logging.basicConfig(
|
||||
level=logging.INFO,
|
||||
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
# 降低 APScheduler 运行心跳日志,避免每分钟刷屏
|
||||
logging.getLogger("apscheduler").setLevel(logging.WARNING)
|
||||
|
||||
from apscheduler.schedulers.asyncio import AsyncIOScheduler
|
||||
from app.services.fetcher_service import fetch_and_save_trending_data
|
||||
from app.services.summary_service import generate_unified_summaries
|
||||
from app.services.delivery_service import check_and_deliver
|
||||
from app.database import engine
|
||||
from app.models.models import Base
|
||||
|
||||
@@ -47,14 +59,24 @@ async def lifespan(app: FastAPI):
|
||||
id='ai_summary_job',
|
||||
replace_existing=True
|
||||
)
|
||||
# 推送调度:每分钟检查是否有用户需要接收邮件推送
|
||||
scheduler.add_job(
|
||||
check_and_deliver,
|
||||
'interval',
|
||||
minutes=1,
|
||||
id='delivery_check_job',
|
||||
replace_existing=True,
|
||||
)
|
||||
|
||||
scheduler.start()
|
||||
print(f"定时抓取任务已启动,每 {CRAWL_INTERVAL} 分钟执行一次")
|
||||
print(f"AI 摘要生成任务已启动,每 {SUMMARY_INTERVAL} 分钟执行一次")
|
||||
print("邮件推送调度已启动,每分钟检查一次")
|
||||
|
||||
# 为了测试方便,启动时立即执行一次
|
||||
await fetch_and_save_trending_data()
|
||||
# await fetch_and_save_trending_data()
|
||||
|
||||
await generate_unified_summaries()
|
||||
# await generate_unified_summaries()
|
||||
|
||||
yield # 此时 FastAPI 开始接受请求
|
||||
|
||||
@@ -67,7 +89,19 @@ async def lifespan(app: FastAPI):
|
||||
app = FastAPI(title="AI 新闻聚合引擎 API", lifespan=lifespan)
|
||||
|
||||
# ==========================================
|
||||
# 2. 挂载路由总线
|
||||
# 2. CORS 中间件:允许前端开发服务器跨域请求
|
||||
# ==========================================
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
# allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
|
||||
allow_origins=["*"],
|
||||
# allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
# ==========================================
|
||||
# 3. 挂载路由总线
|
||||
# ==========================================
|
||||
# 版本控制
|
||||
app.include_router(api_router, prefix="/api/v1")
|
||||
|
||||
@@ -0,0 +1,264 @@
|
||||
# 推送邮件 HTML 模板
|
||||
# 用于生成定时推送给用户的热点摘要邮件
|
||||
|
||||
# 邮件客户端不支持 Font Awesome,改用 Emoji 代替平台图标
|
||||
PLATFORM_EMOJI: dict[str, str] = {
|
||||
"微博热搜": "🔴",
|
||||
"微博": "🔴",
|
||||
"知乎热榜": "🔵",
|
||||
"知乎": "🔵",
|
||||
"百度热搜": "🔍",
|
||||
"今日头条": "📰",
|
||||
"抖音热榜": "🎵",
|
||||
"抖音": "🎵",
|
||||
"bilibili 热搜": "📺",
|
||||
"B站热搜": "📺",
|
||||
"华尔街见闻": "📈",
|
||||
"澎湃新闻": "🌊",
|
||||
"财联社热门": "💰",
|
||||
"凤凰网": "🦅",
|
||||
"贴吧": "💬",
|
||||
}
|
||||
|
||||
DIGEST_HTML_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="zh-CN">
|
||||
<head>
|
||||
<meta charset="UTF-8">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<style>
|
||||
|
||||
body{{margin:0;padding:0;background:#0d1117;color:#e6edf3;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;}}
|
||||
.container{{max-width:640px;margin:0 auto;padding:32px 16px;}}
|
||||
|
||||
|
||||
.header{{text-align:center;padding:10px 0 30px;margin-bottom:24px;border-bottom:1px solid rgba(255,255,255,0.06);}}
|
||||
.header h1{{font-size:26px;font-weight:800;margin:0 0 10px;color:#ffffff;text-shadow:0 0 16px rgba(139,92,246,0.5);letter-spacing:0.5px;}}
|
||||
.header p{{font-size:14px;color:#8b949e;margin:0;}}
|
||||
|
||||
|
||||
.mode-badge{{display:inline-block;margin-top:12px;padding:4px 14px;border-radius:20px;font-size:12px;font-weight:600;letter-spacing:0.5px;}}
|
||||
.mode-default{{background:rgba(59,130,246,0.15);color:#7dd3fc;border:1px solid rgba(59,130,246,0.3);}}
|
||||
.mode-keyword{{background:rgba(168,85,247,0.15);color:#e879f9;border:1px solid rgba(168,85,247,0.3);}}
|
||||
|
||||
|
||||
.event-card{{background:#161b22;border:1px solid #30363d;border-radius:16px;padding:20px;margin-bottom:20px;box-shadow:0 4px 12px rgba(0,0,0,0.2);}}
|
||||
.event-card.is-hot{{border-left:4px solid #f85149;background:linear-gradient(90deg, rgba(248,81,73,0.03) 0%, transparent 100%), #161b22;}}
|
||||
|
||||
|
||||
.event-title{{font-size:18px;font-weight:700;margin:0 0 14px;color:#ffffff;line-height:1.5;}}
|
||||
|
||||
|
||||
.event-meta{{margin-bottom:12px;}}
|
||||
.badge{{display:inline-block;padding:3px 10px;border-radius:6px;font-size:12px;font-weight:600;margin-right:6px;margin-bottom:6px;}}
|
||||
.badge-hot{{background:rgba(248,81,73,0.15);color:#ff7b72;border:1px solid rgba(248,81,73,0.3);}}
|
||||
.badge-warm{{background:rgba(210,153,34,0.15);color:#d29922;border:1px solid rgba(210,153,34,0.3);}}
|
||||
.badge-normal{{background:rgba(56,139,253,0.15);color:#58a6ff;border:1px solid rgba(56,139,253,0.3);}}
|
||||
.badge-tag{{background:rgba(139,148,158,0.15);color:#8b949e;border:1px solid rgba(139,148,158,0.2);}}
|
||||
|
||||
|
||||
.summary{{font-size:14px;line-height:1.6;color:#c9d1d9;padding:12px 16px;background:rgba(139,92,246,0.06);border-radius:0 8px 8px 0;border-left:3px solid #a78bfa;margin-bottom:16px;}}
|
||||
.summary strong{{color:#a78bfa;font-weight:600;}}
|
||||
|
||||
|
||||
.platforms-list{{margin:0;padding:0;list-style:none;background:rgba(255,255,255,0.02);border-radius:10px;padding:12px;}}
|
||||
.platform-item{{padding:8px 0;border-bottom:1px solid rgba(255,255,255,0.05);}}
|
||||
.platform-item:last-child{{border-bottom:none;padding-bottom:0;}}
|
||||
.platform-item:first-child{{padding-top:0;}}
|
||||
.platform-source{{font-size:12px;color:#8b949e;margin-bottom:4px;display:flex;align-items:center;}}
|
||||
.platform-rank{{display:inline-block;padding:2px 6px;border-radius:4px;background:rgba(210,153,34,0.15);color:#d29922;font-size:10px;font-weight:700;margin-left:6px;}}
|
||||
.platform-link{{font-size:14px;color:#79c0ff;text-decoration:none;line-height:1.5;display:block;transition:color 0.2s;}}
|
||||
.platform-link:hover{{text-decoration:underline;color:#a5d6ff;}}
|
||||
.platform-text{{font-size:14px;color:#e6edf3;line-height:1.5;}}
|
||||
|
||||
/* 匹配信息底部栏 */
|
||||
.match-info{{font-size:12px;color:#8b949e;margin-top:16px;padding-top:12px;border-top:1px dashed #30363d;}}
|
||||
.hit{{display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;margin-right:4px;margin-top:4px;}}
|
||||
.hit-exact{{background:rgba(46,160,67,0.15);color:#3fb950;}}
|
||||
.hit-semantic{{background:rgba(163,113,247,0.15);color:#d2a8ff;}}
|
||||
|
||||
/* 页脚 */
|
||||
.footer{{text-align:center;padding:30px 0 10px;margin-top:20px;font-size:12px;color:#484f58;}}
|
||||
.footer a{{color:#79c0ff;text-decoration:none;}}
|
||||
.footer a:hover{{text-decoration:underline;}}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
<div class="container">
|
||||
<div class="header">
|
||||
<h1>InsightRadar · 热点快报</h1>
|
||||
<p>{delivery_time} · 为你精选了 {event_count} 条事件</p>
|
||||
<span class="mode-badge {mode_badge_class}">{mode_label}</span>
|
||||
</div>
|
||||
|
||||
{event_cards_html}
|
||||
|
||||
<div class="footer">
|
||||
<p>此邮件由 InsightRadar 自动推送。</p>
|
||||
<p>如需调整推送设置,请登录 <a href="{app_url}">InsightRadar 控制台</a></p>
|
||||
</div>
|
||||
</div>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
EVENT_CARD_TEMPLATE = """\
|
||||
<div class="event-card{hot_class}">
|
||||
<div class="event-meta">
|
||||
<span class="badge {badge_class}">{hot_label} {hot_score}</span>
|
||||
{tags_html}
|
||||
</div>
|
||||
<div class="event-title">{title}</div>
|
||||
{summary_html}
|
||||
{platforms_html}
|
||||
{match_html}
|
||||
</div>
|
||||
"""
|
||||
|
||||
|
||||
def _hot_level(score: int) -> tuple[str, str, str]:
|
||||
"""返回 (label, badge_class, hot_class)"""
|
||||
if score >= 50:
|
||||
return "全网沸腾", "badge-hot", " is-hot"
|
||||
if score >= 20:
|
||||
return "高度关注", "badge-warm", ""
|
||||
if score >= 10:
|
||||
return "上升中", "badge-normal", ""
|
||||
return "一般关注", "badge-tag", ""
|
||||
|
||||
|
||||
def _get_event_summary(ev) -> str:
|
||||
"""
|
||||
兼容 ORM 字段名(ai_comprehensive_summary)和 schema 字段名(summary)。
|
||||
"""
|
||||
return (
|
||||
getattr(ev, "summary", None)
|
||||
or getattr(ev, "ai_comprehensive_summary", None)
|
||||
or ""
|
||||
)
|
||||
|
||||
|
||||
def _build_platforms_html(platform_list: list[dict]) -> str:
|
||||
"""
|
||||
将平台数据列表渲染为 HTML。
|
||||
每条包含:emoji 图标 + 来源名 + 排名徽章 + 可点击标题链接。
|
||||
"""
|
||||
if not platform_list:
|
||||
return ""
|
||||
|
||||
rows = []
|
||||
seen_sources: set[str] = set()
|
||||
for p in platform_list[:8]:
|
||||
source_name = p.get("source_name", "未知")
|
||||
# 同一来源只显示第一条(通常是排名最靠前的那条)
|
||||
if source_name in seen_sources:
|
||||
continue
|
||||
seen_sources.add(source_name)
|
||||
|
||||
headline = p.get("headline", "")
|
||||
url = p.get("url", "")
|
||||
ranking = p.get("ranking")
|
||||
emoji = PLATFORM_EMOJI.get(source_name, "🔗")
|
||||
|
||||
rank_html = ""
|
||||
if ranking:
|
||||
rank_html = f'<span class="platform-rank">TOP {ranking}</span>'
|
||||
|
||||
if url:
|
||||
title_html = (
|
||||
f'<a href="{url}" class="platform-link">{headline}</a>'
|
||||
)
|
||||
else:
|
||||
title_html = f'<span class="platform-text">{headline}</span>'
|
||||
|
||||
rows.append(
|
||||
f'<li class="platform-item">'
|
||||
f'<div class="platform-source">{emoji} {source_name}{rank_html}</div>'
|
||||
f'{title_html}'
|
||||
f'</li>'
|
||||
)
|
||||
|
||||
if not rows:
|
||||
return ""
|
||||
return '<div class="platforms-list-wrapper"><ul class="platforms-list">' + "".join(rows) + "</ul></div>"
|
||||
|
||||
|
||||
def build_digest_html(
|
||||
items: list,
|
||||
delivery_time_str: str,
|
||||
platforms_map: dict[int, list[dict]] | None = None,
|
||||
app_url: str = "http://localhost:5173",
|
||||
is_default_push: bool = False,
|
||||
) -> str:
|
||||
"""
|
||||
根据事件列表生成推送邮件 HTML 正文。
|
||||
items 元素可以是 MatchedEventResult 或 _DefaultEventItem,
|
||||
二者均有 .event / .tags / .exact_hits / .semantic_hits / .match_score 属性。
|
||||
platforms_map: event_id → [{source_name, headline, url, ranking}]
|
||||
"""
|
||||
if platforms_map is None:
|
||||
platforms_map = {}
|
||||
|
||||
mode_label = "全网热点推送" if is_default_push else "个性化关键词匹配"
|
||||
mode_badge_class = "mode-default" if is_default_push else "mode-keyword"
|
||||
|
||||
cards = []
|
||||
for item in items:
|
||||
ev = item.event
|
||||
hot_label, badge_class, hot_class = _hot_level(ev.hot_score)
|
||||
|
||||
tags_html = "".join(
|
||||
f'<span class="badge badge-tag">{t}</span>'
|
||||
for t in item.tags[:4]
|
||||
)
|
||||
|
||||
summary_text = _get_event_summary(ev)
|
||||
summary_html = ""
|
||||
if summary_text:
|
||||
summary_html = (
|
||||
f'<div class="summary"><strong>AI 洞察:</strong>{summary_text}</div>'
|
||||
)
|
||||
|
||||
platform_list = platforms_map.get(ev.id, [])
|
||||
platforms_html = _build_platforms_html(platform_list)
|
||||
|
||||
match_parts = []
|
||||
# 仅个性化模式才显示匹配信息
|
||||
if not getattr(item, "is_default", False):
|
||||
for h in item.exact_hits[:3]:
|
||||
match_parts.append(f'<span class="hit hit-exact">精确 {h}</span>')
|
||||
for s in item.semantic_hits[:2]:
|
||||
sim_pct = int(s.get("similarity", 0) * 100)
|
||||
match_parts.append(
|
||||
f'<span class="hit hit-semantic">语义 {s.get("topic_keyword", "")} {sim_pct}%</span>'
|
||||
)
|
||||
match_html = ""
|
||||
if match_parts:
|
||||
match_html = (
|
||||
f'<div class="match-info">匹配度 {item.match_score:.0f} · '
|
||||
+ " ".join(match_parts)
|
||||
+ "</div>"
|
||||
)
|
||||
|
||||
cards.append(
|
||||
EVENT_CARD_TEMPLATE.format(
|
||||
hot_class=hot_class,
|
||||
badge_class=badge_class,
|
||||
hot_label=hot_label,
|
||||
hot_score=ev.hot_score,
|
||||
tags_html=tags_html,
|
||||
title=ev.unified_title,
|
||||
summary_html=summary_html,
|
||||
platforms_html=platforms_html,
|
||||
match_html=match_html,
|
||||
)
|
||||
)
|
||||
|
||||
return DIGEST_HTML_TEMPLATE.format(
|
||||
delivery_time=delivery_time_str,
|
||||
event_count=len(items),
|
||||
event_cards_html="\n".join(cards),
|
||||
app_url=app_url,
|
||||
mode_label=mode_label,
|
||||
mode_badge_class=mode_badge_class,
|
||||
)
|
||||
@@ -1,14 +1,27 @@
|
||||
SUMMARY_SYSTEM_PROMPT = "你是一个输出严格 JSON 格式的后台引擎。"
|
||||
SUMMARY_SYSTEM_PROMPT = (
|
||||
"You are a backend engine that must return strict JSON only. "
|
||||
"Do not include markdown, explanation, or extra keys."
|
||||
)
|
||||
|
||||
SUMMARY_USER_PROMPT_TEMPLATE = """
|
||||
你是一个专业的新闻聚合编辑。请根据以下同一个大事件在不同平台的热搜标题,
|
||||
为该事件生成一个客观、吸睛的【统一大标题】,以及一段【多平台视角的综合摘要】。
|
||||
You are a professional cross-platform news editor.
|
||||
Based on the following headlines about the same event from different platforms,
|
||||
return:
|
||||
1) a neutral unified title
|
||||
2) a cross-platform comprehensive summary
|
||||
3) topic tags
|
||||
|
||||
要求:
|
||||
1. 摘要结构类似:"该事件在多平台发酵。微博侧重讨论...,知乎硬核解析...,科技媒体关注..."。
|
||||
2. 提炼出各平台的讨论侧重点,不要简单罗列标题。
|
||||
3. 必须以严格的 JSON 格式返回,只包含 "unified_title" 和 "ai_comprehensive_summary" 两个字段,不要有多余的说明。
|
||||
Rules:
|
||||
1. Return strict JSON with exactly these keys:
|
||||
- "unified_title": string
|
||||
- "ai_comprehensive_summary": string
|
||||
- "topic_keywords": array of 3 to 8 objects
|
||||
2. Each item in "topic_keywords" must be:
|
||||
{{"keyword": string, "relevance_score": number}}
|
||||
3. relevance_score must be in [0, 100].
|
||||
4. keyword should be concise (max 12 chars preferred).
|
||||
5. The language should follow the dominant language in the input.
|
||||
|
||||
各平台热搜标题数据:
|
||||
Cross-platform headline data:
|
||||
{platform_data_text}
|
||||
"""
|
||||
|
||||
@@ -0,0 +1,72 @@
|
||||
# 推送设置相关的请求/响应模型
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 推送时间表 (UserDeliverySchedule)
|
||||
# ==========================================
|
||||
class DeliveryScheduleCreate(BaseModel):
|
||||
"""新增推送时间请求体,时间格式 HH:MM"""
|
||||
delivery_time: str = Field(..., pattern=r"^\d{2}:\d{2}$", description="每天推送的时间,格式 HH:MM")
|
||||
is_active: bool = Field(default=True, description="是否启用此时段")
|
||||
|
||||
|
||||
class DeliveryScheduleUpdate(BaseModel):
|
||||
"""更新推送时间请求体"""
|
||||
delivery_time: Optional[str] = Field(None, pattern=r"^\d{2}:\d{2}$")
|
||||
is_active: Optional[bool] = None
|
||||
|
||||
|
||||
class DeliveryScheduleResponse(BaseModel):
|
||||
"""推送时间响应体"""
|
||||
id: int
|
||||
user_id: int
|
||||
delivery_time: str
|
||||
is_active: bool
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 推送渠道端点 (UserPushEndpoint)
|
||||
# ==========================================
|
||||
class PushEndpointCreate(BaseModel):
|
||||
"""新增推送渠道请求体"""
|
||||
channel_type: str = Field(..., max_length=50, description="渠道类型,如 EMAIL / WECHAT_BOT / TELEGRAM")
|
||||
channel_account: str = Field(..., max_length=255, description="具体接收账号(邮箱地址/Webhook等)")
|
||||
is_active: bool = Field(default=True, description="是否启用")
|
||||
priority_level: int = Field(default=1, ge=1, le=10, description="优先级,1最高")
|
||||
|
||||
|
||||
class PushEndpointUpdate(BaseModel):
|
||||
"""更新推送渠道请求体"""
|
||||
channel_account: Optional[str] = Field(None, max_length=255)
|
||||
is_active: Optional[bool] = None
|
||||
priority_level: Optional[int] = Field(None, ge=1, le=10)
|
||||
|
||||
|
||||
class PushEndpointResponse(BaseModel):
|
||||
"""推送渠道响应体"""
|
||||
id: int
|
||||
user_id: int
|
||||
channel_type: str
|
||||
channel_account: str
|
||||
is_active: bool
|
||||
priority_level: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 推送设置聚合响应(一次性返回全部推送配置)
|
||||
# ==========================================
|
||||
class UserDeliveryConfigResponse(BaseModel):
|
||||
"""用户的完整推送配置(时间表 + 渠道列表)"""
|
||||
schedules: List[DeliveryScheduleResponse] = Field(default_factory=list)
|
||||
endpoints: List[PushEndpointResponse] = Field(default_factory=list)
|
||||
@@ -1,23 +1,30 @@
|
||||
# app/schemas/event_schema.py
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
from typing import List, Optional
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class PlatformTrendResponse(BaseModel):
|
||||
source_id: int
|
||||
platform_name: str # 平台名称,如 "微博热搜"
|
||||
headline: str # 平台对应的具体热搜标题
|
||||
url: Optional[str] # 跳转链接
|
||||
current_ranking: Optional[int] # 当前排名
|
||||
ranking_history: List[int] # 排名历史轨迹,如 [50, 45, 20, 5, 1],供 ApexCharts 渲染
|
||||
platform_name: str
|
||||
headline: str
|
||||
url: Optional[str]
|
||||
current_ranking: Optional[int]
|
||||
ranking_history: List[int]
|
||||
|
||||
|
||||
class UnifiedEventResponse(BaseModel):
|
||||
event_id: int
|
||||
unified_title: str # AI 生成的统一大标题
|
||||
summary: Optional[str] # AI 生成的摘要
|
||||
hot_score: int # 总热度值
|
||||
created_at: datetime # 事件发现时间
|
||||
platforms: List[PlatformTrendResponse] # 挂载的各个平台子热搜
|
||||
# tags: List[str] = [] # 如果后续打通了 ExtractedTopic,可以在这里返回标签
|
||||
unified_title: str
|
||||
summary: Optional[str]
|
||||
hot_score: int
|
||||
created_at: datetime
|
||||
platforms: List[PlatformTrendResponse]
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class PaginatedUnifiedEventResponse(BaseModel):
|
||||
"""分页包装:避免一次性返回全量数据"""
|
||||
total: int
|
||||
has_more: bool
|
||||
data: List[UnifiedEventResponse]
|
||||
|
||||
@@ -0,0 +1,46 @@
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class UserTopicPreferenceCreate(BaseModel):
|
||||
"""新增用户兴趣词请求体。"""
|
||||
interested_keyword: str = Field(..., min_length=1, max_length=100, description="用户感兴趣的关键词")
|
||||
|
||||
|
||||
class UserTopicPreferenceResponse(BaseModel):
|
||||
"""用户兴趣词响应体。"""
|
||||
id: int
|
||||
user_id: int
|
||||
interested_keyword: str
|
||||
created_at: datetime
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
|
||||
class EventMatchSemanticHit(BaseModel):
|
||||
"""语义命中的明细。"""
|
||||
preference_keyword: str
|
||||
topic_keyword: str
|
||||
similarity: float
|
||||
|
||||
|
||||
class MatchedEventResponse(BaseModel):
|
||||
"""推荐事件响应体。"""
|
||||
event_id: int
|
||||
unified_title: str
|
||||
summary: Optional[str]
|
||||
hot_score: int
|
||||
created_at: datetime
|
||||
tags: List[str] = Field(default_factory=list)
|
||||
match_score: float
|
||||
exact_hits: List[str] = Field(default_factory=list)
|
||||
semantic_hits: List[EventMatchSemanticHit] = Field(default_factory=list)
|
||||
|
||||
|
||||
class UserPreferenceRecommendationResponse(BaseModel):
|
||||
"""用户兴趣推荐结果。"""
|
||||
user_id: int
|
||||
total: int
|
||||
data: List[MatchedEventResponse] = Field(default_factory=list)
|
||||
@@ -0,0 +1,454 @@
|
||||
# 定时推送调度服务
|
||||
# 由 APScheduler 每分钟调用,检查当前时刻是否有用户需要接收推送,
|
||||
# 如匹配则生成摘要邮件并发送,同时写入 DeliveryHistory 防重复。
|
||||
import logging
|
||||
from logging.handlers import TimedRotatingFileHandler
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime, time as dt_time, timedelta, timezone, tzinfo
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import SessionLocal
|
||||
from app.models.models import (
|
||||
AppUser,
|
||||
DeliveryHistory,
|
||||
ExtractedTopic,
|
||||
InfoSource,
|
||||
TargetType,
|
||||
TaskStatus,
|
||||
TrendingEvent,
|
||||
UnifiedEvent,
|
||||
UserDeliverySchedule,
|
||||
UserPushEndpoint,
|
||||
UserTopicPreference,
|
||||
utcnow,
|
||||
)
|
||||
from app.prompts.digest_email_template import build_digest_html
|
||||
from app.services.matching_service import recommend_events_for_user
|
||||
from app.utils.email_utils import send_html_email
|
||||
|
||||
logger = logging.getLogger("delivery_service")
|
||||
|
||||
# delivery_service 日志单独写文件,不再输出到控制台
|
||||
_delivery_log_dir = Path(__file__).resolve().parents[2] / "logs"
|
||||
_delivery_log_dir.mkdir(parents=True, exist_ok=True)
|
||||
_delivery_log_file = _delivery_log_dir / "delivery_check.log"
|
||||
if not logger.handlers:
|
||||
_file_handler = TimedRotatingFileHandler(
|
||||
filename=str(_delivery_log_file),
|
||||
when="midnight",
|
||||
interval=1,
|
||||
backupCount=14,
|
||||
encoding="utf-8",
|
||||
)
|
||||
_file_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
|
||||
logger.addHandler(_file_handler)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.propagate = False
|
||||
|
||||
# 推送时间窗口:实际执行时刻与设定时间的最大容差(分钟)
|
||||
DELIVERY_WINDOW_MINUTES = 2
|
||||
# 同一用户两次推送之间的最小间隔(分钟)
|
||||
MIN_PUSH_INTERVAL_MINUTES = 30
|
||||
# 单次推送最多携带的事件数
|
||||
MAX_EVENTS_PER_PUSH = 12
|
||||
# 默认模式热度阈值(无关键词或无匹配时使用)
|
||||
DEFAULT_MODE_HOT_THRESHOLD = 3
|
||||
# 默认模式查询时间窗口(小时)
|
||||
DEFAULT_MODE_HOURS = 48
|
||||
# 用户时区无效时的兜底时区
|
||||
DEFAULT_FALLBACK_TIMEZONE = "Asia/Shanghai"
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 默认热点事件容器(无关键词时使用)
|
||||
# ==========================================
|
||||
@dataclass
|
||||
class _DefaultEventItem:
|
||||
"""
|
||||
无关键词订阅或关键词无匹配时的默认热点包装器,
|
||||
接口与 MatchedEventResult 保持一致,方便统一传给模板。
|
||||
"""
|
||||
event: UnifiedEvent
|
||||
match_score: float = 0.0
|
||||
exact_hits: list[str] = field(default_factory=list)
|
||||
semantic_hits: list[dict[str, Any]] = field(default_factory=list)
|
||||
tags: list[str] = field(default_factory=list)
|
||||
is_default: bool = True
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 时区工具
|
||||
# ==========================================
|
||||
def _time_to_minutes(t: dt_time) -> int:
|
||||
return t.hour * 60 + t.minute
|
||||
|
||||
|
||||
def _is_within_window(schedule_time: dt_time, current_time: dt_time, window: int = DELIVERY_WINDOW_MINUTES) -> bool:
|
||||
"""判断 schedule_time 是否在 current_time ± window 分钟范围内(跨午夜安全)。"""
|
||||
s = _time_to_minutes(schedule_time)
|
||||
c = _time_to_minutes(current_time)
|
||||
diff = abs(s - c)
|
||||
return min(diff, 1440 - diff) <= window
|
||||
|
||||
|
||||
def _resolve_user_timezone(user_timezone: str | None) -> tzinfo:
|
||||
"""解析用户时区,异常时回退到默认时区。"""
|
||||
tz_name = (user_timezone or "").strip() or DEFAULT_FALLBACK_TIMEZONE
|
||||
try:
|
||||
return ZoneInfo(tz_name)
|
||||
except ZoneInfoNotFoundError:
|
||||
logger.warning(
|
||||
"用户时区无效,已回退默认时区。timezone=%s fallback=%s",
|
||||
tz_name, DEFAULT_FALLBACK_TIMEZONE,
|
||||
)
|
||||
try:
|
||||
return ZoneInfo(DEFAULT_FALLBACK_TIMEZONE)
|
||||
except ZoneInfoNotFoundError:
|
||||
logger.warning("系统缺少时区数据库,最终回退为 UTC。建议安装 tzdata 包。")
|
||||
return timezone.utc
|
||||
|
||||
|
||||
def _user_local_time(now_utc: datetime, user_timezone: str | None) -> dt_time:
|
||||
"""把 UTC 当前时刻转换为用户本地时间(仅取 HH:MM)。"""
|
||||
local_dt = now_utc.astimezone(_resolve_user_timezone(user_timezone))
|
||||
return local_dt.time().replace(second=0, microsecond=0)
|
||||
|
||||
|
||||
def _ensure_aware(dt: datetime) -> datetime:
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 数据库查询辅助
|
||||
# ==========================================
|
||||
def _should_skip_by_interval(db: Session, user_id: int) -> bool:
|
||||
"""检查用户是否仍在 30 分钟冷却期内。"""
|
||||
row = (
|
||||
db.query(DeliveryHistory.created_at)
|
||||
.filter(
|
||||
DeliveryHistory.user_id == user_id,
|
||||
DeliveryHistory.status == TaskStatus.SUCCESS,
|
||||
)
|
||||
.order_by(DeliveryHistory.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
if row is None:
|
||||
return False
|
||||
last_time = _ensure_aware(row[0])
|
||||
elapsed = (utcnow() - last_time).total_seconds() / 60.0
|
||||
return elapsed < MIN_PUSH_INTERVAL_MINUTES
|
||||
|
||||
|
||||
def _get_user_email_endpoints(db: Session, user_id: int) -> list[UserPushEndpoint]:
|
||||
"""获取用户已启用的邮件类型推送渠道,按优先级排序。"""
|
||||
return (
|
||||
db.query(UserPushEndpoint)
|
||||
.filter(
|
||||
UserPushEndpoint.user_id == user_id,
|
||||
UserPushEndpoint.channel_type == "EMAIL",
|
||||
UserPushEndpoint.is_active == True,
|
||||
)
|
||||
.order_by(UserPushEndpoint.priority_level.asc())
|
||||
.all()
|
||||
)
|
||||
|
||||
|
||||
def _get_already_pushed_event_ids(db: Session, user_id: int) -> set[int]:
|
||||
"""获取已经推送过的事件 ID 集合,避免重复轰炸。"""
|
||||
rows = (
|
||||
db.query(DeliveryHistory.target_id)
|
||||
.filter(
|
||||
DeliveryHistory.user_id == user_id,
|
||||
DeliveryHistory.target_type == TargetType.EVENT,
|
||||
DeliveryHistory.status == TaskStatus.SUCCESS,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
return {r[0] for r in rows}
|
||||
|
||||
|
||||
def _load_event_platforms(db: Session, event_ids: list[int]) -> dict[int, list[dict]]:
|
||||
"""
|
||||
批量加载事件的平台来源数据。
|
||||
返回:event_id → [{source_name, headline, url, ranking, icon_url}, ...]
|
||||
按排名升序排列(rank 1 最靠前)。
|
||||
"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
|
||||
rows = (
|
||||
db.query(
|
||||
TrendingEvent.unified_event_id,
|
||||
TrendingEvent.current_headline,
|
||||
TrendingEvent.event_url,
|
||||
TrendingEvent.current_ranking,
|
||||
TrendingEvent.icon_url,
|
||||
InfoSource.source_name,
|
||||
)
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
|
||||
.filter(TrendingEvent.unified_event_id.in_(event_ids))
|
||||
.order_by(
|
||||
TrendingEvent.unified_event_id,
|
||||
TrendingEvent.current_ranking.asc().nulls_last(),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
result: dict[int, list[dict]] = {}
|
||||
for event_id, headline, url, ranking, icon_url, source_name in rows:
|
||||
result.setdefault(event_id, []).append({
|
||||
"source_name": source_name or "未知",
|
||||
"headline": headline or "",
|
||||
"url": url or "",
|
||||
"ranking": ranking,
|
||||
"icon_url": icon_url or "",
|
||||
})
|
||||
return result
|
||||
|
||||
|
||||
def _load_event_tags(db: Session, event_ids: list[int]) -> dict[int, list[str]]:
|
||||
"""批量加载事件的标签,返回 event_id → [tag, ...]。"""
|
||||
if not event_ids:
|
||||
return {}
|
||||
rows = (
|
||||
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
|
||||
.filter(
|
||||
ExtractedTopic.target_type == TargetType.EVENT,
|
||||
ExtractedTopic.target_id.in_(event_ids),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
tags_map: dict[int, list[str]] = {}
|
||||
for eid, kw in rows:
|
||||
if kw:
|
||||
tags_map.setdefault(eid, []).append(kw)
|
||||
return tags_map
|
||||
|
||||
|
||||
def _user_has_keywords(db: Session, user_id: int) -> bool:
|
||||
"""判断用户是否配置了关键词订阅。"""
|
||||
return (
|
||||
db.query(UserTopicPreference.id)
|
||||
.filter(UserTopicPreference.user_id == user_id)
|
||||
.first()
|
||||
) is not None
|
||||
|
||||
|
||||
def _get_default_hot_events(
|
||||
db: Session,
|
||||
pushed_ids: set[int],
|
||||
) -> list[_DefaultEventItem]:
|
||||
"""
|
||||
默认模式:获取热度 >= DEFAULT_MODE_HOT_THRESHOLD 的近期热点,
|
||||
排除已推送过的,封装成与 MatchedEventResult 接口相同的对象。
|
||||
"""
|
||||
time_limit = utcnow() - timedelta(hours=DEFAULT_MODE_HOURS)
|
||||
events = (
|
||||
db.query(UnifiedEvent)
|
||||
.filter(
|
||||
UnifiedEvent.hot_score >= DEFAULT_MODE_HOT_THRESHOLD,
|
||||
UnifiedEvent.created_at >= time_limit,
|
||||
)
|
||||
.order_by(UnifiedEvent.hot_score.desc())
|
||||
.limit(MAX_EVENTS_PER_PUSH * 2)
|
||||
.all()
|
||||
)
|
||||
|
||||
event_ids = [e.id for e in events if e.id not in pushed_ids]
|
||||
tags_map = _load_event_tags(db, event_ids)
|
||||
|
||||
result: list[_DefaultEventItem] = []
|
||||
for ev in events:
|
||||
if ev.id in pushed_ids:
|
||||
continue
|
||||
result.append(_DefaultEventItem(
|
||||
event=ev,
|
||||
tags=list(dict.fromkeys(tags_map.get(ev.id, [])))[:6],
|
||||
))
|
||||
if len(result) >= MAX_EVENTS_PER_PUSH:
|
||||
break
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _record_delivery(
|
||||
db: Session,
|
||||
user_id: int,
|
||||
event_ids: list[int],
|
||||
status: TaskStatus,
|
||||
) -> None:
|
||||
"""批量写入推送历史记录。"""
|
||||
for eid in event_ids:
|
||||
record = DeliveryHistory(
|
||||
user_id=user_id,
|
||||
target_type=TargetType.EVENT,
|
||||
target_id=eid,
|
||||
status=status,
|
||||
)
|
||||
db.add(record)
|
||||
db.commit()
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 推送准备
|
||||
# ==========================================
|
||||
@dataclass
|
||||
class _PendingPush:
|
||||
"""暂存需要发送邮件的信息,便于在 async 上下文中发送。"""
|
||||
user_id: int
|
||||
email_targets: list[str]
|
||||
subject: str
|
||||
html_body: str
|
||||
event_ids: list[int]
|
||||
|
||||
|
||||
def _prepare_user_push(db: Session, user: AppUser, schedule: UserDeliverySchedule) -> _PendingPush | None:
|
||||
"""
|
||||
同步准备单个用户的推送数据(DB 操作),不实际发送邮件。
|
||||
推送优先级:
|
||||
1. 有关键词 且 有匹配 → 发送匹配事件
|
||||
2. 有关键词 但 无匹配 → 发送默认热点(热度 >= 3)
|
||||
3. 无关键词 → 发送默认热点(热度 >= 3)
|
||||
"""
|
||||
user_id = user.id
|
||||
|
||||
if _should_skip_by_interval(db, user_id):
|
||||
logger.info(f"用户 {user_id} 仍在 {MIN_PUSH_INTERVAL_MINUTES} 分钟冷却期内,跳过")
|
||||
return None
|
||||
|
||||
email_endpoints = _get_user_email_endpoints(db, user_id)
|
||||
if not email_endpoints:
|
||||
logger.info(f"用户 {user_id} 无可用邮件渠道,跳过")
|
||||
return None
|
||||
|
||||
pushed_ids = _get_already_pushed_event_ids(db, user_id)
|
||||
|
||||
# ——— 决策:匹配模式 or 默认模式 ———
|
||||
items: list = []
|
||||
is_default = False
|
||||
|
||||
has_keywords = _user_has_keywords(db, user_id)
|
||||
if has_keywords:
|
||||
matched = recommend_events_for_user(
|
||||
db,
|
||||
user_id=user_id,
|
||||
min_hot=1,
|
||||
hours=72,
|
||||
limit=MAX_EVENTS_PER_PUSH * 2,
|
||||
)
|
||||
fresh_matched = [m for m in matched if m.event.id not in pushed_ids]
|
||||
if fresh_matched:
|
||||
items = fresh_matched[:MAX_EVENTS_PER_PUSH]
|
||||
logger.info(f"用户 {user_id} 关键词匹配,推送 {len(items)} 条事件")
|
||||
else:
|
||||
logger.info(f"用户 {user_id} 关键词无匹配结果,切换为默认热点模式")
|
||||
is_default = True
|
||||
else:
|
||||
logger.info(f"用户 {user_id} 未配置关键词,使用默认热点模式")
|
||||
is_default = True
|
||||
|
||||
if is_default:
|
||||
items = _get_default_hot_events(db, pushed_ids)
|
||||
if not items:
|
||||
logger.info(f"用户 {user_id} 默认热点无可推送内容,跳过")
|
||||
return None
|
||||
|
||||
# 批量加载平台数据(来源名、标题、URL、排名)
|
||||
event_ids = [item.event.id for item in items]
|
||||
platforms_map = _load_event_platforms(db, event_ids)
|
||||
|
||||
time_str = schedule.delivery_time.strftime("%H:%M")
|
||||
html_body = build_digest_html(
|
||||
items=items,
|
||||
delivery_time_str=time_str,
|
||||
platforms_map=platforms_map,
|
||||
is_default_push=is_default,
|
||||
)
|
||||
|
||||
subject_suffix = "全网热点快报" if is_default else "个性化简报"
|
||||
return _PendingPush(
|
||||
user_id=user_id,
|
||||
email_targets=[ep.channel_account for ep in email_endpoints],
|
||||
subject=f"InsightRadar {subject_suffix} · {time_str}",
|
||||
html_body=html_body,
|
||||
event_ids=event_ids,
|
||||
)
|
||||
|
||||
|
||||
# ==========================================
|
||||
# 调度主入口
|
||||
# ==========================================
|
||||
async def check_and_deliver() -> None:
|
||||
"""
|
||||
定时推送主入口,由 APScheduler 每分钟调用。
|
||||
流程:
|
||||
1. 获取当前 UTC 时间
|
||||
2. 查询所有启用的推送计划
|
||||
3. 对每个计划,按用户本地时区判断是否在推送窗口
|
||||
4. 同步准备推送数据 → 异步发送邮件 → 记录结果
|
||||
"""
|
||||
now = datetime.now(timezone.utc)
|
||||
current_utc = now.time().replace(second=0, microsecond=0)
|
||||
logger.debug(f"推送调度检查 @ UTC {current_utc.strftime('%H:%M')}")
|
||||
|
||||
db: Session = SessionLocal()
|
||||
try:
|
||||
active_schedules = (
|
||||
db.query(UserDeliverySchedule)
|
||||
.filter(UserDeliverySchedule.is_active == True)
|
||||
.all()
|
||||
)
|
||||
|
||||
for schedule in active_schedules:
|
||||
user = db.query(AppUser).filter(AppUser.id == schedule.user_id).first()
|
||||
if not user:
|
||||
continue
|
||||
|
||||
# 用户本地时间对比(核心时区修正)
|
||||
user_current = _user_local_time(now, user.timezone)
|
||||
if not _is_within_window(schedule.delivery_time, user_current):
|
||||
continue
|
||||
|
||||
try:
|
||||
pending = _prepare_user_push(db, user, schedule)
|
||||
if pending is None:
|
||||
continue
|
||||
|
||||
# 异步按优先级尝试各邮件渠道
|
||||
sent = False
|
||||
for target_email in pending.email_targets:
|
||||
try:
|
||||
success = await send_html_email(
|
||||
to_email=target_email,
|
||||
subject=pending.subject,
|
||||
html_content=pending.html_body,
|
||||
)
|
||||
if success:
|
||||
sent = True
|
||||
logger.info(f"用户 {pending.user_id} 邮件发送成功 → {target_email}")
|
||||
break
|
||||
else:
|
||||
logger.warning(f"用户 {pending.user_id} 渠道 {target_email} 发送失败,尝试下一个")
|
||||
except Exception as e:
|
||||
logger.error(f"用户 {pending.user_id} 发送至 {target_email} 异常: {e}")
|
||||
|
||||
_record_delivery(
|
||||
db,
|
||||
user_id=pending.user_id,
|
||||
event_ids=pending.event_ids,
|
||||
status=TaskStatus.SUCCESS if sent else TaskStatus.ERROR,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推送用户 {schedule.user_id} 时异常: {e}", exc_info=True)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"推送调度主循环异常: {e}", exc_info=True)
|
||||
finally:
|
||||
db.close()
|
||||
@@ -0,0 +1,238 @@
|
||||
import os
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.models.models import ExtractedTopic, TargetType, UnifiedEvent, UserTopicPreference, utcnow
|
||||
from app.services.fetcher_service import embedder_model
|
||||
|
||||
|
||||
# 语义匹配阈值:用户关键词和事件标签向量相似度达到该值才计入语义命中
|
||||
DEFAULT_PREFERENCE_SEMANTIC_THRESHOLD = 0.78
|
||||
PREFERENCE_SEMANTIC_THRESHOLD = float(
|
||||
os.getenv("PREFERENCE_SEMANTIC_THRESHOLD", str(DEFAULT_PREFERENCE_SEMANTIC_THRESHOLD))
|
||||
)
|
||||
# 推荐列表最大返回条数
|
||||
DEFAULT_PREFERENCE_RECOMMEND_MAX_LIMIT = 50
|
||||
PREFERENCE_RECOMMEND_MAX_LIMIT = int(
|
||||
os.getenv("PREFERENCE_RECOMMEND_MAX_LIMIT", str(DEFAULT_PREFERENCE_RECOMMEND_MAX_LIMIT))
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class MatchedEventResult:
|
||||
"""用户兴趣匹配后的事件结果。"""
|
||||
event: UnifiedEvent
|
||||
match_score: float
|
||||
exact_hits: list[str]
|
||||
semantic_hits: list[dict[str, Any]]
|
||||
tags: list[str]
|
||||
|
||||
|
||||
def _normalize_text(text: str) -> str:
|
||||
"""统一小写与首尾空白,便于做稳定匹配。"""
|
||||
return text.strip().casefold()
|
||||
|
||||
|
||||
def _build_keyword_embedding_map(keywords: list[str]) -> dict[str, np.ndarray]:
|
||||
"""
|
||||
批量生成关键词向量,并返回原词到向量的映射。
|
||||
这里要求向量已归一化,后续可直接用点积表示余弦相似度。
|
||||
"""
|
||||
if not keywords:
|
||||
return {}
|
||||
|
||||
vectors = embedder_model.encode(keywords, normalize_embeddings=True)
|
||||
result: dict[str, np.ndarray] = {}
|
||||
for keyword, vec in zip(keywords, vectors):
|
||||
result[keyword] = np.asarray(vec, dtype=np.float32)
|
||||
return result
|
||||
|
||||
|
||||
def _ensure_aware(dt: datetime) -> datetime:
|
||||
"""SQLite 读出的 datetime 不带时区信息,统一补上 UTC 后才能和 utcnow() 做减法。"""
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
def _calc_freshness_bonus(event: UnifiedEvent) -> float:
|
||||
"""根据事件新鲜度给一个小额加分,避免旧热点长期占据推荐位。"""
|
||||
age_hours = max((utcnow() - _ensure_aware(event.created_at)).total_seconds() / 3600.0, 0.0)
|
||||
if age_hours <= 6:
|
||||
return 12.0
|
||||
if age_hours <= 24:
|
||||
return 8.0
|
||||
if age_hours <= 72:
|
||||
return 4.0
|
||||
return 0.0
|
||||
|
||||
|
||||
def recommend_events_for_user(
|
||||
db: Session,
|
||||
*,
|
||||
user_id: int,
|
||||
min_hot: int = 3,
|
||||
hours: int = 72,
|
||||
limit: int = 20,
|
||||
semantic_threshold: float | None = None,
|
||||
) -> list[MatchedEventResult]:
|
||||
"""
|
||||
用户兴趣推荐主流程:
|
||||
1) 精确匹配:用户词 == EVENT 标签
|
||||
2) 语义匹配:用户词向量 vs EVENT 标签向量(超过阈值)
|
||||
3) 打分融合:匹配分 + 标签相关度 + 热度 + 新鲜度
|
||||
"""
|
||||
final_limit = max(1, min(limit, PREFERENCE_RECOMMEND_MAX_LIMIT))
|
||||
similarity_threshold = (
|
||||
semantic_threshold
|
||||
if semantic_threshold is not None
|
||||
else PREFERENCE_SEMANTIC_THRESHOLD
|
||||
)
|
||||
|
||||
# 读取用户兴趣词
|
||||
preferences = (
|
||||
db.query(UserTopicPreference)
|
||||
.filter(UserTopicPreference.user_id == user_id)
|
||||
.all()
|
||||
)
|
||||
if not preferences:
|
||||
return []
|
||||
|
||||
preference_keywords = [pref.interested_keyword.strip() for pref in preferences if pref.interested_keyword.strip()]
|
||||
if not preference_keywords:
|
||||
return []
|
||||
|
||||
# 读取候选事件(先做时间和热度过滤,避免全表扫描)
|
||||
time_limit = utcnow() - timedelta(hours=hours)
|
||||
events = (
|
||||
db.query(UnifiedEvent)
|
||||
.filter(
|
||||
UnifiedEvent.hot_score >= min_hot,
|
||||
UnifiedEvent.created_at >= time_limit,
|
||||
)
|
||||
.order_by(UnifiedEvent.hot_score.desc(), UnifiedEvent.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
if not events:
|
||||
return []
|
||||
|
||||
event_id_list = [event.id for event in events]
|
||||
topic_rows = (
|
||||
db.query(
|
||||
ExtractedTopic.target_id,
|
||||
ExtractedTopic.topic_keyword,
|
||||
ExtractedTopic.relevance_score,
|
||||
)
|
||||
.filter(
|
||||
ExtractedTopic.target_type == TargetType.EVENT,
|
||||
ExtractedTopic.target_id.in_(event_id_list),
|
||||
)
|
||||
.all()
|
||||
)
|
||||
if not topic_rows:
|
||||
return []
|
||||
|
||||
# 组织事件标签映射:event_id -> [(tag, relevance_score), ...]
|
||||
event_topics: dict[int, list[tuple[str, float | None]]] = {}
|
||||
for event_id, topic_keyword, relevance_score in topic_rows:
|
||||
if not topic_keyword:
|
||||
continue
|
||||
event_topics.setdefault(event_id, []).append((topic_keyword, relevance_score))
|
||||
|
||||
# 如果某事件没有标签,就不参与推荐
|
||||
if not event_topics:
|
||||
return []
|
||||
|
||||
# 批量编码用户词和标签词,避免逐条调用模型
|
||||
unique_preference_keywords = list(dict.fromkeys(preference_keywords))
|
||||
unique_topic_keywords = list(dict.fromkeys([row[1] for row in topic_rows if row[1]]))
|
||||
pref_vec_map = _build_keyword_embedding_map(unique_preference_keywords)
|
||||
topic_vec_map = _build_keyword_embedding_map(unique_topic_keywords)
|
||||
|
||||
# 预先建立“标准化后用户词集合”,用于精确匹配
|
||||
normalized_pref_set = {_normalize_text(word) for word in unique_preference_keywords}
|
||||
|
||||
scored_results: list[MatchedEventResult] = []
|
||||
for event in events:
|
||||
topic_list = event_topics.get(event.id, [])
|
||||
if not topic_list:
|
||||
continue
|
||||
|
||||
exact_hits: list[str] = []
|
||||
semantic_hits: list[dict[str, Any]] = []
|
||||
score = 0.0
|
||||
|
||||
# 对事件标签逐个匹配用户兴趣
|
||||
for topic_keyword, topic_relevance in topic_list:
|
||||
normalized_topic = _normalize_text(topic_keyword)
|
||||
topic_relevance_score = float(topic_relevance) if topic_relevance is not None else 50.0
|
||||
|
||||
# 1) 精确命中(包括完全相等与包含关系)
|
||||
matched_exact = False
|
||||
if normalized_topic in normalized_pref_set:
|
||||
matched_exact = True
|
||||
else:
|
||||
for pref_word in normalized_pref_set:
|
||||
if pref_word and (pref_word in normalized_topic or normalized_topic in pref_word):
|
||||
matched_exact = True
|
||||
break
|
||||
|
||||
if matched_exact:
|
||||
exact_hits.append(topic_keyword)
|
||||
# 精确命中给较高基础分,标签自身相关度作为增益
|
||||
score += 45.0 + topic_relevance_score * 0.2
|
||||
continue
|
||||
|
||||
# 2) 语义命中(未精确命中时再算)
|
||||
topic_vec = topic_vec_map.get(topic_keyword)
|
||||
if topic_vec is None:
|
||||
continue
|
||||
|
||||
best_pref = None
|
||||
best_sim = -1.0
|
||||
for pref_keyword, pref_vec in pref_vec_map.items():
|
||||
sim = float(np.dot(topic_vec, pref_vec))
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_pref = pref_keyword
|
||||
|
||||
if best_pref is not None and best_sim >= similarity_threshold:
|
||||
semantic_hits.append(
|
||||
{
|
||||
"preference_keyword": best_pref,
|
||||
"topic_keyword": topic_keyword,
|
||||
"similarity": round(best_sim, 4),
|
||||
}
|
||||
)
|
||||
# 语义命中分略低于精确命中,并由相似度放大
|
||||
score += best_sim * 35.0 + topic_relevance_score * 0.12
|
||||
|
||||
# 如果精确和语义都没命中,直接跳过
|
||||
if not exact_hits and not semantic_hits:
|
||||
continue
|
||||
|
||||
# 融合事件热度和新鲜度,避免只看语义分
|
||||
score += min(event.hot_score, 100) * 0.3
|
||||
score += _calc_freshness_bonus(event)
|
||||
|
||||
# 返回标签时做去重,保证接口稳定
|
||||
tags = list(dict.fromkeys([item[0] for item in topic_list]))
|
||||
scored_results.append(
|
||||
MatchedEventResult(
|
||||
event=event,
|
||||
match_score=round(score, 2),
|
||||
exact_hits=list(dict.fromkeys(exact_hits)),
|
||||
semantic_hits=semantic_hits,
|
||||
tags=tags,
|
||||
)
|
||||
)
|
||||
|
||||
scored_results.sort(
|
||||
key=lambda item: (item.match_score, item.event.hot_score, item.event.created_at),
|
||||
reverse=True,
|
||||
)
|
||||
return scored_results[:final_limit]
|
||||
@@ -1,104 +1,241 @@
|
||||
# app/services/summary_service.py
|
||||
import os
|
||||
import json
|
||||
import os
|
||||
from datetime import timedelta
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from openai import AsyncOpenAI
|
||||
|
||||
from app.database import SessionLocal
|
||||
from app.models.models import UnifiedEvent, TrendingEvent, InfoSource, utcnow
|
||||
from app.models.models import (
|
||||
ExtractedTopic,
|
||||
InfoSource,
|
||||
TargetType,
|
||||
TrendingEvent,
|
||||
UnifiedEvent,
|
||||
utcnow,
|
||||
)
|
||||
from app.prompts.summary_prompts import (
|
||||
SUMMARY_SYSTEM_PROMPT,
|
||||
SUMMARY_USER_PROMPT_TEMPLATE,
|
||||
)
|
||||
from app.services.fetcher_service import embedder_model
|
||||
|
||||
HOT_SCORE_THRESHOLD = int(os.getenv("HOT_SCORE_THRESHOLD", 3))
|
||||
AI_API_KEY = os.getenv("AI_API_KEY", '')
|
||||
TOPIC_TAG_MIN_HOT_SCORE = int(os.getenv("TOPIC_TAG_MIN_HOT_SCORE", HOT_SCORE_THRESHOLD))
|
||||
TOPIC_SIMILARITY_THRESHOLD = float(os.getenv("TOPIC_SIMILARITY_THRESHOLD", 0.82))
|
||||
TOPIC_TAG_MAX_COUNT = int(os.getenv("TOPIC_TAG_MAX_COUNT", 8))
|
||||
AI_API_KEY = os.getenv("AI_API_KEY", "")
|
||||
|
||||
|
||||
# 1. 初始化异步客户端 (全局复用)
|
||||
deepseek_client = AsyncOpenAI(
|
||||
api_key=AI_API_KEY,
|
||||
base_url="https://api.deepseek.com"
|
||||
base_url="https://api.deepseek.com",
|
||||
)
|
||||
|
||||
|
||||
async def call_llm_for_summary(platform_data_text: str) -> dict:
|
||||
"""调用 DeepSeek 生成统一标题和多平台视角摘要"""
|
||||
prompt = SUMMARY_USER_PROMPT_TEMPLATE.format(
|
||||
platform_data_text=platform_data_text
|
||||
)
|
||||
"""Call LLM for unified title, summary and topic candidates."""
|
||||
prompt = SUMMARY_USER_PROMPT_TEMPLATE.format(platform_data_text=platform_data_text)
|
||||
|
||||
# await
|
||||
response = await deepseek_client.chat.completions.create(
|
||||
model="deepseek-chat",
|
||||
messages=[
|
||||
{"role": "system", "content": SUMMARY_SYSTEM_PROMPT},
|
||||
{"role": "user", "content": prompt}
|
||||
{"role": "user", "content": prompt},
|
||||
],
|
||||
response_format={"type": "json_object"},
|
||||
temperature=1
|
||||
temperature=1,
|
||||
)
|
||||
|
||||
result_text = response.choices[0].message.content
|
||||
return json.loads(result_text)
|
||||
|
||||
|
||||
def _normalize_score(raw_score: Any) -> float | None:
|
||||
try:
|
||||
score = float(raw_score)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
if score <= 1:
|
||||
score *= 100
|
||||
|
||||
return max(0.0, min(100.0, score))
|
||||
|
||||
|
||||
def parse_topic_keywords(llm_result: dict) -> list[dict[str, Any]]:
|
||||
"""Parse topic keywords from LLM response; support list[str] and list[object]."""
|
||||
raw_topics = llm_result.get("topic_keywords") or []
|
||||
parsed: list[dict[str, Any]] = []
|
||||
seen: set[str] = set()
|
||||
|
||||
for item in raw_topics:
|
||||
keyword = ""
|
||||
score = None
|
||||
|
||||
if isinstance(item, str):
|
||||
keyword = item.strip()
|
||||
elif isinstance(item, dict):
|
||||
raw_keyword = (
|
||||
item.get("keyword")
|
||||
or item.get("topic_keyword")
|
||||
or item.get("name")
|
||||
or item.get("topic")
|
||||
or ""
|
||||
)
|
||||
keyword = str(raw_keyword).strip()
|
||||
score = _normalize_score(item.get("relevance_score") or item.get("score"))
|
||||
|
||||
if not keyword:
|
||||
continue
|
||||
|
||||
keyword = keyword[:100]
|
||||
normalized_key = keyword.casefold()
|
||||
if normalized_key in seen:
|
||||
continue
|
||||
|
||||
seen.add(normalized_key)
|
||||
parsed.append({"keyword": keyword, "score": score})
|
||||
|
||||
return parsed
|
||||
|
||||
|
||||
def normalize_topic_keywords(topic_candidates: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""Deduplicate semantically similar tags using embedding similarity."""
|
||||
if not topic_candidates:
|
||||
return []
|
||||
|
||||
keywords = [item["keyword"] for item in topic_candidates]
|
||||
vectors = embedder_model.encode(keywords, normalize_embeddings=True)
|
||||
|
||||
clusters: list[dict[str, Any]] = []
|
||||
for item, vector in zip(topic_candidates, vectors):
|
||||
vec = np.asarray(vector, dtype=np.float32)
|
||||
|
||||
best_idx = -1
|
||||
best_sim = -1.0
|
||||
for idx, cluster in enumerate(clusters):
|
||||
sim = float(np.dot(vec, cluster["vector"]))
|
||||
if sim > best_sim:
|
||||
best_sim = sim
|
||||
best_idx = idx
|
||||
|
||||
if best_idx >= 0 and best_sim >= TOPIC_SIMILARITY_THRESHOLD:
|
||||
cluster = clusters[best_idx]
|
||||
merged = cluster["vector"] * cluster["count"] + vec
|
||||
norm = float(np.linalg.norm(merged))
|
||||
if norm > 0:
|
||||
cluster["vector"] = merged / norm
|
||||
|
||||
cluster["count"] += 1
|
||||
if item["score"] is not None and (
|
||||
cluster["score"] is None or item["score"] > cluster["score"]
|
||||
):
|
||||
cluster["score"] = item["score"]
|
||||
|
||||
# Prefer shorter tag as canonical keyword.
|
||||
if len(item["keyword"]) < len(cluster["keyword"]):
|
||||
cluster["keyword"] = item["keyword"]
|
||||
else:
|
||||
clusters.append(
|
||||
{
|
||||
"keyword": item["keyword"],
|
||||
"score": item["score"],
|
||||
"vector": vec,
|
||||
"count": 1,
|
||||
}
|
||||
)
|
||||
|
||||
if any(cluster["score"] is not None for cluster in clusters):
|
||||
clusters.sort(key=lambda x: x["score"] if x["score"] is not None else -1.0, reverse=True)
|
||||
|
||||
result = [
|
||||
{"keyword": cluster["keyword"], "score": cluster["score"]}
|
||||
for cluster in clusters[:TOPIC_TAG_MAX_COUNT]
|
||||
]
|
||||
return result
|
||||
|
||||
|
||||
def replace_event_topics(db, event_id: int, normalized_topics: list[dict[str, Any]]) -> None:
|
||||
"""Replace EVENT tags for one unified event atomically within current transaction."""
|
||||
db.query(ExtractedTopic).filter(
|
||||
ExtractedTopic.target_type == TargetType.EVENT,
|
||||
ExtractedTopic.target_id == event_id,
|
||||
).delete(synchronize_session=False)
|
||||
|
||||
for item in normalized_topics:
|
||||
db.add(
|
||||
ExtractedTopic(
|
||||
target_type=TargetType.EVENT,
|
||||
target_id=event_id,
|
||||
topic_keyword=item["keyword"],
|
||||
relevance_score=item["score"],
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
async def generate_unified_summaries():
|
||||
"""定时任务:扫描高热度事件并生成/更新摘要"""
|
||||
print(f"[{utcnow()}] 开始执行 DeepSeek 摘要生成任务...")
|
||||
"""Scheduled task: refresh summaries and topic tags for hot unified events."""
|
||||
print(f"[{utcnow()}] Start unified summary generation task...")
|
||||
|
||||
with SessionLocal() as db:
|
||||
recent_threshold = utcnow() - timedelta(days=3)
|
||||
|
||||
# 必须满足:热度达标 AND (当前热度 > 上次生成摘要时的热度) AND 近期活跃
|
||||
events = db.query(UnifiedEvent).filter(
|
||||
UnifiedEvent.hot_score >= HOT_SCORE_THRESHOLD,
|
||||
UnifiedEvent.hot_score > UnifiedEvent.last_summarized_trends_count,
|
||||
UnifiedEvent.created_at >= recent_threshold
|
||||
UnifiedEvent.created_at >= recent_threshold,
|
||||
).all()
|
||||
|
||||
if not events:
|
||||
print("当前没有需要更新摘要的大事件,任务结束。")
|
||||
print("No events require summary update in this round.")
|
||||
return
|
||||
|
||||
for event in events:
|
||||
# 联合查询获取该事件在各平台的子新闻
|
||||
trends = db.query(TrendingEvent, InfoSource.source_name) \
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id) \
|
||||
.filter(TrendingEvent.unified_event_id == event.id) \
|
||||
trends = (
|
||||
db.query(TrendingEvent, InfoSource.source_name)
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
|
||||
.filter(TrendingEvent.unified_event_id == event.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not trends:
|
||||
continue
|
||||
|
||||
# 按平台归类标题并去重
|
||||
platform_dict = {}
|
||||
platform_dict: dict[str, set[str]] = {}
|
||||
for trend_record, source_name in trends:
|
||||
if source_name not in platform_dict:
|
||||
platform_dict[source_name] = set()
|
||||
platform_dict[source_name].add(trend_record.current_headline)
|
||||
platform_dict.setdefault(source_name, set()).add(trend_record.current_headline)
|
||||
|
||||
# 组装给大模型的 Prompt 数据
|
||||
prompt_lines = [f"【{platform}】: {', '.join(headlines)}" for platform, headlines in platform_dict.items()]
|
||||
prompt_lines = [
|
||||
f"[{platform}] {', '.join(sorted(headlines))}"
|
||||
for platform, headlines in platform_dict.items()
|
||||
]
|
||||
platform_data_text = "\n".join(prompt_lines)
|
||||
|
||||
try:
|
||||
# 调用封装好的异步函数
|
||||
llm_result = await call_llm_for_summary(platform_data_text)
|
||||
|
||||
if "unified_title" in llm_result:
|
||||
if "unified_title" in llm_result and llm_result["unified_title"]:
|
||||
event.unified_title = llm_result["unified_title"]
|
||||
if "ai_comprehensive_summary" in llm_result:
|
||||
if "ai_comprehensive_summary" in llm_result and llm_result["ai_comprehensive_summary"]:
|
||||
event.ai_comprehensive_summary = llm_result["ai_comprehensive_summary"]
|
||||
|
||||
# 成功后更新水位线
|
||||
# 将最后一次总结时的热搜数量,更新为当前最新的 hot_score
|
||||
if event.hot_score >= TOPIC_TAG_MIN_HOT_SCORE:
|
||||
topic_candidates = parse_topic_keywords(llm_result)
|
||||
normalized_topics = normalize_topic_keywords(topic_candidates)
|
||||
if normalized_topics:
|
||||
replace_event_topics(db, event.id, normalized_topics)
|
||||
|
||||
event.last_summarized_trends_count = event.hot_score
|
||||
print(
|
||||
f"Updated event {event.id} summary"
|
||||
f" (hot_score={event.hot_score})."
|
||||
)
|
||||
|
||||
print(f"成功更新大事件 ID {event.id} 的深度摘要 (当前热度: {event.hot_score})。")
|
||||
|
||||
except Exception as e:
|
||||
print(f"大事件 ID {event.id} 摘要生成失败: {e}")
|
||||
except Exception as exc:
|
||||
print(f"Event {event.id} summary generation failed: {exc}")
|
||||
continue
|
||||
|
||||
# 提交事务
|
||||
db.commit()
|
||||
|
||||
@@ -17,7 +17,7 @@ async def send_html_email(
|
||||
to_email: str,
|
||||
subject: str,
|
||||
html_content: str,
|
||||
sender_name: str = "AI 新闻早报",
|
||||
sender_name: str = "AI 新闻",
|
||||
sender_email: str = None
|
||||
) -> bool:
|
||||
"""底层纯异步发送邮件工具"""
|
||||
|
||||
Reference in New Issue
Block a user