Files
2026-04-20 15:53:02 +08:00

465 lines
17 KiB
Python

import json
import os
import time
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
import numpy as np
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
from app.api.dependencies import get_db
from app.models.models import (
ExtractedTopic,
InfoSource,
RankingLog,
TargetType,
TrendingEvent,
UnifiedEvent,
utcnow,
)
from app.schemas.event_schema import (
PaginatedUnifiedEventResponse,
PlatformTrendResponse,
SearchTimelineResponse,
TimelineDataPoint,
UnifiedEventResponse,
)
from app.services.fetcher_service import embedder_model
load_dotenv()
SEARCH_EMBEDDING_THRESHOLD = float(os.getenv("SEARCH_EMBEDDING_THRESHOLD", "0.75"))
SEARCH_MAX_LIMIT = int(os.getenv("SEARCH_MAX_LIMIT", "30"))
SEARCH_DEFAULT_HOURS = int(os.getenv("SEARCH_DEFAULT_HOURS", "168"))
SEARCH_MAX_HOURS = int(os.getenv("SEARCH_MAX_HOURS", "168"))
router = APIRouter()
MAX_RANKING_POINTS = 30
_UNIFIED_EVENTS_CACHE: Dict[str, Tuple[float, PaginatedUnifiedEventResponse]] = {}
CACHE_TTL_SECONDS = 60
def _load_vector(raw_embedding: str | None) -> np.ndarray | None:
"""将字符串形式的向量安全解析为 numpy 数组。"""
if not raw_embedding:
return None
try:
return np.asarray(json.loads(raw_embedding), dtype=np.float32)
except (TypeError, ValueError, json.JSONDecodeError):
return None
def _ensure_aware_datetime(dt: datetime) -> datetime:
"""确保 datetime 带时区;SQLite 返回无时区值时按 UTC 解释。"""
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
@router.get("/unified", response_model=PaginatedUnifiedEventResponse)
def list_unified_events(
min_hot: int = Query(5, ge=0, description="最低热度阈值"),
hours: int = Query(48, ge=1, le=720, description="查询最近多少小时的数据"),
sort_by: str = Query("hot_score", description="排序字段:hot_score | created_at"),
skip: int = Query(0, ge=0, description="分页偏移量"),
limit: int = Query(10, ge=1, le=50, description="每页条数"),
db: Session = Depends(get_db),
):
"""查询统一事件列表,并附带平台趋势与标签信息。"""
# 短期内存缓存,减轻高并发下数据库压力
cache_key = f"{min_hot}:{hours}:{sort_by}:{skip}:{limit}"
current_time = time.time()
if cache_key in _UNIFIED_EVENTS_CACHE:
expire_time, cached_data = _UNIFIED_EVENTS_CACHE[cache_key]
if current_time < expire_time:
return cached_data
time_limit = utcnow() - timedelta(hours=hours)
# 按热度、时间过滤,再关联平台趋势、排名轨迹、标签
base_query = db.query(UnifiedEvent).filter(
UnifiedEvent.hot_score >= min_hot,
UnifiedEvent.created_at >= time_limit,
)
total = base_query.count()
if sort_by == "created_at":
base_query = base_query.order_by(UnifiedEvent.created_at.desc())
else:
base_query = base_query.order_by(UnifiedEvent.hot_score.desc(), UnifiedEvent.created_at.desc())
events = base_query.offset(skip).limit(limit).all()
if not events:
return PaginatedUnifiedEventResponse(total=total, has_more=False, data=[])
event_ids = [ev.id for ev in events]
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id.in_(event_ids))
.all()
)
trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {}
trend_ids: list[int] = []
for trend, source_name in trend_rows:
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
trend_ids.append(trend.id)
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(RankingLog.event_id, RankingLog.ranking_position)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for event_id, position in ranking_rows:
ranking_map.setdefault(event_id, []).append(position)
tag_map: dict[int, list[str]] = {}
tag_rows = (
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id.in_(event_ids),
)
.order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc())
.all()
)
for target_id, keyword in tag_rows:
tag_map.setdefault(target_id, []).append(keyword)
results: list[UnifiedEventResponse] = []
for ev in events:
platform_list: list[PlatformTrendResponse] = []
trends_for_ev = trend_map.get(ev.id, [])
for trend, source_name in trends_for_ev:
history = ranking_map.get(trend.id, [])
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=history,
)
)
last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at
results.append(
UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "Untitled",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
last_active_at=last_active_at,
platforms=platform_list,
tags=tag_map.get(ev.id, []),
)
)
has_more = (skip + limit) < total
response = PaginatedUnifiedEventResponse(total=total, has_more=has_more, data=results)
if len(_UNIFIED_EVENTS_CACHE) > 1000:
_UNIFIED_EVENTS_CACHE.clear()
_UNIFIED_EVENTS_CACHE[cache_key] = (current_time + CACHE_TTL_SECONDS, response)
return response
@router.get("/unified/{event_id}", response_model=UnifiedEventResponse)
def get_unified_event(
event_id: int,
db: Session = Depends(get_db),
):
"""按事件 ID 获取单个统一事件。"""
ev = db.query(UnifiedEvent).filter(UnifiedEvent.id == event_id).first()
if not ev:
raise HTTPException(status_code=404, detail="Event not found")
time_limit = utcnow() - timedelta(hours=720)
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id == event_id)
.all()
)
trend_ids = [t.id for t, _ in trend_rows]
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(RankingLog.event_id, RankingLog.ranking_position)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for eid, pos in ranking_rows:
ranking_map.setdefault(eid, []).append(pos)
tag_rows = (
db.query(ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id == event_id,
)
.order_by(ExtractedTopic.relevance_score.desc())
.all()
)
tags = [row[0] for row in tag_rows]
platform_list: list[PlatformTrendResponse] = []
for trend, source_name in trend_rows:
history = ranking_map.get(trend.id, [])
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=history,
)
)
last_active_at = max(t.updated_at for t, _ in trend_rows) if trend_rows else ev.updated_at
return UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "Untitled",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
last_active_at=last_active_at,
platforms=platform_list,
tags=tags,
)
@router.get("/search_timeline", response_model=SearchTimelineResponse)
def search_events_timeline(
keyword: str = Query(..., description="搜索关键词,支持正则表达式"),
hours: int = Query(None, ge=1, le=SEARCH_MAX_HOURS, description="查询最近多少小时的数据"),
mode: str = Query("hybrid", description="匹配模式:exact | semantic | hybrid"),
semantic_threshold: float = Query(None, ge=0.0, le=1.0, description="语义匹配相似度阈值"),
utc_offset_minutes: int | None = Query(None, ge=-840, le=840, description="客户端相对 UTC 的分钟偏移,东八区传 480"),
db: Session = Depends(get_db),
):
import re
query_text = (keyword or "").strip()
if not query_text:
raise HTTPException(status_code=400, detail="keyword cannot be empty")
if hours is None:
hours = SEARCH_DEFAULT_HOURS
if hours > SEARCH_MAX_HOURS:
hours = SEARCH_MAX_HOURS
match_mode = (mode or "hybrid").strip().lower()
if match_mode not in {"exact", "semantic", "hybrid"}:
match_mode = "hybrid"
use_regex = match_mode in {"exact", "hybrid"}
use_semantic = match_mode in {"semantic", "hybrid"}
sim_threshold = semantic_threshold if semantic_threshold is not None else SEARCH_EMBEDDING_THRESHOLD
pattern = None
if use_regex:
try:
pattern = re.compile(query_text, re.IGNORECASE)
except re.error:
pattern = re.compile(re.escape(query_text), re.IGNORECASE)
query_vec: np.ndarray | None = None
if use_semantic:
try:
query_encoded = embedder_model.encode([query_text], normalize_embeddings=True, show_progress_bar=False)
if len(query_encoded) > 0:
query_vec = np.asarray(query_encoded[0], dtype=np.float32)
except Exception:
query_vec = None
if match_mode == "semantic" and query_vec is None:
use_regex = True
if pattern is None:
pattern = re.compile(re.escape(query_text), re.IGNORECASE)
time_limit = utcnow() - timedelta(hours=hours)
date_format = "%Y-%m-%d %H:00" if hours <= 48 else "%Y-%m-%d"
display_timezone = timezone.utc
if utc_offset_minutes is not None:
display_timezone = timezone(timedelta(minutes=utc_offset_minutes))
def _bucket_label(dt: datetime) -> str:
aware_dt = _ensure_aware_datetime(dt)
return aware_dt.astimezone(display_timezone).strftime(date_format)
all_recent_unified = db.query(UnifiedEvent).filter(UnifiedEvent.created_at >= time_limit).all()
all_recent_trends = db.query(TrendingEvent).filter(TrendingEvent.created_at >= time_limit).all()
matched_event_ids: set[int] = set()
matched_trend_points: list[tuple[int, str]] = []
# 遍历统一事件与平台趋势,按模式做精确/语义匹配
for ev in all_recent_unified:
text_matched = False
if use_regex and pattern is not None:
text_to_search = f"{ev.unified_title or ''} {ev.ai_comprehensive_summary or ''}"
text_matched = bool(pattern.search(text_to_search))
semantic_matched = False
if use_semantic and query_vec is not None:
ev_vec = _load_vector(ev.center_embedding)
if ev_vec is not None:
semantic_matched = float(np.dot(query_vec, ev_vec)) >= sim_threshold
if text_matched or semantic_matched:
matched_event_ids.add(ev.id)
for trend in all_recent_trends:
text_matched = False
if use_regex and pattern is not None and trend.current_headline:
text_matched = bool(pattern.search(trend.current_headline))
semantic_matched = False
if use_semantic and query_vec is not None:
trend_vec = _load_vector(trend.title_embedding)
if trend_vec is not None:
semantic_matched = float(np.dot(query_vec, trend_vec)) >= sim_threshold
if (text_matched or semantic_matched) and trend.unified_event_id:
matched_event_ids.add(trend.unified_event_id)
matched_trend_points.append((trend.unified_event_id, _bucket_label(trend.created_at)))
matched_unified_events = [ev for ev in all_recent_unified if ev.id in matched_event_ids]
matched_unified_events.sort(key=lambda x: x.created_at, reverse=True)
display_events = matched_unified_events[:100]
display_event_ids = [ev.id for ev in display_events]
display_event_id_set = set(display_event_ids)
trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {}
trend_ids: list[int] = []
if display_event_ids:
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id.in_(display_event_ids))
.all()
)
for trend, source_name in trend_rows:
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
trend_ids.append(trend.id)
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(RankingLog.event_id, RankingLog.ranking_position)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for event_id, position in ranking_rows:
ranking_map.setdefault(event_id, []).append(position)
timeline_event_map: dict[str, set[int]] = {}
for ev in display_events:
timeline_event_map.setdefault(_bucket_label(ev.created_at), set()).add(ev.id)
for event_id, time_label in matched_trend_points:
if event_id in display_event_id_set:
timeline_event_map.setdefault(time_label, set()).add(event_id)
timeline_data = [
TimelineDataPoint(time_label=time_label, count=len(event_ids), event_ids=sorted(event_ids))
for time_label, event_ids in sorted(timeline_event_map.items())
]
results: list[UnifiedEventResponse] = []
if display_event_ids:
tag_map: dict[int, list[str]] = {}
tag_rows = (
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id.in_(display_event_ids),
)
.order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc())
.all()
)
for target_id, kw in tag_rows:
tag_map.setdefault(target_id, []).append(kw)
for ev in display_events:
trends_for_ev = trend_map.get(ev.id, [])
platform_list: list[PlatformTrendResponse] = []
seen_platforms = set()
for trend, source_name in trends_for_ev:
uniq_key = f"{source_name}_{trend.current_headline}"
if uniq_key in seen_platforms:
continue
seen_platforms.add(uniq_key)
history = ranking_map.get(trend.id, [])
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=history,
)
)
last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at
results.append(
UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "Untitled",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
last_active_at=last_active_at,
platforms=platform_list,
tags=tag_map.get(ev.id, []),
)
)
return SearchTimelineResponse(keyword=query_text, timeline=timeline_data, events=results)