mirror of
https://github.com/stardrophere/InsightRadar.git
synced 2026-06-05 23:56:36 +08:00
搜索功能加入
This commit is contained in:
@@ -1,8 +1,12 @@
|
||||
# app/api/endpoints/events.py
|
||||
# app/api/endpoints/events.py
|
||||
import json
|
||||
import os
|
||||
import time
|
||||
from datetime import timedelta
|
||||
from typing import Dict, List, Tuple
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Dict, Tuple
|
||||
|
||||
import numpy as np
|
||||
from dotenv import load_dotenv
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@@ -19,63 +23,83 @@ from app.models.models import (
|
||||
from app.schemas.event_schema import (
|
||||
PaginatedUnifiedEventResponse,
|
||||
PlatformTrendResponse,
|
||||
SearchTimelineResponse,
|
||||
TimelineDataPoint,
|
||||
UnifiedEventResponse,
|
||||
)
|
||||
from app.services.fetcher_service import embedder_model
|
||||
|
||||
load_dotenv()
|
||||
|
||||
SEARCH_EMBEDDING_THRESHOLD = float(os.getenv("SEARCH_EMBEDDING_THRESHOLD", "0.75"))
|
||||
SEARCH_MAX_LIMIT = int(os.getenv("SEARCH_MAX_LIMIT", "30"))
|
||||
SEARCH_DEFAULT_HOURS = int(os.getenv("SEARCH_DEFAULT_HOURS", "168"))
|
||||
SEARCH_MAX_HOURS = int(os.getenv("SEARCH_MAX_HOURS", "168"))
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
# 排名轨迹最多返回多少个点,避免长时间跨度下数据过大
|
||||
# 排名轨迹最多返回的点数,避免时间跨度过大时响应体过重。
|
||||
MAX_RANKING_POINTS = 30
|
||||
|
||||
# --- 轻量级接口缓存配置 ---
|
||||
# 统一事件列表接口的短期缓存。
|
||||
_UNIFIED_EVENTS_CACHE: Dict[str, Tuple[float, PaginatedUnifiedEventResponse]] = {}
|
||||
CACHE_TTL_SECONDS = 60
|
||||
# ---------------------------
|
||||
|
||||
|
||||
def _load_vector(raw_embedding: str | None) -> np.ndarray | None:
|
||||
"""将字符串形式的向量安全解析为 numpy 数组。"""
|
||||
if not raw_embedding:
|
||||
return None
|
||||
try:
|
||||
return np.asarray(json.loads(raw_embedding), dtype=np.float32)
|
||||
except (TypeError, ValueError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
|
||||
def _ensure_aware_datetime(dt: datetime) -> datetime:
|
||||
"""确保 datetime 带时区;SQLite 返回无时区值时按 UTC 解释。"""
|
||||
if dt.tzinfo is None:
|
||||
return dt.replace(tzinfo=timezone.utc)
|
||||
return dt
|
||||
|
||||
|
||||
@router.get("/unified", response_model=PaginatedUnifiedEventResponse)
|
||||
def list_unified_events(
|
||||
min_hot: int = Query(5, ge=0, description="热度阈值,仅返回 hot_score >= 此值的事件"),
|
||||
min_hot: int = Query(5, ge=0, description="最低热度阈值"),
|
||||
hours: int = Query(48, ge=1, le=720, description="查询最近多少小时的数据"),
|
||||
sort_by: str = Query("hot_score", description="排序字段: hot_score | created_at"),
|
||||
sort_by: str = Query("hot_score", description="排序字段:hot_score | created_at"),
|
||||
skip: int = Query(0, ge=0, description="分页偏移量"),
|
||||
limit: int = Query(10, ge=1, le=50, description="每页返回条数"),
|
||||
limit: int = Query(10, ge=1, le=50, description="每页条数"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""分页返回统一事件,附带各平台热搜、排名轨迹和标签。"""
|
||||
|
||||
# --- 1. 尝试从缓存读取 ---
|
||||
"""查询统一事件列表,并附带平台趋势与标签信息。"""
|
||||
|
||||
cache_key = f"{min_hot}:{hours}:{sort_by}:{skip}:{limit}"
|
||||
current_time = time.time()
|
||||
|
||||
if cache_key in _UNIFIED_EVENTS_CACHE:
|
||||
expire_time, cached_data = _UNIFIED_EVENTS_CACHE[cache_key]
|
||||
if current_time < expire_time:
|
||||
return cached_data
|
||||
# -----------------------
|
||||
|
||||
time_limit = utcnow() - timedelta(hours=hours)
|
||||
|
||||
# 先查总数,用于前端判断是否还有更多
|
||||
base_query = db.query(UnifiedEvent).filter(
|
||||
UnifiedEvent.hot_score >= min_hot,
|
||||
UnifiedEvent.created_at >= time_limit,
|
||||
)
|
||||
total = base_query.count()
|
||||
|
||||
# 分页查询
|
||||
if sort_by == "created_at":
|
||||
base_query = base_query.order_by(UnifiedEvent.created_at.desc())
|
||||
else:
|
||||
base_query = base_query.order_by(UnifiedEvent.hot_score.desc(), UnifiedEvent.created_at.desc())
|
||||
|
||||
events = base_query.offset(skip).limit(limit).all()
|
||||
|
||||
if not events:
|
||||
return PaginatedUnifiedEventResponse(total=total, has_more=False, data=[])
|
||||
|
||||
event_ids = [ev.id for ev in events]
|
||||
|
||||
# 批量查询所有相关的热搜条目(避免 N+1)
|
||||
trend_rows = (
|
||||
db.query(TrendingEvent, InfoSource.source_name)
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
|
||||
@@ -83,21 +107,16 @@ def list_unified_events(
|
||||
.all()
|
||||
)
|
||||
|
||||
# 按 unified_event_id 分组
|
||||
trend_map: dict[int, list[tuple]] = {}
|
||||
trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {}
|
||||
trend_ids: list[int] = []
|
||||
for trend, source_name in trend_rows:
|
||||
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
|
||||
trend_ids.append(trend.id)
|
||||
|
||||
# 批量查询排名日志(避免逐条查询)
|
||||
ranking_map: dict[int, list[int]] = {}
|
||||
if trend_ids:
|
||||
ranking_rows = (
|
||||
db.query(
|
||||
RankingLog.event_id,
|
||||
RankingLog.ranking_position,
|
||||
)
|
||||
db.query(RankingLog.event_id, RankingLog.ranking_position)
|
||||
.filter(
|
||||
RankingLog.event_id.in_(trend_ids),
|
||||
RankingLog.observed_at >= time_limit,
|
||||
@@ -108,7 +127,6 @@ def list_unified_events(
|
||||
for event_id, position in ranking_rows:
|
||||
ranking_map.setdefault(event_id, []).append(position)
|
||||
|
||||
# 批量查询标签
|
||||
tag_map: dict[int, list[str]] = {}
|
||||
tag_rows = (
|
||||
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
|
||||
@@ -122,14 +140,13 @@ def list_unified_events(
|
||||
for target_id, keyword in tag_rows:
|
||||
tag_map.setdefault(target_id, []).append(keyword)
|
||||
|
||||
# 组装响应
|
||||
results: list[UnifiedEventResponse] = []
|
||||
for ev in events:
|
||||
platform_list: list[PlatformTrendResponse] = []
|
||||
trends_for_ev = trend_map.get(ev.id, [])
|
||||
|
||||
for trend, source_name in trends_for_ev:
|
||||
history = ranking_map.get(trend.id, [])
|
||||
# 截取尾部,只保留最近的点
|
||||
if len(history) > MAX_RANKING_POINTS:
|
||||
history = history[-MAX_RANKING_POINTS:]
|
||||
|
||||
@@ -144,17 +161,12 @@ def list_unified_events(
|
||||
)
|
||||
)
|
||||
|
||||
# 取所有关联热搜条目中最新的 updated_at,代表"最后一次在平台热搜榜看到"的时间
|
||||
last_active_at = (
|
||||
max(t.updated_at for t, _ in trends_for_ev)
|
||||
if trends_for_ev
|
||||
else ev.updated_at
|
||||
)
|
||||
last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at
|
||||
|
||||
results.append(
|
||||
UnifiedEventResponse(
|
||||
event_id=ev.id,
|
||||
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
|
||||
unified_title=ev.unified_title if ev.unified_title else "Untitled",
|
||||
summary=ev.ai_comprehensive_summary,
|
||||
hot_score=ev.hot_score,
|
||||
created_at=ev.created_at,
|
||||
@@ -167,13 +179,9 @@ def list_unified_events(
|
||||
has_more = (skip + limit) < total
|
||||
response = PaginatedUnifiedEventResponse(total=total, has_more=has_more, data=results)
|
||||
|
||||
# --- 2. 写入缓存 ---
|
||||
if len(_UNIFIED_EVENTS_CACHE) > 1000:
|
||||
# 防止内存无限增长
|
||||
_UNIFIED_EVENTS_CACHE.clear()
|
||||
|
||||
_UNIFIED_EVENTS_CACHE[cache_key] = (current_time + CACHE_TTL_SECONDS, response)
|
||||
# ------------------
|
||||
|
||||
return response
|
||||
|
||||
@@ -183,7 +191,7 @@ def get_unified_event(
|
||||
event_id: int,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""按 ID 查询单个统一事件,用于推荐跳转时的聚光灯展示。"""
|
||||
"""按事件 ID 获取单个统一事件。"""
|
||||
ev = db.query(UnifiedEvent).filter(UnifiedEvent.id == event_id).first()
|
||||
if not ev:
|
||||
raise HTTPException(status_code=404, detail="Event not found")
|
||||
@@ -228,6 +236,7 @@ def get_unified_event(
|
||||
history = ranking_map.get(trend.id, [])
|
||||
if len(history) > MAX_RANKING_POINTS:
|
||||
history = history[-MAX_RANKING_POINTS:]
|
||||
|
||||
platform_list.append(
|
||||
PlatformTrendResponse(
|
||||
source_id=trend.source_id,
|
||||
@@ -239,15 +248,11 @@ def get_unified_event(
|
||||
)
|
||||
)
|
||||
|
||||
last_active_at = (
|
||||
max(t.updated_at for t, _ in trend_rows)
|
||||
if trend_rows
|
||||
else ev.updated_at
|
||||
)
|
||||
last_active_at = max(t.updated_at for t, _ in trend_rows) if trend_rows else ev.updated_at
|
||||
|
||||
return UnifiedEventResponse(
|
||||
event_id=ev.id,
|
||||
unified_title=ev.unified_title if ev.unified_title else "暂无标题",
|
||||
unified_title=ev.unified_title if ev.unified_title else "Untitled",
|
||||
summary=ev.ai_comprehensive_summary,
|
||||
hot_score=ev.hot_score,
|
||||
created_at=ev.created_at,
|
||||
@@ -255,3 +260,205 @@ def get_unified_event(
|
||||
platforms=platform_list,
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/search_timeline", response_model=SearchTimelineResponse)
|
||||
def search_events_timeline(
|
||||
keyword: str = Query(..., description="搜索关键词,支持正则表达式"),
|
||||
hours: int = Query(None, ge=1, le=SEARCH_MAX_HOURS, description="查询最近多少小时的数据"),
|
||||
mode: str = Query("hybrid", description="匹配模式:exact | semantic | hybrid"),
|
||||
semantic_threshold: float = Query(None, ge=0.0, le=1.0, description="语义匹配相似度阈值"),
|
||||
utc_offset_minutes: int | None = Query(None, ge=-840, le=840, description="客户端相对 UTC 的分钟偏移,东八区传 480"),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
import re
|
||||
|
||||
query_text = (keyword or "").strip()
|
||||
if not query_text:
|
||||
raise HTTPException(status_code=400, detail="keyword cannot be empty")
|
||||
|
||||
if hours is None:
|
||||
hours = SEARCH_DEFAULT_HOURS
|
||||
if hours > SEARCH_MAX_HOURS:
|
||||
hours = SEARCH_MAX_HOURS
|
||||
|
||||
match_mode = (mode or "hybrid").strip().lower()
|
||||
if match_mode not in {"exact", "semantic", "hybrid"}:
|
||||
match_mode = "hybrid"
|
||||
|
||||
use_regex = match_mode in {"exact", "hybrid"}
|
||||
use_semantic = match_mode in {"semantic", "hybrid"}
|
||||
sim_threshold = semantic_threshold if semantic_threshold is not None else SEARCH_EMBEDDING_THRESHOLD
|
||||
|
||||
pattern = None
|
||||
if use_regex:
|
||||
try:
|
||||
pattern = re.compile(query_text, re.IGNORECASE)
|
||||
except re.error:
|
||||
pattern = re.compile(re.escape(query_text), re.IGNORECASE)
|
||||
|
||||
query_vec: np.ndarray | None = None
|
||||
if use_semantic:
|
||||
try:
|
||||
query_encoded = embedder_model.encode([query_text], normalize_embeddings=True, show_progress_bar=False)
|
||||
if len(query_encoded) > 0:
|
||||
query_vec = np.asarray(query_encoded[0], dtype=np.float32)
|
||||
except Exception:
|
||||
query_vec = None
|
||||
|
||||
if match_mode == "semantic" and query_vec is None:
|
||||
use_regex = True
|
||||
if pattern is None:
|
||||
pattern = re.compile(re.escape(query_text), re.IGNORECASE)
|
||||
|
||||
time_limit = utcnow() - timedelta(hours=hours)
|
||||
date_format = "%Y-%m-%d %H:00" if hours <= 48 else "%Y-%m-%d"
|
||||
|
||||
display_timezone = timezone.utc
|
||||
if utc_offset_minutes is not None:
|
||||
display_timezone = timezone(timedelta(minutes=utc_offset_minutes))
|
||||
|
||||
def _bucket_label(dt: datetime) -> str:
|
||||
aware_dt = _ensure_aware_datetime(dt)
|
||||
return aware_dt.astimezone(display_timezone).strftime(date_format)
|
||||
|
||||
all_recent_unified = db.query(UnifiedEvent).filter(UnifiedEvent.created_at >= time_limit).all()
|
||||
all_recent_trends = db.query(TrendingEvent).filter(TrendingEvent.created_at >= time_limit).all()
|
||||
|
||||
matched_event_ids: set[int] = set()
|
||||
matched_trend_points: list[tuple[int, str]] = []
|
||||
|
||||
for ev in all_recent_unified:
|
||||
text_matched = False
|
||||
if use_regex and pattern is not None:
|
||||
text_to_search = f"{ev.unified_title or ''} {ev.ai_comprehensive_summary or ''}"
|
||||
text_matched = bool(pattern.search(text_to_search))
|
||||
|
||||
semantic_matched = False
|
||||
if use_semantic and query_vec is not None:
|
||||
ev_vec = _load_vector(ev.center_embedding)
|
||||
if ev_vec is not None:
|
||||
semantic_matched = float(np.dot(query_vec, ev_vec)) >= sim_threshold
|
||||
|
||||
if text_matched or semantic_matched:
|
||||
matched_event_ids.add(ev.id)
|
||||
|
||||
for trend in all_recent_trends:
|
||||
text_matched = False
|
||||
if use_regex and pattern is not None and trend.current_headline:
|
||||
text_matched = bool(pattern.search(trend.current_headline))
|
||||
|
||||
semantic_matched = False
|
||||
if use_semantic and query_vec is not None:
|
||||
trend_vec = _load_vector(trend.title_embedding)
|
||||
if trend_vec is not None:
|
||||
semantic_matched = float(np.dot(query_vec, trend_vec)) >= sim_threshold
|
||||
|
||||
if (text_matched or semantic_matched) and trend.unified_event_id:
|
||||
matched_event_ids.add(trend.unified_event_id)
|
||||
matched_trend_points.append((trend.unified_event_id, _bucket_label(trend.created_at)))
|
||||
|
||||
matched_unified_events = [ev for ev in all_recent_unified if ev.id in matched_event_ids]
|
||||
matched_unified_events.sort(key=lambda x: x.created_at, reverse=True)
|
||||
|
||||
display_events = matched_unified_events[:100]
|
||||
display_event_ids = [ev.id for ev in display_events]
|
||||
display_event_id_set = set(display_event_ids)
|
||||
|
||||
trend_map: dict[int, list[tuple[TrendingEvent, str]]] = {}
|
||||
trend_ids: list[int] = []
|
||||
if display_event_ids:
|
||||
trend_rows = (
|
||||
db.query(TrendingEvent, InfoSource.source_name)
|
||||
.join(InfoSource, TrendingEvent.source_id == InfoSource.id)
|
||||
.filter(TrendingEvent.unified_event_id.in_(display_event_ids))
|
||||
.all()
|
||||
)
|
||||
for trend, source_name in trend_rows:
|
||||
trend_map.setdefault(trend.unified_event_id, []).append((trend, source_name))
|
||||
trend_ids.append(trend.id)
|
||||
|
||||
ranking_map: dict[int, list[int]] = {}
|
||||
if trend_ids:
|
||||
ranking_rows = (
|
||||
db.query(RankingLog.event_id, RankingLog.ranking_position)
|
||||
.filter(
|
||||
RankingLog.event_id.in_(trend_ids),
|
||||
RankingLog.observed_at >= time_limit,
|
||||
)
|
||||
.order_by(RankingLog.event_id, RankingLog.observed_at.asc())
|
||||
.all()
|
||||
)
|
||||
for event_id, position in ranking_rows:
|
||||
ranking_map.setdefault(event_id, []).append(position)
|
||||
|
||||
timeline_event_map: dict[str, set[int]] = {}
|
||||
for ev in display_events:
|
||||
timeline_event_map.setdefault(_bucket_label(ev.created_at), set()).add(ev.id)
|
||||
|
||||
for event_id, time_label in matched_trend_points:
|
||||
if event_id in display_event_id_set:
|
||||
timeline_event_map.setdefault(time_label, set()).add(event_id)
|
||||
|
||||
timeline_data = [
|
||||
TimelineDataPoint(time_label=time_label, count=len(event_ids), event_ids=sorted(event_ids))
|
||||
for time_label, event_ids in sorted(timeline_event_map.items())
|
||||
]
|
||||
|
||||
results: list[UnifiedEventResponse] = []
|
||||
if display_event_ids:
|
||||
tag_map: dict[int, list[str]] = {}
|
||||
tag_rows = (
|
||||
db.query(ExtractedTopic.target_id, ExtractedTopic.topic_keyword)
|
||||
.filter(
|
||||
ExtractedTopic.target_type == TargetType.EVENT,
|
||||
ExtractedTopic.target_id.in_(display_event_ids),
|
||||
)
|
||||
.order_by(ExtractedTopic.relevance_score.desc(), ExtractedTopic.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
for target_id, kw in tag_rows:
|
||||
tag_map.setdefault(target_id, []).append(kw)
|
||||
|
||||
for ev in display_events:
|
||||
trends_for_ev = trend_map.get(ev.id, [])
|
||||
platform_list: list[PlatformTrendResponse] = []
|
||||
seen_platforms = set()
|
||||
|
||||
for trend, source_name in trends_for_ev:
|
||||
uniq_key = f"{source_name}_{trend.current_headline}"
|
||||
if uniq_key in seen_platforms:
|
||||
continue
|
||||
seen_platforms.add(uniq_key)
|
||||
|
||||
history = ranking_map.get(trend.id, [])
|
||||
if len(history) > MAX_RANKING_POINTS:
|
||||
history = history[-MAX_RANKING_POINTS:]
|
||||
|
||||
platform_list.append(
|
||||
PlatformTrendResponse(
|
||||
source_id=trend.source_id,
|
||||
platform_name=source_name,
|
||||
headline=trend.current_headline,
|
||||
url=trend.event_url,
|
||||
current_ranking=trend.current_ranking,
|
||||
ranking_history=history,
|
||||
)
|
||||
)
|
||||
|
||||
last_active_at = max(t.updated_at for t, _ in trends_for_ev) if trends_for_ev else ev.updated_at
|
||||
|
||||
results.append(
|
||||
UnifiedEventResponse(
|
||||
event_id=ev.id,
|
||||
unified_title=ev.unified_title if ev.unified_title else "Untitled",
|
||||
summary=ev.ai_comprehensive_summary,
|
||||
hot_score=ev.hot_score,
|
||||
created_at=ev.created_at,
|
||||
last_active_at=last_active_at,
|
||||
platforms=platform_list,
|
||||
tags=tag_map.get(ev.id, []),
|
||||
)
|
||||
)
|
||||
|
||||
return SearchTimelineResponse(keyword=query_text, timeline=timeline_data, events=results)
|
||||
|
||||
Reference in New Issue
Block a user