PySpark Debugging & Introspection Toolkit

Here’s a practical toolbox 🧰 of methods, tools, and commands you can use to debug PySpark jobs, check number of partitions, inspect queries, and instantly detect errors β€” whether in Databricks or standard PySpark environments.


🧰 PySpark Debugging & Introspection Toolkit


βœ… 1. Instant Error Detection

πŸ”Ή A. Use .explain() Early

df.explain(True)  # or df.explain(mode="formatted")
  • Shows the full logical and physical plan
  • Helps detect inefficient operations (e.g., cartesian joins, large shuffles)
  • Reveals if Spark chooses broadcast join or not

πŸ”Ή B. Wrap in try-except

try:
    df = spark.read.csv("path/to/file")
    df.show()
except Exception as e:
    print("ERROR:", str(e))
  • Useful to instantly trap schema mismatches, bad file formats, etc.

βœ… 2. Know Number of Partitions

A. Count partitions of a DataFrame

df.rdd.getNumPartitions()

B. Count partitions after join or groupBy

df_after = df1.join(df2, "id")
print("Partitions:", df_after.rdd.getNumPartitions())

C. Repartition Example

df = df.repartition(200, "region")  # or .coalesce(10)

βœ… 3. Trace Query Plans

πŸ”Ή Logical & Physical Plan

df.explain(True)

Key things to spot:

  • Exchange = shuffle
  • BroadcastHashJoin = good (for small tables)
  • SortMergeJoin = costly
  • WholeStageCodegen = optimized execution

πŸ”Ή Show SQL Plan with AQE Changes

spark.conf.set("spark.sql.adaptive.enabled", "true")
df.explain(True)
  • Adaptive plans will show “ReusedExchange”, skew join optimization, etc.

βœ… 4. Profiling Execution in Spark UI

Open Spark UI (e.g., via Databricks or YARN):

TabWhat It Shows
StagesShuffle operations, task skew
TasksDuration per task, GC time, Input size
SQLQuery timeline and physical plan
StorageCache usage
ExecutorsMemory usage per executor, disk spill

🧠 Pro tip: Look for long-running tasks β†’ may indicate data skew or bad partitioning.


βœ… 5. Schema Inspection

df.printSchema()
df.dtypes
df.schema.simpleString()

Detects:

  • Wrong column types (e.g., string instead of int)
  • Nested columns (structs, arrays)

βœ… 6. Data Sampling & Inspection

df.limit(5).show()
df.select("column").distinct().show()

Use .cache() on reused DataFrames before joins:

df.cache().count()

βœ… 7. Detect Skew or Heavy Partitions

πŸ”Ή Estimate Row Count Per Partition

df.rdd.mapPartitions(lambda x: [sum(1 for _ in x)]).collect()
  • If you see one partition with 1 million rows and others with 1k β†’ data skew!

πŸ”Ή Detect Wide/Skewed Joins

df.explain(True)  # Look for: SortMergeJoin, Exchange, SkewedJoin
  • Consider salting, broadcast join, or changing keys.

βœ… 8. Logging & Debug Prints

Use Spark logs:

spark.sparkContext.setLogLevel("INFO")  # or DEBUG

Or log specific stages:

print("Start reading...")
df = spark.read.json("file.json")
print("Finished reading.")

βœ… 9. Check Memory/Cache Usage

spark.catalog.isCached("df_name")
spark.catalog.clearCache()

Spark UI > Storage tab β†’ shows:

  • Cached RDDs/DataFrames
  • Size in memory/disk

βœ… 10. Databricks-Specific Debug Tips

  • Use %sql and View lineage via SQL editor
  • Job run pages β†’ stage/task duration, skew alerts
  • Delta Tables:
    • Use DESCRIBE DETAIL table
    • OPTIMIZE, ZORDER, VACUUM if query is slow
  • Cluster config:
    • Check driver/executor memory
    • Use Ganglia metrics for memory usage graphs

🎯 Summary: Go-To Commands Cheat Sheet

TaskCommand
Show schemadf.printSchema()
Count partitionsdf.rdd.getNumPartitions()
Explain plandf.explain(True)
Show distinct keysdf.select(\"col\").distinct().show()
Broadcast joinbroadcast(small_df)
Sample rowsdf.limit(10).show()
Count rows per partitiondf.rdd.mapPartitions(...)
View Spark UISpark Web UI or Databricks UI

# PySpark Debugging & Introspection Notebook

from pyspark.sql import SparkSession
from pyspark.sql.functions import broadcast

# βœ… 1. Spark Session
spark = SparkSession.builder \
    .appName("PySpark Debugging Tools") \
    .config("spark.sql.adaptive.enabled", "true") \
    .config("spark.executor.memory", "4g") \
    .config("spark.driver.memory", "2g") \
    .getOrCreate()

# βœ… 2. Sample Data
rdd = spark.sparkContext.parallelize([
    (1, "Alice", 29, "IN"),
    (2, "Bob", 41, "US"),
    (3, "Charlie", 35, "IN"),
    (4, "David", 23, "UK"),
    (5, "Eve", 36, "US")
])

df = rdd.toDF(["id", "name", "age", "country"])

# βœ… 3. Print Schema & Data
df.printSchema()
df.show()

# βœ… 4. Partition Details
print("\nNumber of Partitions:", df.rdd.getNumPartitions())

# Count Rows Per Partition
rows_per_partition = df.rdd.mapPartitions(lambda x: [sum(1 for _ in x)]).collect()
print("Rows per Partition:", rows_per_partition)

# βœ… 5. Explain Query Plan
df_filtered = df.filter(df.age > 30)
df_filtered.explain(True)

# βœ… 6. SQL Registration and Query
df.createOrReplaceTempView("people")
sql_df = spark.sql("SELECT country, COUNT(*) AS cnt FROM people GROUP BY country")
sql_df.show()

# βœ… 7. Broadcast Join Demo
small_df = spark.createDataFrame([("IN", "India"), ("US", "United States"), ("UK", "United Kingdom")], ["code", "country_name"])
broadcast_df = df.join(broadcast(small_df), df.country == small_df.code, "left")
broadcast_df.show()

# βœ… 8. Cache & Check
from pyspark import StorageLevel

cached_df = df_filtered.persist(StorageLevel.MEMORY_AND_DISK)
cached_df.count()  # Triggers caching
print("Is Cached:", spark.catalog.isCached("people"))

# βœ… 9. Stop Session
spark.stop()

Pages: 1 2

Subscribe