# app/services/fetcher_service.py import os import hashlib from datetime import timedelta import httpx import json import numpy as np from dotenv import load_dotenv from sklearn.metrics.pairwise import cosine_similarity from sentence_transformers import SentenceTransformer from app.database import SessionLocal from app.models.models import ( InfoSource, TrendingEvent, NewsArticle, DataSyncTask, TaskStatus, HeadlineRevision, RankingLog, SourceType, utcnow, UnifiedEvent ) # 加载环境变量 load_dotenv() hf_token = os.getenv("HF_TOKEN") SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", 0.72)) API_BASE_URL = os.getenv("API_BASE_URL", "https://newsnow.busiyi.world/api/s") EMBEDDING_MODEL_PATH = os.getenv("EMBEDDING_MODEL_PATH", "") print("正在加载 BAAI/bge-m3 向量模型...") # 全局单例 embedder_model = SentenceTransformer(EMBEDDING_MODEL_PATH, local_files_only=True, device="cuda") print("模型加载完成。") def generate_md5(text: str) -> str: """生成32位MD5哈希值作为全局唯一指纹""" return hashlib.md5(text.encode('utf-8')).hexdigest() def generate_embedding_json(text: str) -> str: """辅助函数:调用大模型生成向量,并序列化为 JSON 字符串""" raw_vec = embedder_model.encode([text], normalize_embeddings=True, show_progress_bar=False)[0] truncated_vec = [round(float(x), 5) for x in raw_vec] return json.dumps(truncated_vec, separators=(',', ':')) def match_or_create_unified_event(db, title: str, embedding_json: str) -> int: """ 辅助函数:大事件聚类中枢。 拿着新计算的向量去数据库里碰,碰到了就返回老 ID,碰不到就建新的。 """ # 提取刚算出来的向量 new_vec = np.array(json.loads(embedding_json)) # 只取最近 3 天的活跃大事件进行比对 three_days_ago = utcnow() - timedelta(days=3) recent_events = db.query(UnifiedEvent).filter( UnifiedEvent.created_at >= three_days_ago ).order_by(UnifiedEvent.created_at.desc()).limit(200).all() if recent_events: valid_events = [ev for ev in recent_events if ev.center_embedding] if valid_events: event_vectors = [json.loads(ev.center_embedding) for ev in valid_events] # 批量矩阵计算相似度 sim_scores = cosine_similarity([new_vec], event_vectors)[0] max_idx = np.argmax(sim_scores) if sim_scores[max_idx] >= SIMILARITY_THRESHOLD: matched_event = valid_events[max_idx] matched_event.hot_score += 1 return matched_event.id # 没匹配到,创建一个新的统一大事件 new_unified = UnifiedEvent( unified_title=title, center_embedding=embedding_json, hot_score=1 # 初始热度 ) db.add(new_unified) db.flush() # 获取自增的主键 ID return new_unified.id def process_hot_trend_item(db, source, item, index: int, external_id: str): """ 处理【热搜/短新闻】的业务逻辑,现已加入 AI 聚类功能 """ title = item.get("title") item_url = item.get("url", "") existing_event = db.query(TrendingEvent).filter( TrendingEvent.source_id == source.id, TrendingEvent.external_id == external_id ).first() event_to_log = None # 核心逻辑:查重后再决定是否调用模型 if existing_event: # 场景 A1:老熟人 if existing_event.current_headline != title: # 标题被暗改,此时需要重新算一次 Embedding new_embedding_json = generate_embedding_json(title) revision = HeadlineRevision( event_id=existing_event.id, previous_headline=existing_event.current_headline, revised_headline=title ) db.add(revision) existing_event.current_headline = title existing_event.title_embedding = new_embedding_json # 更新为新标题的语义向量 # 注:这里不改变它所属的 unified_event_id,因为大体还是同一件事 existing_event.current_ranking = index existing_event.event_url = item_url event_to_log = existing_event else: # 场景 A2:这是一条彻底的全新热搜 # 1. 计算向量 new_embedding_json = generate_embedding_json(title) # 2. 扔进聚类中枢找归宿 matched_event_id = match_or_create_unified_event(db, title, new_embedding_json) # 3. 落库 new_event = TrendingEvent( source_id=source.id, external_id=external_id, current_headline=title, event_url=item_url, current_ranking=index, title_embedding=new_embedding_json, # 存入向量 unified_event_id=matched_event_id # 挂载到大事件下 ) db.add(new_event) db.flush() event_to_log = new_event # 强制记录排名轨迹 rank_log = RankingLog( event_id=event_to_log.id, ranking_position=index ) db.add(rank_log) def process_rss_feed_item(db, source, item, external_id: str): """ 处理【长文章/传统订阅】分支的核心业务逻辑 (写入 NewsArticle 表) """ title = item.get("title") item_url = item.get("url", "") existing_article = db.query(NewsArticle).filter( NewsArticle.source_id == source.id, NewsArticle.external_id == external_id ).first() if existing_article: # 文章若存在,仅更新基础字段 existing_article.article_title = title existing_article.article_url = item_url else: # 全新文章入库 new_article = NewsArticle( source_id=source.id, external_id=external_id, article_title=title, article_url=item_url, ) db.add(new_article) def process_source_data(db, source, items: list) -> int: """ 数据清洗与路由分发层: 遍历 API 返回的 items,生成唯一指纹,并路由到不同的处理模块。 返回成功处理的条目数量。 """ saved_count = 0 platform_id = source.home_url for index, item in enumerate(items, 1): title = item.get("title") if not title: continue item_url = item.get("url", "") # ID 兜底策略:接口ID -> URL -> Title raw_id = item.get("id") or item_url or title external_id = generate_md5(f"{platform_id}_{raw_id}") # 核心路由分流 if source.source_type in (SourceType.HOT_TREND, SourceType.API): process_hot_trend_item(db, source, item, index, external_id) elif source.source_type == SourceType.RSS_FEED: process_rss_feed_item(db, source, item, external_id) saved_count += 1 return saved_count async def fetch_and_save_trending_data(): """ 调度层:负责网络请求、数据库事务管理和异常监控隔离。 """ print(f"[{utcnow()}] 开始执行定时抓取任务...") with SessionLocal() as db: # 获取启用的信息源 sources = db.query(InfoSource).filter(InfoSource.is_enabled == True).all() if not sources: print("没有找到启用的信息源,任务结束。") return # 伪装请求头,规避反爬 custom_headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36", "Accept": "application/json, text/plain, */*", "Referer": "https://newsnow.busiyi.world/", "Origin": "https://newsnow.busiyi.world" } async with httpx.AsyncClient(timeout=15.0, headers=custom_headers) as client: for source in sources: platform_id = source.home_url if not platform_id: continue url = f"{API_BASE_URL}?id={platform_id}&latest" # 初始化监控日志 task_log = DataSyncTask(source_id=source.id, items_fetched=0) try: # 发起网络请求 response = await client.get(url) response.raise_for_status() data_json = response.json() items = data_json.get("items", []) # 调用数据处理层 saved_count = process_source_data(db, source, items) # 业务事务成功提交 task_log.items_fetched = saved_count task_log.task_status = TaskStatus.SUCCESS db.add(task_log) db.commit() print(f"[{source.source_name}] ({source.source_type}) 成功抓取并更新了 {saved_count} 条数据") except Exception as e: # 异常拦截与错误隔离 db.rollback() task_log.task_status = TaskStatus.ERROR task_log.error_trace = str(e) db.add(task_log) db.commit() print(f"[{source.source_name}] 抓取失败: {e}")