搜索功能加入

This commit is contained in:
stardrophere
2026-03-13 18:25:38 +08:00
parent 9440b7f590
commit 6aee65af6c
18 changed files with 1545 additions and 103 deletions
+253 -46
View File
@@ -1,8 +1,12 @@
# app/api/endpoints/events.py
# app/api/endpoints/events.py
import json
import os
import time
from datetime import timedelta
from typing import Dict, List, Tuple
from datetime import datetime, timedelta, timezone
from typing import Dict, Tuple
import numpy as np
from dotenv import load_dotenv
from fastapi import APIRouter, Depends, HTTPException, Query
from sqlalchemy.orm import Session
@@ -19,63 +23,83 @@ from app.models.models import (
from app.schemas.event_schema import (
PaginatedUnifiedEventResponse,
PlatformTrendResponse,
SearchTimelineResponse,
TimelineDataPoint,
UnifiedEventResponse,
)
from app.services.fetcher_service import embedder_model
load_dotenv()
SEARCH_EMBEDDING_THRESHOLD = float(os.getenv("SEARCH_EMBEDDING_THRESHOLD", "0.75"))
SEARCH_MAX_LIMIT = int(os.getenv("SEARCH_MAX_LIMIT", "30"))
SEARCH_DEFAULT_HOURS = int(os.getenv("SEARCH_DEFAULT_HOURS", "168"))
SEARCH_MAX_HOURS = int(os.getenv("SEARCH_MAX_HOURS", "168"))
router = APIRouter()
# 排名轨迹最多返回多少个点,避免时间跨度下数据过大
# 排名轨迹最多返回的点数,避免时间跨度过大时响应体过重。
MAX_RANKING_POINTS = 30
# --- 轻量级接口缓存配置 ---
# 统一事件列表接口的短期缓存。
_UNIFIED_EVENTS_CACHE: Dict[str, Tuple[float, PaginatedUnifiedEventResponse]] = {}
CACHE_TTL_SECONDS = 60
# ---------------------------
def _load_vector(raw_embedding: str | None) -> np.ndarray | None:
"""将字符串形式的向量安全解析为 numpy 数组。"""
if not raw_embedding:
return None
try:
return np.asarray(json.loads(raw_embedding), dtype=np.float32)
except (TypeError, ValueError, json.JSONDecodeError):
return None
def _ensure_aware_datetime(dt: datetime) -> datetime:
"""确保 datetime 带时区;SQLite 返回无时区值时按 UTC 解释。"""
if dt.tzinfo is None:
return dt.replace(tzinfo=timezone.utc)
return dt
@router.get("/unified", response_model=PaginatedUnifiedEventResponse)
def list_unified_events(
min_hot: int = Query(5, ge=0, description="热度阈值,仅返回 hot_score >= 此值的事件"),
min_hot: int = Query(5, ge=0, description="最低热度阈值"),
hours: int = Query(48, ge=1, le=720, description="查询最近多少小时的数据"),
sort_by: str = Query("hot_score", description="排序字段: hot_score | created_at"),
sort_by: str = Query("hot_score", description="排序字段hot_score | created_at"),
skip: int = Query(0, ge=0, description="分页偏移量"),
limit: int = Query(10, ge=1, le=50, description="每页返回条数"),
limit: int = Query(10, ge=1, le=50, description="每页条数"),
db: Session = Depends(get_db),
):
"""分页返回统一事件,附带平台热搜、排名轨迹和标签"""
# --- 1. 尝试从缓存读取 ---
"""查询统一事件列表,并附带平台趋势与标签信息"""
cache_key = f"{min_hot}:{hours}:{sort_by}:{skip}:{limit}"
current_time = time.time()
if cache_key in _UNIFIED_EVENTS_CACHE:
expire_time, cached_data = _UNIFIED_EVENTS_CACHE[cache_key]
if current_time < expire_time:
return cached_data
# -----------------------
time_limit = utcnow() - timedelta(hours=hours)
# 先查总数,用于前端判断是否还有更多
base_query = db.query(UnifiedEvent).filter(
UnifiedEvent.hot_score >= min_hot,
UnifiedEvent.created_at >= time_limit,
)
total = base_query.count()
# 分页查询
if sort_by == "created_at":
base_query = base_query.order_by(UnifiedEvent.created_at.desc())
else:
base_query = base_query.order_by(UnifiedEvent.hot_score.desc(), UnifiedEvent.created_at.desc())
events = base_query.offset(skip).limit(limit).all()
if not events:
return PaginatedUnifiedEventResponse(total=total, has_more=False, data=[])
event_ids = [ev.id for ev in events]
# 批量查询所有相关的热搜条目(避免 N+1)
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
@@ -83,21 +107,16 @@ def list_unified_events(
.all()
)
# 按 unified_event_id 分组
trend_map: dict[int, list[tuple]] = {}
trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {}
trend_ids: list[int] = []
for trend, source_name in trend_rows:
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
trend_ids.append(trend.id)
# 批量查询排名日志(避免逐条查询)
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(
RankingLog.event_id,
RankingLog.ranking_position,
)
db.query(RankingLog.event_id, RankingLog.ranking_position)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
@@ -108,7 +127,6 @@ def list_unified_events(
for event_id, position in ranking_rows:
ranking_map.setdefault(event_id, []).append(position)
# 批量查询标签
tag_map: dict[int, list[str]] = {}
tag_rows = (
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
@@ -122,14 +140,13 @@ def list_unified_events(
for target_id, keyword in tag_rows:
tag_map.setdefault(target_id, []).append(keyword)
# 组装响应
results: list[UnifiedEventResponse] = []
for ev in events:
platform_list: list[PlatformTrendResponse] = []
trends_for_ev = trend_map.get(ev.id, [])
for trend, source_name in trends_for_ev:
history = ranking_map.get(trend.id, [])
# 截取尾部,只保留最近的点
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
@@ -144,17 +161,12 @@ def list_unified_events(
)
)
# 取所有关联热搜条目中最新的 updated_at,代表"最后一次在平台热搜榜看到"的时间
last_active_at = (
max(t.updated_at for t, _ in trends_for_ev)
if trends_for_ev
else ev.updated_at
)
last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at
results.append(
UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
unified_title=ev.unified_title if ev.unified_title else "Untitled",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
@@ -167,13 +179,9 @@ def list_unified_events(
has_more = (skip + limit) < total
response = PaginatedUnifiedEventResponse(total=total, has_more=has_more, data=results)
# --- 2. 写入缓存 ---
if len(_UNIFIED_EVENTS_CACHE) > 1000:
# 防止内存无限增长
_UNIFIED_EVENTS_CACHE.clear()
_UNIFIED_EVENTS_CACHE[cache_key] = (current_time + CACHE_TTL_SECONDS, response)
# ------------------
return response
@@ -183,7 +191,7 @@ def get_unified_event(
event_id: int,
db: Session = Depends(get_db),
):
"""按 ID 查询单个统一事件,用于推荐跳转时的聚光灯展示"""
"""事件 ID 获取单个统一事件。"""
ev = db.query(UnifiedEvent).filter(UnifiedEvent.id == event_id).first()
if not ev:
raise HTTPException(status_code=404, detail="Event not found")
@@ -228,6 +236,7 @@ def get_unified_event(
history = ranking_map.get(trend.id, [])
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
@@ -239,15 +248,11 @@ def get_unified_event(
)
)
last_active_at = (
max(t.updated_at for t, _ in trend_rows)
if trend_rows
else ev.updated_at
)
last_active_at = max(t.updated_at for t, _ in trend_rows) if trend_rows else ev.updated_at
return UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
unified_title=ev.unified_title if ev.unified_title else "Untitled",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
@@ -255,3 +260,205 @@ def get_unified_event(
platforms=platform_list,
tags=tags,
)
@router.get("/search_timeline", response_model=SearchTimelineResponse)
def search_events_timeline(
keyword: str = Query(..., description="搜索关键词,支持正则表达式"),
hours: int = Query(None, ge=1, le=SEARCH_MAX_HOURS, description="查询最近多少小时的数据"),
mode: str = Query("hybrid", description="匹配模式:exact | semantic | hybrid"),
semantic_threshold: float = Query(None, ge=0.0, le=1.0, description="语义匹配相似度阈值"),
utc_offset_minutes: int | None = Query(None, ge=-840, le=840, description="客户端相对 UTC 的分钟偏移,东八区传 480"),
db: Session = Depends(get_db),
):
import re
query_text = (keyword or "").strip()
if not query_text:
raise HTTPException(status_code=400, detail="keyword cannot be empty")
if hours is None:
hours = SEARCH_DEFAULT_HOURS
if hours > SEARCH_MAX_HOURS:
hours = SEARCH_MAX_HOURS
match_mode = (mode or "hybrid").strip().lower()
if match_mode not in {"exact", "semantic", "hybrid"}:
match_mode = "hybrid"
use_regex = match_mode in {"exact", "hybrid"}
use_semantic = match_mode in {"semantic", "hybrid"}
sim_threshold = semantic_threshold if semantic_threshold is not None else SEARCH_EMBEDDING_THRESHOLD
pattern = None
if use_regex:
try:
pattern = re.compile(query_text, re.IGNORECASE)
except re.error:
pattern = re.compile(re.escape(query_text), re.IGNORECASE)
query_vec: np.ndarray | None = None
if use_semantic:
try:
query_encoded = embedder_model.encode([query_text], normalize_embeddings=True, show_progress_bar=False)
if len(query_encoded) > 0:
query_vec = np.asarray(query_encoded[0], dtype=np.float32)
except Exception:
query_vec = None
if match_mode == "semantic" and query_vec is None:
use_regex = True
if pattern is None:
pattern = re.compile(re.escape(query_text), re.IGNORECASE)
time_limit = utcnow() - timedelta(hours=hours)
date_format = "%Y-%m-%d %H:00" if hours <= 48 else "%Y-%m-%d"
display_timezone = timezone.utc
if utc_offset_minutes is not None:
display_timezone = timezone(timedelta(minutes=utc_offset_minutes))
def _bucket_label(dt: datetime) -> str:
aware_dt = _ensure_aware_datetime(dt)
return aware_dt.astimezone(display_timezone).strftime(date_format)
all_recent_unified = db.query(UnifiedEvent).filter(UnifiedEvent.created_at >= time_limit).all()
all_recent_trends = db.query(TrendingEvent).filter(TrendingEvent.created_at >= time_limit).all()
matched_event_ids: set[int] = set()
matched_trend_points: list[tuple[int, str]] = []
for ev in all_recent_unified:
text_matched = False
if use_regex and pattern is not None:
text_to_search = f"{ev.unified_title or ''} {ev.ai_comprehensive_summary or ''}"
text_matched = bool(pattern.search(text_to_search))
semantic_matched = False
if use_semantic and query_vec is not None:
ev_vec = _load_vector(ev.center_embedding)
if ev_vec is not None:
semantic_matched = float(np.dot(query_vec, ev_vec)) >= sim_threshold
if text_matched or semantic_matched:
matched_event_ids.add(ev.id)
for trend in all_recent_trends:
text_matched = False
if use_regex and pattern is not None and trend.current_headline:
text_matched = bool(pattern.search(trend.current_headline))
semantic_matched = False
if use_semantic and query_vec is not None:
trend_vec = _load_vector(trend.title_embedding)
if trend_vec is not None:
semantic_matched = float(np.dot(query_vec, trend_vec)) >= sim_threshold
if (text_matched or semantic_matched) and trend.unified_event_id:
matched_event_ids.add(trend.unified_event_id)
matched_trend_points.append((trend.unified_event_id, _bucket_label(trend.created_at)))
matched_unified_events = [ev for ev in all_recent_unified if ev.id in matched_event_ids]
matched_unified_events.sort(key=lambda x: x.created_at, reverse=True)
display_events = matched_unified_events[:100]
display_event_ids = [ev.id for ev in display_events]
display_event_id_set = set(display_event_ids)
trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {}
trend_ids: list[int] = []
if display_event_ids:
trend_rows = (
db.query(TrendingEvent, InfoSource.source_name)
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
.filter(TrendingEvent.unified_event_id.in_(display_event_ids))
.all()
)
for trend, source_name in trend_rows:
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
trend_ids.append(trend.id)
ranking_map: dict[int, list[int]] = {}
if trend_ids:
ranking_rows = (
db.query(RankingLog.event_id, RankingLog.ranking_position)
.filter(
RankingLog.event_id.in_(trend_ids),
RankingLog.observed_at >= time_limit,
)
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
.all()
)
for event_id, position in ranking_rows:
ranking_map.setdefault(event_id, []).append(position)
timeline_event_map: dict[str, set[int]] = {}
for ev in display_events:
timeline_event_map.setdefault(_bucket_label(ev.created_at), set()).add(ev.id)
for event_id, time_label in matched_trend_points:
if event_id in display_event_id_set:
timeline_event_map.setdefault(time_label, set()).add(event_id)
timeline_data = [
TimelineDataPoint(time_label=time_label, count=len(event_ids), event_ids=sorted(event_ids))
for time_label, event_ids in sorted(timeline_event_map.items())
]
results: list[UnifiedEventResponse] = []
if display_event_ids:
tag_map: dict[int, list[str]] = {}
tag_rows = (
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
.filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id.in_(display_event_ids),
)
.order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc())
.all()
)
for target_id, kw in tag_rows:
tag_map.setdefault(target_id, []).append(kw)
for ev in display_events:
trends_for_ev = trend_map.get(ev.id, [])
platform_list: list[PlatformTrendResponse] = []
seen_platforms = set()
for trend, source_name in trends_for_ev:
uniq_key = f"{source_name}_{trend.current_headline}"
if uniq_key in seen_platforms:
continue
seen_platforms.add(uniq_key)
history = ranking_map.get(trend.id, [])
if len(history) > MAX_RANKING_POINTS:
history = history[-MAX_RANKING_POINTS:]
platform_list.append(
PlatformTrendResponse(
source_id=trend.source_id,
platform_name=source_name,
headline=trend.current_headline,
url=trend.event_url,
current_ranking=trend.current_ranking,
ranking_history=history,
)
)
last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at
results.append(
UnifiedEventResponse(
event_id=ev.id,
unified_title=ev.unified_title if ev.unified_title else "Untitled",
summary=ev.ai_comprehensive_summary,
hot_score=ev.hot_score,
created_at=ev.created_at,
last_active_at=last_active_at,
platforms=platform_list,
tags=tag_map.get(ev.id, []),
)
)
return SearchTimelineResponse(keyword=query_text, timeline=timeline_data, events=results)