Hereβs a practical toolbox π§° of methods, tools, and commands you can use to debug PySpark jobs, check number of partitions, inspect queries, and instantly detect errors β whether in Databricks or standard PySpark environments.
π§° PySpark Debugging & Introspection Toolkit
β 1. Instant Error Detection
πΉ A. Use .explain()
Early
df.explain(True) # or df.explain(mode="formatted")
- Shows the full logical and physical plan
- Helps detect inefficient operations (e.g., cartesian joins, large shuffles)
- Reveals if Spark chooses broadcast join or not
πΉ B. Wrap in try-except
try:
df = spark.read.csv("path/to/file")
df.show()
except Exception as e:
print("ERROR:", str(e))
- Useful to instantly trap schema mismatches, bad file formats, etc.
β 2. Know Number of Partitions
A. Count partitions of a DataFrame
df.rdd.getNumPartitions()
B. Count partitions after join or groupBy
df_after = df1.join(df2, "id")
print("Partitions:", df_after.rdd.getNumPartitions())
C. Repartition Example
df = df.repartition(200, "region") # or .coalesce(10)
β 3. Trace Query Plans
πΉ Logical & Physical Plan
df.explain(True)
Key things to spot:
- Exchange = shuffle
- BroadcastHashJoin = good (for small tables)
- SortMergeJoin = costly
- WholeStageCodegen = optimized execution
πΉ Show SQL Plan with AQE Changes
spark.conf.set("spark.sql.adaptive.enabled", "true")
df.explain(True)
- Adaptive plans will show “ReusedExchange”, skew join optimization, etc.
β 4. Profiling Execution in Spark UI
Open Spark UI (e.g., via Databricks or YARN):
Tab | What It Shows |
---|---|
Stages | Shuffle operations, task skew |
Tasks | Duration per task, GC time, Input size |
SQL | Query timeline and physical plan |
Storage | Cache usage |
Executors | Memory usage per executor, disk spill |
π§ Pro tip: Look for long-running tasks β may indicate data skew or bad partitioning.
β 5. Schema Inspection
df.printSchema()
df.dtypes
df.schema.simpleString()
Detects:
- Wrong column types (e.g., string instead of int)
- Nested columns (structs, arrays)
β 6. Data Sampling & Inspection
df.limit(5).show()
df.select("column").distinct().show()
Use .cache()
on reused DataFrames before joins:
df.cache().count()
β 7. Detect Skew or Heavy Partitions
πΉ Estimate Row Count Per Partition
df.rdd.mapPartitions(lambda x: [sum(1 for _ in x)]).collect()
- If you see one partition with 1 million rows and others with 1k β data skew!
πΉ Detect Wide/Skewed Joins
df.explain(True) # Look for: SortMergeJoin, Exchange, SkewedJoin
- Consider salting, broadcast join, or changing keys.
β 8. Logging & Debug Prints
Use Spark logs:
spark.sparkContext.setLogLevel("INFO") # or DEBUG
Or log specific stages:
print("Start reading...")
df = spark.read.json("file.json")
print("Finished reading.")
β 9. Check Memory/Cache Usage
spark.catalog.isCached("df_name")
spark.catalog.clearCache()
Spark UI > Storage tab β shows:
- Cached RDDs/DataFrames
- Size in memory/disk
β 10. Databricks-Specific Debug Tips
- Use
%sql
and View lineage via SQL editor - Job run pages β stage/task duration, skew alerts
- Delta Tables:
- Use
DESCRIBE DETAIL table
OPTIMIZE
,ZORDER
,VACUUM
if query is slow
- Use
- Cluster config:
- Check driver/executor memory
- Use Ganglia metrics for memory usage graphs
π― Summary: Go-To Commands Cheat Sheet
Task | Command |
---|---|
Show schema | df.printSchema() |
Count partitions | df.rdd.getNumPartitions() |
Explain plan | df.explain(True) |
Show distinct keys | df.select(\"col\").distinct().show() |
Broadcast join | broadcast(small_df) |
Sample rows | df.limit(10).show() |
Count rows per partition | df.rdd.mapPartitions(...) |
View Spark UI | Spark Web UI or Databricks UI |
# PySpark Debugging & Introspection Notebook
from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast
# β
1. Spark Session
spark = SparkSession.builder \
.appName("PySpark Debugging Tools") \
.config("spark.sql.adaptive.enabled", "true") \
.config("spark.executor.memory", "4g") \
.config("spark.driver.memory", "2g") \
.getOrCreate()
# β
2. Sample Data
rdd = spark.sparkContext.parallelize([
(1, "Alice", 29, "IN"),
(2, "Bob", 41, "US"),
(3, "Charlie", 35, "IN"),
(4, "David", 23, "UK"),
(5, "Eve", 36, "US")
])
df = rdd.toDF(["id", "name", "age", "country"])
# β
3. Print Schema & Data
df.printSchema()
df.show()
# β
4. Partition Details
print("\nNumber of Partitions:", df.rdd.getNumPartitions())
# Count Rows Per Partition
rows_per_partition = df.rdd.mapPartitions(lambda x: [sum(1 for _ in x)]).collect()
print("Rows per Partition:", rows_per_partition)
# β
5. Explain Query Plan
df_filtered = df.filter(df.age > 30)
df_filtered.explain(True)
# β
6. SQL Registration and Query
df.createOrReplaceTempView("people")
sql_df = spark.sql("SELECT country, COUNT(*) AS cnt FROM people GROUP BY country")
sql_df.show()
# β
7. Broadcast Join Demo
small_df = spark.createDataFrame([("IN", "India"), ("US", "United States"), ("UK", "United Kingdom")], ["code", "country_name"])
broadcast_df = df.join(broadcast(small_df), df.country == small_df.code, "left")
broadcast_df.show()
# β
8. Cache & Check
from pyspark import StorageLevel
cached_df = df_filtered.persist(StorageLevel.MEMORY_AND_DISK)
cached_df.count() # Triggers caching
print("Is Cached:", spark.catalog.isCached("people"))
# β
9. Stop Session
spark.stop()