Files
InsightRadar/backend/app/services/fetcher_service.py
T
2026-04-20 15:53:02 +08:00

325 lines
12 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
import hashlib
from datetime import timedelta
import httpx
import json
import numpy as np
from dotenv import load_dotenv
from sklearn.metrics.pairwise import cosine_similarity
from sentence_transformers import SentenceTransformer
from app.database import SessionLocal
from app.models.models import (
InfoSource, TrendingEvent, NewsArticle, DataSyncTask, TaskStatus,
HeadlineRevision, RankingLog, SourceType, utcnow, UnifiedEvent
)
# AI辅助生成:deepseek-v3-22026年3月20日
# 加载环境变量
load_dotenv()
hf_token = os.getenv("HF_TOKEN")
SIMILARITY_THRESHOLD = float(os.getenv("SIMILARITY_THRESHOLD", 0.72))
API_BASE_URL = os.getenv("API_BASE_URL", "https://newsnow.busiyi.world/api/s")
EMBEDDING_MODEL_PATH = os.getenv("EMBEDDING_MODEL_PATH", "")
print("正在加载模型...")
# 全局单例
embedder_model = SentenceTransformer(EMBEDDING_MODEL_PATH, local_files_only=True)
print("模型加载完成。")
# AI生成结束
def generate_md5(text: str) -> str:
"""生成 32 位 MD5 作为 external_id,用于跨平台去重"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
def generate_embeddings_batch(texts: list[str]) -> dict:
"""批量生成向量,返回 {text: (embedding_json, numpy_array)}"""
if not texts:
return {}
unique_texts = list(set(texts))
raw_vecs = embedder_model.encode(unique_texts, normalize_embeddings=True, show_progress_bar=False)
result = {}
for text, raw_vec in zip(unique_texts, raw_vecs):
truncated_vec = [round(float(x), 5) for x in raw_vec]
emb_json = json.dumps(truncated_vec, separators=(',', ':'))
result[text] = (emb_json, raw_vec)
return result
class UnifiedEventClusterer:
def __init__(self, db):
self.db = db
three_days_ago = utcnow() - timedelta(days=3)
recent_events = db.query(UnifiedEvent).filter(
UnifiedEvent.created_at >= three_days_ago
).order_by(UnifiedEvent.created_at.desc()).limit(300).all()
self.event_vectors = []
self.event_ids = []
for ev in recent_events:
if ev.center_embedding:
self.event_vectors.append(np.array(json.loads(ev.center_embedding)))
self.event_ids.append(ev.id)
def match_or_create(self, title: str, embedding_json: str, new_vec: np.ndarray) -> int:
"""语义相似则归入已有事件并累加热度,否则创建新 UnifiedEvent"""
if self.event_vectors:
# 批量矩阵计算相似度
sim_scores = cosine_similarity([new_vec], self.event_vectors)[0]
max_idx = np.argmax(sim_scores)
if sim_scores[max_idx] >= SIMILARITY_THRESHOLD:
matched_event_id = self.event_ids[max_idx]
# 更新热度
matched_event = self.db.query(UnifiedEvent).get(matched_event_id)
if matched_event:
matched_event.hot_score += 1
return matched_event_id
# 没匹配到,创建一个新的统一大事件
new_unified = UnifiedEvent(
unified_title=title,
center_embedding=embedding_json,
hot_score=1
)
self.db.add(new_unified)
self.db.flush()
# 更新缓存
self.event_vectors.append(new_vec)
self.event_ids.append(new_unified.id)
return new_unified.id
def process_hot_trend_item(db, source, item, index: int, external_id: str, existing_event, embeddings_dict: dict, clusterer: UnifiedEventClusterer):
"""
处理【热搜/短新闻】的业务逻辑,现已加入 AI 聚类功能
"""
title = item.get("title")
item_url = item.get("url", "")
event_to_log = None
if existing_event:
if existing_event.current_headline != title:
new_embedding_json, _ = embeddings_dict[title]
revision = HeadlineRevision(
event_id=existing_event.id,
previous_headline=existing_event.current_headline,
revised_headline=title
)
db.add(revision)
existing_event.current_headline = title
existing_event.title_embedding = new_embedding_json
existing_event.current_ranking = index
existing_event.event_url = item_url
event_to_log = existing_event
else:
new_embedding_json, new_vec = embeddings_dict[title]
matched_event_id = clusterer.match_or_create(title, new_embedding_json, new_vec)
new_event = TrendingEvent(
source_id=source.id,
external_id=external_id,
current_headline=title,
event_url=item_url,
current_ranking=index,
title_embedding=new_embedding_json,
unified_event_id=matched_event_id
)
db.add(new_event)
db.flush()
event_to_log = new_event
# 强制记录排名轨迹
rank_log = RankingLog(
event_id=event_to_log.id,
ranking_position=index
)
db.add(rank_log)
def process_rss_feed_item(db, source, item, external_id: str, existing_article):
"""
处理【长文章/传统订阅】分支的核心业务逻辑 (写入 NewsArticle 表)
"""
title = item.get("title")
item_url = item.get("url", "")
if existing_article:
# 文章若存在,仅更新基础字段
existing_article.article_title = title
existing_article.article_url = item_url
else:
# 全新文章入库
new_article = NewsArticle(
source_id=source.id,
external_id=external_id,
article_title=title,
article_url=item_url,
)
db.add(new_article)
def process_source_data(db, source, items: list) -> int:
"""
数据清洗与路由分发层:
遍历 API 返回的 items,生成唯一指纹,并路由到不同的处理模块。
采用批量查重和批量向量计算优化性能,避免数据库锁死。
返回成功处理的条目数量。
"""
saved_count = 0
platform_id = source.home_url
valid_items = []
external_ids = []
for item in items:
title = item.get("title")
if not title:
continue
item_url = item.get("url", "")
raw_id = item.get("id") or item_url or title
external_id = generate_md5(f"{platform_id}_{raw_id}")
valid_items.append((item, external_id))
external_ids.append(external_id)
if not valid_items:
return 0
existing_events_dict = {}
existing_articles_dict = {}
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
existing_events = db.query(TrendingEvent).filter(
TrendingEvent.source_id == source.id,
TrendingEvent.external_id.in_(external_ids)
).all()
existing_events_dict = {ev.external_id: ev for ev in existing_events}
elif source.source_type == SourceType.RSS_FEED:
existing_articles = db.query(NewsArticle).filter(
NewsArticle.source_id == source.id,
NewsArticle.external_id.in_(external_ids)
).all()
existing_articles_dict = {art.external_id: art for art in existing_articles}
texts_to_embed = []
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
for item, external_id in valid_items:
title = item.get("title")
existing_event = existing_events_dict.get(external_id)
if existing_event:
if existing_event.current_headline != title:
texts_to_embed.append(title)
else:
texts_to_embed.append(title)
embeddings_dict = generate_embeddings_batch(texts_to_embed)
clusterer = None
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
clusterer = UnifiedEventClusterer(db)
for index, (item, external_id) in enumerate(valid_items, 1):
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
existing_event = existing_events_dict.get(external_id)
process_hot_trend_item(
db, source, item, index, external_id,
existing_event, embeddings_dict, clusterer
)
elif source.source_type == SourceType.RSS_FEED:
existing_article = existing_articles_dict.get(external_id)
process_rss_feed_item(db, source, item, external_id, existing_article)
saved_count += 1
return saved_count
async def fetch_and_save_trending_data():
"""
调度层:负责网络请求、数据库事务管理和异常监控隔离。
"""
print(f"[{utcnow()}] 开始执行定时抓取任务...")
with SessionLocal() as db:
sources = db.query(InfoSource).filter(InfoSource.is_enabled == True).all()
if not sources:
print("没有找到启用的信息源,任务结束。")
return
source_configs = [
{
"id": s.id,
"home_url": s.home_url,
"source_name": s.source_name,
"source_type": s.source_type
}
for s in sources
]
custom_headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/145.0.0.0 Safari/537.36",
"Accept": "application/json, text/plain, */*",
"Referer": "https://newsnow.busiyi.world/",
"Origin": "https://newsnow.busiyi.world"
}
async with httpx.AsyncClient(timeout=15.0, headers=custom_headers) as client:
for s_config in source_configs:
platform_id = s_config["home_url"]
if not platform_id:
continue
url = f"{API_BASE_URL}?id={platform_id}&latest"
try:
response = await client.get(url)
response.raise_for_status()
data_json = response.json()
items = data_json.get("items", [])
with SessionLocal() as db:
# 重新从短 session 中获取 source 实例,以免 detached
source = db.query(InfoSource).get(s_config["id"])
if not source:
continue
task_log = DataSyncTask(source_id=source.id, items_fetched=0)
try:
saved_count = process_source_data(db, source, items)
task_log.items_fetched = saved_count
task_log.task_status = TaskStatus.SUCCESS
db.add(task_log)
db.commit()
print(f"[{source.source_name}] ({source.source_type}) 成功抓取并更新了 {saved_count} 条数据")
except Exception as e:
db.rollback()
raise e
except Exception as e:
with SessionLocal() as log_db:
try:
new_task_log = DataSyncTask(source_id=s_config["id"], items_fetched=0)
new_task_log.task_status = TaskStatus.ERROR
new_task_log.error_trace = str(e)
log_db.add(new_task_log)
log_db.commit()
print(f"[{s_config['source_name']}] 抓取失败: {e}")
except Exception as inner_e:
log_db.rollback()
print(f"[{s_config['source_name']}] 抓取失败,且日志写入失败: {e}, {inner_e}")