diff --git a/backend/app/database.py b/backend/app/database.py index 30c2c4d..785f947 100644 --- a/backend/app/database.py +++ b/backend/app/database.py @@ -6,8 +6,10 @@ from sqlalchemy.orm import sessionmaker SQLALCHEMY_DATABASE_URL = "sqlite:///./data/demo.db" # 创建数据库引擎 +# 增加 timeout=30 允许连接在遇到 locked 时最多等待 30 秒,而不是直接报错 engine = create_engine( - SQLALCHEMY_DATABASE_URL, connect_args={"check_same_thread": False} + SQLALCHEMY_DATABASE_URL, + connect_args={"check_same_thread": False, "timeout": 30} ) diff --git a/backend/app/services/fetcher_service.py b/backend/app/services/fetcher_service.py index 85ebc69..2a25757 100644 --- a/backend/app/services/fetcher_service.py +++ b/backend/app/services/fetcher_service.py @@ -33,65 +33,75 @@ def generate_md5(text: str) -> str: return hashlib.md5(text.encode('utf-8')).hexdigest() -def generate_embedding_json(text: str) -> str: - """辅助函数:调用大模型生成向量,并序列化为 JSON 字符串""" - raw_vec = embedder_model.encode([text], normalize_embeddings=True, show_progress_bar=False)[0] - truncated_vec = [round(float(x), 5) for x in raw_vec] - return json.dumps(truncated_vec, separators=(',', ':')) +def generate_embeddings_batch(texts: list[str]) -> dict: + """批量生成向量,返回 {text: (embedding_json, numpy_array)}""" + if not texts: + return {} + + unique_texts = list(set(texts)) + raw_vecs = embedder_model.encode(unique_texts, normalize_embeddings=True, show_progress_bar=False) + + result = {} + for text, raw_vec in zip(unique_texts, raw_vecs): + truncated_vec = [round(float(x), 5) for x in raw_vec] + emb_json = json.dumps(truncated_vec, separators=(',', ':')) + result[text] = (emb_json, raw_vec) + + return result -def match_or_create_unified_event(db, title: str, embedding_json: str) -> int: - """ - 辅助函数:大事件聚类中枢。 - 拿着新计算的向量去数据库里碰,碰到了就返回老 ID,碰不到就建新的。 - """ - # 提取刚算出来的向量 - new_vec = np.array(json.loads(embedding_json)) - - # 只取最近 3 天的活跃大事件进行比对 - three_days_ago = utcnow() - timedelta(days=3) - recent_events = db.query(UnifiedEvent).filter( - UnifiedEvent.created_at >= three_days_ago - ).order_by(UnifiedEvent.created_at.desc()).limit(200).all() - - if recent_events: - valid_events = [ev for ev in recent_events if ev.center_embedding] - if valid_events: - event_vectors = [json.loads(ev.center_embedding) for ev in valid_events] +class UnifiedEventClusterer: + def __init__(self, db): + self.db = db + three_days_ago = utcnow() - timedelta(days=3) + recent_events = db.query(UnifiedEvent).filter( + UnifiedEvent.created_at >= three_days_ago + ).order_by(UnifiedEvent.created_at.desc()).limit(300).all() + + self.event_vectors = [] + self.event_ids = [] + for ev in recent_events: + if ev.center_embedding: + self.event_vectors.append(np.array(json.loads(ev.center_embedding))) + self.event_ids.append(ev.id) + def match_or_create(self, title: str, embedding_json: str, new_vec: np.ndarray) -> int: + if self.event_vectors: # 批量矩阵计算相似度 - sim_scores = cosine_similarity([new_vec], event_vectors)[0] + sim_scores = cosine_similarity([new_vec], self.event_vectors)[0] max_idx = np.argmax(sim_scores) if sim_scores[max_idx] >= SIMILARITY_THRESHOLD: - matched_event = valid_events[max_idx] + matched_event_id = self.event_ids[max_idx] + # 更新热度 + matched_event = self.db.query(UnifiedEvent).get(matched_event_id) + if matched_event: + matched_event.hot_score += 1 + return matched_event_id - matched_event.hot_score += 1 - return matched_event.id - - # 没匹配到,创建一个新的统一大事件 - new_unified = UnifiedEvent( - unified_title=title, - center_embedding=embedding_json, - hot_score=1 # 初始热度 - ) - db.add(new_unified) - db.flush() # 获取自增的主键 ID - return new_unified.id + # 没匹配到,创建一个新的统一大事件 + new_unified = UnifiedEvent( + unified_title=title, + center_embedding=embedding_json, + hot_score=1 # 初始热度 + ) + self.db.add(new_unified) + self.db.flush() # 获取自增的主键 ID + + # 更新缓存 + self.event_vectors.append(new_vec) + self.event_ids.append(new_unified.id) + + return new_unified.id -def process_hot_trend_item(db, source, item, index: int, external_id: str): +def process_hot_trend_item(db, source, item, index: int, external_id: str, existing_event, embeddings_dict: dict, clusterer: UnifiedEventClusterer): """ 处理【热搜/短新闻】的业务逻辑,现已加入 AI 聚类功能 """ title = item.get("title") item_url = item.get("url", "") - existing_event = db.query(TrendingEvent).filter( - TrendingEvent.source_id == source.id, - TrendingEvent.external_id == external_id - ).first() - event_to_log = None # 核心逻辑:查重后再决定是否调用模型 @@ -99,7 +109,7 @@ def process_hot_trend_item(db, source, item, index: int, external_id: str): # 场景 A1:老熟人 if existing_event.current_headline != title: # 标题被暗改,此时需要重新算一次 Embedding - new_embedding_json = generate_embedding_json(title) + new_embedding_json, _ = embeddings_dict[title] revision = HeadlineRevision( event_id=existing_event.id, @@ -118,10 +128,10 @@ def process_hot_trend_item(db, source, item, index: int, external_id: str): else: # 场景 A2:这是一条彻底的全新热搜 # 1. 计算向量 - new_embedding_json = generate_embedding_json(title) + new_embedding_json, new_vec = embeddings_dict[title] # 2. 扔进聚类中枢找归宿 - matched_event_id = match_or_create_unified_event(db, title, new_embedding_json) + matched_event_id = clusterer.match_or_create(title, new_embedding_json, new_vec) # 3. 落库 new_event = TrendingEvent( @@ -145,18 +155,13 @@ def process_hot_trend_item(db, source, item, index: int, external_id: str): db.add(rank_log) -def process_rss_feed_item(db, source, item, external_id: str): +def process_rss_feed_item(db, source, item, external_id: str, existing_article): """ 处理【长文章/传统订阅】分支的核心业务逻辑 (写入 NewsArticle 表) """ title = item.get("title") item_url = item.get("url", "") - existing_article = db.query(NewsArticle).filter( - NewsArticle.source_id == source.id, - NewsArticle.external_id == external_id - ).first() - if existing_article: # 文章若存在,仅更新基础字段 existing_article.article_title = title @@ -176,27 +181,77 @@ def process_source_data(db, source, items: list) -> int: """ 数据清洗与路由分发层: 遍历 API 返回的 items,生成唯一指纹,并路由到不同的处理模块。 + 采用批量查重和批量向量计算优化性能,避免数据库锁死。 返回成功处理的条目数量。 """ saved_count = 0 platform_id = source.home_url - for index, item in enumerate(items, 1): + # 1. 批量计算外部 ID 并聚合要计算的文本 + valid_items = [] + external_ids = [] + for item in items: title = item.get("title") if not title: continue - item_url = item.get("url", "") - - # ID 兜底策略:接口ID -> URL -> Title raw_id = item.get("id") or item_url or title external_id = generate_md5(f"{platform_id}_{raw_id}") + + valid_items.append((item, external_id)) + external_ids.append(external_id) + + if not valid_items: + return 0 - # 核心路由分流 + # 2. 批量数据库查重 + existing_events_dict = {} + existing_articles_dict = {} + + if source.source_type in (SourceType.HOT_TREND, SourceType.API): + existing_events = db.query(TrendingEvent).filter( + TrendingEvent.source_id == source.id, + TrendingEvent.external_id.in_(external_ids) + ).all() + existing_events_dict = {ev.external_id: ev for ev in existing_events} + elif source.source_type == SourceType.RSS_FEED: + existing_articles = db.query(NewsArticle).filter( + NewsArticle.source_id == source.id, + NewsArticle.external_id.in_(external_ids) + ).all() + existing_articles_dict = {art.external_id: art for art in existing_articles} + + # 3. 筛选出需要进行大模型向量运算的文本 + texts_to_embed = [] + if source.source_type in (SourceType.HOT_TREND, SourceType.API): + for item, external_id in valid_items: + title = item.get("title") + existing_event = existing_events_dict.get(external_id) + if existing_event: + if existing_event.current_headline != title: + texts_to_embed.append(title) + else: + texts_to_embed.append(title) + + # 4. 批量执行大模型推理 + embeddings_dict = generate_embeddings_batch(texts_to_embed) + + # 初始化聚类器(只在热搜模式下需要,且只初始化一次) + clusterer = None + if source.source_type in (SourceType.HOT_TREND, SourceType.API): + clusterer = UnifiedEventClusterer(db) + + # 5. 核心路由分流落库 + for index, (item, external_id) in enumerate(valid_items, 1): if source.source_type in (SourceType.HOT_TREND, SourceType.API): - process_hot_trend_item(db, source, item, index, external_id) + existing_event = existing_events_dict.get(external_id) + process_hot_trend_item( + db, source, item, index, external_id, + existing_event, embeddings_dict, clusterer + ) elif source.source_type == SourceType.RSS_FEED: - process_rss_feed_item(db, source, item, external_id) + existing_article = existing_articles_dict.get(external_id) + process_rss_feed_item(db, source, item, external_id, existing_article) saved_count += 1 @@ -209,55 +264,79 @@ async def fetch_and_save_trending_data(): """ print(f"[{utcnow()}] 开始执行定时抓取任务...") + # 获取启用的信息源 - 这个只读操作用一个短连接 with SessionLocal() as db: - # 获取启用的信息源 sources = db.query(InfoSource).filter(InfoSource.is_enabled == True).all() if not sources: print("没有找到启用的信息源,任务结束。") return + + # 我们把 source 的信息提前提取出来,避免在异步中长期持有 session + source_configs = [ + { + "id": s.id, + "home_url": s.home_url, + "source_name": s.source_name, + "source_type": s.source_type + } + for s in sources + ] - # 伪装请求头,规避反爬 - custom_headers = { - "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36", - "Accept": "application/json, text/plain, */*", - "Referer": "https://newsnow.busiyi.world/", - "Origin": "https://newsnow.busiyi.world" - } + # 伪装请求头,规避反爬 + custom_headers = { + "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36", + "Accept": "application/json, text/plain, */*", + "Referer": "https://newsnow.busiyi.world/", + "Origin": "https://newsnow.busiyi.world" + } - async with httpx.AsyncClient(timeout=15.0, headers=custom_headers) as client: - for source in sources: - platform_id = source.home_url - if not platform_id: - continue + async with httpx.AsyncClient(timeout=15.0, headers=custom_headers) as client: + for s_config in source_configs: + platform_id = s_config["home_url"] + if not platform_id: + continue - url = f"{API_BASE_URL}?id={platform_id}&latest" + url = f"{API_BASE_URL}?id={platform_id}&latest" - # 初始化监控日志 - task_log = DataSyncTask(source_id=source.id, items_fetched=0) + try: + # 1. 网络请求(可能耗时较长,不要包在 db session 里) + response = await client.get(url) + response.raise_for_status() + data_json = response.json() + items = data_json.get("items", []) + + # 2. 数据库事务操作(尽量短,单独使用 session) + with SessionLocal() as db: + # 重新从短 session 中获取 source 实例,以免 detached + source = db.query(InfoSource).get(s_config["id"]) + if not source: + continue + + task_log = DataSyncTask(source_id=source.id, items_fetched=0) + try: + # 调用数据处理层 + saved_count = process_source_data(db, source, items) - try: - # 发起网络请求 - response = await client.get(url) - response.raise_for_status() - data_json = response.json() - items = data_json.get("items", []) - - # 调用数据处理层 - saved_count = process_source_data(db, source, items) - - # 业务事务成功提交 - task_log.items_fetched = saved_count - task_log.task_status = TaskStatus.SUCCESS - db.add(task_log) - db.commit() - print(f"[{source.source_name}] ({source.source_type}) 成功抓取并更新了 {saved_count} 条数据") - - except Exception as e: - # 异常拦截与错误隔离 - db.rollback() - - task_log.task_status = TaskStatus.ERROR - task_log.error_trace = str(e) - db.add(task_log) - db.commit() - print(f"[{source.source_name}] 抓取失败: {e}") + # 业务事务成功提交 + task_log.items_fetched = saved_count + task_log.task_status = TaskStatus.SUCCESS + db.add(task_log) + db.commit() + print(f"[{source.source_name}] ({source.source_type}) 成功抓取并更新了 {saved_count} 条数据") + except Exception as e: + db.rollback() + raise e # 抛出给外层捕获记录日志 + + except Exception as e: + # 异常拦截与错误隔离,另起一个超短事务记录日志 + with SessionLocal() as log_db: + try: + new_task_log = DataSyncTask(source_id=s_config["id"], items_fetched=0) + new_task_log.task_status = TaskStatus.ERROR + new_task_log.error_trace = str(e) + log_db.add(new_task_log) + log_db.commit() + print(f"[{s_config['source_name']}] 抓取失败: {e}") + except Exception as inner_e: + log_db.rollback() + print(f"[{s_config['source_name']}] 抓取失败,且日志写入失败: {e}, {inner_e}") diff --git a/backend/app/services/summary_service.py b/backend/app/services/summary_service.py index 136b81e..eaa49a7 100644 --- a/backend/app/services/summary_service.py +++ b/backend/app/services/summary_service.py @@ -180,62 +180,76 @@ async def generate_unified_summaries(): """Scheduled task: refresh summaries and topic tags for hot unified events.""" print(f"[{utcnow()}] Start unified summary generation task...") + # 先提取需要处理的事件 ID,尽早释放 session,不长期占用 db session with SessionLocal() as db: recent_threshold = utcnow() - timedelta(days=3) - events = db.query(UnifiedEvent).filter( UnifiedEvent.hot_score >= HOT_SCORE_THRESHOLD, UnifiedEvent.hot_score > UnifiedEvent.last_summarized_trends_count, UnifiedEvent.created_at >= recent_threshold, ).all() - + if not events: print("No events require summary update in this round.") return + + # 复制出需要的信息,脱离 session + event_ids = [e.id for e in events] + event_hot_scores = {e.id: e.hot_score for e in events} - for event in events: + # 外层循环:针对每个 event_id 开启一个极短生命周期的 session 获取依赖数据 + for event_id in event_ids: + platform_dict: dict[str, set[str]] = {} + with SessionLocal() as db: trends = ( db.query(TrendingEvent, InfoSource.source_name) .join(InfoSource, TrendingEvent.source_id == InfoSource.id) - .filter(TrendingEvent.unified_event_id == event.id) + .filter(TrendingEvent.unified_event_id == event_id) .all() ) if not trends: continue - platform_dict: dict[str, set[str]] = {} for trend_record, source_name in trends: platform_dict.setdefault(source_name, set()).add(trend_record.current_headline) - prompt_lines = [ - f"[{platform}] {', '.join(sorted(headlines))}" - for platform, headlines in platform_dict.items() - ] - platform_data_text = "\n".join(prompt_lines) + prompt_lines = [ + f"[{platform}] {', '.join(sorted(headlines))}" + for platform, headlines in platform_dict.items() + ] + platform_data_text = "\n".join(prompt_lines) - try: - llm_result = await call_llm_for_summary(platform_data_text) + try: + # 大模型调用可能耗时几十秒,绝对不能把它包裹在数据库事务里 + llm_result = await call_llm_for_summary(platform_data_text) + + # 调用完成后,再开启一个新的极短事务,进行数据回写 + with SessionLocal() as write_db: + event = write_db.query(UnifiedEvent).get(event_id) + if not event: + continue if "unified_title" in llm_result and llm_result["unified_title"]: event.unified_title = llm_result["unified_title"] if "ai_comprehensive_summary" in llm_result and llm_result["ai_comprehensive_summary"]: event.ai_comprehensive_summary = llm_result["ai_comprehensive_summary"] - if event.hot_score >= TOPIC_TAG_MIN_HOT_SCORE: + hot_score = event_hot_scores.get(event_id, event.hot_score) + if hot_score >= TOPIC_TAG_MIN_HOT_SCORE: topic_candidates = parse_topic_keywords(llm_result) normalized_topics = normalize_topic_keywords(topic_candidates) if normalized_topics: - replace_event_topics(db, event.id, normalized_topics) + replace_event_topics(write_db, event.id, normalized_topics) event.last_summarized_trends_count = event.hot_score + write_db.commit() + print( f"Updated event {event.id} summary" f" (hot_score={event.hot_score})." ) - except Exception as exc: - print(f"Event {event.id} summary generation failed: {exc}") - continue - - db.commit() + except Exception as exc: + print(f"Event {event_id} summary generation failed: {exc}") + continue