ALS增量更新¶
1、为何需要增量更新?¶
我们借助ALS可以对用户的行为或者兴趣进行建模,成功将用户和标签进行向量化。
1.1、挑战¶
首先,在推荐业务中,用户的行为每天都在发生变化,向量也需要相应的更新才行,并不是一次性的任务。但在标准的ALS模块中,每次调用fit()
时都会从头重新计算所有因子,这样算出的结果会与之前的结果完全不同,不具有连续性。跨不同模型版本比较或跟踪用户/物品变得困难。
其次,ALS模块要求用户和物品的ID数量不能超过2亿,但在业务场景中用户是10亿量级的,一个办法是分批(batch)处理,但是也是无法让不同批次间的向量处于同一向量空间中。
1.2、归纳 -> 推导¶
Spark ALS是归纳性质的算法( transductive ),每次都需要全量ID进入训练,才能学习到向量表达。但是推荐系统偏爱具有推导性质( inductive )的算法,使用抽样数据学习到模型参数,然后在全量数据上推导。这个区别,类似于图卷积网络(GNN)相对于传统node2vec方法的改进。
1.3、解决方案¶
实现了一个Spark ALS源码的一个扩展插件:spark-incremental-als,Scala 类名为 org.apache.spark.ml.recommendation.IncrementalALS
实现标准的 fit()
方法行为与原始 Spark ALS 一致(从头开始训练),但其主要增强在于提供了单步更新能力。它通过 PySpark 包装器提供了可访问的静态方法 (stepUser
, stepItem
),允许用户使用预先存在的另一方因子来执行用户因子或物品因子的单次 ALS 更新迭代。这有助于实现 在线更新 或 微调 等场景,在这些场景中, 只需要根据新的交互数据更新一组因子,同时保持与先前状态的向量连续性。
该实现被设计为一个独立的扩展,需要用户在其 Spark 环境中包含编译后的 JAR 文件和可能的 Python 包装器脚本。
1.4、核心功能¶
- 标准 ALS 拟合:
fit()
方法的行为与pyspark.ml.recommendation.ALS.fit()
完全相同。 - 单步用户因子更新 (
stepUser
): 基于新的评分数据和固定的、预先存在的物品因子计算更新后的用户因子。处理评分数据中出现的新用户(通过使用零向量初始化其因子)。 - 单步物品因子更新 (
stepItem
): 基于新的评分数据和固定的、预先存在的用户因子计算更新后的物品因子。处理评分数据中出现的新物品(通过初始化其因子)。 - 独立部署: 核心 Scala 逻辑作为外部 JAR 使用,避免修改 Spark 源代码。
- PySpark 封装: 提供 Python 类
IncrementalALS
(继承自pyspark.ml.recommendation.ALS
),便于集成和访问stepUser
/stepItem
函数。
2、整体架构¶
我们用20250306的训练好的模型作为初始值,拟合20250314的新数据,验证如下问题
- 当使用ALS的
fit()
方法时,模型不具有连续性 - 当使用
stepUser
和stepItem
方法时,模型具有连续性
from pyspark.sql import SparkSession
from pyspark.ml.recommendation import ALS, ALSModel
from pyspark.sql.functions import col, udf, expr, collect_set, lit
import pandas as pd
import numpy as np
import random
from loguru import logger
def get_spark():
"""
获取Spark会话,如果不存在则创建一个
"""
if 'spark' not in locals():
spark = SparkSession.builder \
.appName("ALS") \
.config("spark.sql.catalogImplementation", "hive") \
.enableHiveSupport() \
.getOrCreate()
spark.sparkContext.setLogLevel("ERROR")
return spark
spark = get_spark()
spark
SparkSession - hive
def split_dataset(dataset, train_test_split=0.9, seed=123, split_by_user=True, add_negative_samples=False, neg_sample_ratio=1.0):
"""
分割数据集为训练集和测试集,并根据需要添加负样本
"""
def add_random_negative_samples(df, all_tags, ratio=1.0):
"""为每个用户生成随机负样本"""
def add_negative_samples(row):
user_code = row['user_code']
positive_tags = set(row['positive_tags'])
negative_tags = random.sample(list(all_tags.value - positive_tags), int(len(positive_tags) * ratio))
return [(user_code, tag_code, 0.0) for tag_code in negative_tags]
partitions = df.rdd.getNumPartitions()
df_grouped = df.groupby('user_code').agg(collect_set('tag_code').alias('positive_tags')).repartition(partitions)
rdd_neg = df_grouped.rdd.flatMap(add_negative_samples)
return spark.createDataFrame(rdd_neg, ['user_code', 'tag_code', 'weight'])
# 分割数据集
train_data, test_data = dataset.select("user_code", "tag_code", "weight") \
.randomSplit([train_test_split, 1 - train_test_split], seed=seed)
# 按用户分割(可选)
if split_by_user:
dividend = int(1.0 / (1 - train_test_split))
test_data = test_data.where(f"user_code % {dividend} = 1")
# 添加测试集负样本
all_tags = dataset.select('tag_code').distinct().rdd.map(lambda row: row['tag_code']).collect()
all_tags = spark.sparkContext.broadcast(set(all_tags))
logger.info(f"数据中总标签数:{len(all_tags.value)}")
test_data_neg = add_random_negative_samples(test_data, all_tags, ratio=neg_sample_ratio)
test_data = test_data.union(test_data_neg)
# 添加训练集负样本(可选)
if add_negative_samples:
train_data_neg = add_random_negative_samples(train_data, all_tags, ratio=neg_sample_ratio)
train_data = train_data.union(train_data_neg)
return train_data, test_data
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
import numpy as np
# --- Define the Pandas UDF ---
@pandas_udf(DoubleType())
def cosine_similarity_pandas_impl(vec1_series: pd.Series, vec2_series: pd.Series) -> pd.Series:
"""
Calculates cosine similarity between two pandas Series of vectors using NumPy.
(Implementation is the same as the previous example)
"""
def calculate_similarity(v1, v2):
if v1 is None or v2 is None:
return 0.0
# Handle PySpark Vector types if necessary
if hasattr(v1, "toArray"):
v1 = v1.toArray()
if hasattr(v2, "toArray"):
v2 = v2.toArray()
vec1 = np.array(v1)
vec2 = np.array(v2)
dot_product = np.dot(vec1, vec2)
norm1 = np.linalg.norm(vec1)
norm2 = np.linalg.norm(vec2)
if norm1 == 0.0 or norm2 == 0.0:
return 0.0
similarity = dot_product / (norm1 * norm2)
return float(np.clip(similarity, -1.0, 1.0))
return pd.Series([calculate_similarity(v1, v2) for v1, v2 in zip(vec1_series, vec2_series)])
# Register the Pandas UDF with a name for SQL
spark.udf.register("cosine_similarity", cosine_similarity_pandas_impl)
<function __main__.cosine_similarity_pandas_impl(vec1_series: pandas.core.series.Series, vec2_series: pandas.core.series.Series) -> pandas.core.series.Series>
# ===============================
# 主函数,执行ALS模型训练和评估"
# ===============================
class Args:
seed = 123
dataset = "bigdata_vf_als_user_tag_tuple_test"
dataset_dt = "20250314"
dataset_pt = "long_obj"
dataset_substr = "5"
maxIter = 10
regParam = 0.01
rank = 24
implicitPrefs = True
alpha = 500
train_test_split = 0.9
als_blocks = 10
split_by_user = True
add_negative_samples = False
neg_sample_ratio = 1.0
transform = "power4"
topk_items = 50
faiss_index_type = "IP"
args = Args()
加载20250314新数据
# 加载新数据集
dataset = spark.sql(f"""
select *
from {args.dataset}
where dt='{args.dataset_dt}' and pt='{args.dataset_pt}'
and substr(uid, 3, 1) in ({args.dataset_substr})
and substr(uid, 4, 1) = '0'
""")
logger.info("数据集的schema:")
dataset.printSchema()
2025-04-07 11:13:05.057 | INFO | __main__:<module>:9 - 数据集的schema:
root |-- uid: string (nullable = true) |-- user_code: integer (nullable = true) |-- tag_id: string (nullable = true) |-- tag_code: integer (nullable = true) |-- tag_name: string (nullable = true) |-- weight: double (nullable = true) |-- card_u: long (nullable = true) |-- card_t: long (nullable = true) |-- dt: string (nullable = true) |-- pt: string (nullable = true)
# 分割数据集
train_data, test_data = split_dataset(
dataset,
train_test_split=args.train_test_split,
split_by_user=args.split_by_user,
add_negative_samples=args.add_negative_samples,
seed=args.seed,
neg_sample_ratio=args.neg_sample_ratio
)
2025-04-07 11:13:49.204 | INFO | __main__:split_dataset:30 - 数据中总标签数:50758
# 设置ALS模型参数
params = {
"maxIter": args.maxIter,
"regParam": args.regParam,
"userCol": 'user_code',
"itemCol": 'tag_code',
"ratingCol": 'weight',
"rank": args.rank,
"coldStartStrategy": 'drop',
"implicitPrefs": args.implicitPrefs,
"alpha": args.alpha,
"numUserBlocks": args.als_blocks,
"numItemBlocks": args.als_blocks,
"seed": 14683
}
als = ALS(**params)
logger.info("模型参数:{}", params)
2025-04-07 11:18:41.862 | INFO | __main__:<module>:17 - 模型参数:{'maxIter': 10, 'regParam': 0.01, 'userCol': 'user_code', 'itemCol': 'tag_code', 'ratingCol': 'weight', 'rank': 24, 'coldStartStrategy': 'drop', 'implicitPrefs': True, 'alpha': 500, 'numUserBlocks': 10, 'numItemBlocks': 10, 'seed': 14683}
# 分数数据转换
transform = {
"power5": expr("pow(weight, 5)"),
"power4": expr("pow(weight, 4)"),
"power3": expr("pow(weight, 3)"),
"power2": expr("pow(weight, 2)"),
"logit": expr("ln(weight/(1.001 - weight)) + 0.5"),
"none": col("weight")
}
transform_name = args.transform
logger.info(f"使用转换:{transform_name} -> {transform[transform_name]}")
train_data = train_data.withColumn("weight", transform[transform_name])
2025-04-07 11:18:44.235 | INFO | __main__:<module>:11 - 使用转换:power4 -> Column<'pow(weight, 4)'>
加载20250306的模型
model_path = "viewfs:///user_ext/weibo_bigdata_vf/yandi/als/checkpoints/dt=20250306/pt=long_obj/train.d_24.imp_True.reg_0.01.a_500.it_10.tf_power4.sub_5"
model = ALSModel.load(model_path)
3.1、验证:原模型不具有连续性¶
增强版本的IncrementalALS
类,保持了原有ALS的创建方法和接口,包括__init__
和fit
等
from incremental_als_wrapper import IncrementalALS
als_inc = IncrementalALS(**params)
als_inc.setMaxIter(10)
# als_inc.setImplicitPrefs(False)
IncrementalALS_470d7faa19c5
model_update = als_inc.fit(dataset)
物品向量的连续性验证
model.itemFactors.createOrReplaceTempView("factors_old")
model_update.itemFactors.createOrReplaceTempView("factors_new")
/data0/spark/spark-3.2.0-bin/python/pyspark/sql/context.py:127: FutureWarning: Deprecated in 3.0.0. Use SparkSession.builder.getOrCreate() instead. FutureWarning
spark.sql(f"""
create or replace temp view factor_similarity as
select
t1.id,
cosine_similarity(t1.features, t2.features) AS cosine_similarity
from factors_old t1
join factors_new t2
on t1.id=t2.id
""")
DataFrame[]
spark.sql("select * from factor_similarity limit 10").show()
+---+-------------------+ | id| cosine_similarity| +---+-------------------+ | 10|0.28464841842651367| | 20| 0.4691894054412842| | 30| 0.3763030171394348| | 40| 0.2475971281528473| | 60| 0.4742222726345062| | 70| 0.4416641592979431| | 80|0.18065597116947174| |100| 0.403485506772995| |110| 0.3420761227607727| |120|0.35528796911239624| +---+-------------------+
spark.sql("select count(1) from factor_similarity").show()
+--------+ |count(1)| +--------+ | 50133| +--------+
spark.sql("""
select percentile(cosine_similarity, array(0.05, 0.25, 0.5, 0.75, 0.95)) as percentiles
from factor_similarity
""").collect()
[Row(percentiles=[-0.33890747427940365, -0.1805485188961029, -0.053095754235982895, 0.07690224796533585, 0.27937580943107604])]
用户向量的连续性验证¶
稍微复杂点,由于用户uid量级超过ALS的阈值2亿,因此跑ALS的时候把用户编码成整数user_code
才能放入的;相应的,在验证时则需要还原回原始uid来对比。
# 加载数据集
model.userFactors.createOrReplaceTempView("user_factors_old")
user_factors_new = spark.sql(f"""
with yandi_tmp_1 as (
select distinct
uid, user_code
from {args.dataset}
where dt='20250306' and pt='{args.dataset_pt}'
and substr(uid, 3, 1) in ({args.dataset_substr})
),
yandi_tmp_2 as (
select distinct
uid, user_code
from {args.dataset}
where dt='20250314' and pt='{args.dataset_pt}'
and substr(uid, 3, 1) in ({args.dataset_substr})
),
yandi_tmp_3 as (
select
a.uid as uid,
a.user_code as user_code_old,
b.user_code as user_code_new
from yandi_tmp_1 a
join yandi_tmp_2 b
on a.uid = b.uid
),
yandi_tmp_4 as (
select
a.uid as uid,
cast(a.user_code_old as int) as user_code_old,
cast(a.user_code_new as int) as user_code,
b.features
from yandi_tmp_3 a
join user_factors_old b
on a.user_code_old = b.id
)
select * from yandi_tmp_4
""")
user_factors_new.createOrReplaceTempView("user_factors_new")
user_factors_new.cache().show()
+----------+-------------+---------+--------------------+ | uid|user_code_old|user_code| features| +----------+-------------+---------+--------------------+ |1950169903| 24567405| 24518631|[0.18755348, 1.42...| |5150721851| 217720849|217530304|[1.333896, 1.6321...| |7350496386| 251051325|250478445|[-0.17739375, 1.7...| |7450187252| 254006648|253380183|[0.91706353, 1.40...| |6450777661| 240068194|239630225|[0.11180763, 5.21...| |3950390013| 216129908|215956728|[0.45320043, 2.30...| |7850407323| 270351805|269352445|[0.95235443, -0.0...| |7650669818| 159673079|158938373|[0.5653007, 2.657...| |5250152815| 118780714|118579072|[0.7234812, 1.623...| |5650288470| 225158579|224891269|[-1.4311621, 5.08...| |7650715853| 259928557|259193471|[-0.018236224, 0....| |6050140208| 132753095|132417886|[0.99114317, 1.67...| |1950064504| 14547480| 14499060|[0.48918027, 4.77...| |5850242192| 128961964|128663648|[-0.88004345, 0.1...| |7850072659| 270319395|269321272|[1.2595501, 2.826...| |7750384336| 162932224|162126258|[0.2827783, 0.486...| |6050567818| 132762497|132427183|[0.3833326, 1.509...| |7450001822| 253999907|253373602|[-0.29681727, 1.7...| |7850419835| 170090188|169094020|[-1.3753982, -0.9...| |5350719094| 220584694|220366204|[-0.12574238, 0.1...| +----------+-------------+---------+--------------------+ only showing top 20 rows
model_update.userFactors.where("id==24518631").show()
+--------+--------------------+ | id| features| +--------+--------------------+ |24518631|[1.3217652, -0.27...| +--------+--------------------+
model.userFactors.createOrReplaceTempView("factors_old")
model_update.userFactors.createOrReplaceTempView("factors_new")
spark.sql(f"""
create or replace temp view factor_similarity as
select
t1.id,
cosine_similarity(t1.features, t3.features) AS cosine_similarity
from factors_old t1
join user_factors_new t2
on t1.id=t2.user_code_old
join factors_new t3
on t3.id=t2.user_code
""")
DataFrame[]
spark.sql("select * from factor_similarity limit 10").show()
+---------+--------------------+ | id| cosine_similarity| +---------+--------------------+ |214829050| 0.40151792764663696| |144385270| 0.28758251667022705| |234782490| 0.07918567955493927| |240066950| 0.18589799106121063| |225177780| 0.09377120435237885| |234786600|0.008733067661523819| |128989250| 0.03242792561650276| | 22174520| 0.47465264797210693| |170098170| 0.07043634355068207| |153761460| -0.5123785138130188| +---------+--------------------+
spark.sql("select count(1) from factor_similarity").show()
+--------+ |count(1)| +--------+ | 774178| +--------+
spark.sql("""
select percentile(cosine_similarity, array(0.05, 0.25, 0.5, 0.75, 0.95)) as percentiles
from factor_similarity
""").collect()
[Row(percentiles=[-0.16253283694386483, 0.027862628921866417, 0.14594485610723495, 0.27255917340517044, 0.4497376471757889])]
# show some new user factors
spark.sql(f"""
select *
from factors_new t1
left join user_factors_new t2
on t1.id=t2.user_code
where t2.user_code_old is null
""").show()
+---------+--------------------+----+-------------+---------+--------+ | id| features| uid|user_code_old|user_code|features| +---------+--------------------+----+-------------+---------+--------+ |232653140|[0.8850426, -0.68...|null| null| null| null| |215956580|[-0.048879553, 0....|null| null| null| null| |148093929|[1.7840456, 0.737...|null| null| null| null| |146360840|[2.2643657, -1.01...|null| null| null| null| | 27153255|[1.2829087, 0.539...|null| null| null| null| | 18892414|[1.6726558, 0.231...|null| null| null| null| |111049551|[-2.944071, -2.86...|null| null| null| null| |169076896|[-0.0690281, 0.71...|null| null| null| null| | 27155829|[-1.5334187, -1.7...|null| null| null| null| |224888705|[1.0476906, 0.579...|null| null| null| null| |158927767|[1.9354192, -1.41...|null| null| null| null| |153152841|[-0.094397455, -0...|null| null| null| null| |158922285|[3.2785456, -0.05...|null| null| null| null| |150234227|[-0.90303326, 1.0...|null| null| null| null| |250488543|[-1.6161408, 0.43...|null| null| null| null| |224884246|[-0.8153594, 0.91...|null| null| null| null| |135786706|[1.3937851, 0.866...|null| null| null| null| |259180177|[-0.30027804, -0....|null| null| null| null| |227126401|[0.87676966, 0.54...|null| null| null| null| |112906177|[-0.5431016, -0.4...|null| null| null| null| +---------+--------------------+----+-------------+---------+--------+ only showing top 20 rows
3.2、验证:增量更新具有连续性¶
场景 1: 固定 Item 因子,更新 User 因子¶
# 场景 1: 固定 Item 因子,更新 User 因子
print("执行 stepUser...")
updatedUserFactorsDF = als_inc.stepUser(dataset, model.itemFactors)
print("stepUser 完成.")
updatedUserFactorsDF.show(5)
执行 stepUser... stepUser 完成. +-------+--------------------+ | id| features| +-------+--------------------+ |1172960|[0.7890749, 2.878...| |1172970|[-2.1875012, 1.22...| |1172980|[1.1610833, 0.045...| |1173020|[-2.1288333, 0.76...| |1173030|[-0.55638933, 2.8...| +-------+--------------------+ only showing top 5 rows
model.userFactors.createOrReplaceTempView("factors_old")
updatedUserFactorsDF.createOrReplaceTempView("factors_new")
spark.sql(f"""
create or replace temp view factor_similarity as
select
t1.id,
cosine_similarity(t1.features, t3.features) AS cosine_similarity
from factors_old t1
join user_factors_new t2
on t1.id=t2.user_code_old
join factors_new t3
on t3.id=t2.user_code
""")
DataFrame[]
spark.sql("select * from factor_similarity limit 10").show()
+---------+------------------+ | id| cosine_similarity| +---------+------------------+ |117623033|0.9203768968582153| |212243996|0.9786779284477234| |263219567|0.9904569983482361| |138026987| 0.967840313911438| |162931264|0.9622671604156494| | 21152721|0.9958115816116333| |259919373|0.8720948696136475| | 21467706|0.9725728631019592| |114434965|0.9195927977561951| |225156298|0.8778463006019592| +---------+------------------+
spark.sql("select count(1) from factor_similarity").show()
+--------+ |count(1)| +--------+ | 774178| +--------+
spark.sql("""
select percentile(cosine_similarity, array(0.05, 0.25, 0.5, 0.75, 0.95)) as percentiles
from factor_similarity
""").collect()
[Row(percentiles=[0.812578371167183, 0.9143407046794891, 0.9556348323822021, 0.9793210327625275, 0.994638392329216])]
# show some new user factors
spark.sql(f"""
select *
from factors_new t1
left join user_factors_new t2
on t1.id=t2.user_code
where t2.user_code_old is null
""").show()
+-------+--------------------+----+-------------+---------+--------+ | id| features| uid|user_code_old|user_code|features| +-------+--------------------+----+-------------+---------+--------+ |1173017|[0.871421, -0.444...|null| null| null| null| |1173027|[0.909961, 1.0669...|null| null| null| null| |1173589|[1.2461905, 2.190...|null| null| null| null| |1173873|[1.9630883, -1.64...|null| null| null| null| |1174156|[0.8216207, 1.327...|null| null| null| null| |1174317|[-0.28331318, 0.4...|null| null| null| null| |1174562|[0.26192972, 1.49...|null| null| null| null| |1174688|[0.14374854, -1.4...|null| null| null| null| |1174749|[-0.29708534, -0....|null| null| null| null| |1175025|[1.3713406, 0.058...|null| null| null| null| |1416764|[0.9838402, 1.362...|null| null| null| null| |1416860|[-0.32301664, 1.2...|null| null| null| null| |1417219|[-1.2974075, 3.19...|null| null| null| null| |1417766|[1.8327683, 0.022...|null| null| null| null| |1418443|[1.0985812, 0.163...|null| null| null| null| |1616245|[-0.2764664, 0.38...|null| null| null| null| |1616485|[-1.6654207, 0.27...|null| null| null| null| |1616704|[0.45543376, 0.31...|null| null| null| null| |1617069|[-0.95509017, 2.1...|null| null| null| null| |1617107|[0.8142003, 0.325...|null| null| null| null| +-------+--------------------+----+-------------+---------+--------+ only showing top 20 rows
场景 2: 固定 User 因子,更新 Item 因子¶
# 场景 2: 固定 User 因子,更新 Item 因子
print("执行 stepItem...")
updatedItemFactorsDF = als_inc.stepItem(dataset, user_factors_new.select(["uid", "user_code", "features"]).withColumnRenamed("user_code", "id"))
print("stepItem 完成.")
updatedItemFactorsDF.show(5)
执行 stepItem... stepItem 完成. +---+--------------------+ | id| features| +---+--------------------+ | 10|[0.011689293, 0.0...| | 20|[-0.04079031, 0.0...| | 30|[-0.038054943, 0....| | 40|[0.0066935807, 0....| | 60|[0.016619585, -0....| +---+--------------------+ only showing top 5 rows
model.itemFactors.createOrReplaceTempView("factors_old")
updatedItemFactorsDF.createOrReplaceTempView("factors_new")
spark.sql(f"""
create or replace temp view factor_similarity as
select
t1.id,
cosine_similarity(t1.features, t2.features) AS cosine_similarity
from factors_old t1
join factors_new t2
on t1.id=t2.id
""")
DataFrame[]
spark.sql("select * from factor_similarity limit 10").show()
+---+------------------+ | id| cosine_similarity| +---+------------------+ | 10|0.9931339025497437| | 20|0.9904462099075317| | 30|0.9818369746208191| | 40|0.9907748699188232| | 60|0.9800931811332703| | 70|0.9680361747741699| | 80|0.8900574445724487| |100|0.9880322813987732| |110|0.9266476631164551| |120|0.9420976042747498| +---+------------------+
spark.sql("select count(1) from factor_similarity").show()
+--------+ |count(1)| +--------+ | 50133| +--------+
spark.sql("""
select percentile(cosine_similarity, array(0.05, 0.25, 0.5, 0.75, 0.95)) as percentiles
from factor_similarity
""").collect()
[Row(percentiles=[0.8721876740455627, 0.9683352112770081, 0.9871618151664734, 0.9942935705184937, 0.9982350945472718])]