import os import hashlib from datetime import timedelta import httpx import json import numpy as np from dotenv import load_dotenv from sklearn.metrics.pairwise import cosine_similarity from sentence_transformers import SentenceTransformer from app.database import SessionLocal from app.models.models import ( InfoSource, TrendingEvent, NewsArticle, DataSyncTask, TaskStatus, HeadlineRevision, RankingLog, SourceType, utcnow, UnifiedEvent ) # AI辅助生成:deepseek-v3-2,2026年3月20日 # 加载环境变量 load_dotenv() hf_token = os.getenv("HF_TOKEN") SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", 0.72)) API_BASE_URL = os.getenv("API_BASE_URL", "https://newsnow.busiyi.world/api/s") EMBEDDING_MODEL_PATH = os.getenv("EMBEDDING_MODEL_PATH", "") print("正在加载模型...") # 全局单例 embedder_model = SentenceTransformer(EMBEDDING_MODEL_PATH, local_files_only=True) print("模型加载完成。") # AI生成结束 def generate_md5(text: str) -> str: """生成 32 位 MD5 作为 external_id,用于跨平台去重""" return hashlib.md5(text.encode('utf-8')).hexdigest() def generate_embeddings_batch(texts: list[str]) -> dict: """批量生成向量,返回 {text: (embedding_json, numpy_array)}""" if not texts: return {} unique_texts = list(set(texts)) raw_vecs = embedder_model.encode(unique_texts, normalize_embeddings=True, show_progress_bar=False) result = {} for text, raw_vec in zip(unique_texts, raw_vecs): truncated_vec = [round(float(x), 5) for x in raw_vec] emb_json = json.dumps(truncated_vec, separators=(',', ':')) result[text] = (emb_json, raw_vec) return result class UnifiedEventClusterer: def __init__(self, db): self.db = db three_days_ago = utcnow() - timedelta(days=3) recent_events = db.query(UnifiedEvent).filter( UnifiedEvent.created_at >= three_days_ago ).order_by(UnifiedEvent.created_at.desc()).limit(300).all() self.event_vectors = [] self.event_ids = [] for ev in recent_events: if ev.center_embedding: self.event_vectors.append(np.array(json.loads(ev.center_embedding))) self.event_ids.append(ev.id) def match_or_create(self, title: str, embedding_json: str, new_vec: np.ndarray) -> int: """语义相似则归入已有事件并累加热度,否则创建新 UnifiedEvent""" if self.event_vectors: # 批量矩阵计算相似度 sim_scores = cosine_similarity([new_vec], self.event_vectors)[0] max_idx = np.argmax(sim_scores) if sim_scores[max_idx] >= SIMILARITY_THRESHOLD: matched_event_id = self.event_ids[max_idx] # 更新热度 matched_event = self.db.query(UnifiedEvent).get(matched_event_id) if matched_event: matched_event.hot_score += 1 return matched_event_id # 没匹配到,创建一个新的统一大事件 new_unified = UnifiedEvent( unified_title=title, center_embedding=embedding_json, hot_score=1 ) self.db.add(new_unified) self.db.flush() # 更新缓存 self.event_vectors.append(new_vec) self.event_ids.append(new_unified.id) return new_unified.id def process_hot_trend_item(db, source, item, index: int, external_id: str, existing_event, embeddings_dict: dict, clusterer: UnifiedEventClusterer): """ 处理【热搜/短新闻】的业务逻辑,现已加入 AI 聚类功能 """ title = item.get("title") item_url = item.get("url", "") event_to_log = None if existing_event: if existing_event.current_headline != title: new_embedding_json, _ = embeddings_dict[title] revision = HeadlineRevision( event_id=existing_event.id, previous_headline=existing_event.current_headline, revised_headline=title ) db.add(revision) existing_event.current_headline = title existing_event.title_embedding = new_embedding_json existing_event.current_ranking = index existing_event.event_url = item_url event_to_log = existing_event else: new_embedding_json, new_vec = embeddings_dict[title] matched_event_id = clusterer.match_or_create(title, new_embedding_json, new_vec) new_event = TrendingEvent( source_id=source.id, external_id=external_id, current_headline=title, event_url=item_url, current_ranking=index, title_embedding=new_embedding_json, unified_event_id=matched_event_id ) db.add(new_event) db.flush() event_to_log = new_event # 强制记录排名轨迹 rank_log = RankingLog( event_id=event_to_log.id, ranking_position=index ) db.add(rank_log) def process_rss_feed_item(db, source, item, external_id: str, existing_article): """ 处理【长文章/传统订阅】分支的核心业务逻辑 (写入 NewsArticle 表) """ title = item.get("title") item_url = item.get("url", "") if existing_article: # 文章若存在,仅更新基础字段 existing_article.article_title = title existing_article.article_url = item_url else: # 全新文章入库 new_article = NewsArticle( source_id=source.id, external_id=external_id, article_title=title, article_url=item_url, ) db.add(new_article) def process_source_data(db, source, items: list) -> int: """ 数据清洗与路由分发层: 遍历 API 返回的 items,生成唯一指纹,并路由到不同的处理模块。 采用批量查重和批量向量计算优化性能,避免数据库锁死。 返回成功处理的条目数量。 """ saved_count = 0 platform_id = source.home_url valid_items = [] external_ids = [] for item in items: title = item.get("title") if not title: continue item_url = item.get("url", "") raw_id = item.get("id") or item_url or title external_id = generate_md5(f"{platform_id}_{raw_id}") valid_items.append((item, external_id)) external_ids.append(external_id) if not valid_items: return 0 existing_events_dict = {} existing_articles_dict = {} if source.source_type in (SourceType.HOT_TREND, SourceType.API): existing_events = db.query(TrendingEvent).filter( TrendingEvent.source_id == source.id, TrendingEvent.external_id.in_(external_ids) ).all() existing_events_dict = {ev.external_id: ev for ev in existing_events} elif source.source_type == SourceType.RSS_FEED: existing_articles = db.query(NewsArticle).filter( NewsArticle.source_id == source.id, NewsArticle.external_id.in_(external_ids) ).all() existing_articles_dict = {art.external_id: art for art in existing_articles} texts_to_embed = [] if source.source_type in (SourceType.HOT_TREND, SourceType.API): for item, external_id in valid_items: title = item.get("title") existing_event = existing_events_dict.get(external_id) if existing_event: if existing_event.current_headline != title: texts_to_embed.append(title) else: texts_to_embed.append(title) embeddings_dict = generate_embeddings_batch(texts_to_embed) clusterer = None if source.source_type in (SourceType.HOT_TREND, SourceType.API): clusterer = UnifiedEventClusterer(db) for index, (item, external_id) in enumerate(valid_items, 1): if source.source_type in (SourceType.HOT_TREND, SourceType.API): existing_event = existing_events_dict.get(external_id) process_hot_trend_item( db, source, item, index, external_id, existing_event, embeddings_dict, clusterer ) elif source.source_type == SourceType.RSS_FEED: existing_article = existing_articles_dict.get(external_id) process_rss_feed_item(db, source, item, external_id, existing_article) saved_count += 1 return saved_count async def fetch_and_save_trending_data(): """ 调度层:负责网络请求、数据库事务管理和异常监控隔离。 """ print(f"[{utcnow()}] 开始执行定时抓取任务...") with SessionLocal() as db: sources = db.query(InfoSource).filter(InfoSource.is_enabled == True).all() if not sources: print("没有找到启用的信息源,任务结束。") return source_configs = [ { "id": s.id, "home_url": s.home_url, "source_name": s.source_name, "source_type": s.source_type } for s in sources ] custom_headers = { "User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36", "Accept": "application/json, text/plain, */*", "Referer": "https://newsnow.busiyi.world/", "Origin": "https://newsnow.busiyi.world" } async with httpx.AsyncClient(timeout=15.0, headers=custom_headers) as client: for s_config in source_configs: platform_id = s_config["home_url"] if not platform_id: continue url = f"{API_BASE_URL}?id={platform_id}&latest" try: response = await client.get(url) response.raise_for_status() data_json = response.json() items = data_json.get("items", []) with SessionLocal() as db: # 重新从短 session 中获取 source 实例,以免 detached source = db.query(InfoSource).get(s_config["id"]) if not source: continue task_log = DataSyncTask(source_id=source.id, items_fetched=0) try: saved_count = process_source_data(db, source, items) task_log.items_fetched = saved_count task_log.task_status = TaskStatus.SUCCESS db.add(task_log) db.commit() print(f"[{source.source_name}] ({source.source_type}) 成功抓取并更新了 {saved_count} 条数据") except Exception as e: db.rollback() raise e except Exception as e: with SessionLocal() as log_db: try: new_task_log = DataSyncTask(source_id=s_config["id"], items_fetched=0) new_task_log.task_status = TaskStatus.ERROR new_task_log.error_trace = str(e) log_db.add(new_task_log) log_db.commit() print(f"[{s_config['source_name']}] 抓取失败: {e}") except Exception as inner_e: log_db.rollback() print(f"[{s_config['source_name']}] 抓取失败,且日志写入失败: {e}, {inner_e}")