This commit is contained in:
stardrophere
2026-03-12 01:50:08 +08:00
parent 966bcfbba4
commit e28b893a12
7 changed files with 123 additions and 14 deletions
+38 -8
View File
@@ -37,18 +37,48 @@ def _normalize_text(text: str) -> str:
return text.strip().casefold()
_EMBEDDING_CACHE: dict[str, np.ndarray] = {}
MAX_CACHE_SIZE = 10000
def _build_keyword_embedding_map(keywords: list[str]) -> dict[str, np.ndarray]:
"""
批量生成关键词向量,并返回原词到向量的映射。
这里要求向量已归一化,后续可直接用点积表示余弦相似度
批量生成或从缓存获取关键词向量,并返回原词到向量的映射。
结合了批量推理(Batching)的极速优势和内存缓存的 O(1) 读取优势
"""
if not keywords:
return {}
vectors = embedder_model.encode(keywords, normalize_embeddings=True)
result: dict[str, np.ndarray] = {}
for keyword, vec in zip(keywords, vectors):
result[keyword] = np.asarray(vec, dtype=np.float32)
if not keywords:
return result
uncached_keywords = []
# 1. 尝试从缓存获取
for keyword in keywords:
if not keyword:
continue
if keyword in _EMBEDDING_CACHE:
result[keyword] = _EMBEDDING_CACHE[keyword]
else:
uncached_keywords.append(keyword)
# 2. 对未命中的词进行统一的批量推理
if uncached_keywords:
# 去重,避免同一个未缓存的词被计算多次
unique_uncached = list(dict.fromkeys(uncached_keywords))
vectors = embedder_model.encode(unique_uncached, normalize_embeddings=True, show_progress_bar=False)
# 防止缓存无限增长:超过阈值时清空最早存入的一半(简单粗暴的内存控制)
if len(_EMBEDDING_CACHE) > MAX_CACHE_SIZE:
keys_to_delete = list(_EMBEDDING_CACHE.keys())[: MAX_CACHE_SIZE // 2]
for k in keys_to_delete:
del _EMBEDDING_CACHE[k]
# 3. 将新计算的向量存入缓存并回填结果
for keyword, vec in zip(unique_uncached, vectors):
vec_array = np.asarray(vec, dtype=np.float32)
_EMBEDDING_CACHE[keyword] = vec_array
result[keyword] = vec_array
return result