đâŻWhat Are Accumulators in PySpark?
Accumulators are writeâonly shared variables that executors can only add to, while the driver can read their aggregated value after an action completes.
Feature | Detail |
---|---|
Purpose | Collect sideâeffect statistics (counters, sums) during distributed computation |
Visibility | Executors: canâŻadd() Driver: can read result (only reliable after an action) |
Data types | Builtâins: LongAccumulator , DoubleAccumulator , CollectionAccumulator Custom: subclass AccumulatorV2 |
Fault tolerance | Each task keeps a local copy; value is merged on the driver â survives retries (may add twice if task is rerun) |
UI | Shows up in the Spark UI (âAccumulatorsâ tab) for quick debugging |
đď¸âŻHow Accumulators Are Implemented (Under the Hood)
- Created on the driver via
spark.sparkContext.longAccumulator("bad_rows")
. - Serialized and sent to each executor with the task.
- Taskâlocal slot: each executor has its own copy; it calls
add()
without locking. - Task completion: executor sends the local delta back to the driver through the schedulerâs RPC.
- Driver merges deltas using the accumulatorâs
merge()
(for numeric types, an addition). - Spark UI & driver code can read the final value after the action (e.g.,
collect
,count
,write
, âŚ).
Important: If a task is retried Spark may add its contribution again, so accumulators are atâleastâonce; they should not drive program logic, only diagnostics.
đ ď¸âŻCreating & Using Accumulators
1.âŻBuiltâin LongAccumulator
bad_rows_acc = spark.sparkContext.longAccumulator("bad_rows")
def parse(line):
try:
return int(line)
except ValueError:
bad_rows_acc.add(1)
return None # or raise
rdd = spark.sparkContext.textFile("numbers.txt")
valid = rdd.map(parse).filter(lambda x: x is not None)
valid.sum() # â action triggers computation
print("Bad rows:", bad_rows_acc.value) # safe to read here
2.âŻCustom Accumulator (AccumulatorV2)
from pyspark.accumulators import AccumulatorParam # <= Spark 2.0 legacy
from pyspark.sql import SparkSession
from pyspark import AccumulatorV2
class SetAccumulator(AccumulatorV2):
def __init__(self, initial=None):
self._set = set() if initial is None else set(initial)
def isZero(self):
return len(self._set) == 0
def copy(self):
return SetAccumulator(self._set)
def reset(self):
self._set.clear()
def add(self, value):
self._set.add(value)
def merge(self, other):
self._set |= other._set
def value(self):
return self._set
spark = SparkSession.builder.getOrCreate()
unique_errs = SetAccumulator()
spark.sparkContext.register(unique_errs, "errCodes")
# add inside tasks âŚ
đâŻTypical UseâCases
Category | Example |
---|---|
Dataâquality counters | Count malformed records, null columns, late events |
ETL metrics | Track number of rows skipped, converted, anonymized |
Model training | Count labelâdistribution skew, feature missingness |
Debugging | Verify that a branch of logic executes expected # of times |
Performance tuning | Count heavy joins, fallback partitions, spills (with custom instrumented code) |
â ď¸âŻLimitations & Gotchas
- Not idempotent: task retries may duplicate additions â use only for approximate or logging stats.
- ExecutorsâŻcanât read accumulator value â itâs writeâonly on workers.
- Lazy eval: value appears only after an action; reading before returns 0.
- No controlâflow: never make program decisions based on accumulator midâjob (could be inconsistent).
- Streaming structured queries: not supported inside streaming map functions.
đŞâŻBest Practices
Tip | Why |
---|---|
Give them names (longAccumulator("bad_rows") ) | Easier to spot in Spark UI |
Read after an action, not in the middle | Guarantees aggregation complete |
For custom types, subclass AccumulatorV2 | Provides isZero , merge , etc. |
Avoid large perâtask objects | Send small diffs; else network overhead |
Reset between jobs if reused | accumulator.reset() on the driver |
â âŻTakeaway
Accumulators are simple, driverâvisible counters designed for instrumentation and diagnostics in PySpark:
Executors add values â Driver reads final metric.
Why do we need âspecialâ counters at all?
In a singleâmachine Python script | In a Spark job |
---|---|
You can increment a normal count += 1 variable inside a for loop. | Your code is split across dozens or hundreds of JVM/Python worker processes. Each worker has its own copy of every variable, so a simple count += 1 only changes a local copy that the driver never sees. |
Accumulators exist to collect simple metrics coming from many distributed tasks back to the driver in one aggregated value.
They solve the âmany workers â one totalâ problem without you having to collect()
the whole dataset or run a second pass with groupBy()
/count()
.
Three ways to keep a counter in PySpark
Method | Where it runs | How it works | Pros | Cons | Typical use |
---|---|---|---|---|---|
1.âŻDataFrame/RDD aggregationdf.filter(...).count() | Executed on the cluster, result sent to driver | Uses Sparkâs builtâin shuffle & reduce | Accurate, faultâtolerant | Triggers another full pass over the data | Business metrics youâll query anyway |
2.âŻManual perâpartition returnmapPartitions â yield (data, counter) then reduce | Workers compute partial counts, return them as records | Pure functional, deterministic | No side effects; full control | You must change the data path (add extra columns / union extra RDD) | When the counter is part of the pipeline output |
3.âŻAccumulator (e.g. LongAccumulator ) | Workers call acc.add(1) ; Spark merges on driver after the action | Sideâeffect counter, zero extra data shuffle | Almost zero overhead; shows in SparkâŻUI | âAtâleastâonceâ update â slight overâcount if tasks retry; writeâonly from executors | Debugging, dataâquality tallies, progress metrics |
Key distinction
Accumulators are not meant to replace normal aggregations; they are a lightâweight tap for side metrics when you donât want to alter the main data flow.
Minimal working examples
1. Classic aggregation (accurate but full extra pass)
errors = df.filter("status = 'ERROR'").count() # action #1
processed = df.count() # action #2
print(errors, processed)
Two actions â two jobs.
2. MapPartitions pattern (brings counters back in data)
def tag_and_count(iterator):
bad = 0
for row in iterator:
if row.status == "ERROR":
bad += 1
continue
yield row # only good rows pass downstream
yield {"_counter": bad} # add one special row
cleaned = (df.rdd
.mapPartitions(tag_and_count)
.toDF(df.schema) # schema must match; messy
)
Accurate, but youâve polluted the data with âfakeâ rows.
3. Accumulator (side metric, oneâliner)
bad_acc = spark.sparkContext.longAccumulator("bad_rows")
def filter_good(row):
if row.status == "ERROR":
bad_acc.add(1) # every executor just âfireâandâforgetsâ
return False
return True
good_df = df.filter(filter_good)
good_df.write.mode("overwrite").parquet("/data/good") # single job
print("Bad rows seen:", bad_acc.value) # safe AFTER the write
No extra shuffle, no schema tricks, counter pops up in SparkâŻUI.
When not to use accumulators
- Driving control flow
if bad_acc.value > 1000: # â dangerous: might still be 0 midâjob spark.cancelJobGroup(...)
The value is only reliable after the action completes. - Exact audit counts for financial/regulatory reporting
Retries may doubleâincrement; use a deterministic aggregation instead.
When accumulators shine
Use case | Why a normal aggregation is heavy / awkward |
---|---|
Counting malformed JSON lines while streaming a 200âŻGB file into Parquet | Youâd have to read the file twice or cache it just for the error tally. |
Tracking how many records hit the âfallbackâ code path inside a big ML featureâengineering pipeline | You donât want an extra column or a second pass. |
Emitting simple health metrics to the SparkâŻUI so ops can see them in real time | Accumulators surface automatically in the UI. |
Quick checklist for using accumulators safely
- Name them:
sparkContext.longAccumulator("late_events")
- Add, donât read, in executors
- Read only after an action or at
foreachBatch
end - Expect slight overâcount if tasks may retry
- Reset (
acc.reset()
) before the next independent job if reused
TL;DR
- Counters in distributed Spark require a mechanism that all workers can update and the driver can read.
- Accumulators give you that with almost zero coding overhead, perfect for metrics & debugging, but theyâre not a substitute for real aggregations when you need 100âŻ% accuracy.
Great question â letâs break this down carefully.
đš Short answer:
No, a monotonically increasing column (like
monotonically_increasing_id()
) is not a counter in the same sense as an accumulator.
They serve very different purposes.
đ§ Comparison: Counter vs Monotonically Increasing ID
Feature | monotonically_increasing_id() | Accumulator |
---|---|---|
What is it? | A generated column that assigns a unique (but not strictly sequential) ID to each row | A shared variable updated by executors, aggregated on driver |
Scope | Row-level; runs during transformations | Cross-job; updates live while job runs |
Purpose | Add unique row IDs (for joins, surrogate keys) | Count metrics like bad rows, nulls, filtered items |
Type | Column (LongType ) | Driver-side variable (LongAccumulator ) |
Is it guaranteed to be contiguous or sequential? | â No â values are increasing but not contiguous | â
Accumulator grows exactly as you add() |
Used for control logic or metrics? | â No â just a data transformation | â Yes â intended for metrics/debug counters |
đ monotonically_increasing_id()
in Detail
from pyspark.sql.functions import monotonically_increasing_id
df = df.withColumn("row_id", monotonically_increasing_id())
- Spark creates a unique long number per row, often based on partition ID + record ID.
- Not guaranteed to start from 0 or be sequential:
row_id: 0, 1, 4294967296, 4294967297, ...
Useful for:
- Creating row identifiers in otherwise unordered datasets.
- Doing joins where unique keys are missing.
But not useful for:
- Tallying number of errors, nulls, fallback cases
- Tracking distributed metrics
- Logging or UI metrics
đ Example Comparison
â
Want to count rows with NULLs in column amount
?
Accumulator:
nulls_acc = spark.sparkContext.longAccumulator("null_amounts")
df_filtered = df.filter(lambda row: not (row.amount is None or nulls_acc.add(1)))
df_filtered.write.parquet("/cleaned")
print("Rows with nulls:", nulls_acc.value)
monotonically_increasing_id():
df = df.withColumn("row_id", monotonically_increasing_id())
This helps assign IDs but doesnât count anything.
â Summary Table
You want to… | Use |
---|---|
Track how many bad records occurred | Accumulator |
Get a unique ID per row | monotonically_increasing_id() |
Count how many rows matched a condition | Accumulator or df.filter(...).count() |
Generate surrogate keys for a dimension table | monotonically_increasing_id() |
Show metrics in Spark UI for debugging | Accumulator |
Leave a Reply