big update

This commit is contained in:
stardrophere
2026-03-11 20:52:58 +08:00
parent 8ed819a580
commit 966bcfbba4
44 changed files with 7124 additions and 650 deletions
+45 -2
View File
@@ -1,5 +1,12 @@
# app/api/dependencies.py
from fastapi import Depends, HTTPException, status
from fastapi.security import HTTPAuthorizationCredentials, HTTPBearer
from sqlalchemy.orm import Session
from app.core.security import decode_access_token
from app.database import SessionLocal
from app.models.models import AppUser
bearer_scheme = HTTPBearer(auto_error=False)
def get_db():
"""
@@ -10,4 +17,40 @@ def get_db():
try:
yield db
finally:
db.close()
db.close()
def get_current_user(
credentials: HTTPAuthorizationCredentials | None = Depends(bearer_scheme),
db: Session = Depends(get_db),
) -> AppUser:
"""
从 Bearer Token 中解析并返回当前登录用户。
要求:
1. 必须携带 Authorization: Bearer <token>
2. token 验签通过且未过期
3. 用户在数据库中存在
"""
if credentials is None or credentials.scheme.lower() != "bearer":
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Authentication credentials were not provided",
)
token = credentials.credentials
try:
user_id, email = decode_access_token(token)
except ValueError:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid or expired token",
)
user = db.query(AppUser).filter(AppUser.id == user_id).first()
if not user or user.email != email:
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail="Invalid token user",
)
return user
+51 -3
View File
@@ -1,5 +1,6 @@
import math
import os
from datetime import timedelta
from datetime import timedelta, timezone
from typing import Tuple
from fastapi import APIRouter, Depends, HTTPException, status
@@ -30,8 +31,18 @@ from app.utils.email_utils import send_html_email
router = APIRouter()
REGISTER_CODE_EXPIRE_MINUTES = int(os.getenv("REGISTER_CODE_EXPIRE_MINUTES", "10"))
LOGIN_CODE_EXPIRE_MINUTES = int(os.getenv("LOGIN_CODE_EXPIRE_MINUTES", "10"))
DEFAULT_REGISTER_CODE_EXPIRE_MINUTES = 10
DEFAULT_LOGIN_CODE_EXPIRE_MINUTES = 10
DEFAULT_CODE_SEND_COOLDOWN_SECONDS = 60
REGISTER_CODE_EXPIRE_MINUTES = int(
os.getenv("REGISTER_CODE_EXPIRE_MINUTES", str(DEFAULT_REGISTER_CODE_EXPIRE_MINUTES))
)
LOGIN_CODE_EXPIRE_MINUTES = int(
os.getenv("LOGIN_CODE_EXPIRE_MINUTES", str(DEFAULT_LOGIN_CODE_EXPIRE_MINUTES))
)
CODE_SEND_COOLDOWN_SECONDS = int(
os.getenv("CODE_SEND_COOLDOWN_SECONDS", str(DEFAULT_CODE_SEND_COOLDOWN_SECONDS))
)
def _normalize_email(email: str) -> str:
@@ -78,6 +89,41 @@ def _create_code_record(
return code_record, code
def _enforce_code_send_cooldown(db: Session, email: str, purpose: VerificationPurpose) -> None:
"""
防抖:限制同一邮箱同一用途验证码的发送频率,避免用户短时间连续点击。
"""
if CODE_SEND_COOLDOWN_SECONDS <= 0:
return
latest_record = (
db.query(EmailVerificationCode)
.filter(
EmailVerificationCode.email == email,
EmailVerificationCode.purpose == purpose,
)
.order_by(EmailVerificationCode.created_at.desc())
.first()
)
if not latest_record:
return
now = utcnow()
record_time = latest_record.created_at
if record_time.tzinfo is None:
record_time = record_time.replace(tzinfo=timezone.utc)
elapsed_seconds = (now - record_time).total_seconds()
if elapsed_seconds >= CODE_SEND_COOLDOWN_SECONDS:
return
retry_after_seconds = max(1, math.ceil(CODE_SEND_COOLDOWN_SECONDS - elapsed_seconds))
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Please wait {retry_after_seconds}s before requesting another verification code",
headers={"Retry-After": str(retry_after_seconds)},
)
def _build_auth_response(user: AppUser) -> AuthTokenResponse:
token, expires_in = create_access_token(user_id=user.id, email=user.email)
return AuthTokenResponse(
@@ -95,6 +141,7 @@ async def send_register_code(payload: RegisterCodeSendRequest, db: Session = Dep
if existing_user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email is already registered")
_enforce_code_send_cooldown(db, email, VerificationPurpose.REGISTER)
_invalidate_unused_codes(db, email, VerificationPurpose.REGISTER)
code_record, code = _create_code_record(
db,
@@ -128,6 +175,7 @@ async def send_login_code(payload: LoginCodeSendRequest, db: Session = Depends(g
if not user:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Email is not registered")
_enforce_code_send_cooldown(db, email, VerificationPurpose.LOGIN)
_invalidate_unused_codes(db, email, VerificationPurpose.LOGIN)
code_record, code = _create_code_record(
db,
+353
View File
@@ -0,0 +1,353 @@
# 推送设置 API:管理用户的推送时间表和推送渠道
from datetime import time as dt_time
from typing import List
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.api.dependencies import get_current_user, get_db
from app.models.models import AppUser, UserDeliverySchedule, UserPushEndpoint
from app.schemas.delivery_schema import (
DeliveryScheduleCreate,
DeliveryScheduleResponse,
DeliveryScheduleUpdate,
PushEndpointCreate,
PushEndpointResponse,
PushEndpointUpdate,
UserDeliveryConfigResponse,
)
router = APIRouter()
# 两条推送时间之间的最小间隔(分钟)
MIN_SCHEDULE_GAP_MINUTES = 30
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
"""校验路径 user_id 是否为当前登录用户本人。"""
if path_user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only operate your own resources",
)
def _parse_time(time_str: str) -> dt_time:
"""将 HH:MM 字符串解析为 time 对象"""
try:
parts = time_str.split(":")
return dt_time(hour=int(parts[0]), minute=int(parts[1]))
except (ValueError, IndexError):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Invalid time format, expected HH:MM",
)
def _format_time(t: dt_time) -> str:
"""将 time 对象格式化为 HH:MM 字符串"""
return t.strftime("%H:%M")
def _time_to_minutes(t: dt_time) -> int:
return t.hour * 60 + t.minute
def _check_min_gap(
db: Session,
user_id: int,
new_time: dt_time,
exclude_id: int | None = None,
) -> None:
"""
校验新时间与用户已有的所有推送时间之间是否满足最小间隔要求(30 分钟)。
不满足时直接抛出 400 异常。
"""
query = db.query(UserDeliverySchedule).filter(
UserDeliverySchedule.user_id == user_id
)
if exclude_id is not None:
query = query.filter(UserDeliverySchedule.id != exclude_id)
existing = query.all()
new_minutes = _time_to_minutes(new_time)
for s in existing:
old_minutes = _time_to_minutes(s.delivery_time)
diff = abs(new_minutes - old_minutes)
circular_diff = min(diff, 1440 - diff)
if circular_diff < MIN_SCHEDULE_GAP_MINUTES:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"推送时间间隔不能少于 {MIN_SCHEDULE_GAP_MINUTES} 分钟,"
f"与已有的 {_format_time(s.delivery_time)} 冲突",
)
# ==========================================
# 聚合查询:一次性返回用户全部推送配置
# ==========================================
@router.get(
"/users/{user_id}/delivery-config",
response_model=UserDeliveryConfigResponse,
)
def get_delivery_config(
user_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""获取用户的完整推送配置(时间表 + 渠道)。"""
_ensure_self_access(user_id, current_user)
schedules = (
db.query(UserDeliverySchedule)
.filter(UserDeliverySchedule.user_id == user_id)
.order_by(UserDeliverySchedule.delivery_time.asc())
.all()
)
endpoints = (
db.query(UserPushEndpoint)
.filter(UserPushEndpoint.user_id == user_id)
.order_by(UserPushEndpoint.priority_level.asc())
.all()
)
# 手动转换 time 字段为字符串
schedule_list = [
DeliveryScheduleResponse(
id=s.id,
user_id=s.user_id,
delivery_time=_format_time(s.delivery_time),
is_active=s.is_active,
created_at=s.created_at,
)
for s in schedules
]
return UserDeliveryConfigResponse(schedules=schedule_list, endpoints=endpoints)
# ==========================================
# 推送时间表 CRUD
# ==========================================
@router.post(
"/users/{user_id}/delivery-schedules",
response_model=DeliveryScheduleResponse,
status_code=status.HTTP_201_CREATED,
)
def create_delivery_schedule(
user_id: int,
payload: DeliveryScheduleCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一条推送时间。"""
_ensure_self_access(user_id, current_user)
parsed_time = _parse_time(payload.delivery_time)
_check_min_gap(db, user_id, parsed_time)
db_obj = UserDeliverySchedule(
user_id=user_id,
delivery_time=parsed_time,
is_active=payload.is_active,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This delivery time already exists",
)
db.refresh(db_obj)
return DeliveryScheduleResponse(
id=db_obj.id,
user_id=db_obj.user_id,
delivery_time=_format_time(db_obj.delivery_time),
is_active=db_obj.is_active,
created_at=db_obj.created_at,
)
@router.patch(
"/users/{user_id}/delivery-schedules/{schedule_id}",
response_model=DeliveryScheduleResponse,
)
def update_delivery_schedule(
user_id: int,
schedule_id: int,
payload: DeliveryScheduleUpdate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""更新一条推送时间。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserDeliverySchedule)
.filter(
UserDeliverySchedule.id == schedule_id,
UserDeliverySchedule.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
if payload.delivery_time is not None:
new_time = _parse_time(payload.delivery_time)
_check_min_gap(db, user_id, new_time, exclude_id=schedule_id)
db_obj.delivery_time = new_time
if payload.is_active is not None:
db_obj.is_active = payload.is_active
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This delivery time already exists",
)
db.refresh(db_obj)
return DeliveryScheduleResponse(
id=db_obj.id,
user_id=db_obj.user_id,
delivery_time=_format_time(db_obj.delivery_time),
is_active=db_obj.is_active,
created_at=db_obj.created_at,
)
@router.delete(
"/users/{user_id}/delivery-schedules/{schedule_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_delivery_schedule(
user_id: int,
schedule_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一条推送时间。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserDeliverySchedule)
.filter(
UserDeliverySchedule.id == schedule_id,
UserDeliverySchedule.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Schedule not found")
db.delete(db_obj)
db.commit()
return None
# ==========================================
# 推送渠道 CRUD
# ==========================================
@router.post(
"/users/{user_id}/push-endpoints",
response_model=PushEndpointResponse,
status_code=status.HTTP_201_CREATED,
)
def create_push_endpoint(
user_id: int,
payload: PushEndpointCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一个推送渠道。"""
_ensure_self_access(user_id, current_user)
db_obj = UserPushEndpoint(
user_id=user_id,
channel_type=payload.channel_type.upper().strip(),
channel_account=payload.channel_account.strip(),
is_active=payload.is_active,
priority_level=payload.priority_level,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="This channel type already exists for the user",
)
db.refresh(db_obj)
return db_obj
@router.patch(
"/users/{user_id}/push-endpoints/{endpoint_id}",
response_model=PushEndpointResponse,
)
def update_push_endpoint(
user_id: int,
endpoint_id: int,
payload: PushEndpointUpdate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""更新一个推送渠道配置。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserPushEndpoint)
.filter(
UserPushEndpoint.id == endpoint_id,
UserPushEndpoint.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
if payload.channel_account is not None:
db_obj.channel_account = payload.channel_account.strip()
if payload.is_active is not None:
db_obj.is_active = payload.is_active
if payload.priority_level is not None:
db_obj.priority_level = payload.priority_level
db.commit()
db.refresh(db_obj)
return db_obj
@router.delete(
"/users/{user_id}/push-endpoints/{endpoint_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_push_endpoint(
user_id: int,
endpoint_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一个推送渠道。"""
_ensure_self_access(user_id, current_user)
db_obj = (
db.query(UserPushEndpoint)
.filter(
UserPushEndpoint.id == endpoint_id,
UserPushEndpoint.user_id == user_id,
)
.first()
)
if not db_obj:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Push endpoint not found")
db.delete(db_obj)
db.commit()
return None
+192 -46
View File
@@ -1,69 +1,215 @@
# app/api/endpoints/events.py
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from datetime import timedelta
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.api.dependencies import get_db
from app.models.models import UnifiedEvent, TrendingEvent, InfoSource, RankingLog, utcnow
# 导入你上传的 Schema
from app.schemas.event_schema import UnifiedEventResponse, PlatformTrendResponse
from app.models.models import (
ExtractedTopic,
InfoSource,
RankingLog,
TargetType,
TrendingEvent,
UnifiedEvent,
utcnow,
)
from app.schemas.event_schema import (
PaginatedUnifiedEventResponse,
PlatformTrendResponse,
UnifiedEventResponse,
)
router = APIRouter()
# 排名轨迹最多返回多少个点,避免长时间跨度下数据过大
MAX_RANKING_POINTS = 30
@router.get("/unified", response_model=List[UnifiedEventResponse])
@router.get("/unified", response_model=PaginatedUnifiedEventResponse)
def list_unified_events(
min_hot: int = Query(5, description="热度过滤阈值"),
hours: int = Query(24, description="查询过去 X 小时的数据"),
db: Session = Depends(get_db)
min_hot: int = Query(5, ge=0, description="热度阈值,仅返回 hot_score >= 此值的事件"),
hours: int = Query(24, ge=1, le=720, description="查询最近多少小时的数据"),
skip: int = Query(0, ge=0, description="分页偏移量"),
limit: int = Query(10, ge=1, le=50, description="每页返回条数"),
db: Session = Depends(get_db),
):
"""
获取聚合大事件列表,完全适配前端 template.html 所需的数据结构
"""
# 计算时间水位线
"""分页返回统一事件,附带各平台热搜、排名轨迹和标签。"""
time_limit = utcnow() - timedelta(hours=hours)
# 1. 查询大事件(按热度降序,且满足时间范围)
events = db.query(UnifiedEvent).filter(
# 先查总数,用于前端判断是否还有更多
base_query = db.query(UnifiedEvent).filter(
UnifiedEvent.hot_score >= min_hot,
UnifiedEvent.created_at >= time_limit
).order_by(UnifiedEvent.hot_score.desc()).all()
UnifiedEvent.created_at >= time_limit,
)
total = base_query.count()
results = []
# 分页查询
events = (
base_query
.order_by(UnifiedEvent.hot_score.desc())
.offset(skip)
.limit(limit)
.all()
)
if not events:
return PaginatedUnifiedEventResponse(total=total, has_more=False, data=[])
event_ids = [ev.id for ev in events]
# 批量查询所有相关的热搜条目(避免 N+1)
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id.in_(event_ids))
.all()
)
# 按 unified_event_id 分组
trend_map: dict[int, list[tuple]] = {}
trend_ids: list[int] = []
for trend, source_name in trend_rows:
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
trend_ids.append(trend.id)
# 批量查询排名日志(避免逐条查询)
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(
RankingLog.event_id,
RankingLog.ranking_position,
)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for event_id, position in ranking_rows:
ranking_map.setdefault(event_id, []).append(position)
# 批量查询标签
tag_map: dict[int, list[str]] = {}
tag_rows = (
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id.in_(event_ids),
)
.order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc())
.all()
)
for target_id, keyword in tag_rows:
tag_map.setdefault(target_id, []).append(keyword)
# 组装响应
results: list[UnifiedEventResponse] = []
for ev in events:
# 2. 联表查询:获取该大事件下关联的所有平台及其具体热搜信息
trends = db.query(TrendingEvent, InfoSource.source_name).join(
InfoSource, TrendingEvent.source_id == InfoSource.id
).filter(TrendingEvent.unified_event_id == ev.id).all()
platform_list: list[PlatformTrendResponse] = []
for trend, source_name in trend_map.get(ev.id, []):
history = ranking_map.get(trend.id, [])
# 截取尾部,只保留最近的点
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list = []
for trend, s_name in trends:
# 3. 获取排名历史轨迹 (用于前端渲染)
# 这里的排序顺序 asc 保证了数组从旧到新
logs = db.query(RankingLog.ranking_position).filter(
RankingLog.event_id == trend.id,
RankingLog.observed_at >= time_limit
).order_by(RankingLog.observed_at.asc()).all()
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=history,
)
)
# 组装符合 PlatformTrendResponse 结构的字典
platform_list.append(PlatformTrendResponse(
results.append(
UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
platforms=platform_list,
tags=tag_map.get(ev.id, []),
)
)
has_more = (skip + limit) < total
return PaginatedUnifiedEventResponse(total=total, has_more=has_more, data=results)
@router.get("/unified/{event_id}", response_model=UnifiedEventResponse)
def get_unified_event(
event_id: int,
db: Session = Depends(get_db),
):
"""按 ID 查询单个统一事件,用于推荐跳转时的聚光灯展示。"""
ev = db.query(UnifiedEvent).filter(UnifiedEvent.id == event_id).first()
if not ev:
raise HTTPException(status_code=404, detail="Event not found")
time_limit = utcnow() - timedelta(hours=720)
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id == event_id)
.all()
)
trend_ids = [t.id for t, _ in trend_rows]
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(RankingLog.event_id, RankingLog.ranking_position)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for eid, pos in ranking_rows:
ranking_map.setdefault(eid, []).append(pos)
tag_rows = (
db.query(ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id == event_id,
)
.order_by(ExtractedTopic.relevance_score.desc())
.all()
)
tags = [row[0] for row in tag_rows]
platform_list: list[PlatformTrendResponse] = []
for trend, source_name in trend_rows:
history = ranking_map.get(trend.id, [])
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=s_name,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=[log[0] for log in logs]
))
ranking_history=history,
)
)
# 4. 组装符合 UnifiedEventResponse 结构的字典
results.append(UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
platforms=platform_list
))
return results
return UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
platforms=platform_list,
tags=tags,
)
+158
View File
@@ -0,0 +1,158 @@
from typing import List
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.api.dependencies import get_current_user, get_db
from app.models.models import AppUser, UserTopicPreference
from app.schemas.preference_schema import (
MatchedEventResponse,
UserPreferenceRecommendationResponse,
UserTopicPreferenceCreate,
UserTopicPreferenceResponse,
)
from app.services.matching_service import recommend_events_for_user
router = APIRouter()
def _ensure_self_access(path_user_id: int, current_user: AppUser) -> None:
"""校验路径 user_id 是否为当前登录用户本人。"""
if path_user_id != current_user.id:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="You can only operate your own resources",
)
@router.get(
"/users/{user_id}/preferences",
response_model=List[UserTopicPreferenceResponse],
)
def list_user_preferences(
user_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""获取用户已设置的兴趣关键词。"""
_ensure_self_access(user_id, current_user)
preferences = (
db.query(UserTopicPreference)
.filter(UserTopicPreference.user_id == user_id)
.order_by(UserTopicPreference.created_at.desc())
.all()
)
return preferences
@router.post(
"/users/{user_id}/preferences",
response_model=UserTopicPreferenceResponse,
status_code=status.HTTP_201_CREATED,
)
def create_user_preference(
user_id: int,
payload: UserTopicPreferenceCreate,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""新增一个用户兴趣关键词。"""
_ensure_self_access(user_id, current_user)
keyword = payload.interested_keyword.strip()
if not keyword:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Keyword cannot be empty")
db_obj = UserTopicPreference(
user_id=user_id,
interested_keyword=keyword,
)
db.add(db_obj)
try:
db.commit()
except IntegrityError:
db.rollback()
raise HTTPException(
status_code=status.HTTP_409_CONFLICT,
detail="Preference keyword already exists for this user",
)
db.refresh(db_obj)
return db_obj
@router.delete(
"/users/{user_id}/preferences/{preference_id}",
status_code=status.HTTP_204_NO_CONTENT,
)
def delete_user_preference(
user_id: int,
preference_id: int,
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""删除一个用户兴趣关键词。"""
_ensure_self_access(user_id, current_user)
preference = (
db.query(UserTopicPreference)
.filter(
UserTopicPreference.id == preference_id,
UserTopicPreference.user_id == user_id,
)
.first()
)
if not preference:
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Preference not found")
db.delete(preference)
db.commit()
return None
@router.get(
"/users/{user_id}/recommended-events",
response_model=UserPreferenceRecommendationResponse,
)
def recommend_events(
user_id: int,
min_hot: int = Query(3, ge=1, description="最小热度阈值"),
hours: int = Query(72, ge=1, le=24 * 30, description="仅匹配最近多少小时的事件"),
limit: int = Query(20, ge=1, le=50, description="最多返回多少条推荐"),
semantic_threshold: float = Query(0.78, ge=0.0, le=1.0, description="语义匹配相似度阈值"),
db: Session = Depends(get_db),
current_user: AppUser = Depends(get_current_user),
):
"""基于用户兴趣词推荐事件(精确匹配 + 语义匹配)。"""
_ensure_self_access(user_id, current_user)
matched = recommend_events_for_user(
db,
user_id=user_id,
min_hot=min_hot,
hours=hours,
limit=limit,
semantic_threshold=semantic_threshold,
)
result_data: list[MatchedEventResponse] = []
for item in matched:
result_data.append(
MatchedEventResponse(
event_id=item.event.id,
unified_title=item.event.unified_title,
summary=item.event.ai_comprehensive_summary,
hot_score=item.event.hot_score,
created_at=item.event.created_at,
tags=item.tags,
match_score=item.match_score,
exact_hits=item.exact_hits,
semantic_hits=item.semantic_hits,
)
)
return UserPreferenceRecommendationResponse(
user_id=user_id,
total=len(result_data),
data=result_data,
)
+75
View File
@@ -0,0 +1,75 @@
# 公关修改追踪 API:查询热搜标题被偷偷修改的历史记录
from datetime import timedelta
from typing import List, Optional
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from app.api.dependencies import get_db
from app.models.models import HeadlineRevision, InfoSource, TrendingEvent, utcnow
from pydantic import BaseModel, ConfigDict
from datetime import datetime
router = APIRouter()
class HeadlineRevisionResponse(BaseModel):
"""标题修改记录响应体"""
id: int
event_id: int
previous_headline: str
revised_headline: str
source_name: Optional[str] = None
platform_icon: Optional[str] = None
created_at: datetime
model_config = ConfigDict(from_attributes=True)
@router.get("/headline-revisions", response_model=List[HeadlineRevisionResponse])
def list_headline_revisions(
hours: int = Query(48, ge=1, le=720, description="查询最近多少小时内的修改记录"),
limit: int = Query(50, ge=1, le=500, description="最多返回条数"),
db: Session = Depends(get_db),
):
"""
获取最近的标题修改记录列表。
用于公关监测:发现哪些平台偷偷改了热搜标题。
"""
time_limit = utcnow() - timedelta(hours=hours)
rows = (
db.query(HeadlineRevision, InfoSource.source_name)
.join(TrendingEvent, HeadlineRevision.event_id == TrendingEvent.id)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(HeadlineRevision.created_at >= time_limit)
.order_by(HeadlineRevision.created_at.desc())
.limit(limit)
.all()
)
# 平台名到图标的简单映射
icon_map = {
"微博热搜": "weibo",
"知乎热榜": "zhihu",
"百度热搜": "baidu",
"今日头条": "toutiao",
"抖音热榜": "douyin",
"B站热搜": "bilibili",
}
results: list[HeadlineRevisionResponse] = []
for revision, source_name in rows:
results.append(
HeadlineRevisionResponse(
id=revision.id,
event_id=revision.event_id,
previous_headline=revision.previous_headline,
revised_headline=revision.revised_headline,
source_name=source_name,
platform_icon=icon_map.get(source_name, "newspaper"),
created_at=revision.created_at,
)
)
return results
+65
View File
@@ -0,0 +1,65 @@
# 系统状态监控 API:返回爬虫集群运行概况
from datetime import datetime, timedelta
from typing import Optional
from fastapi import APIRouter, Depends
from pydantic import BaseModel
from sqlalchemy import func
from sqlalchemy.orm import Session
from app.api.dependencies import get_db
from app.models.models import DataSyncTask, InfoSource, TaskStatus, utcnow
router = APIRouter()
class SystemStatsResponse(BaseModel):
"""系统运行状态汇总"""
active_sources: int
total_sources: int
items_today: int
success_tasks_today: int
error_tasks_today: int
last_sync_at: Optional[datetime] = None
@router.get("/system/stats", response_model=SystemStatsResponse)
def get_system_stats(db: Session = Depends(get_db)):
"""获取爬虫集群的当日运行状态。"""
today_start = utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
# 信息源统计
total_sources = db.query(func.count(InfoSource.id)).scalar() or 0
active_sources = (
db.query(func.count(InfoSource.id))
.filter(InfoSource.is_enabled.is_(True))
.scalar() or 0
)
# 今日任务统计
today_tasks = (
db.query(DataSyncTask)
.filter(DataSyncTask.created_at >= today_start)
.all()
)
items_today = sum(t.items_fetched for t in today_tasks)
success_count = sum(1 for t in today_tasks if t.task_status == TaskStatus.SUCCESS)
error_count = sum(1 for t in today_tasks if t.task_status == TaskStatus.ERROR)
# 最后一次同步时间
last_task = (
db.query(DataSyncTask)
.filter(DataSyncTask.task_status == TaskStatus.SUCCESS)
.order_by(DataSyncTask.created_at.desc())
.first()
)
return SystemStatsResponse(
active_sources=active_sources,
total_sources=total_sources,
items_today=items_today,
success_tasks_today=success_count,
error_tasks_today=error_count,
last_sync_at=last_task.created_at if last_task else None,
)
+15 -1
View File
@@ -1,6 +1,6 @@
# app/api/router.py
from fastapi import APIRouter
from app.api.endpoints import auth, sources, events
from app.api.endpoints import auth, delivery, events, preferences, revisions, sources, stats
api_router = APIRouter()
@@ -9,4 +9,18 @@ api_router.include_router(sources.router, prefix="/sources", tags=["信息源管
# 注册大事件相关的路由
api_router.include_router(events.router, prefix="/events", tags=["Unified Events"])
# 认证
api_router.include_router(auth.router, prefix="/auth", tags=["Auth"])
# 用户偏好(关键词订阅)
api_router.include_router(preferences.router, tags=["User Preferences"])
# 推送设置(时间表 + 渠道)
api_router.include_router(delivery.router, tags=["Delivery Settings"])
# 公关修改追踪
api_router.include_router(revisions.router, prefix="/events", tags=["Headline Revisions"])
# 系统状态监控
api_router.include_router(stats.router, tags=["System Stats"])
+61 -2
View File
@@ -8,9 +8,15 @@ import time
from typing import Tuple
PASSWORD_HASH_ITERATIONS = int(os.getenv("PASSWORD_HASH_ITERATIONS", "120000"))
DEFAULT_PASSWORD_HASH_ITERATIONS = 120000
PASSWORD_HASH_ITERATIONS = int(
os.getenv("PASSWORD_HASH_ITERATIONS", str(DEFAULT_PASSWORD_HASH_ITERATIONS))
)
AUTH_SECRET_KEY = os.getenv("AUTH_SECRET_KEY", "change-this-secret-in-env")
AUTH_TOKEN_EXPIRE_MINUTES = int(os.getenv("AUTH_TOKEN_EXPIRE_MINUTES", "10080"))
DEFAULT_AUTH_TOKEN_EXPIRE_MINUTES = 10080
AUTH_TOKEN_EXPIRE_MINUTES = int(
os.getenv("AUTH_TOKEN_EXPIRE_MINUTES", str(DEFAULT_AUTH_TOKEN_EXPIRE_MINUTES))
)
def hash_password(password: str) -> str:
@@ -61,6 +67,11 @@ def _urlsafe_b64encode(raw: bytes) -> str:
return base64.urlsafe_b64encode(raw).decode("utf-8").rstrip("=")
def _urlsafe_b64decode(raw: str) -> bytes:
padding = "=" * (-len(raw) % 4)
return base64.urlsafe_b64decode(raw + padding)
def create_access_token(user_id: int, email: str) -> Tuple[str, int]:
expires_in = AUTH_TOKEN_EXPIRE_MINUTES * 60
payload = {
@@ -77,3 +88,51 @@ def create_access_token(user_id: int, email: str) -> Tuple[str, int]:
).digest()
token = f"{encoded_payload}.{_urlsafe_b64encode(signature)}"
return token, expires_in
def decode_access_token(token: str) -> Tuple[int, str]:
"""
解码并校验访问令牌,返回 (user_id, email)。
校验项包括:结构、签名、过期时间、字段完整性。
"""
try:
encoded_payload, encoded_signature = token.split(".", 1)
except ValueError as exc:
raise ValueError("Invalid token format") from exc
try:
provided_signature = _urlsafe_b64decode(encoded_signature)
except Exception as exc:
raise ValueError("Invalid token signature encoding") from exc
expected_signature = hmac.new(
AUTH_SECRET_KEY.encode("utf-8"),
encoded_payload.encode("utf-8"),
hashlib.sha256,
).digest()
if not hmac.compare_digest(provided_signature, expected_signature):
raise ValueError("Invalid token signature")
try:
payload_bytes = _urlsafe_b64decode(encoded_payload)
payload = json.loads(payload_bytes.decode("utf-8"))
except Exception as exc:
raise ValueError("Invalid token payload") from exc
sub = payload.get("sub")
email = payload.get("email")
exp = payload.get("exp")
if not sub or not email or exp is None:
raise ValueError("Token payload missing required fields")
try:
user_id = int(sub)
exp_ts = int(exp)
except (TypeError, ValueError) as exc:
raise ValueError("Invalid token payload types") from exc
if time.time() >= exp_ts:
raise ValueError("Token expired")
return user_id, str(email)
+46
View File
@@ -0,0 +1,46 @@
import requests
import json
# 请将此处的 URL 替换为您实际的 API 基础域名
api_url = "http://10.252.130.135:8000/api/v1/sources/"
# 请求头
headers = {
"Content-Type": "application/json",
# "Authorization": "Bearer YOUR_TOKEN" # 如果接口需要鉴权,请取消注释并填入 Token
}
# 解析后的数据源列表
sources_data = [
{"name": "今日头条", "url": "toutiao"},
{"name": "百度热搜", "url": "baidu"},
{"name": "华尔街见闻", "url": "wallstreetcn-hot"},
{"name": "澎湃新闻", "url": "thepaper"},
{"name": "bilibili 热搜", "url": "bilibili-hot-search"},
{"name": "财联社热门", "url": "cls-hot"},
{"name": "凤凰网", "url": "ifeng"},
{"name": "贴吧", "url": "tieba"},
{"name": "微博", "url": "weibo"},
{"name": "抖音", "url": "douyin"},
{"name": "知乎", "url": "zhihu"}
]
# 遍历数据并发送 POST 请求
for item in sources_data:
payload = {
"source_name": item["name"],
"source_type": "HOT_TREND",
"home_url": item["url"],
"is_enabled": True
}
try:
response = requests.post(api_url, headers=headers, data=json.dumps(payload))
if response.status_code in (200, 201):
print(f"✅ 成功创建: {item['name']}")
else:
print(f"❌ 创建失败: {item['name']} - 状态码: {response.status_code} - 详情: {response.text}")
except Exception as e:
print(f"⚠️ 请求异常: {item['name']} - 错误: {e}")
print("执行完毕!")
+37 -3
View File
@@ -1,12 +1,24 @@
# app/main.py
import logging
import os
from contextlib import asynccontextmanager
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from dotenv import load_dotenv
# 统一配置日志格式和级别,确保 delivery_service 等的 INFO 日志可见
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(name)s: %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# 降低 APScheduler 运行心跳日志,避免每分钟刷屏
logging.getLogger("apscheduler").setLevel(logging.WARNING)
from apscheduler.schedulers.asyncio import AsyncIOScheduler
from app.services.fetcher_service import fetch_and_save_trending_data
from app.services.summary_service import generate_unified_summaries
from app.services.delivery_service import check_and_deliver
from app.database import engine
from app.models.models import Base
@@ -47,14 +59,24 @@ async def lifespan(app: FastAPI):
id='ai_summary_job',
replace_existing=True
)
# 推送调度:每分钟检查是否有用户需要接收邮件推送
scheduler.add_job(
check_and_deliver,
'interval',
minutes=1,
id='delivery_check_job',
replace_existing=True,
)
scheduler.start()
print(f"定时抓取任务已启动,每 {CRAWL_INTERVAL} 分钟执行一次")
print(f"AI 摘要生成任务已启动,每 {SUMMARY_INTERVAL} 分钟执行一次")
print("邮件推送调度已启动,每分钟检查一次")
# 为了测试方便,启动时立即执行一次
await fetch_and_save_trending_data()
# await fetch_and_save_trending_data()
await generate_unified_summaries()
# await generate_unified_summaries()
yield # 此时 FastAPI 开始接受请求
@@ -67,7 +89,19 @@ async def lifespan(app: FastAPI):
app = FastAPI(title="AI 新闻聚合引擎 API", lifespan=lifespan)
# ==========================================
# 2. 挂载路由总线
# 2. CORS 中间件:允许前端开发服务器跨域请求
# ==========================================
app.add_middleware(
CORSMiddleware,
# allow_origins=["http://localhost:5173", "http://127.0.0.1:5173"],
allow_origins=["*"],
# allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# ==========================================
# 3. 挂载路由总线
# ==========================================
# 版本控制
app.include_router(api_router, prefix="/api/v1")
@@ -0,0 +1,264 @@
# 推送邮件 HTML 模板
# 用于生成定时推送给用户的热点摘要邮件
# 邮件客户端不支持 Font Awesome,改用 Emoji 代替平台图标
PLATFORM_EMOJI: dict[str, str] = {
"微博热搜": "🔴",
"微博": "🔴",
"知乎热榜": "🔵",
"知乎": "🔵",
"百度热搜": "🔍",
"今日头条": "📰",
"抖音热榜": "🎵",
"抖音": "🎵",
"bilibili 热搜": "📺",
"B站热搜": "📺",
"华尔街见闻": "📈",
"澎湃新闻": "🌊",
"财联社热门": "💰",
"凤凰网": "🦅",
"贴吧": "💬",
}
DIGEST_HTML_TEMPLATE = """\
<!DOCTYPE html>
<html lang="zh-CN">
<head>
<meta charset="UTF-8">
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<style>
body{{margin:0;padding:0;background:#0d1117;color:#e6edf3;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,"Helvetica Neue",Arial,sans-serif;}}
.container{{max-width:640px;margin:0 auto;padding:32px 16px;}}
.header{{text-align:center;padding:10px 0 30px;margin-bottom:24px;border-bottom:1px solid rgba(255,255,255,0.06);}}
.header h1{{font-size:26px;font-weight:800;margin:0 0 10px;color:#ffffff;text-shadow:0 0 16px rgba(139,92,246,0.5);letter-spacing:0.5px;}}
.header p{{font-size:14px;color:#8b949e;margin:0;}}
.mode-badge{{display:inline-block;margin-top:12px;padding:4px 14px;border-radius:20px;font-size:12px;font-weight:600;letter-spacing:0.5px;}}
.mode-default{{background:rgba(59,130,246,0.15);color:#7dd3fc;border:1px solid rgba(59,130,246,0.3);}}
.mode-keyword{{background:rgba(168,85,247,0.15);color:#e879f9;border:1px solid rgba(168,85,247,0.3);}}
.event-card{{background:#161b22;border:1px solid #30363d;border-radius:16px;padding:20px;margin-bottom:20px;box-shadow:0 4px 12px rgba(0,0,0,0.2);}}
.event-card.is-hot{{border-left:4px solid #f85149;background:linear-gradient(90deg, rgba(248,81,73,0.03) 0%, transparent 100%), #161b22;}}
.event-title{{font-size:18px;font-weight:700;margin:0 0 14px;color:#ffffff;line-height:1.5;}}
.event-meta{{margin-bottom:12px;}}
.badge{{display:inline-block;padding:3px 10px;border-radius:6px;font-size:12px;font-weight:600;margin-right:6px;margin-bottom:6px;}}
.badge-hot{{background:rgba(248,81,73,0.15);color:#ff7b72;border:1px solid rgba(248,81,73,0.3);}}
.badge-warm{{background:rgba(210,153,34,0.15);color:#d29922;border:1px solid rgba(210,153,34,0.3);}}
.badge-normal{{background:rgba(56,139,253,0.15);color:#58a6ff;border:1px solid rgba(56,139,253,0.3);}}
.badge-tag{{background:rgba(139,148,158,0.15);color:#8b949e;border:1px solid rgba(139,148,158,0.2);}}
.summary{{font-size:14px;line-height:1.6;color:#c9d1d9;padding:12px 16px;background:rgba(139,92,246,0.06);border-radius:0 8px 8px 0;border-left:3px solid #a78bfa;margin-bottom:16px;}}
.summary strong{{color:#a78bfa;font-weight:600;}}
.platforms-list{{margin:0;padding:0;list-style:none;background:rgba(255,255,255,0.02);border-radius:10px;padding:12px;}}
.platform-item{{padding:8px 0;border-bottom:1px solid rgba(255,255,255,0.05);}}
.platform-item:last-child{{border-bottom:none;padding-bottom:0;}}
.platform-item:first-child{{padding-top:0;}}
.platform-source{{font-size:12px;color:#8b949e;margin-bottom:4px;display:flex;align-items:center;}}
.platform-rank{{display:inline-block;padding:2px 6px;border-radius:4px;background:rgba(210,153,34,0.15);color:#d29922;font-size:10px;font-weight:700;margin-left:6px;}}
.platform-link{{font-size:14px;color:#79c0ff;text-decoration:none;line-height:1.5;display:block;transition:color 0.2s;}}
.platform-link:hover{{text-decoration:underline;color:#a5d6ff;}}
.platform-text{{font-size:14px;color:#e6edf3;line-height:1.5;}}
/* 匹配信息底部栏 */
.match-info{{font-size:12px;color:#8b949e;margin-top:16px;padding-top:12px;border-top:1px dashed #30363d;}}
.hit{{display:inline-block;padding:2px 8px;border-radius:4px;font-size:11px;font-weight:600;margin-right:4px;margin-top:4px;}}
.hit-exact{{background:rgba(46,160,67,0.15);color:#3fb950;}}
.hit-semantic{{background:rgba(163,113,247,0.15);color:#d2a8ff;}}
/* 页脚 */
.footer{{text-align:center;padding:30px 0 10px;margin-top:20px;font-size:12px;color:#484f58;}}
.footer a{{color:#79c0ff;text-decoration:none;}}
.footer a:hover{{text-decoration:underline;}}
</style>
</head>
<body>
<div class="container">
<div class="header">
<h1>InsightRadar · 热点快报</h1>
<p>{delivery_time} · 为你精选了 {event_count} 条事件</p>
<span class="mode-badge {mode_badge_class}">{mode_label}</span>
</div>
{event_cards_html}
<div class="footer">
<p>此邮件由 InsightRadar 自动推送。</p>
<p>如需调整推送设置,请登录 <a href="{app_url}">InsightRadar 控制台</a></p>
</div>
</div>
</body>
</html>
"""
EVENT_CARD_TEMPLATE = """\
<div class="event-card{hot_class}">
<div class="event-meta">
<span class="badge {badge_class}">{hot_label} {hot_score}</span>
{tags_html}
</div>
<div class="event-title">{title}</div>
{summary_html}
{platforms_html}
{match_html}
</div>
"""
def _hot_level(score: int) -> tuple[str, str, str]:
"""返回 (label, badge_class, hot_class)"""
if score >= 50:
return "全网沸腾", "badge-hot", " is-hot"
if score >= 20:
return "高度关注", "badge-warm", ""
if score >= 10:
return "上升中", "badge-normal", ""
return "一般关注", "badge-tag", ""
def _get_event_summary(ev) -> str:
"""
兼容 ORM 字段名(ai_comprehensive_summary)和 schema 字段名(summary)。
"""
return (
getattr(ev, "summary", None)
or getattr(ev, "ai_comprehensive_summary", None)
or ""
)
def _build_platforms_html(platform_list: list[dict]) -> str:
"""
将平台数据列表渲染为 HTML。
每条包含:emoji 图标 + 来源名 + 排名徽章 + 可点击标题链接。
"""
if not platform_list:
return ""
rows = []
seen_sources: set[str] = set()
for p in platform_list[:8]:
source_name = p.get("source_name", "未知")
# 同一来源只显示第一条(通常是排名最靠前的那条)
if source_name in seen_sources:
continue
seen_sources.add(source_name)
headline = p.get("headline", "")
url = p.get("url", "")
ranking = p.get("ranking")
emoji = PLATFORM_EMOJI.get(source_name, "🔗")
rank_html = ""
if ranking:
rank_html = f'<span class="platform-rank">TOP {ranking}</span>'
if url:
title_html = (
f'<a href="{url}" class="platform-link">{headline}</a>'
)
else:
title_html = f'<span class="platform-text">{headline}</span>'
rows.append(
f'<li class="platform-item">'
f'<div class="platform-source">{emoji} {source_name}{rank_html}</div>'
f'{title_html}'
f'</li>'
)
if not rows:
return ""
return '<div class="platforms-list-wrapper"><ul class="platforms-list">' + "".join(rows) + "</ul></div>"
def build_digest_html(
items: list,
delivery_time_str: str,
platforms_map: dict[int, list[dict]] | None = None,
app_url: str = "http://localhost:5173",
is_default_push: bool = False,
) -> str:
"""
根据事件列表生成推送邮件 HTML 正文。
items 元素可以是 MatchedEventResult 或 _DefaultEventItem
二者均有 .event / .tags / .exact_hits / .semantic_hits / .match_score 属性。
platforms_map: event_id → [{source_name, headline, url, ranking}]
"""
if platforms_map is None:
platforms_map = {}
mode_label = "全网热点推送" if is_default_push else "个性化关键词匹配"
mode_badge_class = "mode-default" if is_default_push else "mode-keyword"
cards = []
for item in items:
ev = item.event
hot_label, badge_class, hot_class = _hot_level(ev.hot_score)
tags_html = "".join(
f'<span class="badge badge-tag">{t}</span>'
for t in item.tags[:4]
)
summary_text = _get_event_summary(ev)
summary_html = ""
if summary_text:
summary_html = (
f'<div class="summary"><strong>AI 洞察:</strong>{summary_text}</div>'
)
platform_list = platforms_map.get(ev.id, [])
platforms_html = _build_platforms_html(platform_list)
match_parts = []
# 仅个性化模式才显示匹配信息
if not getattr(item, "is_default", False):
for h in item.exact_hits[:3]:
match_parts.append(f'<span class="hit hit-exact">精确 {h}</span>')
for s in item.semantic_hits[:2]:
sim_pct = int(s.get("similarity", 0) * 100)
match_parts.append(
f'<span class="hit hit-semantic">语义 {s.get("topic_keyword", "")} {sim_pct}%</span>'
)
match_html = ""
if match_parts:
match_html = (
f'<div class="match-info">匹配度 {item.match_score:.0f} · '
+ " ".join(match_parts)
+ "</div>"
)
cards.append(
EVENT_CARD_TEMPLATE.format(
hot_class=hot_class,
badge_class=badge_class,
hot_label=hot_label,
hot_score=ev.hot_score,
tags_html=tags_html,
title=ev.unified_title,
summary_html=summary_html,
platforms_html=platforms_html,
match_html=match_html,
)
)
return DIGEST_HTML_TEMPLATE.format(
delivery_time=delivery_time_str,
event_count=len(items),
event_cards_html="\n".join(cards),
app_url=app_url,
mode_label=mode_label,
mode_badge_class=mode_badge_class,
)
+21 -8
View File
@@ -1,14 +1,27 @@
SUMMARY_SYSTEM_PROMPT = "你是一个输出严格 JSON 格式的后台引擎。"
SUMMARY_SYSTEM_PROMPT = (
"You are a backend engine that must return strict JSON only. "
"Do not include markdown, explanation, or extra keys."
)
SUMMARY_USER_PROMPT_TEMPLATE = """
你是一个专业的新闻聚合编辑。请根据以下同一个大事件在不同平台的热搜标题,
为该事件生成一个客观、吸睛的【统一大标题】,以及一段【多平台视角的综合摘要】。
You are a professional cross-platform news editor.
Based on the following headlines about the same event from different platforms,
return:
1) a neutral unified title
2) a cross-platform comprehensive summary
3) topic tags
要求:
1. 摘要结构类似:"该事件在多平台发酵。微博侧重讨论...,知乎硬核解析...,科技媒体关注..."
2. 提炼出各平台的讨论侧重点,不要简单罗列标题。
3. 必须以严格的 JSON 格式返回,只包含 "unified_title""ai_comprehensive_summary" 两个字段,不要有多余的说明。
Rules:
1. Return strict JSON with exactly these keys:
- "unified_title": string
- "ai_comprehensive_summary": string
- "topic_keywords": array of 3 to 8 objects
2. Each item in "topic_keywords" must be:
{{"keyword": string, "relevance_score": number}}
3. relevance_score must be in [0, 100].
4. keyword should be concise (max 12 chars preferred).
5. The language should follow the dominant language in the input.
各平台热搜标题数据:
Cross-platform headline data:
{platform_data_text}
"""
+72
View File
@@ -0,0 +1,72 @@
# 推送设置相关的请求/响应模型
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
# ==========================================
# 推送时间表 (UserDeliverySchedule)
# ==========================================
class DeliveryScheduleCreate(BaseModel):
"""新增推送时间请求体,时间格式 HH:MM"""
delivery_time: str = Field(..., pattern=r"^\d{2}:\d{2}$", description="每天推送的时间,格式 HH:MM")
is_active: bool = Field(default=True, description="是否启用此时段")
class DeliveryScheduleUpdate(BaseModel):
"""更新推送时间请求体"""
delivery_time: Optional[str] = Field(None, pattern=r"^\d{2}:\d{2}$")
is_active: Optional[bool] = None
class DeliveryScheduleResponse(BaseModel):
"""推送时间响应体"""
id: int
user_id: int
delivery_time: str
is_active: bool
created_at: datetime
model_config = ConfigDict(from_attributes=True)
# ==========================================
# 推送渠道端点 (UserPushEndpoint)
# ==========================================
class PushEndpointCreate(BaseModel):
"""新增推送渠道请求体"""
channel_type: str = Field(..., max_length=50, description="渠道类型,如 EMAIL / WECHAT_BOT / TELEGRAM")
channel_account: str = Field(..., max_length=255, description="具体接收账号(邮箱地址/Webhook等)")
is_active: bool = Field(default=True, description="是否启用")
priority_level: int = Field(default=1, ge=1, le=10, description="优先级,1最高")
class PushEndpointUpdate(BaseModel):
"""更新推送渠道请求体"""
channel_account: Optional[str] = Field(None, max_length=255)
is_active: Optional[bool] = None
priority_level: Optional[int] = Field(None, ge=1, le=10)
class PushEndpointResponse(BaseModel):
"""推送渠道响应体"""
id: int
user_id: int
channel_type: str
channel_account: str
is_active: bool
priority_level: int
created_at: datetime
updated_at: datetime
model_config = ConfigDict(from_attributes=True)
# ==========================================
# 推送设置聚合响应(一次性返回全部推送配置)
# ==========================================
class UserDeliveryConfigResponse(BaseModel):
"""用户的完整推送配置(时间表 + 渠道列表)"""
schedules: List[DeliveryScheduleResponse] = Field(default_factory=list)
endpoints: List[PushEndpointResponse] = Field(default_factory=list)
+19 -12
View File
@@ -1,23 +1,30 @@
# app/schemas/event_schema.py
from pydantic import BaseModel
from pydantic import BaseModel, Field
from typing import List, Optional
from datetime import datetime
class PlatformTrendResponse(BaseModel):
source_id: int
platform_name: str # 平台名称,如 "微博热搜"
headline: str # 平台对应的具体热搜标题
url: Optional[str] # 跳转链接
current_ranking: Optional[int] # 当前排名
ranking_history: List[int] # 排名历史轨迹,如 [50, 45, 20, 5, 1],供 ApexCharts 渲染
platform_name: str
headline: str
url: Optional[str]
current_ranking: Optional[int]
ranking_history: List[int]
class UnifiedEventResponse(BaseModel):
event_id: int
unified_title: str # AI 生成的统一大标题
summary: Optional[str] # AI 生成的摘要
hot_score: int # 总热度值
created_at: datetime # 事件发现时间
platforms: List[PlatformTrendResponse] # 挂载的各个平台子热搜
# tags: List[str] = [] # 如果后续打通了 ExtractedTopic,可以在这里返回标签
unified_title: str
summary: Optional[str]
hot_score: int
created_at: datetime
platforms: List[PlatformTrendResponse]
tags: List[str] = Field(default_factory=list)
class PaginatedUnifiedEventResponse(BaseModel):
"""分页包装:避免一次性返回全量数据"""
total: int
has_more: bool
data: List[UnifiedEventResponse]
+46
View File
@@ -0,0 +1,46 @@
from datetime import datetime
from typing import List, Optional
from pydantic import BaseModel, ConfigDict, Field
class UserTopicPreferenceCreate(BaseModel):
"""新增用户兴趣词请求体。"""
interested_keyword: str = Field(..., min_length=1, max_length=100, description="用户感兴趣的关键词")
class UserTopicPreferenceResponse(BaseModel):
"""用户兴趣词响应体。"""
id: int
user_id: int
interested_keyword: str
created_at: datetime
model_config = ConfigDict(from_attributes=True)
class EventMatchSemanticHit(BaseModel):
"""语义命中的明细。"""
preference_keyword: str
topic_keyword: str
similarity: float
class MatchedEventResponse(BaseModel):
"""推荐事件响应体。"""
event_id: int
unified_title: str
summary: Optional[str]
hot_score: int
created_at: datetime
tags: List[str] = Field(default_factory=list)
match_score: float
exact_hits: List[str] = Field(default_factory=list)
semantic_hits: List[EventMatchSemanticHit] = Field(default_factory=list)
class UserPreferenceRecommendationResponse(BaseModel):
"""用户兴趣推荐结果。"""
user_id: int
total: int
data: List[MatchedEventResponse] = Field(default_factory=list)
+454
View File
@@ -0,0 +1,454 @@
# 定时推送调度服务
# 由 APScheduler 每分钟调用,检查当前时刻是否有用户需要接收推送,
# 如匹配则生成摘要邮件并发送,同时写入 DeliveryHistory 防重复。
import logging
from logging.handlers import TimedRotatingFileHandler
from dataclasses import dataclass, field
from datetime import datetime, time as dt_time, timedelta, timezone, tzinfo
from pathlib import Path
from typing import Any
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from sqlalchemy.orm import Session
from app.database import SessionLocal
from app.models.models import (
AppUser,
DeliveryHistory,
ExtractedTopic,
InfoSource,
TargetType,
TaskStatus,
TrendingEvent,
UnifiedEvent,
UserDeliverySchedule,
UserPushEndpoint,
UserTopicPreference,
utcnow,
)
from app.prompts.digest_email_template import build_digest_html
from app.services.matching_service import recommend_events_for_user
from app.utils.email_utils import send_html_email
logger = logging.getLogger("delivery_service")
# delivery_service 日志单独写文件,不再输出到控制台
_delivery_log_dir = Path(__file__).resolve().parents[2] / "logs"
_delivery_log_dir.mkdir(parents=True, exist_ok=True)
_delivery_log_file = _delivery_log_dir / "delivery_check.log"
if not logger.handlers:
_file_handler = TimedRotatingFileHandler(
filename=str(_delivery_log_file),
when="midnight",
interval=1,
backupCount=14,
encoding="utf-8",
)
_file_handler.setFormatter(logging.Formatter("%(asctime)s [%(levelname)s] %(name)s: %(message)s"))
logger.addHandler(_file_handler)
logger.setLevel(logging.INFO)
logger.propagate = False
# 推送时间窗口:实际执行时刻与设定时间的最大容差(分钟)
DELIVERY_WINDOW_MINUTES = 2
# 同一用户两次推送之间的最小间隔(分钟)
MIN_PUSH_INTERVAL_MINUTES = 30
# 单次推送最多携带的事件数
MAX_EVENTS_PER_PUSH = 12
# 默认模式热度阈值(无关键词或无匹配时使用)
DEFAULT_MODE_HOT_THRESHOLD = 3
# 默认模式查询时间窗口(小时)
DEFAULT_MODE_HOURS = 48
# 用户时区无效时的兜底时区
DEFAULT_FALLBACK_TIMEZONE = "Asia/Shanghai"
# ==========================================
# 默认热点事件容器(无关键词时使用)
# ==========================================
@dataclass
class _DefaultEventItem:
"""
无关键词订阅或关键词无匹配时的默认热点包装器,
接口与 MatchedEventResult 保持一致,方便统一传给模板。
"""
event: UnifiedEvent
match_score: float = 0.0
exact_hits: list[str] = field(default_factory=list)
semantic_hits: list[dict[str, Any]] = field(default_factory=list)
tags: list[str] = field(default_factory=list)
is_default: bool = True
# ==========================================
# 时区工具
# ==========================================
def _time_to_minutes(t: dt_time) -> int:
return t.hour * 60 + t.minute
def _is_within_window(schedule_time: dt_time, current_time: dt_time, window: int = DELIVERY_WINDOW_MINUTES) -> bool:
"""判断 schedule_time 是否在 current_time ± window 分钟范围内(跨午夜安全)。"""
s = _time_to_minutes(schedule_time)
c = _time_to_minutes(current_time)
diff = abs(s - c)
return min(diff, 1440 - diff) <= window
def _resolve_user_timezone(user_timezone: str | None) -> tzinfo:
"""解析用户时区,异常时回退到默认时区。"""
tz_name = (user_timezone or "").strip() or DEFAULT_FALLBACK_TIMEZONE
try:
return ZoneInfo(tz_name)
except ZoneInfoNotFoundError:
logger.warning(
"用户时区无效,已回退默认时区。timezone=%s fallback=%s",
tz_name, DEFAULT_FALLBACK_TIMEZONE,
)
try:
return ZoneInfo(DEFAULT_FALLBACK_TIMEZONE)
except ZoneInfoNotFoundError:
logger.warning("系统缺少时区数据库,最终回退为 UTC。建议安装 tzdata 包。")
return timezone.utc
def _user_local_time(now_utc: datetime, user_timezone: str | None) -> dt_time:
"""把 UTC 当前时刻转换为用户本地时间(仅取 HH:MM)。"""
local_dt = now_utc.astimezone(_resolve_user_timezone(user_timezone))
return local_dt.time().replace(second=0, microsecond=0)
def _ensure_aware(dt: datetime) -> datetime:
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
# ==========================================
# 数据库查询辅助
# ==========================================
def _should_skip_by_interval(db: Session, user_id: int) -> bool:
"""检查用户是否仍在 30 分钟冷却期内。"""
row = (
db.query(DeliveryHistory.created_at)
.filter(
DeliveryHistory.user_id == user_id,
DeliveryHistory.status == TaskStatus.SUCCESS,
)
.order_by(DeliveryHistory.created_at.desc())
.first()
)
if row is None:
return False
last_time = _ensure_aware(row[0])
elapsed = (utcnow() - last_time).total_seconds() / 60.0
return elapsed < MIN_PUSH_INTERVAL_MINUTES
def _get_user_email_endpoints(db: Session, user_id: int) -> list[UserPushEndpoint]:
"""获取用户已启用的邮件类型推送渠道,按优先级排序。"""
return (
db.query(UserPushEndpoint)
.filter(
UserPushEndpoint.user_id == user_id,
UserPushEndpoint.channel_type == "EMAIL",
UserPushEndpoint.is_active == True,
)
.order_by(UserPushEndpoint.priority_level.asc())
.all()
)
def _get_already_pushed_event_ids(db: Session, user_id: int) -> set[int]:
"""获取已经推送过的事件 ID 集合,避免重复轰炸。"""
rows = (
db.query(DeliveryHistory.target_id)
.filter(
DeliveryHistory.user_id == user_id,
DeliveryHistory.target_type == TargetType.EVENT,
DeliveryHistory.status == TaskStatus.SUCCESS,
)
.all()
)
return {r[0] for r in rows}
def _load_event_platforms(db: Session, event_ids: list[int]) -> dict[int, list[dict]]:
"""
批量加载事件的平台来源数据。
返回:event_id → [{source_name, headline, url, ranking, icon_url}, ...]
按排名升序排列(rank 1 最靠前)。
"""
if not event_ids:
return {}
rows = (
db.query(
TrendingEvent.unified_event_id,
TrendingEvent.current_headline,
TrendingEvent.event_url,
TrendingEvent.current_ranking,
TrendingEvent.icon_url,
InfoSource.source_name,
)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id.in_(event_ids))
.order_by(
TrendingEvent.unified_event_id,
TrendingEvent.current_ranking.asc().nulls_last(),
)
.all()
)
result: dict[int, list[dict]] = {}
for event_id, headline, url, ranking, icon_url, source_name in rows:
result.setdefault(event_id, []).append({
"source_name": source_name or "未知",
"headline": headline or "",
"url": url or "",
"ranking": ranking,
"icon_url": icon_url or "",
})
return result
def _load_event_tags(db: Session, event_ids: list[int]) -> dict[int, list[str]]:
"""批量加载事件的标签,返回 event_id → [tag, ...]。"""
if not event_ids:
return {}
rows = (
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id.in_(event_ids),
)
.all()
)
tags_map: dict[int, list[str]] = {}
for eid, kw in rows:
if kw:
tags_map.setdefault(eid, []).append(kw)
return tags_map
def _user_has_keywords(db: Session, user_id: int) -> bool:
"""判断用户是否配置了关键词订阅。"""
return (
db.query(UserTopicPreference.id)
.filter(UserTopicPreference.user_id == user_id)
.first()
) is not None
def _get_default_hot_events(
db: Session,
pushed_ids: set[int],
) -> list[_DefaultEventItem]:
"""
默认模式:获取热度 >= DEFAULT_MODE_HOT_THRESHOLD 的近期热点,
排除已推送过的,封装成与 MatchedEventResult 接口相同的对象。
"""
time_limit = utcnow() - timedelta(hours=DEFAULT_MODE_HOURS)
events = (
db.query(UnifiedEvent)
.filter(
UnifiedEvent.hot_score >= DEFAULT_MODE_HOT_THRESHOLD,
UnifiedEvent.created_at >= time_limit,
)
.order_by(UnifiedEvent.hot_score.desc())
.limit(MAX_EVENTS_PER_PUSH * 2)
.all()
)
event_ids = [e.id for e in events if e.id not in pushed_ids]
tags_map = _load_event_tags(db, event_ids)
result: list[_DefaultEventItem] = []
for ev in events:
if ev.id in pushed_ids:
continue
result.append(_DefaultEventItem(
event=ev,
tags=list(dict.fromkeys(tags_map.get(ev.id, [])))[:6],
))
if len(result) >= MAX_EVENTS_PER_PUSH:
break
return result
def _record_delivery(
db: Session,
user_id: int,
event_ids: list[int],
status: TaskStatus,
) -> None:
"""批量写入推送历史记录。"""
for eid in event_ids:
record = DeliveryHistory(
user_id=user_id,
target_type=TargetType.EVENT,
target_id=eid,
status=status,
)
db.add(record)
db.commit()
# ==========================================
# 推送准备
# ==========================================
@dataclass
class _PendingPush:
"""暂存需要发送邮件的信息,便于在 async 上下文中发送。"""
user_id: int
email_targets: list[str]
subject: str
html_body: str
event_ids: list[int]
def _prepare_user_push(db: Session, user: AppUser, schedule: UserDeliverySchedule) -> _PendingPush | None:
"""
同步准备单个用户的推送数据(DB 操作),不实际发送邮件。
推送优先级:
1. 有关键词 且 有匹配 → 发送匹配事件
2. 有关键词 但 无匹配 → 发送默认热点(热度 >= 3)
3. 无关键词 → 发送默认热点(热度 >= 3)
"""
user_id = user.id
if _should_skip_by_interval(db, user_id):
logger.info(f"用户 {user_id} 仍在 {MIN_PUSH_INTERVAL_MINUTES} 分钟冷却期内,跳过")
return None
email_endpoints = _get_user_email_endpoints(db, user_id)
if not email_endpoints:
logger.info(f"用户 {user_id} 无可用邮件渠道,跳过")
return None
pushed_ids = _get_already_pushed_event_ids(db, user_id)
# ——— 决策:匹配模式 or 默认模式 ———
items: list = []
is_default = False
has_keywords = _user_has_keywords(db, user_id)
if has_keywords:
matched = recommend_events_for_user(
db,
user_id=user_id,
min_hot=1,
hours=72,
limit=MAX_EVENTS_PER_PUSH * 2,
)
fresh_matched = [m for m in matched if m.event.id not in pushed_ids]
if fresh_matched:
items = fresh_matched[:MAX_EVENTS_PER_PUSH]
logger.info(f"用户 {user_id} 关键词匹配,推送 {len(items)} 条事件")
else:
logger.info(f"用户 {user_id} 关键词无匹配结果,切换为默认热点模式")
is_default = True
else:
logger.info(f"用户 {user_id} 未配置关键词,使用默认热点模式")
is_default = True
if is_default:
items = _get_default_hot_events(db, pushed_ids)
if not items:
logger.info(f"用户 {user_id} 默认热点无可推送内容,跳过")
return None
# 批量加载平台数据(来源名、标题、URL、排名)
event_ids = [item.event.id for item in items]
platforms_map = _load_event_platforms(db, event_ids)
time_str = schedule.delivery_time.strftime("%H:%M")
html_body = build_digest_html(
items=items,
delivery_time_str=time_str,
platforms_map=platforms_map,
is_default_push=is_default,
)
subject_suffix = "全网热点快报" if is_default else "个性化简报"
return _PendingPush(
user_id=user_id,
email_targets=[ep.channel_account for ep in email_endpoints],
subject=f"InsightRadar {subject_suffix} · {time_str}",
html_body=html_body,
event_ids=event_ids,
)
# ==========================================
# 调度主入口
# ==========================================
async def check_and_deliver() -> None:
"""
定时推送主入口,由 APScheduler 每分钟调用。
流程:
1. 获取当前 UTC 时间
2. 查询所有启用的推送计划
3. 对每个计划,按用户本地时区判断是否在推送窗口
4. 同步准备推送数据 → 异步发送邮件 → 记录结果
"""
now = datetime.now(timezone.utc)
current_utc = now.time().replace(second=0, microsecond=0)
logger.debug(f"推送调度检查 @ UTC {current_utc.strftime('%H:%M')}")
db: Session = SessionLocal()
try:
active_schedules = (
db.query(UserDeliverySchedule)
.filter(UserDeliverySchedule.is_active == True)
.all()
)
for schedule in active_schedules:
user = db.query(AppUser).filter(AppUser.id == schedule.user_id).first()
if not user:
continue
# 用户本地时间对比(核心时区修正)
user_current = _user_local_time(now, user.timezone)
if not _is_within_window(schedule.delivery_time, user_current):
continue
try:
pending = _prepare_user_push(db, user, schedule)
if pending is None:
continue
# 异步按优先级尝试各邮件渠道
sent = False
for target_email in pending.email_targets:
try:
success = await send_html_email(
to_email=target_email,
subject=pending.subject,
html_content=pending.html_body,
)
if success:
sent = True
logger.info(f"用户 {pending.user_id} 邮件发送成功 → {target_email}")
break
else:
logger.warning(f"用户 {pending.user_id} 渠道 {target_email} 发送失败,尝试下一个")
except Exception as e:
logger.error(f"用户 {pending.user_id} 发送至 {target_email} 异常: {e}")
_record_delivery(
db,
user_id=pending.user_id,
event_ids=pending.event_ids,
status=TaskStatus.SUCCESS if sent else TaskStatus.ERROR,
)
except Exception as e:
logger.error(f"推送用户 {schedule.user_id} 时异常: {e}", exc_info=True)
except Exception as e:
logger.error(f"推送调度主循环异常: {e}", exc_info=True)
finally:
db.close()
+238
View File
@@ -0,0 +1,238 @@
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
from typing import Any
import numpy as np
from sqlalchemy.orm import Session
from app.models.models import ExtractedTopic, TargetType, UnifiedEvent, UserTopicPreference, utcnow
from app.services.fetcher_service import embedder_model
# 语义匹配阈值:用户关键词和事件标签向量相似度达到该值才计入语义命中
DEFAULT_PREFERENCE_SEMANTIC_THRESHOLD = 0.78
PREFERENCE_SEMANTIC_THRESHOLD = float(
os.getenv("PREFERENCE_SEMANTIC_THRESHOLD", str(DEFAULT_PREFERENCE_SEMANTIC_THRESHOLD))
)
# 推荐列表最大返回条数
DEFAULT_PREFERENCE_RECOMMEND_MAX_LIMIT = 50
PREFERENCE_RECOMMEND_MAX_LIMIT = int(
os.getenv("PREFERENCE_RECOMMEND_MAX_LIMIT", str(DEFAULT_PREFERENCE_RECOMMEND_MAX_LIMIT))
)
@dataclass
class MatchedEventResult:
"""用户兴趣匹配后的事件结果。"""
event: UnifiedEvent
match_score: float
exact_hits: list[str]
semantic_hits: list[dict[str, Any]]
tags: list[str]
def _normalize_text(text: str) -> str:
"""统一小写与首尾空白,便于做稳定匹配。"""
return text.strip().casefold()
def _build_keyword_embedding_map(keywords: list[str]) -> dict[str, np.ndarray]:
"""
批量生成关键词向量,并返回原词到向量的映射。
这里要求向量已归一化,后续可直接用点积表示余弦相似度。
"""
if not keywords:
return {}
vectors = embedder_model.encode(keywords, normalize_embeddings=True)
result: dict[str, np.ndarray] = {}
for keyword, vec in zip(keywords, vectors):
result[keyword] = np.asarray(vec, dtype=np.float32)
return result
def _ensure_aware(dt: datetime) -> datetime:
"""SQLite 读出的 datetime 不带时区信息,统一补上 UTC 后才能和 utcnow() 做减法。"""
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
def _calc_freshness_bonus(event: UnifiedEvent) -> float:
"""根据事件新鲜度给一个小额加分,避免旧热点长期占据推荐位。"""
age_hours = max((utcnow() - _ensure_aware(event.created_at)).total_seconds() / 3600.0, 0.0)
if age_hours <= 6:
return 12.0
if age_hours <= 24:
return 8.0
if age_hours <= 72:
return 4.0
return 0.0
def recommend_events_for_user(
db: Session,
*,
user_id: int,
min_hot: int = 3,
hours: int = 72,
limit: int = 20,
semantic_threshold: float | None = None,
) -> list[MatchedEventResult]:
"""
用户兴趣推荐主流程:
1) 精确匹配:用户词 == EVENT 标签
2) 语义匹配:用户词向量 vs EVENT 标签向量(超过阈值)
3) 打分融合:匹配分 + 标签相关度 + 热度 + 新鲜度
"""
final_limit = max(1, min(limit, PREFERENCE_RECOMMEND_MAX_LIMIT))
similarity_threshold = (
semantic_threshold
if semantic_threshold is not None
else PREFERENCE_SEMANTIC_THRESHOLD
)
# 读取用户兴趣词
preferences = (
db.query(UserTopicPreference)
.filter(UserTopicPreference.user_id == user_id)
.all()
)
if not preferences:
return []
preference_keywords = [pref.interested_keyword.strip() for pref in preferences if pref.interested_keyword.strip()]
if not preference_keywords:
return []
# 读取候选事件(先做时间和热度过滤,避免全表扫描)
time_limit = utcnow() - timedelta(hours=hours)
events = (
db.query(UnifiedEvent)
.filter(
UnifiedEvent.hot_score >= min_hot,
UnifiedEvent.created_at >= time_limit,
)
.order_by(UnifiedEvent.hot_score.desc(), UnifiedEvent.created_at.desc())
.all()
)
if not events:
return []
event_id_list = [event.id for event in events]
topic_rows = (
db.query(
ExtractedTopic.target_id,
ExtractedTopic.topic_keyword,
ExtractedTopic.relevance_score,
)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id.in_(event_id_list),
)
.all()
)
if not topic_rows:
return []
# 组织事件标签映射:event_id -> [(tag, relevance_score), ...]
event_topics: dict[int, list[tuple[str, float | None]]] = {}
for event_id, topic_keyword, relevance_score in topic_rows:
if not topic_keyword:
continue
event_topics.setdefault(event_id, []).append((topic_keyword, relevance_score))
# 如果某事件没有标签,就不参与推荐
if not event_topics:
return []
# 批量编码用户词和标签词,避免逐条调用模型
unique_preference_keywords = list(dict.fromkeys(preference_keywords))
unique_topic_keywords = list(dict.fromkeys([row[1] for row in topic_rows if row[1]]))
pref_vec_map = _build_keyword_embedding_map(unique_preference_keywords)
topic_vec_map = _build_keyword_embedding_map(unique_topic_keywords)
# 预先建立“标准化后用户词集合”,用于精确匹配
normalized_pref_set = {_normalize_text(word) for word in unique_preference_keywords}
scored_results: list[MatchedEventResult] = []
for event in events:
topic_list = event_topics.get(event.id, [])
if not topic_list:
continue
exact_hits: list[str] = []
semantic_hits: list[dict[str, Any]] = []
score = 0.0
# 对事件标签逐个匹配用户兴趣
for topic_keyword, topic_relevance in topic_list:
normalized_topic = _normalize_text(topic_keyword)
topic_relevance_score = float(topic_relevance) if topic_relevance is not None else 50.0
# 1) 精确命中(包括完全相等与包含关系)
matched_exact = False
if normalized_topic in normalized_pref_set:
matched_exact = True
else:
for pref_word in normalized_pref_set:
if pref_word and (pref_word in normalized_topic or normalized_topic in pref_word):
matched_exact = True
break
if matched_exact:
exact_hits.append(topic_keyword)
# 精确命中给较高基础分,标签自身相关度作为增益
score += 45.0 + topic_relevance_score * 0.2
continue
# 2) 语义命中(未精确命中时再算)
topic_vec = topic_vec_map.get(topic_keyword)
if topic_vec is None:
continue
best_pref = None
best_sim = -1.0
for pref_keyword, pref_vec in pref_vec_map.items():
sim = float(np.dot(topic_vec, pref_vec))
if sim > best_sim:
best_sim = sim
best_pref = pref_keyword
if best_pref is not None and best_sim >= similarity_threshold:
semantic_hits.append(
{
"preference_keyword": best_pref,
"topic_keyword": topic_keyword,
"similarity": round(best_sim, 4),
}
)
# 语义命中分略低于精确命中,并由相似度放大
score += best_sim * 35.0 + topic_relevance_score * 0.12
# 如果精确和语义都没命中,直接跳过
if not exact_hits and not semantic_hits:
continue
# 融合事件热度和新鲜度,避免只看语义分
score += min(event.hot_score, 100) * 0.3
score += _calc_freshness_bonus(event)
# 返回标签时做去重,保证接口稳定
tags = list(dict.fromkeys([item[0] for item in topic_list]))
scored_results.append(
MatchedEventResult(
event=event,
match_score=round(score, 2),
exact_hits=list(dict.fromkeys(exact_hits)),
semantic_hits=semantic_hits,
tags=tags,
)
)
scored_results.sort(
key=lambda item: (item.match_score, item.event.hot_score, item.event.created_at),
reverse=True,
)
return scored_results[:final_limit]
+175 -38
View File
@@ -1,104 +1,241 @@
# app/services/summary_service.py
import os
import json
import os
from datetime import timedelta
from typing import Any
import numpy as np
from openai import AsyncOpenAI
from app.database import SessionLocal
from app.models.models import UnifiedEvent, TrendingEvent, InfoSource, utcnow
from app.models.models import (
ExtractedTopic,
InfoSource,
TargetType,
TrendingEvent,
UnifiedEvent,
utcnow,
)
from app.prompts.summary_prompts import (
SUMMARY_SYSTEM_PROMPT,
SUMMARY_USER_PROMPT_TEMPLATE,
)
from app.services.fetcher_service import embedder_model
HOT_SCORE_THRESHOLD = int(os.getenv("HOT_SCORE_THRESHOLD", 3))
AI_API_KEY = os.getenv("AI_API_KEY", '')
TOPIC_TAG_MIN_HOT_SCORE = int(os.getenv("TOPIC_TAG_MIN_HOT_SCORE", HOT_SCORE_THRESHOLD))
TOPIC_SIMILARITY_THRESHOLD = float(os.getenv("TOPIC_SIMILARITY_THRESHOLD", 0.82))
TOPIC_TAG_MAX_COUNT = int(os.getenv("TOPIC_TAG_MAX_COUNT", 8))
AI_API_KEY = os.getenv("AI_API_KEY", "")
# 1. 初始化异步客户端 (全局复用)
deepseek_client = AsyncOpenAI(
api_key=AI_API_KEY,
base_url="https://api.deepseek.com"
base_url="https://api.deepseek.com",
)
async def call_llm_for_summary(platform_data_text: str) -> dict:
"""调用 DeepSeek 生成统一标题和多平台视角摘要"""
prompt = SUMMARY_USER_PROMPT_TEMPLATE.format(
platform_data_text=platform_data_text
)
"""Call LLM for unified title, summary and topic candidates."""
prompt = SUMMARY_USER_PROMPT_TEMPLATE.format(platform_data_text=platform_data_text)
# await
response = await deepseek_client.chat.completions.create(
model="deepseek-chat",
messages=[
{"role": "system", "content": SUMMARY_SYSTEM_PROMPT},
{"role": "user", "content": prompt}
{"role": "user", "content": prompt},
],
response_format={"type": "json_object"},
temperature=1
temperature=1,
)
result_text = response.choices[0].message.content
return json.loads(result_text)
def _normalize_score(raw_score: Any) -> float | None:
try:
score = float(raw_score)
except (TypeError, ValueError):
return None
if score <= 1:
score *= 100
return max(0.0, min(100.0, score))
def parse_topic_keywords(llm_result: dict) -> list[dict[str, Any]]:
"""Parse topic keywords from LLM response; support list[str] and list[object]."""
raw_topics = llm_result.get("topic_keywords") or []
parsed: list[dict[str, Any]] = []
seen: set[str] = set()
for item in raw_topics:
keyword = ""
score = None
if isinstance(item, str):
keyword = item.strip()
elif isinstance(item, dict):
raw_keyword = (
item.get("keyword")
or item.get("topic_keyword")
or item.get("name")
or item.get("topic")
or ""
)
keyword = str(raw_keyword).strip()
score = _normalize_score(item.get("relevance_score") or item.get("score"))
if not keyword:
continue
keyword = keyword[:100]
normalized_key = keyword.casefold()
if normalized_key in seen:
continue
seen.add(normalized_key)
parsed.append({"keyword": keyword, "score": score})
return parsed
def normalize_topic_keywords(topic_candidates: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Deduplicate semantically similar tags using embedding similarity."""
if not topic_candidates:
return []
keywords = [item["keyword"] for item in topic_candidates]
vectors = embedder_model.encode(keywords, normalize_embeddings=True)
clusters: list[dict[str, Any]] = []
for item, vector in zip(topic_candidates, vectors):
vec = np.asarray(vector, dtype=np.float32)
best_idx = -1
best_sim = -1.0
for idx, cluster in enumerate(clusters):
sim = float(np.dot(vec, cluster["vector"]))
if sim > best_sim:
best_sim = sim
best_idx = idx
if best_idx >= 0 and best_sim >= TOPIC_SIMILARITY_THRESHOLD:
cluster = clusters[best_idx]
merged = cluster["vector"] * cluster["count"] + vec
norm = float(np.linalg.norm(merged))
if norm > 0:
cluster["vector"] = merged / norm
cluster["count"] += 1
if item["score"] is not None and (
cluster["score"] is None or item["score"] > cluster["score"]
):
cluster["score"] = item["score"]
# Prefer shorter tag as canonical keyword.
if len(item["keyword"]) < len(cluster["keyword"]):
cluster["keyword"] = item["keyword"]
else:
clusters.append(
{
"keyword": item["keyword"],
"score": item["score"],
"vector": vec,
"count": 1,
}
)
if any(cluster["score"] is not None for cluster in clusters):
clusters.sort(key=lambda x: x["score"] if x["score"] is not None else -1.0, reverse=True)
result = [
{"keyword": cluster["keyword"], "score": cluster["score"]}
for cluster in clusters[:TOPIC_TAG_MAX_COUNT]
]
return result
def replace_event_topics(db, event_id: int, normalized_topics: list[dict[str, Any]]) -> None:
"""Replace EVENT tags for one unified event atomically within current transaction."""
db.query(ExtractedTopic).filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id == event_id,
).delete(synchronize_session=False)
for item in normalized_topics:
db.add(
ExtractedTopic(
target_type=TargetType.EVENT,
target_id=event_id,
topic_keyword=item["keyword"],
relevance_score=item["score"],
)
)
async def generate_unified_summaries():
"""定时任务:扫描高热度事件并生成/更新摘要"""
print(f"[{utcnow()}] 开始执行 DeepSeek 摘要生成任务...")
"""Scheduled task: refresh summaries and topic tags for hot unified events."""
print(f"[{utcnow()}] Start unified summary generation task...")
with SessionLocal() as db:
recent_threshold = utcnow() - timedelta(days=3)
# 必须满足:热度达标 AND (当前热度 > 上次生成摘要时的热度) AND 近期活跃
events = db.query(UnifiedEvent).filter(
UnifiedEvent.hot_score >= HOT_SCORE_THRESHOLD,
UnifiedEvent.hot_score > UnifiedEvent.last_summarized_trends_count,
UnifiedEvent.created_at >= recent_threshold
UnifiedEvent.created_at >= recent_threshold,
).all()
if not events:
print("当前没有需要更新摘要的大事件,任务结束。")
print("No events require summary update in this round.")
return
for event in events:
# 联合查询获取该事件在各平台的子新闻
trends = db.query(TrendingEvent, InfoSource.source_name) \
.join(InfoSource, TrendingEvent.source_id == InfoSource.id) \
.filter(TrendingEvent.unified_event_id == event.id) \
trends = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id == event.id)
.all()
)
if not trends:
continue
# 按平台归类标题并去重
platform_dict = {}
platform_dict: dict[str, set[str]] = {}
for trend_record, source_name in trends:
if source_name not in platform_dict:
platform_dict[source_name] = set()
platform_dict[source_name].add(trend_record.current_headline)
platform_dict.setdefault(source_name, set()).add(trend_record.current_headline)
# 组装给大模型的 Prompt 数据
prompt_lines = [f"{platform}】: {', '.join(headlines)}" for platform, headlines in platform_dict.items()]
prompt_lines = [
f"[{platform}] {', '.join(sorted(headlines))}"
for platform, headlines in platform_dict.items()
]
platform_data_text = "\n".join(prompt_lines)
try:
# 调用封装好的异步函数
llm_result = await call_llm_for_summary(platform_data_text)
if "unified_title" in llm_result:
if "unified_title" in llm_result and llm_result["unified_title"]:
event.unified_title = llm_result["unified_title"]
if "ai_comprehensive_summary" in llm_result:
if "ai_comprehensive_summary" in llm_result and llm_result["ai_comprehensive_summary"]:
event.ai_comprehensive_summary = llm_result["ai_comprehensive_summary"]
# 成功后更新水位线
# 将最后一次总结时的热搜数量,更新为当前最新的 hot_score
if event.hot_score >= TOPIC_TAG_MIN_HOT_SCORE:
topic_candidates = parse_topic_keywords(llm_result)
normalized_topics = normalize_topic_keywords(topic_candidates)
if normalized_topics:
replace_event_topics(db, event.id, normalized_topics)
event.last_summarized_trends_count = event.hot_score
print(
f"Updated event {event.id} summary"
f" (hot_score={event.hot_score})."
)
print(f"成功更新大事件 ID {event.id} 的深度摘要 (当前热度: {event.hot_score})。")
except Exception as e:
print(f"大事件 ID {event.id} 摘要生成失败: {e}")
except Exception as exc:
print(f"Event {event.id} summary generation failed: {exc}")
continue
# 提交事务
db.commit()
+1 -1
View File
@@ -17,7 +17,7 @@ async def send_html_email(
to_email: str,
subject: str,
html_content: str,
sender_name: str = "AI 新闻早报",
sender_name: str = "AI 新闻",
sender_email: str = None
) -> bool:
"""底层纯异步发送邮件工具"""