使用 Faiss 和 Spark 进行可扩展的用户相似度搜索(分布式索引)¶
🎯 目标:在海量数据中寻找相似用户¶
在许多推荐系统中,一个关键任务是找到与给定用户相似的其他用户。如果我们知道用户 A 与用户 B 相似,我们就可以将用户 B 喜欢过的物品推荐给用户 A(这是一种 "u2u" 或用户到用户的方法,进而可以引导出 "u2u2i" 或用户到用户到物品的推荐)。
在本项目中,用户兴趣被表示为24维的数值向量,称为 embedding。我们的任务是,从一个大规模用户集合(数据集 A,约2亿用户)中,为每个用户从另一个大规模用户集合(数据集 B,约1000万用户)中找到与他们最相似的 TopK 个用户。
😫 挑战:速度与规模¶
此前,我们曾尝试使用 Spark 的 BucketedRandomProjectionLSH
(Locality Sensitive Hashing,局部敏感哈希)。虽然 LSH 设计用于近似最近邻搜索,但它本身也带来了一些问题:
- 手动调参:
bucketLength
和numHashTables
等参数需要针对特定数据进行仔细调整。在召回率(找到足够多的相关相似用户)和计算成本之间找到合适的平衡点非常棘手。 - 运行时间长: 对于我们的数据集规模,该过程大约需要10小时,这对于实际应用、迭代开发或频繁更新来说太长了。
这些限制促使我们寻找一种更高效的解决方案。
✨ 解决方案:Faiss 与 Spark 的结合¶
Faiss 由 Facebook AI 开发,是一个用于密集向量高效相似度搜索和聚类的库。它速度极快,但传统上在单机上运行。
- 单机 Faiss 的局限性:
- 数据集 B(1000万用户,每个用户一个24维向量)可能过大,难以构建一个能完全加载到单机内存中的索引。
- 更关键的是,数据集 A(2亿用户)规模太大,无法高效地从单机执行搜索。我们不能简单地加载所有2亿向量然后逐个查询。
为了克服这些限制,我们将 Faiss 的强大功能与 Apache Spark 的分布式计算能力相结合。核心思想是:
- 分布式索引构建 (针对 Dataset B):
- 在 Spark Driver 端使用 Dataset B 的一个样本训练一个“粗量化器”(coarse quantizer,可以理解为数据结构的初步草图)。
- 将这个预训练好的量化器广播到所有 Spark Worker 节点。
- 对 Dataset B 进行分区。每个 Worker 为其分配到的 Dataset B 数据子集构建一个部分 Faiss 索引,这个过程会使用之前广播的量化器。
- 每个 Worker 将其部分索引,以Hive表的形式序列化到HDFS存储。
- 分布式搜索 (针对 Dataset A):
- 将HDFS存储的索引分片(Shards)地址,用
SparkContext.addFiles()
分发到所有Worker节点上。 - 在Worker中通过SparkFiles.get()获取本地目录路径,然后内部循环遍历所有本地索引文件。每次只有一个索引分片需要被加载进内存。
- 对于每个查询用户,我们使用一个Heap(最小堆)来实时维护Top-K结果,只产生很小的时间开销。
- 对 Dataset A 进行分区。每个 Worker 获取其分配到的 Dataset A 用户子集,收集搜索结果到Hive表中。
- 将HDFS存储的索引分片(Shards)地址,用
这种方法通过将索引构建和搜索过程都进行分布式处理,使我们能够处理海量数据集。
🤔 Faiss IndexIVFFlat
工作原理 (一个简单的类比)¶
我们使用的是 Faiss 的 IndexIVFFlat
。让我们用一个图书馆的类比来解析它的工作方式:
整理图书馆 (KMeans 聚类 - "IVF" 部分):
- 想象一下,数据集 B 是一个巨大的书籍收藏库(代表用户向量)。
IndexIVFFlat
首先尝试将它们组织起来。它使用一种聚类算法(如 K-Means)将相似的书籍分到nlist
个区域或“单元格”中。每个单元格由一个“质心”(centroid,该区域书籍的平均代表)来表示。这就是“粗量化”步骤。 - 当我们“训练”索引时,Faiss 实际上是在学习这些质心的位置。
- 想象一下,数据集 B 是一个巨大的书籍收藏库(代表用户向量)。
找到正确的过道 (搜索量化器):
- 现在,你带着一本来自数据集 A 的书(一个查询用户向量)来寻找数据集 B 中的相似书籍。
- Faiss 不会直接将你的书与图书馆中的每一本书进行比较,而是首先将其与所有
nlist
个区域的质心进行比较。 - 它会找出那些质心与你的书最接近的
nprobe
个区域。nprobe
是你设定的一个参数——代表你愿意搜索多少条“过道”。
在这些过道内搜索 ( "Flat" 部分):
- 现在,Faiss 只查看选定的
nprobe
个区域内的书籍。 - 在每个选定的区域内部,它执行一次穷举搜索或称为“Flat”搜索。它会仔细地将你的查询书籍(用户 A 向量)与这些区域中的每一本书(用户 B 向量)进行比较。
- 现在,Faiss 只查看选定的
相似度度量 (
METRIC_INNER_PRODUCT
):- 在比较书籍(向量)时,我们使用
METRIC_INNER_PRODUCT
(内积)。对于归一化向量(长度为1的向量),内积等同于余弦相似度。更高的内积意味着向量更对齐,因此更相似。
- 在比较书籍(向量)时,我们使用
这种两步过程(首先找到有希望的区域,然后仅在这些区域内进行详尽搜索)比直接将查询向量与数据集 B 中的每个向量进行比较要快得多。
🌊 整体工作流程¶
以下是我们在 Spark 上使用分布式 Faiss 方法的流程图:
现在,让我们深入代码实现。
0. 导入与设置¶
导入标准库、初始化 Spark 会话以及 Faiss。
import os
import sys
import argparse
import faiss
import pandas as pd
import numpy as np
import base64
from pyspark.sql import SparkSession
from pyspark.sql.types import FloatType, StringType, StructField, StructType
from pyspark.ml.linalg import Vector
from pyspark.sql.functions import lit, row_number
from pyspark.sql.window import Window
from faiss.contrib.evaluation import knn_intersection_measure
from loguru import logger
__doc__ = """
- 在其他步骤中生成用户向量,落地hive表中,含有`uid bigint`列和`features array<float>`列
- 对datasetA的每个用户,在datasetB中检索,用cosine_similarity度量,返回topK的相似结果
- 索引使用faiss的`IndexIVFFLat`,度量使用`Inner Product`
- 建立索引的过程和匹配过程均使用spark的分布式引擎
"""
# 配置Pandas显示选项
pd.options.display.max_colwidth = 500
def in_notebook():
"""
检查当前代码是否在Jupyter Notebook中运行
"""
try:
return get_ipython().__class__.__name__ == 'ZMQInteractiveShell'
logger.info("当前代码在Jupyter Notebook中运行")
except NameError:
return False
def get_spark():
"""
获取Spark会话,如果不存在则创建一个
"""
if 'spark' not in locals():
spark = SparkSession.builder \
.appName("ALS") \
.config("spark.sql.catalogImplementation", "hive") \
.config("spark.executorEnv.HADOOP_CONF_DIR", "/usr/local/hadoop-2.7.3/etc/hadoop") \
.enableHiveSupport() \
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
return spark
1. 配置参数 (Args
类)¶
所有可配置的参数都在这里定义。这包括用于数据集的 SQL 查询、输出表名以及 Faiss 特定的参数。
# ===============================
# 主函数,执行ALS模型训练和评估"
# ===============================
class Args:
# --- I/O and Data Parameters ---
datasetA_sql = "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (0) and rand() < 0.1"
datasetB_sql = "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (1)"
knn_table = "bigdata_vf_long_inte_user_als_knn"
knn_dt = "20250315"
feature_col_name = "features"
uid_col_name = "uid"
dimension = 24
# --- Faiss Index Construction ---
index_type = "IVFSQ8" # Options: "IVFFlat", "IVFSQ8", "IVFSQfp16", "IVFPQ", "Flat", "IVFPQR"
nlist_for_ivf = 1024
faiss_metric_type = 1 # 0: METRIC_INNER_PRODUCT, 1: METRIC_L2
# --- Faiss Preprocessing & Quantization ---
pre_normalize = False # 是否对特征向量进行预L2归一化
m_for_pq = 12 # m: 将向量分割为 m 段。dimension 必须能被 m 整除。24维可以被12整除,每段2维。压缩后12字节
nbits_for_pq = 8 # nbits_per_code=8, 表示每个子空间的码本大小为 2^8 = 256
d_for_pca = -1 # PCA矩阵的线性降维到D,通常设置D<dimension
# --- Search and Execution Parameters ---
topK = 50
distance_cutoff = 0.9
num_partitions_for_index = 5 # 索引阶段对datasetB进行分区,生成对应数量的分布式索引
num_partitions_for_match = 100 # 匹配阶段对datasetA进行分区
nprobe_for_ivf = 70 # 表示在搜索过程中搜索的候选分区项数
training_sample_fraction = 0.2 # 索引阶段对datasetB进行采样的样本比例,用于训练quantizer
args = Args()
if not in_notebook():
example = """ # 示例
python step_knn.py \\
--datasetA_sql "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (0) and rand() < 0.01" \\
--datasetB_sql "select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (1) and rand() < 0.01" \\
--knn_table "bigdata_vf_long_inte_user_als_knn" \\
--knn_dt "20250315" \\
--distance_cutoff 0.9 \\
--topK 50 \\
--feature_col_name "features" \\
--uid_col_name "uid" \\
--dimension 24 \\
--nlist_for_ivf 1024 \\
--faiss_metric_type 0 \\
--num_partitions_for_index 5 \\
--num_partitions_for_match 100 \\
--training_sample_fraction 0.2 \\
--nprobe_for_ivf 70 \\
--index_type "IVFSQ8" \\
--pre_normalize \\
--m_for_pq 12 \\
--nbits_for_pq 8 \\
--d_for_pca -1 \\
"""
parser = argparse.ArgumentParser(
description=__doc__,
epilog=example,
formatter_class=argparse.RawDescriptionHelpFormatter
)
# --- I/O and Data Parameters ---
io_group = parser.add_argument_group("I/O and Data Parameters")
io_group.add_argument("--datasetA_sql", type=str, default=args.datasetA_sql, help="[Input] 数据集A (查询集) 的SQL查询")
io_group.add_argument("--datasetB_sql", type=str, default=args.datasetB_sql, help="[Input] 数据集B (被索引集) 的SQL查询")
io_group.add_argument("--knn_table", type=str, default=args.knn_table, help="[Output] 输出KNN结果的表名")
io_group.add_argument("--knn_dt", type=str, default=args.knn_dt, help="[Output] 输出KNN结果表的日期分区")
io_group.add_argument("--feature_col_name", type=str, default=args.feature_col_name, help="特征向量的列名")
io_group.add_argument("--uid_col_name", type=str, default=args.uid_col_name, help="用户ID的列名")
io_group.add_argument("--dimension", type=int, default=args.dimension, help="特征向量的维度")
# --- Faiss Index Construction ---
index_group = parser.add_argument_group("Faiss Index Construction")
index_group.add_argument("--index_type", type=str, default=args.index_type,
choices=["IVFFlat", "IVFSQ8", "IVFSQfp16", "IVFPQ", "Flat", "IVFPQR"],
help="Faiss索引类型,决定了速度、内存和精度的权衡")
index_group.add_argument("--faiss_metric_type", type=int, default=args.faiss_metric_type, choices=[0, 1],
help="相似度度量类型 (0: METRIC_INNER_PRODUCT for cosine, 1: METRIC_L2 for euclidean)")
index_group.add_argument("--nlist_for_ivf", type=int, default=args.nlist_for_ivf,
help="IVF索引的聚类中心数 (nlist),影响索引的粒度")
# --- Faiss Preprocessing & Quantization ---
quant_group = parser.add_argument_group("Faiss Preprocessing & Quantization")
quant_group.add_argument("--pre_normalize", action='store_true',
help="是否对特征向量进行L2归一化。使用内积时推荐开启")
quant_group.add_argument("--d_for_pca", type=int, default=args.d_for_pca,
help="使用PCA将向量降维至此维度 (-1表示不使用)")
quant_group.add_argument("--m_for_pq", type=int, default=args.m_for_pq,
help="PQ: 将向量分割为m段 (dimension必须能被m整除)")
quant_group.add_argument("--nbits_for_pq", type=int, default=args.nbits_for_pq,
help="PQ: 每段的比特数,码本大小为 2^nbits")
# --- Search and Execution Parameters ---
exec_group = parser.add_argument_group("Search and Execution Parameters")
exec_group.add_argument("--topK", type=int, default=args.topK, help="为每个查询向量返回的Top-K结果数")
exec_group.add_argument("--distance_cutoff", type=float, default=args.distance_cutoff,
help="结果相似度/距离阈值,用于过滤最终结果")
exec_group.add_argument("--nprobe_for_ivf", type=int, default=args.nprobe_for_ivf,
help="IVF搜索时查询的倒排列表数,nprobe越大,召回越高,速度越慢")
exec_group.add_argument("--training_sample_fraction", type=float, default=args.training_sample_fraction,
help="用于训练索引(quantizer)的样本比例")
exec_group.add_argument("--num_partitions_for_index", type=int, default=args.num_partitions_for_index,
help="Spark: 构建索引阶段的分区数")
exec_group.add_argument("--num_partitions_for_match", type=int, default=args.num_partitions_for_match,
help="Spark: 匹配(搜索)阶段的分区数")
# Set default for the boolean flag based on Args class
parser.set_defaults(pre_normalize=args.pre_normalize)
args = parser.parse_args()
spark = get_spark()
spark
SparkSession - hive
2. 加载数据集¶
使用在参数中定义的 SQL 查询加载数据集 A(查询用户)和数据集 B(用于构建索引的用户)。
logger.info(f"加载数据集A: {args.datasetA_sql}")
datasetA = spark.sql(args.datasetA_sql)
datasetA.createOrReplaceTempView("datasetA")
2025-07-31 17:17:50.112 | INFO | __main__:<module>:2 - 加载数据集A: select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (0) and rand() < 0.1
logger.info(f"加载数据集B: {args.datasetB_sql}")
datasetB = spark.sql(args.datasetB_sql)
datasetB.createOrReplaceTempView("datasetB")
2025-07-31 17:17:54.752 | INFO | __main__:<module>:2 - 加载数据集B: select * from bigdata_vf_long_inte_user_als_factor where dt='20250315' and pt='0' and substr(uid,3,1) in (1)
步骤 0: 索引创建工厂函数¶
定义一个工厂函数 create_faiss_index
,用于创建 Faiss 索引。该函数接受数据集 B 的样本、量化器参数和其他配置参数。
目前支持的量化器类型包括:
Flat
:直接存储向量,不进行量化。IVFFlat
:使用倒排文件索引和扁平量化器。IVFPQ
:使用倒排文件索引和乘积量化器。IVFScalarQuantizer
:使用倒排文件索引和标量量化器。
此外,工厂函数也兼容了Faiss的预处理步骤,例如PCA和归一化。
logger.info(40 * "=")
logger.info("** 步骤 0: 索引创建工厂函数 **")
logger.info(40 * "=")
def create_faiss_index(index_type, dimension, nlist, metric_type,
pre_normalize=False, m_for_pq=None, nbits_for_pq=8,
d_for_pca=-1):
"""
根据指定的类型创建Faiss索引对象
"""
# =================================================================
# 步骤 1: 确定预处理链和计算最终维度
# =================================================================
transforms = []
current_dimension = dimension
# 首先应用PCA降维
if d_for_pca and d_for_pca > 0:
logger.info(f"Step 1.1: 添加 PCA 转换 (d={current_dimension} -> {d_for_pca})")
pca_transform = faiss.PCAMatrix(current_dimension, d_for_pca)
transforms.append(pca_transform)
current_dimension = d_for_pca
else:
logger.info("Step 1.1: 跳过 PCA 转换 (未指定 d_for_pca)。")
# 最后,应用归一化(通常在所有降维和旋转之后)
if pre_normalize:
logger.info(f"Step 1.2: 添加 L2 Normalization 转换 (d={current_dimension})")
norm_transform = faiss.NormalizationTransform(current_dimension)
transforms.append(norm_transform)
else:
logger.info("Step 1.2: 跳过 L2 Normalization 转换 (未指定 pre_normalize)。")
logger.info(f"预处理链构建完成。最终到达索引的向量维度为: {current_dimension}")
# =================================================================
# 步骤 2: 使用最终维度创建核心索引
# =================================================================
# 创建用于粗量化的quantizer,它的维度必须是最终维度
quantizer = faiss.IndexFlatIP(current_dimension) if metric_type == faiss.METRIC_INNER_PRODUCT else faiss.IndexFlatL2(current_dimension)
core_index = None # 核心索引
core_index_type = index_type
if core_index_type == "Flat":
core_index = faiss.IndexIDMap2(faiss.IndexFlat(current_dimension, metric_type))
bytes_per_vector = current_dimension * 4
elif core_index_type == "IVFFlat":
core_index = faiss.IndexIVFFlat(quantizer, current_dimension, nlist, metric_type)
bytes_per_vector = current_dimension * 4
elif core_index_type == "IVFSQ8":
core_index = faiss.IndexIVFScalarQuantizer(quantizer, current_dimension, nlist, faiss.ScalarQuantizer.QT_8bit, metric_type)
bytes_per_vector = current_dimension * 1
elif core_index_type == "IVFSQfp16":
core_index = faiss.IndexIVFScalarQuantizer(quantizer, current_dimension, nlist, faiss.ScalarQuantizer.QT_fp16, metric_type)
bytes_per_vector = current_dimension * 2
elif core_index_type == "IVFPQ":
if m_for_pq is None or current_dimension % m_for_pq != 0:
raise ValueError(f"对于 PQ, 最终维度 ({current_dimension}) 必须能被 m ({m_for_pq}) 整除。")
core_index = faiss.IndexIVFPQ(quantizer, current_dimension, nlist, m_for_pq, nbits_for_pq, metric_type)
bytes_per_vector = m_for_pq * (nbits_for_pq / 8)
elif core_index_type == "IVFPQR":
core_index = faiss.IndexIVFPQR(quantizer, current_dimension, nlist, m_for_pq, nbits_for_pq, m_for_pq, nbits_for_pq, metric_type)
else:
raise ValueError(f"Unsupported core_index_type: {core_index_type}")
logger.info(f"Step 2: 核心索引 '{core_index_type}' 已创建 (d={current_dimension})")
logger.info(f" -> 理论上,压缩后每个向量占用内存: {bytes_per_vector:.0f} 字节")
# =================================================================
# 步骤 3: 反向包裹预处理层,构建最终索引
# =================================================================
final_index = core_index
for transform in reversed(transforms):
final_index = faiss.IndexPreTransform(transform, final_index)
logger.info(f"Step 3: 预处理层包裹完成。最终索引构建成功!")
return final_index
2025-07-31 17:17:54.951 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:17:54.953 | INFO | __main__:<module>:3 - ** 步骤 0: 索引创建工厂函数 ** 2025-07-31 17:17:54.954 | INFO | __main__:<module>:4 - ========================================
步骤 1: (可选但推荐) Driver 端本地评估 Faiss IVF 召回率¶
在启动完整的分布式作业之前,最好在数据的样本上测试 Faiss 参数(如 nlist
, nprobe
)。这有助于选择在召回率和速度之间提供良好权衡的参数。
函数 evaluate_local_ivf_recall
通过以下方式实现此目的:
- 从数据集 A(查询向量)和数据集 B(索引向量)中采样。
- 构建一个“真实”的精确索引(
IndexFlatIP
)和待评估的IndexIVFFlat
索引。 - 比较它们的搜索结果,以估算
IndexIVFFlat
的召回率。
logger.info(40 * "=")
logger.info("** 步骤 1: 在Driver端测试参数的召回效果 **")
logger.info(40 * "=")
def evaluate_local_ivf_recall(
datasetA,
datasetB,
sample_A_count: int = 10000, # 查询样本数量
sample_B_count: int = 20000, # 构建索引的样本数量
top_k_eval: int = 100, # 检索 Top-K 的数量
ranks_to_evaluate_recall: list = [1, 10, 50], # 需要评估的召回 rank
ndimension: int = 24,
nlist: int = 1024, # IVF 索引的聚类中心数
nprobe: int = 100, # 搜索时查询的倒排列表数
metric_type: int = faiss.METRIC_INNER_PRODUCT, # 距离度量类型
index_type_to_eval: str = "IVFFlat",
pre_normalize: bool = False,
m_for_pq: int = None,
nbits_for_pq: int = 8,
d_for_pca: int = -1,
):
"""
在Driver端评估Faiss IndexIVFFlat索引的召回率。
步骤:
1. 从datasetA中采样查询向量。
2. 从datasetB中采样构建IndexFlatIP(基准)和IndexIVFFlat(待评估)。
3. 对比两个索引的检索结果以计算召回率。
参数:
datasetA (DataFrame): 包含用户特征向量的 Spark DataFrame (用于查询)。
datasetB (DataFrame): 包含用户特征向量的 Spark DataFrame (用于构建索引)。
sample_A_count (int): 从datasetA中抽取的查询向量数。
sample_B_count (int): 从datasetB中抽取的构建索引的向量数。
top_k_eval (int): 每次查询返回的Top-K个结果。
ranks_to_evaluate_recall (List[int]): 计算召回率的rank值。
ndimension (int): 向量维度。
nlist (int): Faiss IVF索引的聚类中心数。
nprobe (int): IVF索引搜索时检查的倒排列表数。
metric_type (int): Faiss距离度量方式,如faiss.METRIC_INNER_PRODUCT或faiss.METRIC_L2。
index_type_to_eval (str): 待评估的索引类型,如"IVFFlat"。
pre_normalize (bool): 是否在训练之前对特征向量进行预处理,如归一化,默认为False。
m_for_pq (int): 如果使用PQ索引,则指定PQ索引的子空间数量,默认为None。
nbits_for_pq (int): PQ索引的每个子空间的码本大小,默认为8。
d_for_pca (int): PCA矩阵的线性降维到D,通常设置D<dimension。
注意:
- 函数会从datasetA和datasetB中抽取指定数量的向量进行测试。
- 函数会创建一个基于IndexFlatIP的索引作为基准,并使用IndexIVFFlat进行评估。
- 函数会返回一个字典,包含每个rank对应的knn交集度量值(即近似召回率)。
返回:
Dict[int, float]: 每个rank对应的knn交集度量值(即近似召回率)。
"""
try:
logger.info("开始评估本地Faiss IVF索引的召回率...")
logger.info(f"评估用到的参数:{locals()}")
# --- Step 1: 从datasetA中提取查询向量 ---
logger.info(f"从datasetA中采样 {sample_A_count} 条数据作为查询向量...")
xq = np.concatenate(datasetA.select(args.feature_col_name).take(sample_A_count))
if xq.size == 0:
logger.warning("未获取到有效的查询向量!")
return {}
# --- Step 2: 从datasetB中提取构建索引所需的向量和UID ---
logger.info(f"从datasetB中采样 {sample_B_count} 条数据用于构建索引...")
db_sample = datasetB.select(args.feature_col_name, args.uid_col_name).take(sample_B_count)
xb = np.array([row[args.feature_col_name] for row in db_sample]).astype('float32')
uids = np.array([row[args.uid_col_name] for row in db_sample]).astype('int64')
if xb.size == 0 or uids.size == 0:
logger.warning("未获取到有效的构建索引所需向量或UID。")
return {}
# --- Step 3: 构建参考索引 (IndexFlatIP) ---
logger.info("正在构建参考索引 (IndexFlat) 用于对比召回结果...")
index_flat_ref = create_faiss_index("Flat", ndimension, nlist, metric_type, pre_normalize, m_for_pq=None, nbits_for_pq=None, d_for_pca=-1)
index_flat_ref.add_with_ids(xb, uids)
logger.info(f"参考索引已构建,包含 {index_flat_ref.ntotal} 个向量。")
# --- Step 4: 执行参考索引的搜索 ---
logger.info(f"使用参考索引进行Top-{top_k_eval}检索...")
Dref, Iref = index_flat_ref.search(xq, top_k_eval)
# --- Step 5: 构建并训练待评估的索引 ---
logger.info(f"正在构建待评估的 {index_type_to_eval} 索引 (nlist={nlist}, nprobe={nprobe})...")
# Use the factory to create the index
index_to_eval = create_faiss_index(index_type_to_eval, ndimension, nlist, metric_type, pre_normalize, m_for_pq, nbits_for_pq, d_for_pca)
if xb.shape[0] < nlist:
logger.warning(f"警告:构建IVF索引的数据点({xb.shape[0]})小于nlist({nlist}), 可能影响训练效果。")
index_to_eval.train(xb) # 训练索引 (trains coarse and scalar quantizers if applicable)
index_to_eval.add_with_ids(xb, uids) # 添加向量及其对应UID
# 设置nprobe参数
ps = faiss.ParameterSpace()
ps.set_index_parameter(index_to_eval, "nprobe", nprobe)
# index_to_eval.nprobe = nprobe
logger.info(f"{index_type_to_eval} 已构建并训练完成,当前包含 {index_to_eval.ntotal} 个向量。")
# --- Step 6: 执行待评估索引的搜索 ---
logger.info(f"使用待评估索引进行Top-{top_k_eval}检索...")
Dnew, Inew = index_to_eval.search(xq, top_k_eval)
# --- Step 7: 计算召回率 ---
recall_results = {}
logger.info("开始计算召回率指标...")
for rank in ranks_to_evaluate_recall:
if rank > top_k_eval:
logger.warning(f"Rank {rank} 超出Top-K范围 {top_k_eval},跳过此rank的评估。")
continue
try:
recall_value = knn_intersection_measure(Inew[:, :rank], Iref[:, :rank])
recall_results[rank] = recall_value
logger.info(f"Recall@{rank}: {recall_value:.4f}")
except Exception as e:
logger.error(f"在Rank@{rank}计算召回率时发生错误: {e}")
logger.info("召回率评估完成。")
return recall_results
except Exception as e:
logger.error(f"执行本地IVF召回率评估失败: {e}")
import traceback
traceback.print_exc()
return {}
2025-07-31 17:17:56.425 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:17:56.426 | INFO | __main__:<module>:3 - ** 步骤 1: 在Driver端测试参数的召回效果 ** 2025-07-31 17:17:56.428 | INFO | __main__:<module>:4 - ========================================
recall_metrics = evaluate_local_ivf_recall(
datasetA=datasetA,
datasetB=datasetB,
sample_A_count=5000,
sample_B_count=60000,
top_k_eval=100,
ranks_to_evaluate_recall=[1, 10, 50, 100],
ndimension=args.dimension,
nlist=args.nlist_for_ivf,
nprobe=args.nprobe_for_ivf,
metric_type=args.faiss_metric_type,
index_type_to_eval=args.index_type,
pre_normalize=args.pre_normalize,
m_for_pq=args.m_for_pq,
nbits_for_pq=args.nbits_for_pq,
d_for_pca=args.d_for_pca
)
logger.info(f"索引 {args.index_type} 召回率评估结果: {recall_metrics}")
2025-07-31 17:17:59.135 | INFO | __main__:evaluate_local_ivf_recall:57 - 开始评估本地Faiss IVF索引的召回率... 2025-07-31 17:17:59.161 | INFO | __main__:evaluate_local_ivf_recall:58 - 评估用到的参数:{'datasetA': DataFrame[uid: string, code: int, features: array<float>, norm: float, dt: string, pt: string], 'datasetB': DataFrame[uid: string, code: int, features: array<float>, norm: float, dt: string, pt: string], 'sample_A_count': 5000, 'sample_B_count': 60000, 'top_k_eval': 100, 'ranks_to_evaluate_recall': [1, 10, 50, 100], 'ndimension': 24, 'nlist': 1024, 'nprobe': 70, 'metric_type': 1, 'index_type_to_eval': 'IVFSQ8', 'pre_normalize': False, 'm_for_pq': 12, 'nbits_for_pq': 8, 'd_for_pca': -1} 2025-07-31 17:17:59.162 | INFO | __main__:evaluate_local_ivf_recall:61 - 从datasetA中采样 5000 条数据作为查询向量... 2025-07-31 17:18:07.850 | INFO | __main__:evaluate_local_ivf_recall:68 - 从datasetB中采样 60000 条数据用于构建索引... 2025-07-31 17:18:15.440 | INFO | __main__:evaluate_local_ivf_recall:78 - 正在构建参考索引 (IndexFlat) 用于对比召回结果... 2025-07-31 17:18:15.441 | INFO | __main__:create_faiss_index:25 - Step 1.1: 跳过 PCA 转换 (未指定 d_for_pca)。 2025-07-31 17:18:15.442 | INFO | __main__:create_faiss_index:33 - Step 1.2: 跳过 L2 Normalization 转换 (未指定 pre_normalize)。 2025-07-31 17:18:15.442 | INFO | __main__:create_faiss_index:35 - 预处理链构建完成。最终到达索引的向量维度为: 24 2025-07-31 17:18:15.443 | INFO | __main__:create_faiss_index:67 - Step 2: 核心索引 'Flat' 已创建 (d=24) 2025-07-31 17:18:15.445 | INFO | __main__:create_faiss_index:68 - -> 理论上,压缩后每个向量占用内存: 96 字节 2025-07-31 17:18:15.446 | INFO | __main__:create_faiss_index:76 - Step 3: 预处理层包裹完成。最终索引构建成功! 2025-07-31 17:18:15.459 | INFO | __main__:evaluate_local_ivf_recall:81 - 参考索引已构建,包含 60000 个向量。 2025-07-31 17:18:15.460 | INFO | __main__:evaluate_local_ivf_recall:84 - 使用参考索引进行Top-100检索... 2025-07-31 17:18:16.251 | INFO | __main__:evaluate_local_ivf_recall:88 - 正在构建待评估的 IVFSQ8 索引 (nlist=1024, nprobe=70)... 2025-07-31 17:18:16.252 | INFO | __main__:create_faiss_index:25 - Step 1.1: 跳过 PCA 转换 (未指定 d_for_pca)。 2025-07-31 17:18:16.253 | INFO | __main__:create_faiss_index:33 - Step 1.2: 跳过 L2 Normalization 转换 (未指定 pre_normalize)。 2025-07-31 17:18:16.254 | INFO | __main__:create_faiss_index:35 - 预处理链构建完成。最终到达索引的向量维度为: 24 2025-07-31 17:18:16.255 | INFO | __main__:create_faiss_index:67 - Step 2: 核心索引 'IVFSQ8' 已创建 (d=24) 2025-07-31 17:18:16.256 | INFO | __main__:create_faiss_index:68 - -> 理论上,压缩后每个向量占用内存: 24 字节 2025-07-31 17:18:16.257 | INFO | __main__:create_faiss_index:76 - Step 3: 预处理层包裹完成。最终索引构建成功! 2025-07-31 17:18:18.095 | INFO | __main__:evaluate_local_ivf_recall:102 - IVFSQ8 已构建并训练完成,当前包含 60000 个向量。 2025-07-31 17:18:18.097 | INFO | __main__:evaluate_local_ivf_recall:105 - 使用待评估索引进行Top-100检索... 2025-07-31 17:18:18.187 | INFO | __main__:evaluate_local_ivf_recall:110 - 开始计算召回率指标... 2025-07-31 17:18:18.312 | INFO | __main__:evaluate_local_ivf_recall:119 - Recall@1: 0.9680 2025-07-31 17:18:18.438 | INFO | __main__:evaluate_local_ivf_recall:119 - Recall@10: 0.9803 2025-07-31 17:18:18.585 | INFO | __main__:evaluate_local_ivf_recall:119 - Recall@50: 0.9866 2025-07-31 17:18:18.770 | INFO | __main__:evaluate_local_ivf_recall:119 - Recall@100: 0.9879 2025-07-31 17:18:18.771 | INFO | __main__:evaluate_local_ivf_recall:123 - 召回率评估完成。 2025-07-31 17:18:18.818 | INFO | __main__:<module>:19 - 索引 IVFSQ8 召回率评估结果: {1: 0.968, 10: 0.9803, 50: 0.986564, 100: 0.987892}
步骤 2: Driver 端训练全局粗量化器¶
“粗量化器”(coarse quantizer,本质上是 IndexIVFFlat
的 K-Means 质心)需要被训练。为了确保所有 Worker 使用相同的数据“地图”,我们在 Spark Driver 端使用 Dataset B 的一个样本来集中训练这个量化器。
函数 train_global_coarse_quantizer_on_driver
:
- 将 Dataset B 的一部分数据采样到 Driver 端。
- 使用这些样本训练一个临时的
IndexIVFFlat
。 - 提取并序列化训练好的粗量化器(其中包含聚类质心)。
logger.info(40 * "=")
logger.info(" ** 步骤 2: Driver端训练全局索引 **")
logger.info(40 * "=")
def train_global_index_shell_on_driver(
dataset: pd.DataFrame, # PySpark DataFrame for datasetB
d_dimension: int = 24,
nlist_global: int = 50, # 全局的nlist
metric_type_global: int = 0,
index_type_global: str = "IVFFlat", # 全局索引类型
training_sample_fraction: float = 0.2,
feature_col_name: str = "features",
pre_normalize: bool = False,
m_for_pq: int = None,
nbits_for_pq: int = 8,
d_for_pca: int = -1
):
"""
在Driver端训练用于Faiss IndexIVFFlat索引的全局粗量化器(coarse quantizer)。
此过程包括从datasetB中采样数据、将特征向量转换为NumPy数组、使用指定参数创建并训练一个临时IndexIVFFlat,
然后提取并序列化其内部的粗量化器以供后续在所有Worker上共享。
参数:
dataset (pd.DataFrame): 输入数据集 datasetB,包含用于训练粗量化器的特征向量。
d_dimension (int): 特征向量的维度,默认值为24。
nlist_global (int): Faiss IVF索引的聚类中心数(即nlist),也是粗量化器训练的目标数量,默认值为50。
metric_type_global (int): 距离度量类型,如 faiss.METRIC_INNER_PRODUCT 或 faiss.METRIC_L2,默认为0。
index_type_global (str): 全局索引类型,如 "IVFFlat", "IVFSQ8" 等。
training_sample_fraction (float): 从datasetB中采样的比例,用于控制训练数据量,默认为0.2。
feature_col_name (str): 数据集中特征列的名称,默认为"features"。
pre_normalize (bool): 是否在训练之前对特征向量进行预处理,如归一化,默认为False。
m_for_pq (int): 如果使用PQ索引,则指定PQ索引的子空间数量,默认为None。
nbits_for_pq (int): PQ索引的每个子空间的码本大小,默认为8。
d_for_pca (int): PCA矩阵的线性降维到D,通常设置D<dimension。
返回:
bytes: 序列化后的训练好的粗量化器对象,可用于广播并在各个Worker上构建一致的IVF索引结构。
异常:
ValueError: 如果采样后的训练数据为空或无效,则抛出异常。
"""
logger.info(f"开始在Driver端为 {index_type_global} 训练全局索引外壳 (nlist={nlist_global})...")
logger.info(f"从 datasetB 采样 {training_sample_fraction*100}% 的数据用于训练...")
# 采样数据到Driver端
# 注意:如果采样比例过大或datasetB本身巨大,toPandas() 可能导致Driver OOM
# 需要根据实际情况调整采样比例,确保Driver能够处理。
sample_pd_df = dataset.select(feature_col_name)\
.sample(False, training_sample_fraction)\
.toPandas()
if sample_pd_df.empty:
raise ValueError("为Faiss粗量化器训练的样本数据为空!请检查采样比例或源数据。")
# 将Pandas Series中的Spark Vector转换为NumPy数组
training_vectors_list = []
for vec_obj in sample_pd_df[feature_col_name]:
if vec_obj is not None:
if isinstance(vec_obj, Vector):
training_vectors_list.append(vec_obj.toArray())
elif isinstance(vec_obj, (list, tuple)):
training_vectors_list.append(vec_obj)
if not training_vectors_list:
raise ValueError("转换后的训练向量列表为空!")
training_vectors_np = np.array(training_vectors_list).astype('float32')
if training_vectors_np.shape[0] < nlist_global:
logger.warning(f"警告: 训练样本数 ({training_vectors_np.shape[0]}) 小于 NLIST_GLOBAL ({nlist_global})。"
"Faiss IVF训练质量可能受影响或失败。建议增加采样数据或减小NLIST_GLOBAL。")
if training_vectors_np.shape[0] == 0:
raise ValueError("没有有效的训练向量来训练粗量化器。")
logger.info(f"使用 {training_vectors_np.shape[0]} 个向量 (维度={d_dimension}) 训练粗量化器...")
# 1. 使用工厂函数创建索引
global_index_shell = create_faiss_index(index_type_global, d_dimension, nlist_global, metric_type_global, pre_normalize, m_for_pq, nbits_for_pq, d_for_pca)
# 2. 训练索引(这将训练所有需要训练的部分,如IVF聚类和SQ范围)
global_index_shell.train(training_vectors_np)
logger.info("全局索引外壳训练完毕。")
# 3. 序列化这个训练好的、但仍然是空的索引
return faiss.serialize_index(global_index_shell)
serialized_global_index_shell = None
try:
serialized_global_index_shell = train_global_index_shell_on_driver(
dataset=datasetB,
d_dimension=args.dimension,
nlist_global=args.nlist_for_ivf,
metric_type_global=args.faiss_metric_type,
index_type_global=args.index_type,
training_sample_fraction=args.training_sample_fraction,
feature_col_name=args.feature_col_name,
pre_normalize=args.pre_normalize,
m_for_pq=args.m_for_pq,
nbits_for_pq=args.nbits_for_pq,
d_for_pca=args.d_for_pca,
)
except Exception as e_train:
logger.error(f"Driver端训练全局索引外壳失败: {e_train}")
2025-07-31 17:18:19.470 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:18:19.472 | INFO | __main__:<module>:3 - ** 步骤 2: Driver端训练全局粗量化器 ** 2025-07-31 17:18:19.473 | INFO | __main__:<module>:4 - ======================================== 2025-07-31 17:18:20.817 | INFO | __main__:train_global_index_shell_on_driver:44 - 开始在Driver端为 IVFSQ8 训练全局索引外壳 (nlist=1024)... 2025-07-31 17:18:20.818 | INFO | __main__:train_global_index_shell_on_driver:45 - 从 datasetB 采样 20.0% 的数据用于训练... 2025-07-31 17:18:37.187 | INFO | __main__:train_global_index_shell_on_driver:77 - 使用 346103 个向量 (维度=24) 训练粗量化器... 2025-07-31 17:18:37.188 | INFO | __main__:create_faiss_index:25 - Step 1.1: 跳过 PCA 转换 (未指定 d_for_pca)。 2025-07-31 17:18:37.189 | INFO | __main__:create_faiss_index:33 - Step 1.2: 跳过 L2 Normalization 转换 (未指定 pre_normalize)。 2025-07-31 17:18:37.189 | INFO | __main__:create_faiss_index:35 - 预处理链构建完成。最终到达索引的向量维度为: 24 2025-07-31 17:18:37.190 | INFO | __main__:create_faiss_index:67 - Step 2: 核心索引 'IVFSQ8' 已创建 (d=24) 2025-07-31 17:18:37.191 | INFO | __main__:create_faiss_index:68 - -> 理论上,压缩后每个向量占用内存: 24 字节 2025-07-31 17:18:37.192 | INFO | __main__:create_faiss_index:76 - Step 3: 预处理层包裹完成。最终索引构建成功! 2025-07-31 17:18:42.691 | INFO | __main__:train_global_index_shell_on_driver:85 - 全局索引外壳训练完毕。
步骤 3: 广播序列化的全局粗量化器¶
训练好并序列化后的粗量化器随后被广播到所有 Spark Worker 节点。这使得每个 Worker 都可以使用相同的质心集合来初始化其本地的 IndexIVFFlat
。
logger.info(40 * "=")
logger.info(" ** 步骤 3: 广播序列化的全局索引外壳 **")
logger.info(40 * "=")
broadcasted_index_shell_bytes = spark.sparkContext.broadcast(serialized_global_index_shell)
logger.info("全局索引外壳已广播。")
2025-07-31 17:19:12.268 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:19:12.270 | INFO | __main__:<module>:3 - ** 步骤 3: 广播序列化的全局索引外壳 ** 2025-07-31 17:19:12.271 | INFO | __main__:<module>:4 - ======================================== 2025-07-31 17:19:12.278 | INFO | __main__:<module>:7 - 全局索引外壳已广播。
步骤 4: Worker 端填充部分 IVF 索引¶
数据集 B 被重新分区,每个分区由一个 Worker 处理。
函数 worker_populate_partial_index
在每个 Worker 上运行:
- 反序列化广播过来的全局粗量化器。
- 将其分配到的 Dataset B 分区中的向量(及其对应的 UID)添加到这个本地索引中。
- 序列化这个已填充数据的本地索引(一个
np.array
),进行Base64编码,并解码为ascii字符串 - 将文本化的索引,以Hive表的形式写入HDFS,形成Distributed Index Shards
logger.info(40 * "=")
logger.info("** 步骤 4: Worker填充部分索引**")
logger.info(40 * "=")
def worker_populate_partial_index(
p_idx: int,
iterator: iter,
broadcasted_ser_index_shell_bytes: bytes, # 序列化后的全局索引外壳
d_dimension: int,
feature_col_name: str,
uid_col_name: str
):
"""
在每个worker上,反序列化全局索引外壳,并用本地数据填充它。
参数:
p_idx: 当前Worker的索引
iterator: 迭代器,包含每个Worker的数据
broadcasted_ser_index_shell_bytes: 广播的全局索引外壳
d_dimension: 特征向量的维度
feature_col_name: 特征列的名称
uid_col_name: 用户ID列的名称
返回:
序列化后的局部索引列表
"""
import sys
import faiss
import numpy as np
from pyspark.ml.linalg import Vector
local_vectors_list = []
local_uids_int64_list = []
for row in iterator:
if row[feature_col_name] is not None and row[uid_col_name] is not None:
try:
uid_int64 = np.int64(str(row[uid_col_name]))
feature_data = row[feature_col_name]
current_vector_py_list = None
if isinstance(feature_data, Vector):
if feature_data.size == d_dimension:
current_vector_py_list = feature_data.toArray().tolist()
elif isinstance(feature_data, (list, tuple)):
if len(feature_data) == d_dimension:
temp_list = []; valid_list = True
for x in feature_data:
try: temp_list.append(float(x))
except: valid_list = False; break
if valid_list: current_vector_py_list = temp_list
if current_vector_py_list:
local_uids_int64_list.append(uid_int64)
local_vectors_list.append(current_vector_py_list)
except ValueError: pass
except Exception: pass
if not local_vectors_list:
return []
local_vectors_np = np.array(local_vectors_list).astype('float32')
local_uids_np_int64 = np.array(local_uids_int64_list)
if local_vectors_np.ndim == 1:
if local_vectors_np.shape[0] == d_dimension and d_dimension > 0:
local_vectors_np = local_vectors_np.reshape(1, -1)
else:
return []
if local_vectors_np.shape[0] == 0 or local_vectors_np.shape[0] != local_uids_np_int64.shape[0]:
return []
try:
# 1. Worker反序列化广播过来的空的、但已训练的索引外壳
local_index = faiss.deserialize_index(broadcasted_ser_index_shell_bytes)
# 2. 将当前分区的向量添加到这个本地索引中
local_index.add_with_ids(local_vectors_np, local_uids_np_int64)
# 3. 序列化这个填充了数据的本地索引
serialized_bytes_to_return = faiss.serialize_index(local_index)
# 4. 进行Base64编码,并解码为ascii字符串以便存入Hive
serialized_str_to_return = base64.b64encode(serialized_bytes_to_return).decode('ascii')
ntotal_to_return = local_index.ntotal
# 使用 .nbytes 来获取真实的内存占用
size_in_bytes = serialized_bytes_to_return.nbytes
return [(p_idx, serialized_str_to_return, ntotal_to_return, size_in_bytes)]
except Exception as e_faiss_worker_op:
sys.stderr.write(f"P{p_idx}: Error in worker_populate_partial_index: {e_faiss_worker_op}\n")
return []
2025-07-31 17:19:14.771 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:19:14.772 | INFO | __main__:<module>:3 - ** 步骤 4: Worker填充部分索引** 2025-07-31 17:19:14.773 | INFO | __main__:<module>:4 - ========================================
app_id = spark.sparkContext.applicationId.replace("-", "_")
temp_hive_table = f"weibo_bigdata_tmp.yandi_{args.knn_table}_{args.knn_dt}_{app_id}"
logger.info(f"创建临时Hive表: {temp_hive_table}")
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {temp_hive_table} (
-- 分区ID,标识当前Worker的索引分片
shard_id INT,
-- 列类型为STRING,用于存储Base64编码后的索引
encoded_shard STRING,
-- 分片中向量的数量
num_vectors INT,
-- 分片的内存占用大小(字节数)
shard_size_bytes BIGINT
)
ROW FORMAT DELIMITED FIELDS TERMINATED BY '\t'
LINES TERMINATED BY '\n'
STORED AS TEXTFILE
""")
2025-07-31 17:19:16.531 | INFO | __main__:<module>:4 - 创建临时Hive表: weibo_bigdata_tmp.yandi_bigdata_vf_long_inte_user_als_knn_20250315_local_1753953461155
DataFrame[]
logger.info(f"将在 {args.num_partitions_for_index} 个分区上并行填充局部索引...")
try:
collected_worker_outputs = datasetB.rdd.repartition(args.num_partitions_for_index).mapPartitionsWithIndex(
lambda p_idx, iterator: worker_populate_partial_index(
p_idx, iterator,
broadcasted_index_shell_bytes.value,
args.dimension,
args.feature_col_name,
args.uid_col_name
)
)
spark.createDataFrame(
collected_worker_outputs,
schema=["shard_id", "encoded_shard", "num_vectors", "shard_size_bytes"]
).write.mode("overwrite").insertInto(temp_hive_table)
logger.info(f"将索引分片写入临时Hive表 {temp_hive_table} 完成")
except Exception as e:
logger.error(f"Error during worker index population: {e}")
import traceback
traceback.print_exc()
2025-07-31 17:19:20.225 | INFO | __main__:<module>:2 - 将在 5 个分区上并行填充局部索引... 2025-07-31 17:19:59.807 | INFO | __main__:<module>:17 - 将索引分片写入临时Hive表 weibo_bigdata_tmp.yandi_bigdata_vf_long_inte_user_als_knn_20250315_local_1753953461155 完成
logger.info("--- 开始检查序列化索引的内存占用 ---")
sizes_df = spark.sql(f"""
SELECT shard_id, num_vectors, shard_size_bytes / 1024 / 1024 AS shard_size_mb
FROM {temp_hive_table}
""").toPandas()
for index, row in sizes_df.iterrows():
shard_id = row['shard_id']
num_vectors = row['num_vectors']
shard_size_mb = row['shard_size_mb']
logger.info(f"- 分片 {int(shard_id)}: 向量数量={int(num_vectors)}, 内存占用={shard_size_mb:.2f} MB")
logger.info(f"--- 检查完毕 ---")
logger.info(f"所有有效索引分片序列化后总大小: {sizes_df['shard_size_mb'].sum():.2f} MB")
2025-07-31 17:19:59.902 | INFO | __main__:<module>:2 - --- 开始检查序列化索引的内存占用 --- 2025-07-31 17:20:01.420 | INFO | __main__:<module>:12 - - 分片 2: 向量数量=345397, 内存占用=10.64 MB 2025-07-31 17:20:01.422 | INFO | __main__:<module>:12 - - 分片 0: 向量数量=345396, 内存占用=10.64 MB 2025-07-31 17:20:01.423 | INFO | __main__:<module>:12 - - 分片 3: 向量数量=345373, 内存占用=10.64 MB 2025-07-31 17:20:01.423 | INFO | __main__:<module>:12 - - 分片 4: 向量数量=345394, 内存占用=10.64 MB 2025-07-31 17:20:01.424 | INFO | __main__:<module>:12 - - 分片 1: 向量数量=345390, 内存占用=10.64 MB 2025-07-31 17:20:01.425 | INFO | __main__:<module>:14 - --- 检查完毕 --- 2025-07-31 17:20:01.427 | INFO | __main__:<module>:15 - 所有有效索引分片序列化后总大小: 53.21 MB
步骤 5: 将索引分片打包并广播¶
Driver 获取Hive表的HDFS路径,并通过 addFile 注册分发到每个Worker的本地目录。Worker 可以通过 SparkFiles.get()
获取这些索引分片的本地路径。
logger.info(40 * "=")
logger.info("** 步骤 5: 将索引分片打包并广播**")
logger.info(40 * "=")
# 执行DESCRIBE命令并获取结果
desc_df = spark.sql(f"DESCRIBE FORMATTED {temp_hive_table}")
# 过滤出Location那一行
location_row = desc_df.filter("col_name = 'Location'").first()
# 提取出HDFS路径
hdfs_shards_path = location_row["data_type"]
logger.info(f"成功获取到Hive表的数据位置: {hdfs_shards_path}")
# 使用addFile分发这个HDFS目录
sc = spark.sparkContext
sc.addFile(hdfs_shards_path, recursive=True)
hdfs_shards_basename = os.path.basename(hdfs_shards_path)
logger.info(f"✅ 已通过addFile分发HDFS目录 '{hdfs_shards_basename}'")
2025-07-31 17:20:01.461 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:20:01.463 | INFO | __main__:<module>:3 - ** 步骤 5: 将索引分片打包并广播 ** 2025-07-31 17:20:01.464 | INFO | __main__:<module>:4 - ======================================== 2025-07-31 17:20:01.724 | INFO | __main__:<module>:12 - 成功获取到Hive表的数据位置: viewfs://c9/user_ext/weibo_bigdata_vf/warehouse/yandi_bigdata_vf_long_inte_user_als_knn_20250315_local_1753953461155 2025-07-31 17:20:04.004 | INFO | __main__:<module>:18 - ✅ 已通过addFile分发HDFS目录 'yandi_bigdata_vf_long_inte_user_als_knn_20250315_local_1753953461155'
步骤 6: 执行并行搜索并收集结果¶
函数 worker_search_and_collect_results
在每个 Worker 上运行:
- 依次轮训本地目录中的索引分片路径,每次加载一个Sharded Index到内存中
- 对于每个查询用户(来自 Dataset A),使用
IndexIVFFlat
执行搜索。 - 对于不同索引分片的搜索结果,每个用户都使用最小堆(Heap)维护合并后的 Top-K 结果。
- 将所有结果写入到一个 Hive 表中。
logger.info(40 * "=")
logger.info("** 步骤 6: 执行并行搜索并收集结果 **")
logger.info(40 * "=")
def parallel_search_with_heap_worker(
partition_id: int,
iterator_rows,
broadcast_dir_info: str,
NPROBE_SEARCH: int,
TOP_K_SEARCH: int,
METRIC_TYPE: int,
dimension: int,
feature_col_name: str,
uid_col_name: str
):
"""
Worker函数,它从一个通过Spark分发的HDFS目录中读取索引分片,
并使用这些索引来处理一个datasetA分区。
它在内部对每个查询用户使用一个heap(堆)来实时维护Top-K结果。
"""
import faiss
import base64
import numpy as np
import heapq
from pyspark import SparkFiles
# ------------------- 数据准备 -------------------
# 一次性将分区内所有查询向量加载到内存
query_vectors_list = []
query_uids_list = []
for row in iterator_rows:
try:
feature_data = row[feature_col_name]
if feature_data is not None and len(feature_data) == dimension:
query_uids_list.append(str(row[uid_col_name]))
# 直接使用list, 避免早期转换为numpy增加内存开销
query_vectors_list.append(feature_data)
except Exception as e:
sys.stderr.write(f"P{partition_id}: 跳过一行数据,原因: {e}\n")
if not query_uids_list:
return []
xq_part = np.array(query_vectors_list, dtype='float32')
num_queries_in_part = xq_part.shape[0]
# ------------------- Heap 初始化 -------------------
query_heaps = {i: [] for i in range(num_queries_in_part)}
is_max_heap = (METRIC_TYPE == faiss.METRIC_L2)
# ------------------- 拉取远程目录到本地 -------------------
local_shards_dir_path = SparkFiles.get(broadcast_dir_info)
part_files = [f for f in os.listdir(local_shards_dir_path) if not f.startswith('.')]
if not part_files:
sys.stderr.write(f"P{partition_id}: 分发的目录 '{local_shards_dir_path}' 中没有找到索引分片文件.\n")
return []
# ------------------- 迭代搜索 -------------------
for part_file in part_files:
try:
full_part_path = os.path.join(local_shards_dir_path, part_file)
with open(full_part_path, 'r') as f:
for line in f:
if not line.strip(): continue
# 解析每一行,提取分片ID、序列化字符串、向量数量和分片大小
fields = line.rstrip("\r\n").split("\t")
shard_id, serialized_str, num_vectors, shard_size_bytes = fields
# 读取一个索引分片文件的字节流到内存
serialized_bytes = base64.b64decode(serialized_str)
# b. 将bytes转换为numpy数组以供Faiss使用
buffer_as_np_array = np.frombuffer(serialized_bytes, dtype='uint8')
# c. 解序列化
index_shard = faiss.deserialize_index(buffer_as_np_array)
ps = faiss.ParameterSpace()
ps.set_index_parameter(index_shard, 'nprobe', NPROBE_SEARCH)
# d. 搜索
D_results, I_results = index_shard.search(xq_part, TOP_K_SEARCH)
# e. 更新Heaps
for i in range(num_queries_in_part):
current_heap = query_heaps[i]
for k in range(TOP_K_SEARCH):
dist = D_results[i][k]
uid_b = I_results[i][k]
if uid_b == -1: break
heap_item = (-dist, uid_b) if is_max_heap else (dist, uid_b)
if len(current_heap) < TOP_K_SEARCH:
heapq.heappush(current_heap, heap_item)
else:
heapq.heappushpop(current_heap, heap_item)
del index_shard, serialized_bytes, buffer_as_np_array
except Exception as e:
sys.stderr.write(f"P{partition_id}: 处理文件 '{part_file}' 时出错: {e}\n")
continue
# ------------------- 格式化输出 -------------------
final_results = []
for i in range(num_queries_in_part):
uid_a = query_uids_list[i]
heap = query_heaps[i]
for item in heap:
dist = -item[0] if is_max_heap else item[0]
uid_b = str(item[1])
final_results.append((uid_a, uid_b, float(dist)))
return final_results
2025-07-31 17:20:04.103 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:20:04.104 | INFO | __main__:<module>:3 - ** 步骤 6: 执行并行搜索并收集结果 ** 2025-07-31 17:20:04.105 | INFO | __main__:<module>:4 - ========================================
# 定义结果RDD的Schema
# 注意:现在不需要rank和shard_id了,因为heap已经处理了排序
knn_results_schema = StructType([
StructField("uid_a", StringType(), True),
StructField("uid_b", StringType(), True),
StructField("distance", FloatType(), True)
])
# 执行唯一的Spark Job
knn_results_rdd = datasetA.rdd.repartition(args.num_partitions_for_match).mapPartitionsWithIndex(
lambda p_idx, iterator: parallel_search_with_heap_worker(
partition_id=p_idx,
iterator_rows=iterator,
broadcast_dir_info=hdfs_shards_basename,
NPROBE_SEARCH=args.nprobe_for_ivf,
TOP_K_SEARCH=args.topK,
METRIC_TYPE=args.faiss_metric_type,
dimension=args.dimension,
feature_col_name=args.feature_col_name,
uid_col_name=args.uid_col_name
)
)
# 将RDD转换为DataFrame
all_pairs_df = spark.createDataFrame(knn_results_rdd, schema=knn_results_schema)
步骤 7: 将 KNN 结果输出到 Hive 表¶
最终得到的相似用户对 DataFrame (knn_pairs_df
) 被写入到一个 Hive 表中,该表按日期和 uid_a
的一部分进行分区,以便于查询和管理。结果会根据 distance_cutoff
进行过滤。
logger.info(40 * "=")
logger.info("** 步骤 7: 将最终结果输出到Hive表 **")
logger.info(40 * "=")
logger.info(f"创建或准备最终结果表: {args.knn_table}")
spark.sql(f"""
CREATE TABLE IF NOT EXISTS {args.knn_table} (
uid_a STRING,
uid_b STRING,
distance FLOAT,
rank INT
)
PARTITIONED BY (dt STRING, pt STRING)
STORED AS PARQUET
""")
# 根据度量类型确定排序方式
order_col = "distance"
if args.faiss_metric_type == faiss.METRIC_L2:
# L2距离,升序 (距离越小越好)
window_spec = Window.partitionBy("uid_a").orderBy(all_pairs_df[order_col].asc())
else:
# 内积,降序 (值越大越好)
window_spec = Window.partitionBy("uid_a").orderBy(all_pairs_df[order_col].desc())
# 添加最终的rank列
final_df = all_pairs_df.withColumn("rank", row_number().over(window_spec))
# 应用距离阈值过滤
if args.distance_cutoff is not None:
if args.faiss_metric_type == faiss.METRIC_L2:
final_df = final_df.filter(f"distance < {args.distance_cutoff}")
else:
final_df = final_df.filter(f"distance > {args.distance_cutoff}")
# 增加分区列并写入Hive
final_df.filter(f"rank <= {args.topK}") \
.filter("uid_a <> uid_b") \
.withColumn("dt", lit(args.knn_dt)) \
.withColumn("pt", final_df.uid_a.substr(-1, 1)) \
.write.mode("overwrite") \
.insertInto(args.knn_table)
logger.info(f"KNN结果已成功写入表: {args.knn_table} (分区 dt={args.knn_dt})")
2025-07-31 17:20:04.645 | INFO | __main__:<module>:2 - ======================================== 2025-07-31 17:20:04.647 | INFO | __main__:<module>:3 - ** 步骤 7: 将最终结果输出到Hive表 ** 2025-07-31 17:20:04.648 | INFO | __main__:<module>:4 - ======================================== 2025-07-31 17:20:04.649 | INFO | __main__:<module>:6 - 创建或准备最终结果表: bigdata_vf_long_inte_user_als_knn 2025-07-31 17:21:45.779 | INFO | __main__:<module>:45 - KNN结果已成功写入表: bigdata_vf_long_inte_user_als_knn (分区 dt=20250315)
logger.info("任务完成!显示部分结果:")
spark.sql(f"SELECT * FROM {args.knn_table} WHERE dt='{args.knn_dt}' ORDER BY uid_a, rank LIMIT 20").show()
2025-07-31 17:21:45.864 | INFO | __main__:<module>:2 - 任务完成!显示部分结果:
+----------+----------+----------+----+--------+---+ | uid_a| uid_b| distance|rank| dt| pt| +----------+----------+----------+----+--------+---+ |1000000153|6138721030|0.96903944| 0|20250315| 3| |1000000153|5621332690|0.96661687| 1|20250315| 3| |1000000153|5525586120| 0.9633894| 2|20250315| 3| |1000000153|6500644830| 0.9600231| 3|20250315| 3| |1000000153|5807222500| 0.9594191| 4|20250315| 3| |1000000153|6523887000|0.95883214| 5|20250315| 3| |1000000153|2305743370| 0.9573312| 6|20250315| 3| |1000000153|6143038350| 0.9557892| 7|20250315| 3| |1000000153|2126798500| 0.9551554| 8|20250315| 3| |1000000153|5721626560| 0.9541809| 9|20250315| 3| |1000000153|1912830870| 0.9538749| 10|20250315| 3| |1000000153|3340577420|0.95359516| 11|20250315| 3| |1000000153|6016112630| 0.9533439| 12|20250315| 3| |1000000153|2133592790|0.95221055| 13|20250315| 3| |1000000153|6120397410|0.95187557| 14|20250315| 3| |1000000153|2312551120| 0.9517801| 15|20250315| 3| |1000000153|2600758530| 0.951372| 16|20250315| 3| |1000000153|3026391770|0.95084023| 17|20250315| 3| |1000000153|6532585820|0.95029217| 18|20250315| 3| |1000000153|3922555110| 0.9500053| 19|20250315| 3| +----------+----------+----------+----+--------+---+
logger.info(f"清理临时Hive表: {temp_hive_table}")
spark.sql(f"DROP TABLE IF EXISTS {temp_hive_table}")
2025-07-31 17:29:10.144 | INFO | __main__:<module>:2 - 清理临时Hive表: weibo_bigdata_tmp.yandi_bigdata_vf_long_inte_user_als_knn_20250315_local_1753953461155
DataFrame[]