Files
2026-03-13 23:48:49 +08:00

348 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
# app/services/fetcher_service.py
"""
抓取服务:从外部 API 拉取热搜/RSS 数据,做查重、向量聚类、入库
热搜分支:语义聚类到 UnifiedEventRSS 分支:写入 NewsArticle
"""
import os
import hashlib
from datetime import timedelta
import httpx
import json
import numpy as np
from dotenv import load_dotenv
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from app.database import SessionLocal
from app.models.models import (
InfoSource, TrendingEvent, NewsArticle, DataSyncTask, TaskStatus,
HeadlineRevision, RankingLog, SourceType, utcnow, UnifiedEvent
)
# 加载环境变量
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", 0.72))
API_BASE_URL = os.getenv("API_BASE_URL", "https://newsnow.busiyi.world/api/s")
EMBEDDING_MODEL_PATH = os.getenv("EMBEDDING_MODEL_PATH", "")
print("正在加载 BAAI/bge-m3 向量模型...")
# 全局单例
embedder_model = SentenceTransformer(EMBEDDING_MODEL_PATH, local_files_only=True, device="cuda")
print("模型加载完成。")
def generate_md5(text: str) -> str:
"""生成 32 位 MD5 作为 external_id,用于跨平台去重"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def generate_embeddings_batch(texts: list[str]) -> dict:
"""批量生成向量,返回 {text: (embedding_json, numpy_array)}"""
if not texts:
return {}
unique_texts = list(set(texts))
raw_vecs = embedder_model.encode(unique_texts, normalize_embeddings=True, show_progress_bar=False)
result = {}
for text, raw_vec in zip(unique_texts, raw_vecs):
truncated_vec = [round(float(x), 5) for x in raw_vec]
emb_json = json.dumps(truncated_vec, separators=(',', ':'))
result[text] = (emb_json, raw_vec)
return result
class UnifiedEventClusterer:
def __init__(self, db):
self.db = db
three_days_ago = utcnow() - timedelta(days=3)
recent_events = db.query(UnifiedEvent).filter(
UnifiedEvent.created_at >= three_days_ago
).order_by(UnifiedEvent.created_at.desc()).limit(300).all()
self.event_vectors = []
self.event_ids = []
for ev in recent_events:
if ev.center_embedding:
self.event_vectors.append(np.array(json.loads(ev.center_embedding)))
self.event_ids.append(ev.id)
def match_or_create(self, title: str, embedding_json: str, new_vec: np.ndarray) -> int:
"""语义相似则归入已有事件并累加热度,否则创建新 UnifiedEvent"""
if self.event_vectors:
# 批量矩阵计算相似度
sim_scores = cosine_similarity([new_vec], self.event_vectors)[0]
max_idx = np.argmax(sim_scores)
if sim_scores[max_idx] >= SIMILARITY_THRESHOLD:
matched_event_id = self.event_ids[max_idx]
# 更新热度
matched_event = self.db.query(UnifiedEvent).get(matched_event_id)
if matched_event:
matched_event.hot_score += 1
return matched_event_id
# 没匹配到,创建一个新的统一大事件
new_unified = UnifiedEvent(
unified_title=title,
center_embedding=embedding_json,
hot_score=1 # 初始热度
)
self.db.add(new_unified)
self.db.flush() # 获取自增的主键 ID
# 更新缓存
self.event_vectors.append(new_vec)
self.event_ids.append(new_unified.id)
return new_unified.id
def process_hot_trend_item(db, source, item, index: int, external_id: str, existing_event, embeddings_dict: dict, clusterer: UnifiedEventClusterer):
"""
处理【热搜/短新闻】的业务逻辑,现已加入 AI 聚类功能
"""
title = item.get("title")
item_url = item.get("url", "")
event_to_log = None
# 查重:已存在则可能只需更新标题/排名;不存在则需聚类并新建
if existing_event:
# 场景 A1:老熟人
if existing_event.current_headline != title:
# 标题被暗改,此时需要重新算一次 Embedding
new_embedding_json, _ = embeddings_dict[title]
revision = HeadlineRevision(
event_id=existing_event.id,
previous_headline=existing_event.current_headline,
revised_headline=title
)
db.add(revision)
existing_event.current_headline = title
existing_event.title_embedding = new_embedding_json # 更新为新标题的语义向量
# 注:这里不改变它所属的 unified_event_id,因为大体还是同一件事
existing_event.current_ranking = index
existing_event.event_url = item_url
event_to_log = existing_event
else:
# 场景 A2:这是一条彻底的全新热搜
# 1. 计算向量
new_embedding_json, new_vec = embeddings_dict[title]
# 2. 扔进聚类中枢找归宿
matched_event_id = clusterer.match_or_create(title, new_embedding_json, new_vec)
# 3. 落库
new_event = TrendingEvent(
source_id=source.id,
external_id=external_id,
current_headline=title,
event_url=item_url,
current_ranking=index,
title_embedding=new_embedding_json, # 存入向量
unified_event_id=matched_event_id # 挂载到大事件下
)
db.add(new_event)
db.flush()
event_to_log = new_event
# 强制记录排名轨迹
rank_log = RankingLog(
event_id=event_to_log.id,
ranking_position=index
)
db.add(rank_log)
def process_rss_feed_item(db, source, item, external_id: str, existing_article):
"""
处理【长文章/传统订阅】分支的核心业务逻辑 (写入 NewsArticle 表)
"""
title = item.get("title")
item_url = item.get("url", "")
if existing_article:
# 文章若存在,仅更新基础字段
existing_article.article_title = title
existing_article.article_url = item_url
else:
# 全新文章入库
new_article = NewsArticle(
source_id=source.id,
external_id=external_id,
article_title=title,
article_url=item_url,
)
db.add(new_article)
def process_source_data(db, source, items: list) -> int:
"""
数据清洗与路由分发层:
遍历 API 返回的 items,生成唯一指纹,并路由到不同的处理模块。
采用批量查重和批量向量计算优化性能,避免数据库锁死。
返回成功处理的条目数量。
"""
saved_count = 0
platform_id = source.home_url
# 1. 批量计算外部 ID 并聚合要计算的文本
valid_items = []
external_ids = []
for item in items:
title = item.get("title")
if not title:
continue
item_url = item.get("url", "")
raw_id = item.get("id") or item_url or title
external_id = generate_md5(f"{platform_id}_{raw_id}")
valid_items.append((item, external_id))
external_ids.append(external_id)
if not valid_items:
return 0
# 批量查重:按 external_id 判断是更新还是新增
existing_events_dict = {}
existing_articles_dict = {}
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
existing_events = db.query(TrendingEvent).filter(
TrendingEvent.source_id == source.id,
TrendingEvent.external_id.in_(external_ids)
).all()
existing_events_dict = {ev.external_id: ev for ev in existing_events}
elif source.source_type == SourceType.RSS_FEED:
existing_articles = db.query(NewsArticle).filter(
NewsArticle.source_id == source.id,
NewsArticle.external_id.in_(external_ids)
).all()
existing_articles_dict = {art.external_id: art for art in existing_articles}
# 仅对需要算向量的标题做批量 embedding,避免重复计算
texts_to_embed = []
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
for item, external_id in valid_items:
title = item.get("title")
existing_event = existing_events_dict.get(external_id)
if existing_event:
if existing_event.current_headline != title:
texts_to_embed.append(title)
else:
texts_to_embed.append(title)
# 4. 批量执行大模型推理
embeddings_dict = generate_embeddings_batch(texts_to_embed)
# 初始化聚类器(只在热搜模式下需要,且只初始化一次)
clusterer = None
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
clusterer = UnifiedEventClusterer(db)
# 按来源类型分流:热搜/API → TrendingEvent + 聚类;RSS → NewsArticle
for index, (item, external_id) in enumerate(valid_items, 1):
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
existing_event = existing_events_dict.get(external_id)
process_hot_trend_item(
db, source, item, index, external_id,
existing_event, embeddings_dict, clusterer
)
elif source.source_type == SourceType.RSS_FEED:
existing_article = existing_articles_dict.get(external_id)
process_rss_feed_item(db, source, item, external_id, existing_article)
saved_count += 1
return saved_count
async def fetch_and_save_trending_data():
"""
调度层:负责网络请求、数据库事务管理和异常监控隔离。
"""
print(f"[{utcnow()}] 开始执行定时抓取任务...")
# 获取启用的信息源 - 这个只读操作用一个短连接
with SessionLocal() as db:
sources = db.query(InfoSource).filter(InfoSource.is_enabled == True).all()
if not sources:
print("没有找到启用的信息源,任务结束。")
return
# 我们把 source 的信息提前提取出来,避免在异步中长期持有 session
source_configs = [
{
"id": s.id,
"home_url": s.home_url,
"source_name": s.source_name,
"source_type": s.source_type
}
for s in sources
]
# 伪装请求头,规避反爬
custom_headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Referer": "https://newsnow.busiyi.world/",
"Origin": "https://newsnow.busiyi.world"
}
async with httpx.AsyncClient(timeout=15.0, headers=custom_headers) as client:
for s_config in source_configs:
platform_id = s_config["home_url"]
if not platform_id:
continue
url = f"{API_BASE_URL}?id={platform_id}&latest"
try:
# 1. 网络请求(可能耗时较长,不要包在 db session 里)
response = await client.get(url)
response.raise_for_status()
data_json = response.json()
items = data_json.get("items", [])
# 2. 数据库事务操作(尽量短,单独使用 session)
with SessionLocal() as db:
# 重新从短 session 中获取 source 实例,以免 detached
source = db.query(InfoSource).get(s_config["id"])
if not source:
continue
task_log = DataSyncTask(source_id=source.id, items_fetched=0)
try:
# 调用数据处理层
saved_count = process_source_data(db, source, items)
# 业务事务成功提交
task_log.items_fetched = saved_count
task_log.task_status = TaskStatus.SUCCESS
db.add(task_log)
db.commit()
print(f"[{source.source_name}] ({source.source_type}) 成功抓取并更新了 {saved_count} 条数据")
except Exception as e:
db.rollback()
raise e # 抛出给外层捕获记录日志
except Exception as e:
# 异常拦截与错误隔离,另起一个超短事务记录日志
with SessionLocal() as log_db:
try:
new_task_log = DataSyncTask(source_id=s_config["id"], items_fetched=0)
new_task_log.task_status = TaskStatus.ERROR
new_task_log.error_trace = str(e)
log_db.add(new_task_log)
log_db.commit()
print(f"[{s_config['source_name']}] 抓取失败: {e}")
except Exception as inner_e:
log_db.rollback()
print(f"[{s_config['source_name']}] 抓取失败,且日志写入失败: {e}, {inner_e}")