一、为什么需要共享变量?
1. 普通全局变量的分布式缺陷
Spark程序分为Driver主进程与多个Executor工作进程。算子闭包中引用的外部全局变量会参与任务序列化,每一个Task都会单独生成一份独立副本。
每个Task仅能修改自身本地副本,修改逻辑无法同步回Driver主节点,也不能跨Task共享数据。任务数量越多,重复拷贝带来的网络传输、内存开销越大,无法实现分布式全局计数、全局维表匹配等业务需求。
2. Spark提供两类专用共享变量
| 类型 | 核心用途 | 核心特性 |
|---|---|---|
| 广播变量Broadcast | 高效下发只读维表、静态配置,消除Task重复拷贝的资源损耗 | 每个Executor仅存一份数据,节点内所有Task共享;推荐存储不可变对象 |
| 累加器Accumulator | 分布式全局计数、数值汇总,分区结果统一汇总至Driver | Executor仅支持增量add操作,原生数字累加器仅加法;自定义AccumulatorV2可实现任意复杂聚合 |
二、广播变量(Broadcast Variable)
1. 普通变量直接分发的性能缺陷
# Driver端定义映射字典
user_province = {101:"广东",102:"北京",103:"上海",104:"江苏"}
user_logs = sc.parallelize([101,102,101,102])
# 每个Task都会完整拷贝字典,任务越多开销越大
result = user_logs.map(lambda uid: (uid, user_province.get(uid, "未知")))
若集群存在上千个Task,字典会重复传输上千次,大维度字典场景会严重占用带宽与节点内存。
2. 标准正确使用方式
# Driver封装广播变量
broadcast_province = sc.broadcast(user_province)
# 算子通过.value读取原始数据
result = user_logs.map(lambda uid: (uid, broadcast_province.value.get(uid, "未知")))
result.collect()
3. 核心使用规则
- 创建:仅允许在Driver主线程执行
sc.broadcast(原始数据); - 读取:算子内部必须通过
.value获取包裹的数据; - 约束:广播引用对象无法重新赋值;若
.value内是list、dict等可变容器,算子中修改容器元素不会报错,但会造成集群数据不一致,业务中禁止此类操作,建议广播元组、常量字典等不可变数据; - 资源释放:大广播维表使用完毕调用
unpersist()释放内存,支持blocking=True开启阻塞释放,适合超大维度表场景。
4. 广播变量完整可运行综合案例
from pyspark import SparkConf, SparkContext
if __name__ == "__main__":
conf = SparkConf().setMaster("local[*]").setAppName("broadcast_demo")
sc = SparkContext(conf=conf)
# 静态区域维度映射
area_map = {1:"华南", 2:"华北", 3:"华东"}
# 封装广播变量
bc_area = sc.broadcast(area_map)
# 用户ID数据流
user_rdd = sc.parallelize([1,2,3,1,2,3,1])
# 维表关联匹配
res_rdd = user_rdd.map(lambda x: (x, bc_area.value.get(x, "未知区域")))
print(res_rdd.collect())
# 释放广播占用内存(修复变量名bug)
bc_area.unpersist()
# 阻塞释放写法:bc_area.unpersist(blocking=True)
sc.stop()
# 输出结果:[(1, '华南'), (2, '华北'), (3, '华南'), (1, '华南'), (2, '华北'), (3, '华东'), (1, '华南')]
三、累加器(Accumulator)
1. 普通变量分布式计数失效原理
count = 0
data = sc.parallelize([1,15,8,22,30,5])
def process_num(num):
global count
if num>10:
count +=1
return num
data.map(process_num).collect()
print(count) # 输出恒为0
每个Task持有独立count副本,仅修改本地值,无法同步修改Driver原始变量,全局统计逻辑完全失效。
2. 原生数字累加器标准写法
# Driver创建累加器,初始值为全局汇总基准
accum = sc.accumulator(0)
data = sc.parallelize([1,15,8,22,30,5])
def process_num(num):
# Python函数内修改外部共享变量,必须声明global
global accum
if num>10:
accum.add(1)
return num
data.map(process_num).collect()
print(accum.value) # 输出3
3. 核心使用规范
- 创建:仅Driver执行
sc.accumulator(初始基准值); - 更新:Executor仅能调用
.add()执行增量叠加,不支持直接赋值覆盖; - 读写区分:算子内禁止读取
.value;local[*]多线程本地环境偶尔巧合读出完整值,但集群环境100%统计失真,仅全部任务执行完毕后,在Driver读取全局汇总值; - 底层逻辑:各分区独立局部累加,任务结束后分区结果统一合并至Driver。
4. 分层全套可运行累加器案例
案例1 基础数据条数统计
from pyspark import SparkConf, SparkContext
if __name__ == "__main__":
conf = SparkConf().setMaster("local[*]").setAppName("acc_base")
sc = SparkContext(conf=conf)
acc = sc.accumulator(0)
rdd = sc.parallelize([1,2,3,4,5])
rdd.foreach(lambda x: acc.add(1))
print("总数据量:", acc.value)
sc.stop()
# 输出:总数据量:5
案例2 多次Action重复计数演示
from pyspark import SparkConf, SparkContext
if __name__ == "__main__":
conf = SparkConf().setMaster("local[*]").setAppName("acc_repeat")
sc = SparkContext(conf=conf)
acc = sc.accumulator(0)
rdd = sc.parallelize([1,2,3])
# foreach、map等转换算子配合多次collect都会重复执行DAG
rdd.foreach(lambda x: acc.add(1))
print("第一次统计:", acc.value)
rdd.foreach(lambda x: acc.add(1))
print("第二次统计:", acc.value)
sc.stop()
# 输出
# 第一次统计:3
# 第二次统计:6
原理:RDD惰性求值,每次Action会完整重算整条DAG;Stage任务失败重试同样触发重复累加。
案例3 算子内部创建累加器失效演示
from pyspark import SparkConf, SparkContext
if __name__ == "__main__":
conf = SparkConf().setMaster("local[*]").setAppName("acc_error")
sc = SparkContext(conf=conf)
# Driver定义全局累加器
global_acc = sc.accumulator(0)
rdd = sc.parallelize([1,2,3])
# 算子内新建局部累加器,与全局完全无关
rdd.foreach(lambda x: sc.accumulator(0).add(1))
print("全局统计总数:", global_acc.value)
sc.stop()
# 输出:全局统计总数:0
案例4 foreachPartition分区优化统计
from pyspark import SparkConf, SparkContext
if __name__ == "__main__":
conf = SparkConf().setMaster("local[*]").setAppName("acc_opt")
sc = SparkContext(conf=conf)
acc = sc.accumulator(0)
rdd = sc.parallelize([1,2,3,4,5,6,7,8,9], numSlices=3)
def batch_count(iter_data):
for item in iter_data:
acc.add(1)
rdd.foreachPartition(batch_count)
print("总数据量:", acc.value)
sc.stop()
# 输出:总数据量:9
优势:foreach逐条触发函数,foreachPartition以分区批量处理,海量数据大幅减少函数交互开销。
案例5 自定义AccumulatorV2复杂聚合
from pyspark import SparkConf, SparkContext
from pyspark.accumulators import AccumulatorV2
class NegativeAcc(AccumulatorV2[int, list]):
def __init__(self):
self.neg_list = []
def isZero(self):
return len(self.neg_list) == 0
def copy(self):
# 任务重试/快照时复制累加器实例
new_acc = NegativeAcc()
new_acc.neg_list = self.neg_list.copy()
return new_acc
def reset(self):
self.neg_list.clear()
def add(self, val):
if val < 0:
self.neg_list.append(val)
def merge(self, other):
# 合并不同Executor分区的累加结果
self.neg_list.extend(other.neg_list)
def value(self):
return self.neg_list
if __name__ == "__main__":
conf = SparkConf().setMaster("local[*]").setAppName("acc_v2")
sc = SparkContext(conf=conf)
neg_acc = NegativeAcc()
sc.register(neg_acc)
data_rdd = sc.parallelize([-1, 4, -6, 9, -2])
data_rdd.foreach(lambda x: neg_acc.add(x))
print("所有负数:", neg_acc.value)
sc.stop()
# 输出:所有负数:[-1, -6, -2]
补充:原生sc.accumulator仅支持加法;AccumulatorV2可自定义任意聚合逻辑(最值、集合、减法等)。
四、广播变量与累加器核心区分对照表
| 对比维度 | 广播变量Broadcast | 累加器Accumulator |
|---|---|---|
| 数据流向 | Driver下发数据至所有Executor | 各Executor汇总数据上传至Driver |
| 修改权限 | 广播引用不可重赋值;value可变容器禁止修改 | 仅支持add增量,原生无其他运算,自定义可扩展 |
| 共享粒度 | 单Executor内所有Task共用一份数据 | 仅Driver创建实例全局生效,算子新建实例无效 |
| 典型业务场景 | 维表关联、静态全局配置 | 日志总量、异常数据、分布式数值统计 |
五、实操避坑全要点
- 广播变量禁止在算子内修改list、dict可变容器,会引发数据不一致;大维表使用完毕调用
unpersist()释放内存,超大表可传入blocking=True阻塞释放。 - 原生数字累加器仅支持加法,集合、最值等复杂聚合需求使用AccumulatorV2自定义累加器。
- 累加器不可在map、filter转换算子内部读取
.value,本地测试正常,集群统计失真。 - 多次Action、Stage失败重试都会触发DAG重算,累加数值重复叠加;对统计精度要求高的业务,每次独立统计建议新建累加器实例。
- 广播、累加器两类共享变量均只能在Driver主线程实例化,算子内部创建完全失效。
- 仅做统计无数据转换逻辑时,优先使用foreachPartition,避免map生成冗余RDD占用集群内存。
转载自 CSDN-专业IT技术社区
原文链接:https://blog.csdn.net/2301_79652681/article/details/162244854



