๐ง Optimizing Repartitioning & Minimizing Shuffling in PySpark
Repartitioning is essential in distributed computing to optimize parallel execution, but excessive shuffling can degrade performance. Here’s how to handle it efficiently:
๐น 1๏ธโฃ Understanding Repartitioning Methods
1. repartition(n)
โ Increases parallelism but causes full shuffle
df = df.repartition(10) # Redistributes into 10 partitions
โ Use Case: When load balancing is needed (e.g., skewed data).
โ Downside: Full shuffle across worker nodes.
2. repartition(col)
โ Redistributes based on column values
df = df.repartition("category") # Partition based on 'category'
โ Use Case: Optimizes joins and aggregations when filtering by category
.
โ Downside: Shuffles data across the cluster.
3. coalesce(n)
โ Reduces partitions without full shuffle
df = df.coalesce(4) # Reduce to 4 partitions
โ Use Case: Used after filtering to reduce shuffle & optimize performance.
โ Downside: Cannot increase partitions, only reduces.
๐น 2๏ธโฃ How to Minimize Shuffling?
1๏ธโฃ Repartition Early in ETL
If you know that your data will be processed by a specific column, repartition before transformations:
df = df.repartition("category") # Avoid unnecessary shuffle in later joins/aggregations
โ Prevents multiple shuffle operations later.
2๏ธโฃ Use broadcast()
Instead of Repartitioning for Small Tables
For small lookup tables, use broadcast join instead of repartitioning:
from pyspark.sql.functions import broadcast
df_final = df_large.join(broadcast(df_small), "id", "inner")
โ Reduces shuffle when joining a large table with a small table.
3๏ธโฃ Optimize Joins Using DISTRIBUTE BY
Instead of repartition()
If joining two large tables, use DISTRIBUTE BY
instead of repartition:
df.createOrReplaceTempView("big_table")
df_small.createOrReplaceTempView("small_table")
query = """
SELECT * FROM big_table
DISTRIBUTE BY category
JOIN small_table
ON big_table.category = small_table.category
"""
df_final = spark.sql(query)
โ Distributes data efficiently before joining, reducing shuffle.
4๏ธโฃ Use coalesce()
Instead of repartition()
for Output
If reducing partitions before saving to disk, use coalesce() to avoid full shuffle:
df_final.coalesce(1).write.csv("output.csv", header=True)
โ Avoids full shuffle while reducing partitions.
๐น 3๏ธโฃ Example: Combining All Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize Spark
spark = SparkSession.builder \
.appName("Optimized_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") \
.config("spark.executor.memory", "4g") \
.config("spark.memory.fraction", "0.8") \
.getOrCreate()
# Load Large Data
df_large = spark.read.parquet("large_data.parquet").repartition("category") # Repartition early
df_small = spark.read.parquet("small_data.parquet")
# Optimize Join
if df_small.count() < 1000000:
df_final = df_large.join(broadcast(df_small), "id", "inner")
else:
df_final = df_large.join(df_small, "id", "inner").repartition("category")
# Reduce Partitions Before Writing
df_final = df_final.coalesce(4)
df_final.write.parquet("optimized_output.parquet")
spark.stop()
โ Repartitioned Early
โ Used Broadcast Join for small tables
โ Distributed Large Table Before Join
โ Reduced Partitions Before Writing
๐ Summary: Best Practices for Repartitioning
โ
Use repartition(col)
early to reduce shuffling later
โ
Use broadcast()
for small tables instead of repartitioning
โ
Use DISTRIBUTE BY
instead of repartition()
for large joins
โ
Use coalesce()
to reduce partitions before writing
๐ง Ensuring Early Repartition & Sequential Execution in PySpark
By default, PySpark follows lazy evaluation, meaning transformations (like repartition()
) are not executed immediately. They are only triggered when an action (e.g., .count()
, .show()
, .write()
) is called.
๐น 1๏ธโฃ Ensuring Repartition Happens Early
Since repartition is a transformation, Spark does not execute it immediately. To force execution, we can use an action immediately after repartitioning:
df = df.repartition("category") # Repartition before heavy transformations
df.count() # Triggers the repartition immediately
โ Ensures data is shuffled before moving to the next steps.
๐น 2๏ธโฃ Forcing Sequential Execution
To make sure operations happen in order, use actions after key transformations:
df = df.repartition("category") # Step 1: Repartition
df.cache().count() # Step 2: Trigger execution & cache result
df_filtered = df.filter("amount > 100") # Step 3: Apply filter
df_filtered.count() # Step 4: Force execution
df_aggregated = df_filtered.groupBy("category").sum("amount") # Step 5: Aggregate
df_aggregated.show() # Step 6: Trigger execution
โ Each step is executed in sequence
โ Avoids unnecessary recomputation (because of cache()
)
๐น 3๏ธโฃ Using persist()
or cache()
for Sequential Execution
Instead of triggering execution with .count()
, use caching:
df = df.repartition("category").persist() # Persist after repartition
df.count() # Triggers execution & caches partitioned data
df_filtered = df.filter("amount > 100")
df_filtered.persist()
df_filtered.count() # Executes filter before moving ahead
โ Ensures data is partitioned before filtering
โ Reduces recomputation in later stages
๐น 4๏ธโฃ Using checkpoint()
for Strict Sequential Execution
If data is very large, use checkpointing instead of caching:
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")
df = df.repartition("category").checkpoint() # Save intermediate state
df.count() # Executes the repartition
df_filtered = df.filter("amount > 100").checkpoint() # Checkpoint filtered data
df_filtered.count() # Ensures execution before proceeding
โ Forces Spark to save intermediate results to disk
โ Prevents re-execution of previous steps
๐น 5๏ธโฃ Using foreachPartition()
to Trigger Execution
Another way to ensure sequential execution is using foreachPartition(), which triggers an action for each partition:
df.repartition("category").foreachPartition(lambda x: list(x)) # Forces execution
โ Ensures repartitioning is completed before moving ahead.
๐ Best Practices for Sequential Execution in PySpark
โ
Use .count()
after repartitioning to trigger execution
โ
Use .persist()
or .cache()
to avoid recomputation
โ
Use .checkpoint()
for large datasets
โ
Use .foreachPartition()
to force execution per partition
# Best Practice Template for PySpark SQL API & CTE-based ETL with Optimizations
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# Initialize Spark Session
spark = SparkSession.builder \
.appName("PySparkSQL_ETL") \
.config("spark.sql.autoBroadcastJoinThreshold", "524288000") # Set auto-broadcast threshold to 500MB
.config("spark.executor.memory", "4g") # Increase executor memory
.config("spark.driver.memory", "2g") # Increase driver memory
.config("spark.memory.fraction", "0.8") # Allocate more memory for computation
.getOrCreate()
# Set Checkpoint Directory
spark.sparkContext.setCheckpointDir("/tmp/checkpoint_dir")
# Sample Data (Creating a DataFrame)
data = [(1, "A", "active", 100),
(2, "B", "inactive", 200),
(3, "A", "active", 150),
(4, "C", "active", 120),
(5, "B", "inactive", 300)]
columns = ["id", "category", "status", "amount"]
df = spark.createDataFrame(data, columns)
# Repartition Early & Trigger Execution
print("Repartitioning DataFrame...")
df = df.repartition("category").persist()
df.count() # Forces execution & caching
# Approach 1: Using Temp Views for Step-by-Step ETL
df.createOrReplaceTempView("source_data")
# Step 1: Filter Active Records
filtered_query = """
SELECT * FROM source_data WHERE status = 'active'
"""
filtered_df = spark.sql(filtered_query).checkpoint()
filtered_df.count() # Ensures execution before proceeding
filtered_df.createOrReplaceTempView("filtered_data")
# Cache intermediate result
spark.sql("CACHE TABLE filtered_data")
# Step 2: Aggregation
aggregated_query = """
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
"""
aggregated_df = spark.sql(aggregated_query).persist()
aggregated_df.count() # Forces execution
aggregated_df.show()
# Approach 2: Using CTE for Optimized Query Execution
cte_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
aggregated_data AS (
SELECT category, SUM(amount) AS total_amount
FROM filtered_data
GROUP BY category
)
SELECT * FROM aggregated_data
"""
cte_df = spark.sql(cte_query).checkpoint()
cte_df.count() # Ensures execution
cte_df.show()
# Additional Example: Using Multiple CTEs for Complex Transformations
complex_query = """
WITH filtered_data AS (
SELECT * FROM source_data WHERE status = 'active'
),
ranked_data AS (
SELECT *, RANK() OVER (PARTITION BY category ORDER BY amount DESC) AS rank
FROM filtered_data
)
SELECT * FROM ranked_data WHERE rank = 1
"""
ranked_df = spark.sql(complex_query).checkpoint()
ranked_df.count() # Ensures execution
ranked_df.show()
# Broadcast Join Optimization
small_data = [(1, "extraA"), (2, "extraB"), (3, "extraC")]
small_columns = ["id", "extra_info"]
df_small = spark.createDataFrame(small_data, small_columns)
# Decide whether to broadcast based on size
if df_small.count() < 1000000: # Example: Broadcast if less than 1 million rows
df_final = df.join(broadcast(df_small), "id", "inner")
else:
df_final = df.join(df_small, "id", "inner").repartition("category")
df_final.persist()
df_final.count() # Forces execution before writing
df_final.write.mode("overwrite").parquet("optimized_output.parquet")
# Closing Spark Session
spark.stop()
Leave a Reply