# app/api/endpoints/events.py import json import os import time from datetime import datetime, timedelta, timezone from typing import Dict, Tuple import numpy as np from dotenv import load_dotenv from fastapi import APIRouter, Depends, HTTPException, Query from sqlalchemy.orm import Session from app.api.dependencies import get_db from app.models.models import ( ExtractedTopic, InfoSource, RankingLog, TargetType, TrendingEvent, UnifiedEvent, utcnow, ) from app.schemas.event_schema import ( PaginatedUnifiedEventResponse, PlatformTrendResponse, SearchTimelineResponse, TimelineDataPoint, UnifiedEventResponse, ) from app.services.fetcher_service import embedder_model load_dotenv() SEARCH_EMBEDDING_THRESHOLD = float(os.getenv("SEARCH_EMBEDDING_THRESHOLD", "0.75")) SEARCH_MAX_LIMIT = int(os.getenv("SEARCH_MAX_LIMIT", "30")) SEARCH_DEFAULT_HOURS = int(os.getenv("SEARCH_DEFAULT_HOURS", "168")) SEARCH_MAX_HOURS = int(os.getenv("SEARCH_MAX_HOURS", "168")) router = APIRouter() # 排名轨迹最多返回的点数,避免时间跨度过大时响应体过重。 MAX_RANKING_POINTS = 30 # 统一事件列表接口的短期缓存。 _UNIFIED_EVENTS_CACHE: Dict[str, Tuple[float, PaginatedUnifiedEventResponse]] = {} CACHE_TTL_SECONDS = 60 def _load_vector(raw_embedding: str | None) -> np.ndarray | None: """将字符串形式的向量安全解析为 numpy 数组。""" if not raw_embedding: return None try: return np.asarray(json.loads(raw_embedding), dtype=np.float32) except (TypeError, ValueError, json.JSONDecodeError): return None def _ensure_aware_datetime(dt: datetime) -> datetime: """确保 datetime 带时区;SQLite 返回无时区值时按 UTC 解释。""" if dt.tzinfo is None: return dt.replace(tzinfo=timezone.utc) return dt @router.get("/unified", response_model=PaginatedUnifiedEventResponse) def list_unified_events( min_hot: int = Query(5, ge=0, description="最低热度阈值"), hours: int = Query(48, ge=1, le=720, description="查询最近多少小时的数据"), sort_by: str = Query("hot_score", description="排序字段:hot_score | created_at"), skip: int = Query(0, ge=0, description="分页偏移量"), limit: int = Query(10, ge=1, le=50, description="每页条数"), db: Session = Depends(get_db), ): """查询统一事件列表,并附带平台趋势与标签信息。""" cache_key = f"{min_hot}:{hours}:{sort_by}:{skip}:{limit}" current_time = time.time() if cache_key in _UNIFIED_EVENTS_CACHE: expire_time, cached_data = _UNIFIED_EVENTS_CACHE[cache_key] if current_time < expire_time: return cached_data time_limit = utcnow() - timedelta(hours=hours) base_query = db.query(UnifiedEvent).filter( UnifiedEvent.hot_score >= min_hot, UnifiedEvent.created_at >= time_limit, ) total = base_query.count() if sort_by == "created_at": base_query = base_query.order_by(UnifiedEvent.created_at.desc()) else: base_query = base_query.order_by(UnifiedEvent.hot_score.desc(), UnifiedEvent.created_at.desc()) events = base_query.offset(skip).limit(limit).all() if not events: return PaginatedUnifiedEventResponse(total=total, has_more=False, data=[]) event_ids = [ev.id for ev in events] trend_rows = ( db.query(TrendingEvent, InfoSource.source_name) .join(InfoSource, TrendingEvent.source_id == InfoSource.id) .filter(TrendingEvent.unified_event_id.in_(event_ids)) .all() ) trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {} trend_ids: list[int] = [] for trend, source_name in trend_rows: trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name)) trend_ids.append(trend.id) ranking_map: dict[int, list[int]] = {} if trend_ids: ranking_rows = ( db.query(RankingLog.event_id, RankingLog.ranking_position) .filter( RankingLog.event_id.in_(trend_ids), RankingLog.observed_at >= time_limit, ) .order_by(RankingLog.event_id, RankingLog.observed_at.asc()) .all() ) for event_id, position in ranking_rows: ranking_map.setdefault(event_id, []).append(position) tag_map: dict[int, list[str]] = {} tag_rows = ( db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword) .filter( ExtractedTopic.target_type == TargetType.EVENT, ExtractedTopic.target_id.in_(event_ids), ) .order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc()) .all() ) for target_id, keyword in tag_rows: tag_map.setdefault(target_id, []).append(keyword) results: list[UnifiedEventResponse] = [] for ev in events: platform_list: list[PlatformTrendResponse] = [] trends_for_ev = trend_map.get(ev.id, []) for trend, source_name in trends_for_ev: history = ranking_map.get(trend.id, []) if len(history) > MAX_RANKING_POINTS: history = history[-MAX_RANKING_POINTS:] platform_list.append( PlatformTrendResponse( source_id=trend.source_id, platform_name=source_name, headline=trend.current_headline, url=trend.event_url, current_ranking=trend.current_ranking, ranking_history=history, ) ) last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at results.append( UnifiedEventResponse( event_id=ev.id, unified_title=ev.unified_title if ev.unified_title else "Untitled", summary=ev.ai_comprehensive_summary, hot_score=ev.hot_score, created_at=ev.created_at, last_active_at=last_active_at, platforms=platform_list, tags=tag_map.get(ev.id, []), ) ) has_more = (skip + limit) < total response = PaginatedUnifiedEventResponse(total=total, has_more=has_more, data=results) if len(_UNIFIED_EVENTS_CACHE) > 1000: _UNIFIED_EVENTS_CACHE.clear() _UNIFIED_EVENTS_CACHE[cache_key] = (current_time + CACHE_TTL_SECONDS, response) return response @router.get("/unified/{event_id}", response_model=UnifiedEventResponse) def get_unified_event( event_id: int, db: Session = Depends(get_db), ): """按事件 ID 获取单个统一事件。""" ev = db.query(UnifiedEvent).filter(UnifiedEvent.id == event_id).first() if not ev: raise HTTPException(status_code=404, detail="Event not found") time_limit = utcnow() - timedelta(hours=720) trend_rows = ( db.query(TrendingEvent, InfoSource.source_name) .join(InfoSource, TrendingEvent.source_id == InfoSource.id) .filter(TrendingEvent.unified_event_id == event_id) .all() ) trend_ids = [t.id for t, _ in trend_rows] ranking_map: dict[int, list[int]] = {} if trend_ids: ranking_rows = ( db.query(RankingLog.event_id, RankingLog.ranking_position) .filter( RankingLog.event_id.in_(trend_ids), RankingLog.observed_at >= time_limit, ) .order_by(RankingLog.event_id, RankingLog.observed_at.asc()) .all() ) for eid, pos in ranking_rows: ranking_map.setdefault(eid, []).append(pos) tag_rows = ( db.query(ExtractedTopic.topic_keyword) .filter( ExtractedTopic.target_type == TargetType.EVENT, ExtractedTopic.target_id == event_id, ) .order_by(ExtractedTopic.relevance_score.desc()) .all() ) tags = [row[0] for row in tag_rows] platform_list: list[PlatformTrendResponse] = [] for trend, source_name in trend_rows: history = ranking_map.get(trend.id, []) if len(history) > MAX_RANKING_POINTS: history = history[-MAX_RANKING_POINTS:] platform_list.append( PlatformTrendResponse( source_id=trend.source_id, platform_name=source_name, headline=trend.current_headline, url=trend.event_url, current_ranking=trend.current_ranking, ranking_history=history, ) ) last_active_at = max(t.updated_at for t, _ in trend_rows) if trend_rows else ev.updated_at return UnifiedEventResponse( event_id=ev.id, unified_title=ev.unified_title if ev.unified_title else "Untitled", summary=ev.ai_comprehensive_summary, hot_score=ev.hot_score, created_at=ev.created_at, last_active_at=last_active_at, platforms=platform_list, tags=tags, ) @router.get("/search_timeline", response_model=SearchTimelineResponse) def search_events_timeline( keyword: str = Query(..., description="搜索关键词,支持正则表达式"), hours: int = Query(None, ge=1, le=SEARCH_MAX_HOURS, description="查询最近多少小时的数据"), mode: str = Query("hybrid", description="匹配模式:exact | semantic | hybrid"), semantic_threshold: float = Query(None, ge=0.0, le=1.0, description="语义匹配相似度阈值"), utc_offset_minutes: int | None = Query(None, ge=-840, le=840, description="客户端相对 UTC 的分钟偏移,东八区传 480"), db: Session = Depends(get_db), ): import re query_text = (keyword or "").strip() if not query_text: raise HTTPException(status_code=400, detail="keyword cannot be empty") if hours is None: hours = SEARCH_DEFAULT_HOURS if hours > SEARCH_MAX_HOURS: hours = SEARCH_MAX_HOURS match_mode = (mode or "hybrid").strip().lower() if match_mode not in {"exact", "semantic", "hybrid"}: match_mode = "hybrid" use_regex = match_mode in {"exact", "hybrid"} use_semantic = match_mode in {"semantic", "hybrid"} sim_threshold = semantic_threshold if semantic_threshold is not None else SEARCH_EMBEDDING_THRESHOLD pattern = None if use_regex: try: pattern = re.compile(query_text, re.IGNORECASE) except re.error: pattern = re.compile(re.escape(query_text), re.IGNORECASE) query_vec: np.ndarray | None = None if use_semantic: try: query_encoded = embedder_model.encode([query_text], normalize_embeddings=True, show_progress_bar=False) if len(query_encoded) > 0: query_vec = np.asarray(query_encoded[0], dtype=np.float32) except Exception: query_vec = None if match_mode == "semantic" and query_vec is None: use_regex = True if pattern is None: pattern = re.compile(re.escape(query_text), re.IGNORECASE) time_limit = utcnow() - timedelta(hours=hours) date_format = "%Y-%m-%d %H:00" if hours <= 48 else "%Y-%m-%d" display_timezone = timezone.utc if utc_offset_minutes is not None: display_timezone = timezone(timedelta(minutes=utc_offset_minutes)) def _bucket_label(dt: datetime) -> str: aware_dt = _ensure_aware_datetime(dt) return aware_dt.astimezone(display_timezone).strftime(date_format) all_recent_unified = db.query(UnifiedEvent).filter(UnifiedEvent.created_at >= time_limit).all() all_recent_trends = db.query(TrendingEvent).filter(TrendingEvent.created_at >= time_limit).all() matched_event_ids: set[int] = set() matched_trend_points: list[tuple[int, str]] = [] for ev in all_recent_unified: text_matched = False if use_regex and pattern is not None: text_to_search = f"{ev.unified_title or ''} {ev.ai_comprehensive_summary or ''}" text_matched = bool(pattern.search(text_to_search)) semantic_matched = False if use_semantic and query_vec is not None: ev_vec = _load_vector(ev.center_embedding) if ev_vec is not None: semantic_matched = float(np.dot(query_vec, ev_vec)) >= sim_threshold if text_matched or semantic_matched: matched_event_ids.add(ev.id) for trend in all_recent_trends: text_matched = False if use_regex and pattern is not None and trend.current_headline: text_matched = bool(pattern.search(trend.current_headline)) semantic_matched = False if use_semantic and query_vec is not None: trend_vec = _load_vector(trend.title_embedding) if trend_vec is not None: semantic_matched = float(np.dot(query_vec, trend_vec)) >= sim_threshold if (text_matched or semantic_matched) and trend.unified_event_id: matched_event_ids.add(trend.unified_event_id) matched_trend_points.append((trend.unified_event_id, _bucket_label(trend.created_at))) matched_unified_events = [ev for ev in all_recent_unified if ev.id in matched_event_ids] matched_unified_events.sort(key=lambda x: x.created_at, reverse=True) display_events = matched_unified_events[:100] display_event_ids = [ev.id for ev in display_events] display_event_id_set = set(display_event_ids) trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {} trend_ids: list[int] = [] if display_event_ids: trend_rows = ( db.query(TrendingEvent, InfoSource.source_name) .join(InfoSource, TrendingEvent.source_id == InfoSource.id) .filter(TrendingEvent.unified_event_id.in_(display_event_ids)) .all() ) for trend, source_name in trend_rows: trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name)) trend_ids.append(trend.id) ranking_map: dict[int, list[int]] = {} if trend_ids: ranking_rows = ( db.query(RankingLog.event_id, RankingLog.ranking_position) .filter( RankingLog.event_id.in_(trend_ids), RankingLog.observed_at >= time_limit, ) .order_by(RankingLog.event_id, RankingLog.observed_at.asc()) .all() ) for event_id, position in ranking_rows: ranking_map.setdefault(event_id, []).append(position) timeline_event_map: dict[str, set[int]] = {} for ev in display_events: timeline_event_map.setdefault(_bucket_label(ev.created_at), set()).add(ev.id) for event_id, time_label in matched_trend_points: if event_id in display_event_id_set: timeline_event_map.setdefault(time_label, set()).add(event_id) timeline_data = [ TimelineDataPoint(time_label=time_label, count=len(event_ids), event_ids=sorted(event_ids)) for time_label, event_ids in sorted(timeline_event_map.items()) ] results: list[UnifiedEventResponse] = [] if display_event_ids: tag_map: dict[int, list[str]] = {} tag_rows = ( db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword) .filter( ExtractedTopic.target_type == TargetType.EVENT, ExtractedTopic.target_id.in_(display_event_ids), ) .order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc()) .all() ) for target_id, kw in tag_rows: tag_map.setdefault(target_id, []).append(kw) for ev in display_events: trends_for_ev = trend_map.get(ev.id, []) platform_list: list[PlatformTrendResponse] = [] seen_platforms = set() for trend, source_name in trends_for_ev: uniq_key = f"{source_name}_{trend.current_headline}" if uniq_key in seen_platforms: continue seen_platforms.add(uniq_key) history = ranking_map.get(trend.id, []) if len(history) > MAX_RANKING_POINTS: history = history[-MAX_RANKING_POINTS:] platform_list.append( PlatformTrendResponse( source_id=trend.source_id, platform_name=source_name, headline=trend.current_headline, url=trend.event_url, current_ranking=trend.current_ranking, ranking_history=history, ) ) last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at results.append( UnifiedEventResponse( event_id=ev.id, unified_title=ev.unified_title if ev.unified_title else "Untitled", summary=ev.ai_comprehensive_summary, hot_score=ev.hot_score, created_at=ev.created_at, last_active_at=last_active_at, platforms=platform_list, tags=tag_map.get(ev.id, []), ) ) return SearchTimelineResponse(keyword=query_text, timeline=timeline_data, events=results)