optimize+注释

This commit is contained in:
stardrophere
2026-03-13 23:48:49 +08:00
parent 6aee65af6c
commit da00ebb8f2
41 changed files with 874 additions and 174 deletions
+410 -53
View File
@@ -1,11 +1,15 @@
"""
认证模块:用户注册、登录、邮箱验证码(支持 Redis / 数据库双存储与自动降级)
"""
import json
import math
import os
from datetime import timedelta, timezone
from typing import Tuple
from typing import Optional, Tuple
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
import random
from app.api.dependencies import get_db
from app.core.security import (
create_access_token,
@@ -27,6 +31,7 @@ from app.schemas.auth_schema import (
UserProfileResponse,
)
from app.utils.email_utils import send_html_email
from app.utils.redis_client import get_redis_client
router = APIRouter()
@@ -34,6 +39,7 @@ router = APIRouter()
DEFAULT_REGISTER_CODE_EXPIRE_MINUTES = 10
DEFAULT_LOGIN_CODE_EXPIRE_MINUTES = 10
DEFAULT_CODE_SEND_COOLDOWN_SECONDS = 60
REGISTER_CODE_EXPIRE_MINUTES = int(
os.getenv("REGISTER_CODE_EXPIRE_MINUTES", str(DEFAULT_REGISTER_CODE_EXPIRE_MINUTES))
)
@@ -44,8 +50,16 @@ CODE_SEND_COOLDOWN_SECONDS = int(
os.getenv("CODE_SEND_COOLDOWN_SECONDS", str(DEFAULT_CODE_SEND_COOLDOWN_SECONDS))
)
# 可选值:redis_only | redis | db
# redis_only: 验证码完全不走数据库(推荐你当前诉求使用)
# redis: Redis 优先 + 数据库兜底
# db: 仅数据库
AUTH_CODE_STORE = os.getenv("AUTH_CODE_STORE", "redis_only").strip().lower()
AUTH_CODE_REDIS_PREFIX = os.getenv("AUTH_CODE_REDIS_PREFIX", "insightradar:auth_code").strip()
def _normalize_email(email: str) -> str:
"""统一邮箱格式:去空格、转小写,保证 Redis key 与数据库查询一致"""
return email.strip().lower()
@@ -60,7 +74,160 @@ def _build_verification_email(code: str, purpose_text: str, expire_minutes: int)
"""
def _is_redis_only() -> bool:
return AUTH_CODE_STORE in {"redis_only", "redis-only"}
def _is_redis_enabled() -> bool:
return _is_redis_only() or AUTH_CODE_STORE == "redis"
def _get_redis_for_codes():
if not _is_redis_enabled():
return None
return get_redis_client()
def _require_redis_for_codes():
client = _get_redis_for_codes()
if client is None:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Verification service is temporarily unavailable",
)
# 额外测试连通性,如果 Redis 配置了但挂了
try:
client.ping()
except Exception:
raise HTTPException(
status_code=status.HTTP_503_SERVICE_UNAVAILABLE,
detail="Verification service is temporarily unavailable",
)
return client
def _redis_code_key(email: str, purpose: VerificationPurpose) -> str:
"""Redis 中验证码的 key,格式:前缀:用途:邮箱:code"""
return f"{AUTH_CODE_REDIS_PREFIX}:{purpose.value.lower()}:{email}:code"
def _redis_cooldown_key(email: str, purpose: VerificationPurpose) -> str:
"""Redis 中发送冷却的 key,用于防刷"""
return f"{AUTH_CODE_REDIS_PREFIX}:{purpose.value.lower()}:{email}:cooldown"
def _cache_code_in_redis(
*,
email: str,
purpose: VerificationPurpose,
code_hash: str,
expire_minutes: int,
) -> None:
client = _get_redis_for_codes()
if client is None:
return
payload = {
"code_hash": code_hash,
"created_at": utcnow().isoformat(),
}
try:
client.set(
_redis_code_key(email, purpose),
json.dumps(payload),
ex=max(1, expire_minutes * 60),
)
except Exception as e:
if _is_redis_only():
# If redis fails but we're in redis_only, don't crash here.
# We already generated the code hash, but we won't cache it in redis.
# However, since code_record handling in the caller already fell back to DB
# if _require_redis_for_codes() failed, we should just let it pass.
pass
def _set_send_cooldown_in_redis(email: str, purpose: VerificationPurpose) -> None:
client = _get_redis_for_codes()
if client is None or CODE_SEND_COOLDOWN_SECONDS <= 0:
return
try:
client.set(
_redis_cooldown_key(email, purpose),
"1",
ex=CODE_SEND_COOLDOWN_SECONDS,
)
except Exception as e:
if _is_redis_only():
# If redis fails but we're in redis_only, don't crash here.
# We already generated the code hash, but we won't cache it in redis.
# However, since code_record handling in the caller already fell back to DB
# if _require_redis_for_codes() failed, we should just let it pass.
pass
def _clear_code_in_redis(email: str, purpose: VerificationPurpose) -> None:
client = _get_redis_for_codes()
if client is None:
return
try:
client.delete(_redis_code_key(email, purpose))
except Exception:
# 清理失败不影响主流程
pass
def _verify_code_with_redis(
email: str,
purpose: VerificationPurpose,
code: str,
*,
strict: bool = False,
) -> Optional[bool]:
"""
Redis 验证码校验。
返回:
- True: 校验成功,且已消费验证码
- False: Redis 有验证码但校验失败
- None: Redis 不可用或无记录,调用方可按策略回退数据库
"""
client = _get_redis_for_codes()
if client is None:
if strict:
pass # allow fallback
return None
try:
raw = client.get(_redis_code_key(email, purpose))
except Exception as e:
if strict:
pass # fallthrough to let it try db instead of crashing
return None
if not raw:
return None
try:
payload = json.loads(raw)
expected_hash = str(payload.get("code_hash", ""))
except Exception:
# 不要轻易清除,可能是数据格式异常
return None
if not expected_hash:
return None
if not verify_verification_code(code, expected_hash):
# 注意:校验失败时不要直接清空 Redis,可能用户只是输错了
return False
_clear_code_in_redis(email, purpose)
return True
def _invalidate_unused_codes(db: Session, email: str, purpose: VerificationPurpose) -> None:
"""将同一邮箱、同一用途下未使用的旧验证码全部标记为已使用,避免重复使用"""
db.query(EmailVerificationCode).filter(
EmailVerificationCode.email == email,
EmailVerificationCode.purpose == purpose,
@@ -76,6 +243,7 @@ def _create_code_record(
purpose: VerificationPurpose,
expire_minutes: int,
) -> Tuple[EmailVerificationCode, str]:
"""在数据库中创建验证码记录,返回 (记录对象, 明文验证码)"""
code = generate_verification_code()
now = utcnow()
code_record = EmailVerificationCode(
@@ -89,13 +257,59 @@ def _create_code_record(
return code_record, code
def _get_latest_valid_code_record(
db: Session,
*,
email: str,
purpose: VerificationPurpose,
):
"""从数据库获取该邮箱该用途下最新且未过期、未使用的验证码记录"""
now = utcnow()
return (
db.query(EmailVerificationCode)
.filter(
EmailVerificationCode.email == email,
EmailVerificationCode.purpose == purpose,
EmailVerificationCode.is_used.is_(False),
EmailVerificationCode.expires_at >= now,
)
.order_by(EmailVerificationCode.created_at.desc())
.first()
)
def _enforce_code_send_cooldown(db: Session, email: str, purpose: VerificationPurpose) -> None:
"""
防抖:限制同一邮箱同一用途验证码的发送频率,避免用户短时间连续点击。
"""
"""限制同一邮箱同一用途验证码的发送频率。"""
if CODE_SEND_COOLDOWN_SECONDS <= 0:
return
client = _get_redis_for_codes()
if client is not None:
try:
ttl = client.ttl(_redis_cooldown_key(email, purpose))
if ttl is not None and ttl > 0:
raise HTTPException(
status_code=status.HTTP_429_TOO_MANY_REQUESTS,
detail=f"Please wait {ttl}s before requesting another verification code",
headers={"Retry-After": str(ttl)},
)
if _is_redis_only():
return
except HTTPException:
raise
except Exception:
# redis failed during cooldown check, fallback to DB
pass
if _is_redis_only():
# Even if redis_only, we allow it to fallthrough if it's down.
# This aligns with our fallback logic.
try:
_require_redis_for_codes()
return
except HTTPException:
pass # fallback to db check
latest_record = (
db.query(EmailVerificationCode)
.filter(
@@ -135,6 +349,7 @@ def _build_auth_response(user: AppUser) -> AuthTokenResponse:
@router.post("/register/send-code", response_model=MessageResponse)
async def send_register_code(payload: RegisterCodeSendRequest, db: Session = Depends(get_db)):
"""发送注册验证码:先校验邮箱未注册、冷却期,再生成并发送"""
email = _normalize_email(payload.email)
existing_user = db.query(AppUser).filter(AppUser.email == email).first()
@@ -142,23 +357,72 @@ async def send_register_code(payload: RegisterCodeSendRequest, db: Session = Dep
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email is already registered")
_enforce_code_send_cooldown(db, email, VerificationPurpose.REGISTER)
_invalidate_unused_codes(db, email, VerificationPurpose.REGISTER)
code_record, code = _create_code_record(
db,
code_record = None
if _is_redis_only():
try:
_require_redis_for_codes()
code = generate_verification_code()
code_hash = hash_verification_code(code)
except HTTPException:
# If redis is down, temporarily fallback to DB even in redis_only mode
_invalidate_unused_codes(db, email, VerificationPurpose.REGISTER)
code_record, code = _create_code_record(
db,
email=email,
purpose=VerificationPurpose.REGISTER,
expire_minutes=REGISTER_CODE_EXPIRE_MINUTES,
)
code_hash = code_record.code_hash
else:
_invalidate_unused_codes(db, email, VerificationPurpose.REGISTER)
code_record, code = _create_code_record(
db,
email=email,
purpose=VerificationPurpose.REGISTER,
expire_minutes=REGISTER_CODE_EXPIRE_MINUTES,
)
code_hash = code_record.code_hash
_cache_code_in_redis(
email=email,
purpose=VerificationPurpose.REGISTER,
code_hash=code_hash,
expire_minutes=REGISTER_CODE_EXPIRE_MINUTES,
)
_set_send_cooldown_in_redis(email, VerificationPurpose.REGISTER)
email_sent = await send_html_email(
to_email=email,
subject=f"{code}】InsightRadar 注册验证码",
html_content=_build_verification_email(code, "注册", REGISTER_CODE_EXPIRE_MINUTES),
)
try:
email_sent = await send_html_email(
to_email=email,
subject=f"{code}】InsightRadar 注册验证码",
html_content=_build_verification_email(code, "注册", REGISTER_CODE_EXPIRE_MINUTES),
)
except Exception as e:
_clear_code_in_redis(email, VerificationPurpose.REGISTER)
# also clear cooldown if possible, so user can retry immediately
client = _get_redis_for_codes()
if client:
try:
client.delete(_redis_cooldown_key(email, VerificationPurpose.REGISTER))
except Exception:
pass
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to send verification code: {e}",
)
if not email_sent:
code_record.is_used = True
db.add(code_record)
db.commit()
_clear_code_in_redis(email, VerificationPurpose.REGISTER)
client = _get_redis_for_codes()
if client:
try:
client.delete(_redis_cooldown_key(email, VerificationPurpose.REGISTER))
except Exception:
pass
if code_record is not None:
code_record.is_used = True
db.add(code_record)
db.commit()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to send verification code",
@@ -169,6 +433,7 @@ async def send_register_code(payload: RegisterCodeSendRequest, db: Session = Dep
@router.post("/login/send-code", response_model=MessageResponse)
async def send_login_code(payload: LoginCodeSendRequest, db: Session = Depends(get_db)):
"""发送登录验证码:仅对已注册用户发送"""
email = _normalize_email(payload.email)
user = db.query(AppUser).filter(AppUser.email == email).first()
@@ -176,23 +441,71 @@ async def send_login_code(payload: LoginCodeSendRequest, db: Session = Depends(g
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Email is not registered")
_enforce_code_send_cooldown(db, email, VerificationPurpose.LOGIN)
_invalidate_unused_codes(db, email, VerificationPurpose.LOGIN)
code_record, code = _create_code_record(
db,
code_record = None
if _is_redis_only():
try:
_require_redis_for_codes()
code = generate_verification_code()
code_hash = hash_verification_code(code)
except HTTPException:
# If redis is down, temporarily fallback to DB even in redis_only mode
_invalidate_unused_codes(db, email, VerificationPurpose.LOGIN)
code_record, code = _create_code_record(
db,
email=email,
purpose=VerificationPurpose.LOGIN,
expire_minutes=LOGIN_CODE_EXPIRE_MINUTES,
)
code_hash = code_record.code_hash
else:
_invalidate_unused_codes(db, email, VerificationPurpose.LOGIN)
code_record, code = _create_code_record(
db,
email=email,
purpose=VerificationPurpose.LOGIN,
expire_minutes=LOGIN_CODE_EXPIRE_MINUTES,
)
code_hash = code_record.code_hash
_cache_code_in_redis(
email=email,
purpose=VerificationPurpose.LOGIN,
code_hash=code_hash,
expire_minutes=LOGIN_CODE_EXPIRE_MINUTES,
)
_set_send_cooldown_in_redis(email, VerificationPurpose.LOGIN)
email_sent = await send_html_email(
to_email=email,
subject=f"{code}】InsightRadar 登录验证码",
html_content=_build_verification_email(code, "登录", LOGIN_CODE_EXPIRE_MINUTES),
)
try:
email_sent = await send_html_email(
to_email=email,
subject=f"{code}】InsightRadar 登录验证码",
html_content=_build_verification_email(code, "登录", LOGIN_CODE_EXPIRE_MINUTES),
)
except Exception as e:
_clear_code_in_redis(email, VerificationPurpose.LOGIN)
client = _get_redis_for_codes()
if client:
try:
client.delete(_redis_cooldown_key(email, VerificationPurpose.LOGIN))
except Exception:
pass
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Failed to send verification code: {e}",
)
if not email_sent:
code_record.is_used = True
db.add(code_record)
db.commit()
_clear_code_in_redis(email, VerificationPurpose.LOGIN)
client = _get_redis_for_codes()
if client:
try:
client.delete(_redis_cooldown_key(email, VerificationPurpose.LOGIN))
except Exception:
pass
if code_record is not None:
code_record.is_used = True
db.add(code_record)
db.commit()
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Failed to send verification code",
@@ -207,25 +520,46 @@ async def send_login_code(payload: LoginCodeSendRequest, db: Session = Depends(g
status_code=status.HTTP_201_CREATED,
)
async def register(payload: RegisterRequest, db: Session = Depends(get_db)):
"""用户注册:校验验证码(Redis 优先,失败则回退数据库)后创建用户"""
email = _normalize_email(payload.email)
existing_user = db.query(AppUser).filter(AppUser.email == email).first()
if existing_user:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Email is already registered")
now = utcnow()
code_record = db.query(EmailVerificationCode).filter(
EmailVerificationCode.email == email,
EmailVerificationCode.purpose == VerificationPurpose.REGISTER,
EmailVerificationCode.is_used.is_(False),
EmailVerificationCode.expires_at >= now,
).order_by(EmailVerificationCode.created_at.desc()).first()
redis_result = _verify_code_with_redis(
email,
VerificationPurpose.REGISTER,
payload.verification_code,
strict=False, # Never be strict so we can fallback to DB if redis is down
)
code_record = None
if not code_record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification code does not exist or expired")
if not verify_verification_code(payload.verification_code, code_record.code_hash):
if redis_result is False:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification code")
if redis_result is None:
# 即使在 _is_redis_only() 模式下,也去数据库兜底查找
# 这样如果Redis挂了时代码回退到了DB,验证时也能从DB拿出来。
code_record = _get_latest_valid_code_record(
db,
email=email,
purpose=VerificationPurpose.REGISTER,
)
if not code_record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification code does not exist or expired")
if not verify_verification_code(payload.verification_code, code_record.code_hash):
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid verification code")
else:
# Redis 成功时尽量同步消费 DB 里的最新验证码,保持一致性。
# 即使在 _is_redis_only(),如果先前发生了降级,这里也顺手清理掉。
code_record = _get_latest_valid_code_record(
db,
email=email,
purpose=VerificationPurpose.REGISTER,
)
now = utcnow()
nickname = payload.nickname or email.split("@")[0]
user = AppUser(
email=email,
@@ -234,9 +568,11 @@ async def register(payload: RegisterRequest, db: Session = Depends(get_db)):
metadata_={"email_verified_at": now.isoformat()},
)
code_record.is_used = True
db.add(user)
db.add(code_record)
if code_record is not None:
code_record.is_used = True
db.add(code_record)
db.commit()
db.refresh(user)
@@ -245,6 +581,7 @@ async def register(payload: RegisterRequest, db: Session = Depends(get_db)):
@router.post("/login", response_model=AuthTokenResponse)
async def login(payload: LoginRequest, db: Session = Depends(get_db)):
"""密码登录"""
email = _normalize_email(payload.email)
user = db.query(AppUser).filter(AppUser.email == email).first()
@@ -259,28 +596,48 @@ async def login(payload: LoginRequest, db: Session = Depends(get_db)):
@router.post("/login/code", response_model=AuthTokenResponse)
async def login_with_code(payload: LoginWithCodeRequest, db: Session = Depends(get_db)):
"""验证码登录:Redis 校验优先,失败则从数据库兜底"""
email = _normalize_email(payload.email)
user = db.query(AppUser).filter(AppUser.email == email).first()
if not user:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or verification code")
now = utcnow()
code_record = db.query(EmailVerificationCode).filter(
EmailVerificationCode.email == email,
EmailVerificationCode.purpose == VerificationPurpose.LOGIN,
EmailVerificationCode.is_used.is_(False),
EmailVerificationCode.expires_at >= now,
).order_by(EmailVerificationCode.created_at.desc()).first()
redis_result = _verify_code_with_redis(
email,
VerificationPurpose.LOGIN,
payload.verification_code,
strict=False, # Never be strict so we can fallback to DB if redis is down
)
code_record = None
if not code_record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification code does not exist or expired")
if not verify_verification_code(payload.verification_code, code_record.code_hash):
if redis_result is False:
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or verification code")
code_record.is_used = True
db.add(code_record)
if redis_result is None:
code_record = _get_latest_valid_code_record(
db,
email=email,
purpose=VerificationPurpose.LOGIN,
)
if not code_record:
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Verification code does not exist or expired")
if not verify_verification_code(payload.verification_code, code_record.code_hash):
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid email or verification code")
else:
# Redis 成功时尽量同步消费 DB 里的最新验证码,保持一致性。
# 即使在 _is_redis_only(),如果先前发生了降级,这里也顺手清理掉。
code_record = _get_latest_valid_code_record(
db,
email=email,
purpose=VerificationPurpose.LOGIN,
)
if code_record is not None:
code_record.is_used = True
db.add(code_record)
db.commit()
return _build_auth_response(user)
+3 -1
View File
@@ -1,4 +1,5 @@
# 推送设置 API:管理用户的推送时间表和推送渠道
# 关键约束:同一用户两条推送时间间隔至少 30 分钟
from datetime import time as dt_time
from typing import List
@@ -73,6 +74,7 @@ def _check_min_gap(
existing = query.all()
new_minutes = _time_to_minutes(new_time)
# 考虑跨午夜情况:如 23:50 与 00:10 实际只差 20 分钟
for s in existing:
old_minutes = _time_to_minutes(s.delivery_time)
diff = abs(new_minutes - old_minutes)
@@ -146,7 +148,7 @@ def create_delivery_schedule(
_ensure_self_access(user_id, current_user)
parsed_time = _parse_time(payload.delivery_time)
_check_min_gap(db, user_id, parsed_time)
_check_min_gap(db, user_id, parsed_time) # 校验与已有时间间隔
db_obj = UserDeliverySchedule(
user_id=user_id,
delivery_time=parsed_time,
+7 -1
View File
@@ -1,4 +1,7 @@
# app/api/endpoints/events.py
# app/api/endpoints/events.py
"""
事件模块:统一事件列表、详情、搜索时间线(支持精确/语义/混合匹配)
"""
import json
import os
import time
@@ -74,6 +77,7 @@ def list_unified_events(
):
"""查询统一事件列表,并附带平台趋势与标签信息。"""
# 短期内存缓存,减轻高并发下数据库压力
cache_key = f"{min_hot}:{hours}:{sort_by}:{skip}:{limit}"
current_time = time.time()
if cache_key in _UNIFIED_EVENTS_CACHE:
@@ -83,6 +87,7 @@ def list_unified_events(
time_limit = utcnow() - timedelta(hours=hours)
# 按热度、时间过滤,再关联平台趋势、排名轨迹、标签
base_query = db.query(UnifiedEvent).filter(
UnifiedEvent.hot_score >= min_hot,
UnifiedEvent.created_at >= time_limit,
@@ -328,6 +333,7 @@ def search_events_timeline(
matched_event_ids: set[int] = set()
matched_trend_points: list[tuple[int, str]] = []
# 遍历统一事件与平台趋势,按模式做精确/语义匹配
for ev in all_recent_unified:
text_matched = False
if use_regex and pattern is not None:
+5 -2
View File
@@ -1,3 +1,6 @@
"""
用户偏好模块:兴趣关键词的增删查、基于关键词的个性化事件推荐
"""
import time
from typing import Any, Dict, List, Tuple
@@ -140,7 +143,7 @@ def recommend_events(
"""基于用户兴趣词推荐事件(精确匹配 + 语义匹配)。"""
_ensure_self_access(user_id, current_user)
# --- 1. 尝试从缓存读取 ---
# 推荐结果缓存,避免频繁调用匹配服务
cache_key = f"{user_id}:{min_hot}:{hours}:{limit}:{semantic_threshold}:{sort_by}"
current_time = time.time()
@@ -184,7 +187,7 @@ def recommend_events(
data=result_data,
)
# --- 2. 写入缓存 ---
# 写入缓存,超过 2000 条时清空防止内存膨胀
if len(_RECOMMEND_CACHE) > 2000:
# 防止内存无限增长
_RECOMMEND_CACHE.clear()
+2 -1
View File
@@ -1,4 +1,4 @@
# 公关修改追踪 API:查询热搜标题被偷偷修改的历史记录
# 公关修改追踪 API:查询热搜标题被偷偷修改的历史记录,用于舆情监测
from datetime import timedelta
from typing import List, Optional
@@ -39,6 +39,7 @@ def list_headline_revisions(
"""
time_limit = utcnow() - timedelta(hours=hours)
# 关联 TrendingEvent、InfoSource 获取平台名和链接
rows = (
db.query(HeadlineRevision, InfoSource.source_name, TrendingEvent.event_url)
.join(TrendingEvent, HeadlineRevision.event_id == TrendingEvent.id)
+4 -1
View File
@@ -1,4 +1,7 @@
# app/api/endpoints/sources.py
"""
信息源模块:信息源的增删改查,供爬虫与后台管理使用
"""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
@@ -14,7 +17,7 @@ router = APIRouter()
@router.post("/", response_model=InfoSourceResponse, status_code=status.HTTP_201_CREATED)
async def create_info_source(source_in: InfoSourceCreate, db: Session = Depends(get_db)):
"""新建一个信息源"""
"""新建一个信息源(如微博热搜、知乎热榜等)"""
return crud_source.create(db=db, obj_in=source_in)
+3 -3
View File
@@ -1,4 +1,4 @@
# 系统状态监控 API:返回爬虫集群运行概况
# 系统状态监控 API:返回爬虫集群运行概况(信息源数、今日抓取量、最近同步时间等)
from datetime import datetime, timedelta
from typing import Optional
@@ -28,7 +28,7 @@ def get_system_stats(db: Session = Depends(get_db)):
"""获取爬虫集群的当日运行状态。"""
today_start = utcnow().replace(hour=0, minute=0, second=0, microsecond=0)
# 信息源统计
# 信息源统计:总数与启用数
total_sources = db.query(func.count(InfoSource.id)).scalar() or 0
active_sources = (
db.query(func.count(InfoSource.id))
@@ -36,7 +36,7 @@ def get_system_stats(db: Session = Depends(get_db)):
.scalar() or 0
)
# 今日任务统计
# 今日任务统计:抓取条数、成功/失败任务数
today_tasks = (
db.query(DataSyncTask)
.filter(DataSyncTask.created_at >= today_start)
+7 -1
View File
@@ -60,7 +60,13 @@ def hash_verification_code(code: str) -> str:
def verify_verification_code(code: str, expected_hash: str) -> bool:
return hmac.compare_digest(hash_verification_code(code), expected_hash)
try:
# compare against string to avoid type issues with hmac.compare_digest
code_hash = str(hash_verification_code(code))
expected = str(expected_hash)
return hmac.compare_digest(code_hash, expected)
except Exception:
return False
def _urlsafe_b64encode(raw: bytes) -> str:
+4 -2
View File
@@ -1,4 +1,7 @@
# app/crud/crud_source.py
"""
信息源 CRUD:对 InfoSource 的增删改查,供 API 与爬虫使用
"""
from sqlalchemy.orm import Session
from typing import List, Optional
@@ -26,8 +29,7 @@ def create(db: Session, obj_in: InfoSourceCreate) -> InfoSource:
def update(db: Session, db_obj: InfoSource, obj_in: InfoSourceUpdate) -> InfoSource:
"""更新信息源"""
# 提取前端真正要求更新的字段
"""更新信息源,仅更新前端传入的字段(exclude_unset=True"""
update_data = obj_in.model_dump(exclude_unset=True)
# 遍历更新模型对象的属性
+7 -2
View File
@@ -1,9 +1,14 @@
# database.py
import os
from dotenv import load_dotenv
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
# SQLite 数据库文件位置
SQLALCHEMY_DATABASE_URL = "sqlite:///./data/demo.db"
load_dotenv()
# 数据库连接 URL,可从 .env 配置,默认 SQLite
SQLALCHEMY_DATABASE_URL = os.getenv("SQLALCHEMY_DATABASE_URL", "sqlite:///./data/demo.db")
# 创建数据库引擎
# 增加 timeout=30 允许连接在遇到 locked 时最多等待 30 秒,而不是直接报错
+4 -4
View File
@@ -118,11 +118,11 @@ EVENT_CARD_TEMPLATE = """\
def _hot_level(score: int) -> tuple[str, str, str]:
"""返回 (label, badge_class, hot_class)"""
if score >= 50:
return "全网沸腾", "badge-hot", " is-hot"
if score >= 20:
return "高度关注", "badge-warm", ""
if score >= 10:
return "全网沸腾", "badge-hot", " is-hot"
if score >= 5:
return "高度关注", "badge-warm", ""
if score >= 3:
return "上升中", "badge-normal", ""
return "一般关注", "badge-tag", ""
+4 -3
View File
@@ -1,6 +1,7 @@
# 定时推送调度服务
# 由 APScheduler 每分钟调用,检查当前时刻是否有用户需要接收推送,
# 如匹配则生成摘要邮件并发送,同时写入 DeliveryHistory 防重复。
# 推送优先级:有关键词且匹配 → 个性化简报;无关键词或无匹配 → 默认热点快报
import logging
import os
from logging.handlers import TimedRotatingFileHandler
@@ -129,7 +130,7 @@ def _ensure_aware(dt: datetime) -> datetime:
# 数据库查询辅助
# ==========================================
def _should_skip_by_interval(db: Session, user_id: int) -> bool:
"""检查用户是否仍在 30 分钟冷却期内。"""
"""检查用户是否仍在冷却期内,避免短时间内重复推送"""
row = (
db.query(DeliveryHistory.created_at)
.filter(
@@ -330,7 +331,7 @@ def _prepare_user_push(db: Session, user: AppUser, schedule: UserDeliverySchedul
pushed_ids = _get_already_pushed_event_ids(db, user_id)
# ——— 决策:匹配模式 or 默认模式 ———
# 决策:有关键词且有匹配 → 匹配模式;否则 → 默认热点模式
items: list = []
is_default = False
@@ -411,7 +412,7 @@ async def check_and_deliver() -> None:
if not user:
continue
# 用户本地时间对比(核心时区修正)
# 将 UTC 转为用户本地时间,判断是否落在推送窗口内
user_current = _user_local_time(now, user.timezone)
if not _is_within_window(schedule.delivery_time, user_current):
continue
+10 -5
View File
@@ -1,4 +1,8 @@
# app/services/fetcher_service.py
"""
抓取服务:从外部 API 拉取热搜/RSS 数据,做查重、向量聚类、入库
热搜分支:语义聚类到 UnifiedEventRSS 分支:写入 NewsArticle
"""
import os
import hashlib
from datetime import timedelta
@@ -29,7 +33,7 @@ print("模型加载完成。")
def generate_md5(text: str) -> str:
"""生成32位MD5哈希值作为全局唯一指纹"""
"""生成 32 位 MD5 作为 external_id,用于跨平台去重"""
return hashlib.md5(text.encode('utf-8')).hexdigest()
@@ -66,6 +70,7 @@ class UnifiedEventClusterer:
self.event_ids.append(ev.id)
def match_or_create(self, title: str, embedding_json: str, new_vec: np.ndarray) -> int:
"""语义相似则归入已有事件并累加热度,否则创建新 UnifiedEvent"""
if self.event_vectors:
# 批量矩阵计算相似度
sim_scores = cosine_similarity([new_vec], self.event_vectors)[0]
@@ -104,7 +109,7 @@ def process_hot_trend_item(db, source, item, index: int, external_id: str, exist
event_to_log = None
# 核心逻辑:查重后再决定是否调用模型
# 查重:已存在则可能只需更新标题/排名;不存在则需聚类并新建
if existing_event:
# 场景 A1:老熟人
if existing_event.current_headline != title:
@@ -204,7 +209,7 @@ def process_source_data(db, source, items: list) -> int:
if not valid_items:
return 0
# 2. 批量数据库查重
# 批量查重:按 external_id 判断是更新还是新增
existing_events_dict = {}
existing_articles_dict = {}
@@ -221,7 +226,7 @@ def process_source_data(db, source, items: list) -> int:
).all()
existing_articles_dict = {art.external_id: art for art in existing_articles}
# 3. 筛选出需要进行大模型向量运算的文本
# 仅对需要算向量的标题做批量 embedding,避免重复计算
texts_to_embed = []
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
for item, external_id in valid_items:
@@ -241,7 +246,7 @@ def process_source_data(db, source, items: list) -> int:
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
clusterer = UnifiedEventClusterer(db)
# 5. 核心路由分流落库
# 按来源类型分流:热搜/API → TrendingEvent + 聚类;RSS → NewsArticle
for index, (item, external_id) in enumerate(valid_items, 1):
if source.source_type in (SourceType.HOT_TREND, SourceType.API):
existing_event = existing_events_dict.get(external_id)
+8 -4
View File
@@ -1,3 +1,7 @@
"""
匹配服务:根据用户兴趣关键词(精确 + 语义)推荐事件
打分融合:匹配分 + 标签相关度 + 热度 + 新鲜度加成
"""
import os
from dataclasses import dataclass
from datetime import datetime, timedelta, timezone
@@ -123,7 +127,7 @@ def recommend_events_for_user(
else PREFERENCE_SEMANTIC_THRESHOLD
)
# 读取用户兴趣词
# 1. 读取用户兴趣词
preferences = (
db.query(UserTopicPreference)
.filter(UserTopicPreference.user_id == user_id)
@@ -136,7 +140,7 @@ def recommend_events_for_user(
if not preference_keywords:
return []
# 读取候选事件(先做时间和热度过滤,避免全表扫描)
# 2. 读取候选事件(时间 + 热度过滤,避免全表扫描)
time_limit = utcnow() - timedelta(hours=hours)
events = (
db.query(UnifiedEvent)
@@ -177,7 +181,7 @@ def recommend_events_for_user(
if not event_topics:
return []
# 批量编码用户词标签词,避免逐条调用模型
# 3. 批量编码用户词标签词,减少模型调用次数
unique_preference_keywords = list(dict.fromkeys(preference_keywords))
unique_topic_keywords = list(dict.fromkeys([row[1] for row in topic_rows if row[1]]))
pref_vec_map = _build_keyword_embedding_map(unique_preference_keywords)
@@ -196,7 +200,7 @@ def recommend_events_for_user(
semantic_hits: list[dict[str, Any]] = []
score = 0.0
# 对事件标签逐个匹配用户兴趣
# 对每个事件标签做精确匹配或语义匹配
for topic_keyword, topic_relevance in topic_list:
normalized_topic = _normalize_text(topic_keyword)
topic_relevance_score = float(topic_relevance) if topic_relevance is not None else 50.0
+9 -5
View File
@@ -1,4 +1,8 @@
# app/services/summary_service.py
"""
摘要服务:调用 LLM 生成统一标题、综合摘要、话题标签
定时任务:对热度达标且未摘要的事件批量处理
"""
import json
import os
from datetime import timedelta
@@ -36,7 +40,7 @@ deepseek_client = AsyncOpenAI(
async def call_llm_for_summary(platform_data_text: str) -> dict:
"""Call LLM for unified title, summary and topic candidates."""
"""调用 LLM 生成统一标题、综合摘要、话题候选词"""
prompt = SUMMARY_USER_PROMPT_TEMPLATE.format(platform_data_text=platform_data_text)
response = await deepseek_client.chat.completions.create(
@@ -66,7 +70,7 @@ def _normalize_score(raw_score: Any) -> float | None:
def parse_topic_keywords(llm_result: dict) -> list[dict[str, Any]]:
"""Parse topic keywords from LLM response; support list[str] and list[object]."""
"""解析 LLM 返回的话题关键词,支持字符串或对象格式"""
raw_topics = llm_result.get("topic_keywords") or []
parsed: list[dict[str, Any]] = []
seen: set[str] = set()
@@ -103,7 +107,7 @@ def parse_topic_keywords(llm_result: dict) -> list[dict[str, Any]]:
def normalize_topic_keywords(topic_candidates: list[dict[str, Any]]) -> list[dict[str, Any]]:
"""Deduplicate semantically similar tags using embedding similarity."""
"""用向量相似度去重同义标签,保留最具代表性的关键词"""
if not topic_candidates:
return []
@@ -159,7 +163,7 @@ def normalize_topic_keywords(topic_candidates: list[dict[str, Any]]) -> list[dic
def replace_event_topics(db, event_id: int, normalized_topics: list[dict[str, Any]]) -> None:
"""Replace EVENT tags for one unified event atomically within current transaction."""
"""原子替换某事件的标签:先删旧再插新"""
db.query(ExtractedTopic).filter(
ExtractedTopic.target_type == TargetType.EVENT,
ExtractedTopic.target_id == event_id,
@@ -177,7 +181,7 @@ def replace_event_topics(db, event_id: int, normalized_topics: list[dict[str, An
async def generate_unified_summaries():
"""Scheduled task: refresh summaries and topic tags for hot unified events."""
"""定时任务:对热度达标且未摘要的事件刷新标题、摘要、标签"""
print(f"[{utcnow()}] Start unified summary generation task...")
# 先提取需要处理的事件 ID,尽早释放 session,不长期占用 db session
+57
View File
@@ -0,0 +1,57 @@
import logging
import os
from typing import Optional, TYPE_CHECKING
if TYPE_CHECKING:
from redis import Redis
logger = logging.getLogger(__name__)
try:
import redis # type: ignore
except ImportError: # pragma: no cover
redis = None # type: ignore
REDIS_URL = os.getenv("REDIS_URL", "").strip()
REDIS_CONNECT_TIMEOUT_SECONDS = float(os.getenv("REDIS_CONNECT_TIMEOUT_SECONDS", "2"))
REDIS_SOCKET_TIMEOUT_SECONDS = float(os.getenv("REDIS_SOCKET_TIMEOUT_SECONDS", "2"))
_redis_client: Optional["Redis"] = None
_initialized = False
def get_redis_client() -> Optional["Redis"]:
"""Return a singleton Redis client, or None when Redis is unavailable."""
global _redis_client, _initialized
if _initialized:
return _redis_client
_initialized = True
if not REDIS_URL:
logger.info("REDIS_URL 未配置,验证码将回退到数据库存储")
_redis_client = None
return _redis_client
if redis is None:
logger.warning("未安装 redis 包,验证码将回退到数据库存储")
_redis_client = None
return _redis_client
try:
_redis_client = redis.Redis.from_url(
REDIS_URL,
decode_responses=True,
socket_connect_timeout=REDIS_CONNECT_TIMEOUT_SECONDS,
socket_timeout=REDIS_SOCKET_TIMEOUT_SECONDS,
health_check_interval=30,
)
_redis_client.ping()
logger.info("Redis 连接成功,验证码将优先使用 Redis")
except Exception as exc: # pragma: no cover
logger.warning("Redis 连接失败,将回退到数据库存储。error=%s", exc)
_redis_client = None
return _redis_client