使用 Faiss 和 Spark 进行可扩展的用户相似度搜索¶
🎯 目标:在海量数据中寻找相似用户¶
在许多推荐系统中,一个关键任务是找到与给定用户相似的其他用户。如果我们知道用户 A 与用户 B 相似,我们就可以将用户 B 喜欢过的物品推荐给用户 A(这是一种 "u2u" 或用户到用户的方法,进而可以引导出 "u2u2i" 或用户到用户到物品的推荐)。
在本项目中,用户兴趣被表示为24维的数值向量,称为 embedding。我们的任务是,从一个大规模用户集合(数据集 A,约2亿用户)中,为每个用户从另一个大规模用户集合(数据集 B,约1000万用户)中找到与他们最相似的 TopK 个用户。
😫 挑战:速度与规模¶
此前,我们曾尝试使用 Spark 的 BucketedRandomProjectionLSH
(Locality Sensitive Hashing,局部敏感哈希)。虽然 LSH 设计用于近似最近邻搜索,但它本身也带来了一些问题:
- 手动调参:
bucketLength
和numHashTables
等参数需要针对特定数据进行仔细调整。在召回率(找到足够多的相关相似用户)和计算成本之间找到合适的平衡点非常棘手。 - 运行时间长: 对于我们的数据集规模,该过程大约需要10小时,这对于实际应用、迭代开发或频繁更新来说太长了。
这些限制促使我们寻找一种更高效的解决方案。
✨ 解决方案:Faiss 与 Spark 的结合¶
Faiss 由 Facebook AI 开发,是一个用于密集向量高效相似度搜索和聚类的库。它速度极快,但传统上在单机上运行。
- 单机 Faiss 的局限性:
- 数据集 B(1000万用户,每个用户一个24维向量)可能过大,难以构建一个能完全加载到单机内存中的索引。
- 更关键的是,数据集 A(2亿用户)规模太大,无法高效地从单机执行搜索。我们不能简单地加载所有2亿向量然后逐个查询。
为了克服这些限制,我们将 Faiss 的强大功能与 Apache Spark 的分布式计算能力相结合。核心思想是:
- 分布式索引构建 (针对 Dataset B):
- 在 Spark Driver 端使用 Dataset B 的一个样本训练一个“粗量化器”(coarse quantizer,可以理解为数据结构的初步草图)。
- 将这个预训练好的量化器广播到所有 Spark Worker 节点。
- 对 Dataset B 进行分区。每个 Worker 为其分配到的 Dataset B 数据子集构建一个部分 Faiss 索引,这个过程会使用之前广播的量化器。
- 在 Driver 端合并这些部分索引,从而为整个 Dataset B 创建一个单一、完整的 Faiss 索引。
- 分布式搜索 (针对 Dataset A):
- 将最终合并好的 Faiss 索引(来自 Dataset B)广播到所有 Spark Worker 节点。
- 对 Dataset A 进行分区。每个 Worker 获取其分配到的 Dataset A 用户子集,并在广播过来的 Faiss 索引中搜索相似用户。
- 收集搜索结果。
这种方法通过将索引构建和搜索过程都进行分布式处理,使我们能够处理海量数据集。
🤔 Faiss IndexIVFFlat
工作原理 (一个简单的类比)¶
我们使用的是 Faiss 的 IndexIVFFlat
。让我们用一个图书馆的类比来解析它的工作方式:
整理图书馆 (KMeans 聚类 - "IVF" 部分):
- 想象一下,数据集 B 是一个巨大的书籍收藏库(代表用户向量)。
IndexIVFFlat
首先尝试将它们组织起来。它使用一种聚类算法(如 K-Means)将相似的书籍分到nlist
个区域或“单元格”中。每个单元格由一个“质心”(centroid,该区域书籍的平均代表)来表示。这就是“粗量化”步骤。 - 当我们“训练”索引时,Faiss 实际上是在学习这些质心的位置。
- 想象一下,数据集 B 是一个巨大的书籍收藏库(代表用户向量)。
找到正确的过道 (搜索量化器):
- 现在,你带着一本来自数据集 A 的书(一个查询用户向量)来寻找数据集 B 中的相似书籍。
- Faiss 不会直接将你的书与图书馆中的每一本书进行比较,而是首先将其与所有
nlist
个区域的质心进行比较。 - 它会找出那些质心与你的书最接近的
nprobe
个区域。nprobe
是你设定的一个参数——代表你愿意搜索多少条“过道”。
在这些过道内搜索 ( "Flat" 部分):
- 现在,Faiss 只查看选定的
nprobe
个区域内的书籍。 - 在每个选定的区域内部,它执行一次穷举搜索或称为“Flat”搜索。它会仔细地将你的查询书籍(用户 A 向量)与这些区域中的每一本书(用户 B 向量)进行比较。
- 现在,Faiss 只查看选定的
相似度度量 (
METRIC_INNER_PRODUCT
):- 在比较书籍(向量)时,我们使用
METRIC_INNER_PRODUCT
(内积)。对于归一化向量(长度为1的向量),内积等同于余弦相似度。更高的内积意味着向量更对齐,因此更相似。
- 在比较书籍(向量)时,我们使用
这种两步过程(首先找到有希望的区域,然后仅在这些区域内进行详尽搜索)比直接将查询向量与数据集 B 中的每个向量进行比较要快得多。
🌊 整体工作流程¶
以下是我们在 Spark 上使用分布式 Faiss 方法的流程图:
graph TD A[Spark 加载数据集 A & B] --> B(步骤 1: Driver 端本地评估 Faiss 参数); A --> C[采样数据集 B]; C --> D(步骤 2: Driver 端训练全局粗量化器); D --> E(步骤 3: 将序列化的全局量化器广播给 Worker); F[对数据集 B 进行分区] --> G(步骤 4: Worker 使用全局量化器构建部分 Faiss 索引); E --> G; G --> H(步骤 5: Driver 端合并部分索引为最终全局索引); H --> I(步骤 6a: 将最终全局索引广播给 Worker); J[对数据集 A 进行分区] --> K(步骤 6b: Worker 基于最终索引执行分布式搜索); I --> K; K --> L[收集 KNN 结果]; L --> M(步骤 7: 将结果保存到 Hive 表);
现在,让我们深入代码实现。
0. 导入与设置¶
导入标准库、初始化 Spark 会话以及 Faiss。
import argparse
import faiss
import pandas as pd
import numpy as np
from pyspark.sql import SparkSession
from pyspark.sql.types import FloatType, IntegerType, StringType, StructField, StructType
from pyspark.ml.linalg import Vector
from faiss.contrib.evaluation import knn_intersection_measure
from loguru import logger
__doc__ = """
- 在其他步骤中生成用户向量,落地hive表中,含有`uid bigint`列和`features array<float>`列
- 对datasetA的每个用户,在datasetB中检索,用cosine_similarity度量,返回topK的相似结果
- 索引使用faiss的`IndexIVFFLat`,度量使用`Inner Product`
- 建立索引的过程和匹配过程均使用spark的分布式引擎
"""
# 配置Pandas显示选项
pd.options.display.max_colwidth = 500
def in_notebook():
"""
检查当前代码是否在Jupyter Notebook中运行
"""
try:
return get_ipython().__class__.__name__ == 'ZMQInteractiveShell'
except NameError:
return False
def get_spark():
"""
获取Spark会话,如果不存在则创建一个
"""
if 'spark' not in locals():
spark = SparkSession.builder \
.appName("ALS") \
.config("spark.sql.catalogImplementation", "hive") \
.enableHiveSupport() \
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
return spark
spark = get_spark()
logger.info("Spark session initialized.")
2025-05-30 18:11:36.080 | INFO | __main__:<module>:25 - Spark session initialized.
1. 配置参数 (Args
类)¶
所有可配置的参数都在这里定义。这包括用于数据集的 SQL 查询、输出表名以及 Faiss 特定的参数。
class Args:
datasetA_sql = "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (0) and rand() < 0.5"
datasetB_sql = "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (1) and rand() < 0.2"
knn_table = "bigdata_vf_long_inte_user_als_knn"
knn_dt = "20250315"
topK = 50
distance_cutoff = 0.9
feature_col_name = "features"
uid_col_name = "uid"
dimension = 24
nlist_for_ivf = 1024
faiss_metric_type = faiss.METRIC_INNER_PRODUCT
num_partitions_for_index = 5 # 索引阶段对datasetB进行分区,确保有足够的数据分布到各个分区进行有意义的局部训练
num_partitions_for_match = 100 # 匹配阶段对datasetA进行分区
nprobe_for_ivf = 50 # 表示在搜索过程中搜索的候选分区项数
training_sample_fraction = 0.2 # 索引阶段对datasetB进行采样的样本比例,用于训练quantizer
args = Args()
if not in_notebook():
example = """ # 示例
python step_knn.py \
--datasetA_sql "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (0) and rand() < 0.1" \
--datasetB_sql "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (1)" \
--knn_table "bigdata_vf_long_inte_user_als_knn" \
--knn_dt "20250315" \
--distance_cutoff 0.9 \
--topK 50 \
--feature_col_name "features" \
--uid_col_name "uid" \
--dimension 24 \
--nlist_for_ivf 1024 \
--faiss_metric_type 0 \
--num_partitions_for_index 5 \
--num_partitions_for_match 100 \
--training_sample_fraction 0.2 \
--nprobe_for_ivf 70
"""
parser = argparse.ArgumentParser(description=__doc__, epilog=example, formatter_class=argparse.RawDescriptionHelpFormatter)
# dataset group
group = parser.add_argument_group("input")
group.add_argument("--datasetA_sql", type=str, default=args.datasetA_sql, help="数据集A SQL查询")
group.add_argument("--datasetB_sql", type=str, default=args.datasetB_sql, help="数据集B SQL查询")
# table group
group = parser.add_argument_group("output")
group.add_argument("--knn_table", type=str, default=args.knn_table, help="输出KNN表名")
group.add_argument("--knn_dt", type=str, default=args.knn_dt, help="输出KNN表日期")
# lsh group
group = parser.add_argument_group("lsh")
group.add_argument("--distance_cutoff", type=float, default=args.distance_cutoff, help="相似度阈值")
group.add_argument("--topK", type=int, default=args.topK, help="输出的topK数量")
group.add_argument("--feature_col_name", type=str, default=args.feature_col_name, help="输入列名")
group.add_argument("--uid_col_name", type=str, default=args.uid_col_name, help="ID列名")
group.add_argument("--dimension", type=int, default=args.dimension, help="特征向量维度")
group.add_argument("--nlist_for_ivf", type=int, default=args.nlist_for_ivf, help="faiss的nlist")
group.add_argument("--faiss_metric_type", type=int, default=0, choices=[0, 1], help="相似度度量类型 (e.g., 0: METRIC_INNER_PRODUCT, 1: METRIC_L2)")
group.add_argument("--num_partitions_for_index", type=int, default=args.num_partitions_for_index, help="索引阶段对datasetB进行分区,确保足够的数据分布到各个分区,进行有意义的局部训练")
group.add_argument("--num_partitions_for_match", type=int, default=args.num_partitions_for_match, help="匹配阶段对datasetA进行分区")
group.add_argument("--nprobe_for_ivf", type=int, default=args.nprobe_for_ivf, help="在搜索过程中的候选分区项数")
group.add_argument("--training_sample_fraction", type=float, default=args.training_sample_fraction, help="索引阶段对datasetB进行采样的样本比例,用于训练quantizer")
args = parser.parse_args()
2. 加载数据集¶
使用在参数中定义的 SQL 查询加载数据集 A(查询用户)和数据集 B(用于构建索引的用户)。
# 此单元格将包含加载 datasetA 和 datasetB 的 Spark SQL 查询。
logger.info(f"加载数据集A: {args.datasetA_sql}")
datasetA = spark.sql(args.datasetA_sql)
datasetA.createOrReplaceTempView("datasetA")
logger.info(f"Dataset A count: {datasetA.count()}")
2025-05-30 18:11:41.322 | INFO | __main__:<module>:3 - 加载数据集A: select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (0) and rand() < 0.5 2025-05-30 18:11:42.486 | INFO | __main__:<module>:6 - Dataset A count: 894260
logger.info(f"加载数据集B: {args.datasetB_sql}")
datasetB = spark.sql(args.datasetB_sql)
datasetB.createOrReplaceTempView("datasetB")
logger.info(f"Dataset B count: {datasetB.count()}")
2025-05-30 18:11:43.420 | INFO | __main__:<module>:2 - 加载数据集B: select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (1) and rand() < 0.2 2025-05-30 18:11:44.657 | INFO | __main__:<module>:5 - Dataset B count: 345754
步骤 1: (可选但推荐) Driver 端本地评估 Faiss IVF 召回率¶
在启动完整的分布式作业之前,最好在数据的样本上测试 Faiss 参数(如 nlist
, nprobe
)。这有助于选择在召回率和速度之间提供良好权衡的参数。
函数 evaluate_local_ivf_recall
通过以下方式实现此目的:
- 从数据集 A(查询向量)和数据集 B(索引向量)中采样。
- 构建一个“真实”的精确索引(
IndexFlatIP
)和待评估的IndexIVFFlat
索引。 - 比较它们的搜索结果,以估算
IndexIVFFlat
的召回率。
logger.info(40 * "=")
logger.info("** 步骤 1: 在Driver端测试参数的召回效果 **")
logger.info(40 * "=")
def evaluate_local_ivf_recall(
datasetA,
datasetB,
sample_A_count: int = 10000, # 查询样本数量
sample_B_count: int = 20000, # 构建索引的样本数量
top_k_eval: int = 100, # 检索 Top-K 的数量
ranks_to_evaluate_recall: list = [1, 10, 50], # 需要评估的召回 rank
ndimension: int = 24,
nlist: int = 1024, # IVF 索引的聚类中心数
nprobe: int = 100, # 搜索时查询的倒排列表数
metric_type: int = faiss.METRIC_INNER_PRODUCT # 距离度量类型
):
"""
在Driver端评估Faiss IndexIVFFlat索引的召回率。
步骤:
1. 从datasetA中采样查询向量。
2. 从datasetB中采样构建IndexFlatIP(基准)和IndexIVFFlat(待评估)。
3. 对比两个索引的检索结果以计算召回率。
参数:
datasetA (DataFrame): 包含用户特征向量的 Spark DataFrame (用于查询)。
datasetB (DataFrame): 包含用户特征向量的 Spark DataFrame (用于构建索引)。
sample_A_count (int): 从datasetA中抽取的查询向量数。
sample_B_count (int): 从datasetB中抽取的构建索引的向量数。
top_k_eval (int): 每次查询返回的Top-K个结果。
ranks_to_evaluate_recall (List[int]): 计算召回率的rank值。
ndimension (int): 向量维度。
nlist (int): Faiss IVF索引的聚类中心数。
nprobe (int): IVF索引搜索时检查的倒排列表数。
metric_type (int): Faiss距离度量方式,如faiss.METRIC_INNER_PRODUCT或faiss.METRIC_L2。
返回:
Dict[int, float]: 每个rank对应的knn交集度量值(即近似召回率)。
"""
try:
logger.info("开始评估本地Faiss IVF索引的召回率...")
logger.info(f"评估用到的参数:{locals()}")
# --- Step 1: 从datasetA中提取查询向量 ---
logger.info(f"从datasetA中采样 {sample_A_count} 条数据作为查询向量...")
xq = np.concatenate(datasetA.select(args.feature_col_name).take(sample_A_count))
if xq.size == 0:
logger.warning("未获取到有效的查询向量!")
return {}
# --- Step 2: 从datasetB中提取构建索引所需的向量和UID ---
logger.info(f"从datasetB中采样 {sample_B_count} 条数据用于构建索引...")
xb = np.concatenate(datasetB.select(args.feature_col_name).take(sample_B_count))
uids = np.concatenate(datasetB.select(args.uid_col_name).take(sample_B_count))
if xb.size == 0 or uids.size == 0:
logger.warning("未获取到有效的构建索引所需向量或UID。")
return {}
# --- Step 3: 构建参考索引 (IndexFlatIP) ---
logger.info("正在构建参考索引 (IndexFlatIP) 用于对比召回结果...")
index_flat_ref = faiss.IndexIDMap2(faiss.IndexFlatIP(ndimension)) # 使用IDMap2保留原始UID
index_flat_ref.add_with_ids(xb, uids)
logger.info(f"参考索引已构建,包含 {index_flat_ref.ntotal} 个向量。")
# --- Step 4: 执行参考索引的搜索 ---
logger.info(f"使用参考索引进行Top-{top_k_eval}检索...")
Dref, Iref = index_flat_ref.search(xq, top_k_eval)
# --- Step 5: 构建并训练待评估的IVF索引 ---
logger.info(f"正在构建待评估的IndexIVFFlat索引 (nlist={nlist}, nprobe={nprobe})...")
quantizer = faiss.IndexFlatIP(ndimension) if metric_type == faiss.METRIC_INNER_PRODUCT else faiss.IndexFlatL2(ndimension)
index_ivfflat = faiss.IndexIVFFlat(quantizer, ndimension, nlist, metric_type)
if xb.shape[0] < nlist:
logger.warning(f"警告:构建IVF索引的数据点({xb.shape[0]})小于nlist({nlist}), 可能影响训练效果。")
index_ivfflat.train(xb) # 训练索引
index_ivfflat.add_with_ids(xb, uids) # 添加向量及其对应UID
# 设置nprobe参数
index_ivfflat.nprobe = nprobe
logger.info(f"IndexIVFFlat已构建并训练完成,当前包含 {index_ivfflat.ntotal} 个向量。")
# --- Step 6: 执行待评估索引的搜索 ---
logger.info(f"使用待评估索引进行Top-{top_k_eval}检索...")
Dnew, Inew = index_ivfflat.search(xq, top_k_eval)
# --- Step 7: 计算召回率 ---
recall_results = {}
logger.info("开始计算召回率指标...")
for rank in ranks_to_evaluate_recall:
if rank > top_k_eval:
logger.warning(f"Rank {rank} 超出Top-K范围 {top_k_eval},跳过此rank的评估。")
continue
try:
recall_value = knn_intersection_measure(Inew[:, :rank], Iref[:, :rank])
recall_results[rank] = recall_value
logger.info(f"Recall@{rank}: {recall_value:.4f}")
except Exception as e:
logger.error(f"在Rank@{rank}计算召回率时发生错误: {e}")
logger.info("召回率评估完成。")
return recall_results
except Exception as e:
logger.error(f"执行本地IVF召回率评估失败: {e}")
return {}
2025-05-30 18:11:47.868 | INFO | __main__:<module>:2 - ======================================== 2025-05-30 18:11:47.869 | INFO | __main__:<module>:3 - ** 步骤 1: 在Driver端测试参数的召回效果 ** 2025-05-30 18:11:47.870 | INFO | __main__:<module>:4 - ========================================
recall_metrics = evaluate_local_ivf_recall(
datasetA=datasetA,
datasetB=datasetB,
sample_A_count=5000,
sample_B_count=10000,
top_k_eval=100,
ranks_to_evaluate_recall=[1, 10, 50, 100],
ndimension=args.dimension,
nlist=args.nlist_for_ivf,
nprobe=args.nprobe_for_ivf,
metric_type=args.faiss_metric_type
)
logger.info(f"IVF索引召回率评估结果: {recall_metrics}")
2025-05-30 18:11:49.120 | INFO | __main__:evaluate_local_ivf_recall:43 - 开始评估本地Faiss IVF索引的召回率... 2025-05-30 18:11:49.124 | INFO | __main__:evaluate_local_ivf_recall:44 - 评估用到的参数:{'datasetA': DataFrame[uid: string, code: int, features: array<float>, norm: float, dt: string, pt: string], 'datasetB': DataFrame[uid: string, code: int, features: array<float>, norm: float, dt: string, pt: string], 'sample_A_count': 5000, 'sample_B_count': 10000, 'top_k_eval': 100, 'ranks_to_evaluate_recall': [1, 10, 50, 100], 'ndimension': 24, 'nlist': 1024, 'nprobe': 50, 'metric_type': 0} 2025-05-30 18:11:49.125 | INFO | __main__:evaluate_local_ivf_recall:47 - 从datasetA中采样 5000 条数据作为查询向量... 2025-05-30 18:11:49.124 | INFO | __main__:evaluate_local_ivf_recall:44 - 评估用到的参数:{'datasetA': DataFrame[uid: string, code: int, features: array<float>, norm: float, dt: string, pt: string], 'datasetB': DataFrame[uid: string, code: int, features: array<float>, norm: float, dt: string, pt: string], 'sample_A_count': 5000, 'sample_B_count': 10000, 'top_k_eval': 100, 'ranks_to_evaluate_recall': [1, 10, 50, 100], 'ndimension': 24, 'nlist': 1024, 'nprobe': 50, 'metric_type': 0} 2025-05-30 18:11:49.125 | INFO | __main__:evaluate_local_ivf_recall:47 - 从datasetA中采样 5000 条数据作为查询向量... 2025-05-30 18:11:53.357 | INFO | __main__:evaluate_local_ivf_recall:54 - 从datasetB中采样 10000 条数据用于构建索引... 2025-05-30 18:11:59.415 | INFO | __main__:evaluate_local_ivf_recall:63 - 正在构建参考索引 (IndexFlatIP) 用于对比召回结果... 2025-05-30 18:11:59.425 | INFO | __main__:evaluate_local_ivf_recall:66 - 参考索引已构建,包含 10000 个向量。 2025-05-30 18:11:59.427 | INFO | __main__:evaluate_local_ivf_recall:69 - 使用参考索引进行Top-100检索... 2025-05-30 18:11:59.643 | INFO | __main__:evaluate_local_ivf_recall:73 - 正在构建待评估的IndexIVFFlat索引 (nlist=1024, nprobe=50)... 2025-05-30 18:12:00.102 | INFO | __main__:evaluate_local_ivf_recall:85 - IndexIVFFlat已构建并训练完成,当前包含 10000 个向量。 2025-05-30 18:12:00.104 | INFO | __main__:evaluate_local_ivf_recall:88 - 使用待评估索引进行Top-100检索... 2025-05-30 18:12:00.137 | INFO | __main__:evaluate_local_ivf_recall:93 - 开始计算召回率指标... 2025-05-30 18:12:00.265 | INFO | __main__:evaluate_local_ivf_recall:103 - Recall@1: 0.9992 2025-05-30 18:12:00.401 | INFO | __main__:evaluate_local_ivf_recall:103 - Recall@10: 0.9963 2025-05-30 18:12:00.558 | INFO | __main__:evaluate_local_ivf_recall:103 - Recall@50: 0.9871 2025-05-30 18:12:00.756 | INFO | __main__:evaluate_local_ivf_recall:103 - Recall@100: 0.9716 2025-05-30 18:12:00.757 | INFO | __main__:evaluate_local_ivf_recall:107 - 召回率评估完成。 2025-05-30 18:12:00.759 | INFO | __main__:<module>:15 - IVF索引召回率评估结果: {1: 0.9992, 10: 0.99632, 50: 0.987104, 100: 0.971576}
步骤 2: Driver 端训练全局粗量化器¶
“粗量化器”(coarse quantizer,本质上是 IndexIVFFlat
的 K-Means 质心)需要被训练。为了确保所有 Worker 使用相同的数据“地图”,我们在 Spark Driver 端使用 Dataset B 的一个样本来集中训练这个量化器。
函数 train_global_coarse_quantizer_on_driver
:
- 将 Dataset B 的一部分数据采样到 Driver 端。
- 使用这些样本训练一个临时的
IndexIVFFlat
。 - 提取并序列化训练好的粗量化器(其中包含聚类质心)。
logger.info(40 * "=")
logger.info(" ** 步骤 2: Driver端训练全局粗量化器 **")
logger.info(40 * "=")
def train_global_coarse_quantizer_on_driver(
dataset: pd.DataFrame, # PySpark DataFrame for datasetB
d_dimension: int = 24,
nlist_global: int = 50, # 全局的nlist
metric_type_global: int = 0,
training_sample_fraction: float = 0.2,
feature_col_name: str = "features"
):
"""
在Driver端训练用于Faiss IndexIVFFlat索引的全局粗量化器(coarse quantizer)。
此过程包括从datasetB中采样数据、将特征向量转换为NumPy数组、使用指定参数创建并训练一个临时IndexIVFFlat,
然后提取并序列化其内部的粗量化器以供后续在所有Worker上共享。
参数:
dataset (pd.DataFrame): 输入数据集 datasetB,包含用于训练粗量化器的特征向量。
d_dimension (int): 特征向量的维度,默认值为24。
nlist_global (int): Faiss IVF索引的聚类中心数(即nlist),也是粗量化器训练的目标数量,默认值为50。
metric_type_global (int): 距离度量类型,如 faiss.METRIC_INNER_PRODUCT 或 faiss.METRIC_L2,默认为0。
training_sample_fraction (float): 从datasetB中采样的比例,用于控制训练数据量,默认为0.2。
feature_col_name (str): 数据集中特征列的名称,默认为"features"。
返回:
bytes: 序列化后的训练好的粗量化器对象,可用于广播并在各个Worker上构建一致的IVF索引结构。
异常:
ValueError: 如果采样后的训练数据为空或无效,则抛出异常。
"""
logger.info(f"开始在Driver端训练全局粗量化器 (nlist={nlist_global})...")
logger.info(f"从 datasetB 采样 {training_sample_fraction*100}% 的数据用于训练...")
# 采样数据到Driver端
# 注意:如果采样比例过大或datasetB本身巨大,toPandas() 可能导致Driver OOM
# 需要根据实际情况调整采样比例,确保Driver能够处理。
sample_pd_df = dataset.select(feature_col_name)\
.sample(False, training_sample_fraction)\
.toPandas()
if sample_pd_df.empty:
raise ValueError("为Faiss粗量化器训练的样本数据为空!请检查采样比例或源数据。")
# 将Pandas Series中的Spark Vector转换为NumPy数组
training_vectors_list = []
for vec_obj in sample_pd_df[feature_col_name]:
if vec_obj is not None:
if isinstance(vec_obj, Vector):
training_vectors_list.append(vec_obj.toArray())
elif isinstance(vec_obj, (list, tuple)): # 如果已经是list (不太可能从DataFrame直接是list)
training_vectors_list.append(vec_obj)
if not training_vectors_list:
raise ValueError("转换后的训练向量列表为空!")
training_vectors_np = np.array(training_vectors_list).astype('float32')
if training_vectors_np.shape[0] < nlist_global:
logger.warning(f"警告: 训练样本数 ({training_vectors_np.shape[0]}) 小于 NLIST_GLOBAL ({nlist_global})。"
"Faiss IVF训练质量可能受影响或失败。建议增加采样数据或减小NLIST_GLOBAL。")
if training_vectors_np.shape[0] == 0:
raise ValueError("没有有效的训练向量来训练粗量化器。")
logger.info(f"使用 {training_vectors_np.shape[0]} 个向量 (维度={d_dimension}) 训练粗量化器...")
# 1. 创建将用于IndexIVFFlat的粗量化器 (例如IndexFlatIP)
coarse_quantizer = faiss.IndexFlatIP(d_dimension) if metric_type_global == faiss.METRIC_INNER_PRODUCT else faiss.IndexFlatL2(d_dimension)
# 2. 创建一个临时的IndexIVFFlat用于训练其内部的粗量化器(即获取centroids)
# 这个临时的IVF索引的nlist必须与我们最终期望的全局nlist一致。
temp_ivf_for_training = faiss.IndexIVFFlat(coarse_quantizer, d_dimension, nlist_global, metric_type_global)
temp_ivf_for_training.train(training_vectors_np)
logger.info("全局粗量化器训练完毕。")
# 3. 提取并序列化这个训练好的粗量化器 (它现在是 temp_ivf_for_training.quantizer)
# 这个quantizer对象包含了训练好的聚类中心。
serialized_trained_coarse_quantizer = faiss.serialize_index(temp_ivf_for_training.quantizer)
return serialized_trained_coarse_quantizer
2025-05-30 18:12:03.470 | INFO | __main__:<module>:2 - ======================================== 2025-05-30 18:12:03.472 | INFO | __main__:<module>:3 - ** 步骤 2: Driver端训练全局粗量化器 ** 2025-05-30 18:12:03.473 | INFO | __main__:<module>:4 - ========================================
serialized_global_quantizer = None
try:
serialized_global_quantizer = train_global_coarse_quantizer_on_driver(
dataset=datasetB,
d_dimension=args.dimension,
nlist_global=args.nlist_for_ivf,
metric_type_global=args.faiss_metric_type,
training_sample_fraction=args.training_sample_fraction,
feature_col_name=args.feature_col_name
)
except Exception as e_train:
logger.error(f"Driver端训练全局粗量化器失败: {e_train}")
2025-05-30 18:12:04.508 | INFO | __main__:train_global_coarse_quantizer_on_driver:34 - 开始在Driver端训练全局粗量化器 (nlist=1024)... 2025-05-30 18:12:04.509 | INFO | __main__:train_global_coarse_quantizer_on_driver:35 - 从 datasetB 采样 20.0% 的数据用于训练... 2025-05-30 18:12:04.509 | INFO | __main__:train_global_coarse_quantizer_on_driver:35 - 从 datasetB 采样 20.0% 的数据用于训练... 2025-05-30 18:12:15.074 | INFO | __main__:train_global_coarse_quantizer_on_driver:67 - 使用 68820 个向量 (维度=24) 训练粗量化器... 2025-05-30 18:12:16.506 | INFO | __main__:train_global_coarse_quantizer_on_driver:77 - 全局粗量化器训练完毕。
步骤 3: 广播序列化的全局粗量化器¶
训练好并序列化后的粗量化器随后被广播到所有 Spark Worker 节点。这使得每个 Worker 都可以使用相同的质心集合来初始化其本地的 IndexIVFFlat
。
logger.info(40 * "=")
logger.info(" ** 步骤 3: 广播序列化的全局粗量化器 **")
logger.info(40 * "=")
broadcasted_quantizer_ser_bytes = spark.sparkContext.broadcast(serialized_global_quantizer)
logger.info("全局粗量化器已广播。")
2025-05-30 18:12:19.141 | INFO | __main__:<module>:2 - ======================================== 2025-05-30 18:12:19.143 | INFO | __main__:<module>:3 - ** 步骤 3: 广播序列化的全局粗量化器 ** 2025-05-30 18:12:19.145 | INFO | __main__:<module>:4 - ======================================== 2025-05-30 18:12:19.151 | INFO | __main__:<module>:7 - 全局粗量化器已广播。
步骤 4: Worker 端填充部分 IVF 索引¶
数据集 B 被重新分区,每个分区由一个 Worker 处理。
函数 worker_populate_partial_ivf_with_global_quantizer
在每个 Worker 上运行:
- 反序列化广播过来的全局粗量化器。
- 使用此量化器初始化一个本地的
IndexIVFFlat
。关键的一点是,它将此索引的量化器部分标记为is_trained = True
,因为量化器部分已经训练完毕。 - 将其分配到的 Dataset B 分区中的向量(及其对应的 UID)添加到这个本地索引中。
- 序列化这个已填充数据的本地索引,并将其发送回 Driver。
logger.info(40 * "=")
logger.info("** 步骤 4: Worker填充部分索引**")
logger.info(40 * "=")
def worker_populate_partial_ivf_with_global_quantizer(
p_idx: int,
iterator: iter,
broadcasted_ser_global_quantizer_bytes: bytes, # 序列化后的全局粗量化器
d_dimension: int,
nlist_global: int, # 全局的nlist,用于初始化IVF结构
metric_type: int,
feature_col_name: str,
uid_col_name: str
):
"""
将数据并行划分到Worker上运行,在每个worker上填充部分IVF索引。
参数:
p_idx: 当前Worker的索引
iterator: 迭代器,包含每个Worker的数据
broadcasted_ser_global_quantizer_bytes: 广播的全局粗量化器
d_dimension: 特征向量的维度
nlist_global: 全局的nlist,用于初始化IVF结构
metric_type: 索引的度量类型
feature_col_name: 特征列的名称
uid_col_name: 用户ID列的名称
返回:
序列化后的IVF索引
"""
import sys
import faiss
import numpy as np
from pyspark.ml.linalg import Vector
# --- 数据提取和NumPy转换 (与您之前的 worker_test_phase6_full 相同) ---
local_vectors_list = []
local_uids_int64_list = []
for row in iterator:
if row[feature_col_name] is not None and row[uid_col_name] is not None:
try:
uid_int64 = np.int64(str(row[uid_col_name]))
feature_data = row[feature_col_name]
current_vector_py_list = None
if isinstance(feature_data, Vector):
if feature_data.size == d_dimension:
current_vector_py_list = feature_data.toArray().tolist()
elif isinstance(feature_data, (list, tuple)):
if len(feature_data) == d_dimension:
temp_list = []
valid_list = True
for x in feature_data:
try:
temp_list.append(float(x))
except:
valid_list = False
break
if valid_list:
current_vector_py_list = temp_list
if current_vector_py_list:
local_uids_int64_list.append(uid_int64)
local_vectors_list.append(current_vector_py_list)
except ValueError: pass
except Exception: pass
if not local_vectors_list:
return []
local_vectors_np = np.array(local_vectors_list).astype('float32')
local_uids_np_int64 = np.array(local_uids_int64_list)
if local_vectors_np.ndim == 1:
if local_vectors_np.shape[0] == d_dimension and d_dimension > 0:
local_vectors_np = local_vectors_np.reshape(1, -1)
else:
return []
if local_vectors_np.shape[0] == 0 or local_vectors_np.shape[0] != local_uids_np_int64.shape[0]:
return []
# --- 数据提取和NumPy转换部分结束 ---
try:
# 1. Worker反序列化广播过来的全局粗量化器
global_coarse_quantizer_on_worker = faiss.deserialize_index(broadcasted_ser_global_quantizer_bytes)
# 2. 创建本地IndexIVFFlat,使用全局训练的粗量化器
# nlist_global 必须与训练全局粗量化器时使用的nlist一致
local_ivf_index = faiss.IndexIVFFlat(
global_coarse_quantizer_on_worker,
d_dimension,
nlist_global, # 使用全局nlist
metric_type
)
# 3. **关键:标记此索引的粗量化器部分为已训练**
local_ivf_index.is_trained = True
# 注意:如果将来使用IndexIVFPQ,PQ部分也需要类似处理(广播训练好的PQ,或在worker训练局部PQ——后者不推荐)
# 4. 将当前分区的向量添加到这个本地IndexIVFFlat中
# 由于粗量化器是全局的,这些向量会被分配到全局一致的聚类单元中
if local_vectors_np.shape[0] > 0: # 确保有数据才添加
local_ivf_index.add_with_ids(local_vectors_np, local_uids_np_int64)
# 5. 序列化这个填充了数据的本地索引
# 它现在包含了基于全局聚类中心的倒排列表数据
serialized_bytes_to_return = faiss.serialize_index(local_ivf_index)
ntotal_to_return = local_ivf_index.ntotal
return [(p_idx, serialized_bytes_to_return, ntotal_to_return)]
except Exception as e_faiss_worker_op:
sys.stderr.write(f"P{p_idx}: Error in worker_populate_partial_ivf: {e_faiss_worker_op}\n")
return []
2025-05-30 18:12:21.898 | INFO | __main__:<module>:2 - ======================================== 2025-05-30 18:12:21.899 | INFO | __main__:<module>:3 - ** 步骤 4: Worker填充部分索引** 2025-05-30 18:12:21.900 | INFO | __main__:<module>:4 - ========================================
logger.info(f"将在 {args.num_partitions_for_index} 个分区上并行填充局部IndexIVFFlat...")
try:
collected_worker_outputs = datasetB.rdd.repartition(args.num_partitions_for_index).mapPartitionsWithIndex(
lambda p_idx, iterator: worker_populate_partial_ivf_with_global_quantizer(
p_idx, iterator,
broadcasted_quantizer_ser_bytes.value, # 使用广播的序列化全局粗量化器
args.dimension,
args.nlist_for_ivf,
args.faiss_metric_type,
args.feature_col_name,
args.uid_col_name
)
).collect() # This collect will trigger the serialization of results
logger.info(f"成功从 {len(collected_worker_outputs)} 个分区收集了局部填充的索引信息。")
if collected_worker_outputs:
logger.info(f"Sample result from Phase 6: {collected_worker_outputs}")
except Exception as e:
logger.error(f"Error in worker_populate_partial_ivf_with_global_quantizer: {e}")
import traceback
traceback.print_exc()
2025-05-30 18:12:23.088 | INFO | __main__:<module>:2 - 将在 5 个分区上并行填充局部IndexIVFFlat... 2025-05-30 18:12:37.944 | INFO | __main__:<module>:15 - 成功从 5 个分区收集了局部填充的索引信息。 2025-05-30 18:12:37.946 | INFO | __main__:<module>:17 - Sample result from Phase 6: [(0, array([ 73, 119, 70, ..., 0, 0, 0], dtype=uint8), 69159), (1, array([ 73, 119, 70, ..., 0, 0, 0], dtype=uint8), 69145), (2, array([ 73, 119, 70, ..., 0, 0, 0], dtype=uint8), 69162), (3, array([ 73, 119, 70, ..., 0, 0, 0], dtype=uint8), 69127), (4, array([ 73, 119, 70, ..., 0, 0, 0], dtype=uint8), 69161)]
步骤 5: Driver 端合并部分索引¶
Driver 从 Worker 收集所有序列化后的部分索引。
函数 merge_worker_indexes_on_driver
:
- 在 Driver 端初始化一个新的“空”
IndexIVFFlat
,使用相同的全局粗量化器。 - 遍历收集到的序列化部分索引。
- 对于每个部分索引,将其反序列化,并使用 Faiss 的
merge_from
方法将其内容合并到 Driver 上的主索引中。
最终得到一个单一的、最终的 IndexIVFFlat
,它包含了 Dataset B 中的所有向量,并且是基于全局一致的质心集合构建的。
logger.info(40 * "=")
logger.info("** 步骤 5: Driver端合并部分索引**")
logger.info(40 * "=")
def merge_worker_indexes_on_driver(collected_worker_outputs, broadcasted_quantizer_ser_bytes, args):
"""
在Driver端合并所有从Worker收集到的部分Faiss IndexIVFFlat索引。
参数:
collected_worker_outputs (list): 从各个Worker收集到的序列化索引及其元数据。
每个元素形如 (partition_id, serialized_index_bytes, num_vectors_in_part)。
broadcasted_quantizer_ser_bytes: 广播的全局粗量化器(已训练)的序列化字节。
args: 包含参数的对象,至少应包含以下属性:
- dimension (int): 向量维度。
- nlist_for_ivf (int): IVF中聚类中心的数量。
- faiss_metric_type (int): Faiss距离度量类型,例如faiss.METRIC_INNER_PRODUCT或faiss.METRIC_L2。
返回:
final_merged_ivf_index (faiss.IndexIVFFlat or None): 合并后的最终Faiss IVF索引对象;
如果没有有效分区则返回None。
"""
import faiss
if not collected_worker_outputs:
return None
logger.info("开始在Driver端合并所有局部填充的IndexIVFFlat...")
# 1. 反序列化广播的全局粗量化器
final_coarse_quantizer_on_driver = faiss.deserialize_index(broadcasted_quantizer_ser_bytes.value)
# 2. 创建最终合并索引的“壳”
final_merged_ivf_index = faiss.IndexIVFFlat(
final_coarse_quantizer_on_driver,
args.dimension,
args.nlist_for_ivf,
args.faiss_metric_type
)
final_merged_ivf_index.is_trained = True # 因为粗量化器已训练
# 3. 逐个合并来自worker的、已填充的、序列化的局部IVF索引
for part_id, ser_idx_bytes, num_vectors_in_part in collected_worker_outputs:
if num_vectors_in_part == 0 or ser_idx_bytes is None:
logger.info(f"跳过分区 {part_id},无有效数据。")
continue
try:
partial_index_obj = faiss.deserialize_index(ser_idx_bytes)
logger.info(f"Driver: 准备合并分区 {part_id} (含 {num_vectors_in_part} 向量)...")
# 调用 merge_from。第二个参数 '0' 表示不偏移已存入的全局唯一 int64 UID。
final_merged_ivf_index.merge_from(partial_index_obj, 0)
# Faiss merge_from会正确处理ntotal的更新。
logger.info(f"Driver: 合并分区 {part_id} 后,主索引总向量数: {final_merged_ivf_index.ntotal}")
except Exception as e_merge_final:
logger.error(f"Driver: 合并分区 {part_id} 的索引时出错: {e_merge_final}")
return final_merged_ivf_index
2025-05-30 18:12:42.666 | INFO | __main__:<module>:2 - ======================================== 2025-05-30 18:12:42.668 | INFO | __main__:<module>:3 - ** 步骤 5: Driver端合并部分索引** 2025-05-30 18:12:42.669 | INFO | __main__:<module>:4 - ========================================
final_index = merge_worker_indexes_on_driver(
collected_worker_outputs=collected_worker_outputs,
broadcasted_quantizer_ser_bytes=broadcasted_quantizer_ser_bytes,
args=args
)
if final_index:
logger.info(f"合并完成,最终索引包含 {final_index.ntotal} 个向量。")
else:
logger.warning("未成功合并任何索引。")
2025-05-30 18:12:43.866 | INFO | __main__:merge_worker_indexes_on_driver:28 - 开始在Driver端合并所有局部填充的IndexIVFFlat... 2025-05-30 18:12:43.875 | INFO | __main__:merge_worker_indexes_on_driver:50 - Driver: 准备合并分区 0 (含 69159 向量)... 2025-05-30 18:12:43.880 | INFO | __main__:merge_worker_indexes_on_driver:56 - Driver: 合并分区 0 后,主索引总向量数: 69159 2025-05-30 18:12:43.889 | INFO | __main__:merge_worker_indexes_on_driver:50 - Driver: 准备合并分区 1 (含 69145 向量)... 2025-05-30 18:12:43.893 | INFO | __main__:merge_worker_indexes_on_driver:56 - Driver: 合并分区 1 后,主索引总向量数: 138304 2025-05-30 18:12:43.903 | INFO | __main__:merge_worker_indexes_on_driver:50 - Driver: 准备合并分区 2 (含 69162 向量)... 2025-05-30 18:12:43.908 | INFO | __main__:merge_worker_indexes_on_driver:56 - Driver: 合并分区 2 后,主索引总向量数: 207466 2025-05-30 18:12:43.917 | INFO | __main__:merge_worker_indexes_on_driver:50 - Driver: 准备合并分区 3 (含 69127 向量)... 2025-05-30 18:12:43.921 | INFO | __main__:merge_worker_indexes_on_driver:56 - Driver: 合并分区 3 后,主索引总向量数: 276593 2025-05-30 18:12:43.931 | INFO | __main__:merge_worker_indexes_on_driver:50 - Driver: 准备合并分区 4 (含 69161 向量)... 2025-05-30 18:12:43.936 | INFO | __main__:merge_worker_indexes_on_driver:56 - Driver: 合并分区 4 后,主索引总向量数: 345754 2025-05-30 18:12:43.939 | INFO | __main__:<module>:8 - 合并完成,最终索引包含 345754 个向量。
步骤 6: 分布式检索¶
现在我们有了 Dataset B 的完整 final_index
,可以为 Dataset A 中的用户执行相似度搜索了。
final_index
被序列化并广播到所有 Spark Worker。- 数据集 A 被重新分区。
函数 distributed_search_worker
在每个 Worker 上运行:
- 反序列化广播过来的
final_index
。 - 在此索引上设置用于搜索的
nprobe
参数。 - 对于其分配到的 Dataset A 分区中的每个用户向量,它查询
final_index
以找到来自 Dataset B 的 TopK 个相似用户 UID 及其距离。 - 返回一个包含 (uid_a, uid_b, distance, rank) 元组的列表。
logger.info(40 * "=")
logger.info("** 步骤 6: 进行分布式检索**")
logger.info(40 * "=")
def distributed_search_worker(
partition_id: int,
iterator_rows,
broadcasted_index_bytes: bytes, # 序列化后的全局Faiss索引
D_DIMENSION: int,
NPROBE_SEARCH: int, # 搜索时使用的nprobe
TOP_K_SEARCH: int, # 检索Top-K
feature_col_name: str,
uid_col_name: str
):
"""
在Worker端对 datasetA 的一个分区中的向量执行Faiss搜索,查询训练好的 final_index(来自datasetB)。
参数:
partition_id (int): 分区ID(日志用)
iterator_rows (iter): datasetA 的当前分区数据行迭代器
broadcasted_index_bytes (bytes): 广播的序列化最终Faiss索引
D_DIMENSION (int): 向量维度
NPROBE_SEARCH (int): IVF搜索的nprobe值
TOP_K_SEARCH (int): 返回Top-K个结果
feature_col_name (str): datasetA 中特征列名
uid_col_name (str): datasetA 中UID列名
返回:
List[Tuple[str, str, float]]: 每个元素为 (uid_a, uid_b, distance)
"""
import sys
import faiss
import numpy as np
from pyspark.ml.linalg import Vector
results = []
try:
# 反序列化广播过来的索引(每个worker只做一次)
index_worker = faiss.deserialize_index(broadcasted_index_bytes)
if hasattr(index_worker, 'nprobe') and NPROBE_SEARCH > 0:
index_worker.nprobe = NPROBE_SEARCH
except Exception as e_load:
sys.stderr.write(f"Partition {partition_id}: 反序列化Faiss索引失败: {e_load}\n")
return []
sys.stdout.write(f"Partition {partition_id}: 反序列化Faiss索引成功,nprobe={index_worker.nprobe}\n")
query_vectors_list = []
query_uids_list = []
# 提取本分区所有有效的A向量和UID
for row in iterator_rows:
if row[feature_col_name] is not None and row[uid_col_name] is not None:
try:
feature_data = row[feature_col_name]
current_vector_list = None
if isinstance(feature_data, Vector):
if feature_data.size == D_DIMENSION:
current_vector_list = feature_data.toArray().tolist()
elif isinstance(feature_data, (list, tuple)):
if len(feature_data) == D_DIMENSION:
try: current_vector_list = [float(x) for x in feature_data]
except: pass
if current_vector_list:
query_uids_list.append(str(row[uid_col_name]))
query_vectors_list.append(current_vector_list)
except Exception as e_row:
sys.stderr.write(f"Partition {partition_id}: 处理某行时出错: {e_row}\n")
if not query_vectors_list:
return []
xq = np.array(query_vectors_list).astype('float32')
if xq.ndim == 1:
if xq.shape[0] == D_DIMENSION and D_DIMENSION > 0:
xq = xq.reshape(1, -1)
else:
return []
if xq.shape[0] == 0:
return []
sys.stdout.write(f"Partition {partition_id}: 搜索向量数: {xq.shape[0]}\n")
# 执行批量搜索
try:
D_results, I_results = index_worker.search(xq, TOP_K_SEARCH)
except Exception as e_search:
sys.stderr.write(f"Partition {partition_id}: Faiss搜索出错: {e_search}\n")
return []
# 转换结果为(uid_a, uid_b, score, rank)
for i in range(xq.shape[0]):
uid_a = query_uids_list[i]
for k_idx in range(TOP_K_SEARCH):
uid_b_int64 = I_results[i][k_idx]
if uid_b_int64 == -1:
break
distance_score = float(D_results[i][k_idx])
results.append((uid_a, str(uid_b_int64), distance_score, k_idx))
sys.stdout.write(f"Partition {partition_id}: 搜索完成,返回结果数: {len(results)}\n")
return results
2025-05-30 18:12:46.975 | INFO | __main__:<module>:2 - ======================================== 2025-05-30 18:12:46.976 | INFO | __main__:<module>:3 - ** 步骤 6: 进行分布式检索** 2025-05-30 18:12:46.977 | INFO | __main__:<module>:4 - ========================================
logger.info("开始将最终索引序列化并广播给所有Worker...")
serialized_final_index_bytes = faiss.serialize_index(final_index)
broadcasted_index = spark.sparkContext.broadcast(serialized_final_index_bytes)
try:
logger.info(f"开始对 datasetA 的所有分区执行分布式检索 ...")
knn_pairs_rdd = datasetA.rdd.repartition(args.num_partitions_for_match).mapPartitionsWithIndex(
lambda p_idx, iterator: distributed_search_worker(
partition_id=p_idx,
iterator_rows=iterator,
broadcasted_index_bytes=broadcasted_index.value,
D_DIMENSION=args.dimension,
NPROBE_SEARCH=args.nprobe_for_ivf,
TOP_K_SEARCH=args.topK,
feature_col_name=args.feature_col_name,
uid_col_name=args.uid_col_name
)
)
knn_pairs_schema = StructType([
StructField("uid_a", StringType(), True),
StructField("uid_b", StringType(), True),
StructField("distance", FloatType(), True),
StructField("rank", IntegerType(), True),
])
knn_pairs_df = spark.createDataFrame(knn_pairs_rdd, schema=knn_pairs_schema)
logger.info(f"分布式检索完成")
knn_pairs_df.show(50, truncate=False)
except Exception as e_dist_search:
logger.error(f"执行分布式检索时出错: {e_dist_search}")
import traceback
traceback.print_exc()
2025-05-30 18:12:48.342 | INFO | __main__:<module>:2 - 开始将最终索引序列化并广播给所有Worker... 2025-05-30 18:12:48.582 | INFO | __main__:<module>:7 - 开始对 datasetA 的所有分区执行分布式检索 ... 2025-05-30 18:12:48.698 | INFO | __main__:<module>:28 - 分布式检索完成
+----------+----------+----------+----+ |uid_a |uid_b |distance |rank| +----------+----------+----------+----+ |7405360310|6015969270|0.9592322 |0 | |7405360310|7811785930|0.9353308 |1 | |7405360310|7811168410|0.9277942 |2 | |7405360310|6412357340|0.9266744 |3 | |7405360310|7810441900|0.92530954|4 | |7405360310|6517782970|0.9237408 |5 | |7405360310|7913581060|0.92291665|6 | |7405360310|7810159120|0.91983354|7 | |7405360310|7410549450|0.9191794 |8 | |7405360310|7413725220|0.915444 |9 | |7405360310|7216370090|0.9144743 |10 | |7405360310|7412415580|0.91367227|11 | |7405360310|6616963530|0.91232824|12 | |7405360310|5914261830|0.9116912 |13 | |7405360310|7315083990|0.91028804|14 | |7405360310|7916618590|0.91025925|15 | |7405360310|6919813820|0.9101951 |16 | |7405360310|2719533260|0.90888333|17 | |7405360310|7911630640|0.90748376|18 | |7405360310|6618833940|0.9074596 |19 | |7405360310|7813426290|0.90602654|20 | |7405360310|7412083140|0.9059177 |21 | |7405360310|7718520790|0.9057053 |22 | |7405360310|7315530320|0.9056492 |23 | |7405360310|5616194500|0.90551794|24 | |7405360310|7211852130|0.9041486 |25 | |7405360310|7914652460|0.90381217|26 | |7405360310|7718600980|0.9037398 |27 | |7405360310|7819706260|0.9023258 |28 | |7405360310|5717773090|0.90217376|29 | |7405360310|7911507180|0.9016013 |30 | |7405360310|7716563560|0.90149474|31 | |7405360310|7811007510|0.9010081 |32 | |7405360310|7913502200|0.9009439 |33 | |7405360310|6117511260|0.90038925|34 | |7405360310|7917898920|0.90028185|35 | |7405360310|7810609130|0.9001176 |36 | |7405360310|7910967690|0.90003 |37 | |7405360310|7810844680|0.89974564|38 | |7405360310|7510759490|0.89876306|39 | |7405360310|7913173720|0.8986522 |40 | |7405360310|1913677540|0.8986089 |41 | |7405360310|7915411520|0.89806074|42 | |7405360310|7919242310|0.89804214|43 | |7405360310|7415795400|0.89790463|44 | |7405360310|6914326140|0.8973903 |45 | |7405360310|7815709790|0.8972927 |46 | |7405360310|7811828780|0.8966652 |47 | |7405360310|7418111170|0.89505917|48 | |7405360310|6212225650|0.894961 |49 | +----------+----------+----------+----+ only showing top 50 rows
步骤 7: 将 KNN 结果输出到 Hive 表¶
最终得到的相似用户对 DataFrame (knn_pairs_df
) 被写入到一个 Hive 表中,该表按日期和 uid_a
的一部分进行分区,以便于查询和管理。结果会根据 distance_cutoff
进行过滤。
logger.info(40 * "=")
logger.info("** 步骤 7: 将KNN结果输出到hive表 **")
logger.info(40 * "=")
logger.info(f"创建KNN结果表 {args.knn_table}")
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {args.knn_table} (
uid_a STRING COMMENT '用户A的UID',
uid_b STRING COMMENT '用户B的UID',
distance FLOAT COMMENT '用户A和用户B之间的距离',
rank INT COMMENT '用户A和用户B之间的距离的排名'
)
PARTITIONED BY (
dt string COMMENT '日期分区',
pt string COMMENT '数据分区(默认用户尾号分区,0-9)'
)
STORED AS PARQUET
""")
logger.info(f"开始将KNN结果写入表 {args.knn_table}")
knn_pairs_df.createOrReplaceTempView("knn_pairs_df")
spark.sql(f"""
INSERT OVERWRITE TABLE {args.knn_table} PARTITION (dt='{args.knn_dt}', pt)
SELECT uid_a, uid_b, distance, rank, substr(uid_a, -1, 1) AS pt
FROM knn_pairs_df
WHERE distance > {args.distance_cutoff}
""")
2025-05-30 18:13:11.111 | INFO | __main__:<module>:2 - ======================================== 2025-05-30 18:13:11.112 | INFO | __main__:<module>:3 - ** 步骤 7: 将KNN结果输出到hive表 ** 2025-05-30 18:13:11.114 | INFO | __main__:<module>:4 - ======================================== 2025-05-30 18:13:11.115 | INFO | __main__:<module>:5 - 创建KNN结果表 bigdata_vf_long_inte_user_als_knn 2025-05-30 18:13:11.163 | INFO | __main__:<module>:19 - 开始将KNN结果写入表 bigdata_vf_long_inte_user_als_knn
DataFrame[]
logger.info(f"查看KNN结果表 {args.knn_table}")
spark.sql(f"""
SELECT * FROM {args.knn_table} WHERE dt='{args.knn_dt}' AND pt='0' ORDER BY uid_a, rank LIMIT 100
""").show()
2025-05-30 18:15:48.721 | INFO | __main__:<module>:2 - 查看KNN结果表 bigdata_vf_long_inte_user_als_knn
+----------+----------+----------+----+--------+---+ | uid_a| uid_b| distance|rank| dt| pt| +----------+----------+----------+----+--------+---+ |1000000820|6516198400|0.99393636| 0|20250315| 0| |1000000820|6318273430| 0.9929233| 1|20250315| 0| |1000000820|6012572690| 0.9925984| 2|20250315| 0| |1000000820|7316930610|0.99214613| 3|20250315| 0| |1000000820|3212350630|0.98974764| 4|20250315| 0| |1000000820|1919647980| 0.9895695| 5|20250315| 0| |1000000820|7311782830| 0.9890783| 6|20250315| 0| |1000000820|6219967230|0.98898816| 7|20250315| 0| |1000000820|5515029330| 0.9886268| 8|20250315| 0| |1000000820|7318415650|0.98857033| 9|20250315| 0| |1000000820|7412586940| 0.9883963| 10|20250315| 0| |1000000820|7412719450| 0.9880947| 11|20250315| 0| |1000000820|7317085670| 0.9880854| 12|20250315| 0| |1000000820|5812558820| 0.9871978| 13|20250315| 0| |1000000820|6613478540| 0.9864643| 14|20250315| 0| |1000000820|5719920840| 0.9864158| 15|20250315| 0| |1000000820|2813458000|0.98621565| 16|20250315| 0| |1000000820|2115298370| 0.9860913| 17|20250315| 0| |1000000820|6214402040|0.98598206| 18|20250315| 0| |1000000820|7415923620| 0.9856943| 19|20250315| 0| +----------+----------+----------+----+--------+---+ only showing top 20 rows
结论¶
通过将 Faiss 的高性能向量搜索与 Spark 的分布式计算框架相结合,我们为海量数据集中的 TopK 相似用户查找创建了一个可扩展的解决方案。这种方法克服了单机 Faiss 的内存限制以及纯 Spark LSH 实现的性能/调优挑战。 经过线上测试,原先LSH需要运行10小时的任务,现在只需要运行30分钟,提升为20倍。 关键步骤包括:
- 集中训练全局 Faiss 量化器。
- 在 Worker 上使用全局量化器分布式构建部分 Faiss 索引。
- 在 Driver 端将部分索引合并为最终的全局索引。
- 广播最终索引,供 Worker 进行分布式搜索。
该方法为实现大规模相似度搜索提供了一种稳健且高效的途径。
提交代码¶
$SPARK_HOME/bin/spark-submit \
--master yarn \
--deploy-mode cluster \
--conf "spark.app.name=${app_name}" \
--num-executors 50 \
--executor-cores 4 \
--conf spark.dynamicAllocation.maxExecutors=300 \
--driver-memory 10g \
--conf spark.driver.memoryOverhead=10g \
--executor-memory 4g \
--conf spark.executor.memoryOverhead=14g \
--archives viewfs://c9/user_ext/weibo_bigdata_text/yandi/udf/Python-3.9.16-spark.zip#Python \
step_knn.py \
--datasetA_sql "select * from ${factor_table} where dt='${dt}'" \
--datasetB_sql "select * from ${factor_table} where dt='${dt}' and pt='0' and substr(uid, 3, 1) in (0, 1, 2, 3, 4)" \
--knn_table "${knn_table}" \
--knn_dt "${dt}" \
--distance_cutoff 0.9 \
--topK 50 \
--feature_col_name "features" \
--uid_col_name "uid" \
--dimension 24 \
--nlist_for_ivf 1024 \
--faiss_metric_type 0 \
--num_partitions_for_index 50 \
--num_partitions_for_match 8000 \
--training_sample_fraction 0.1 \
--nprobe_for_ivf 70