PySpark Coding Practice Questions

Great question!

You’re absolutely right that in Spark SQL, when you use aggregate functions like SUM(), AVG(), etc., you must either:

  • Group by the other selected columns
  • Or only select aggregate columns

🔎 Example in Spark SQL:

SELECT dept, SUM(salary) 
FROM employee 
GROUP BY dept;      -- ✅ Allowed

But this would fail:

SELECT dept, salary, SUM(salary) 
FROM employee 
GROUP BY dept        -- ❌ salary is not grouped or aggregated

✅ How This Is Handled in PySpark DataFrame API

PySpark’s DataFrame API enforces the same rules, but you manage it through .groupBy().

🔹 Example:

df.groupBy("dept").agg(
    F.sum("salary").alias("total_salary"),
    F.count("*").alias("count")
)

Here:

  • Only dept can appear outside the aggregation because it’s in groupBy()
  • You cannot just add df.select("dept", "salary", F.sum("salary")) — unless salary is grouped or aggregated

❌ Invalid PySpark Code:

df.groupBy("dept").agg(F.sum("salary")).select("dept", "salary")  # ❌

This will throw an error like:

Reference 'salary' is ambiguous, cannot resolve

Because salary was not in groupBy() or an aggregate.


✅ If You Still Need Other Columns (non-aggregated)

You can use:

  • first(), max(), or collect_list() to carry those columns forward.

Example:

df.groupBy("dept").agg(
    F.sum("salary").alias("total_salary"),
    F.first("manager").alias("any_manager")  # preserves one value
)

🧠 Why It Works Like This

Spark DataFrames and SQL are both based on logical plans. This rule ensures:

  • Data is grouped correctly
  • There’s no ambiguity about which row’s non-aggregated value to use

✅ Summary

ScenarioRequired Action
Aggregating some columnsGroup by non-aggregates
Want non-aggregated columns tooUse first(), max()
Need full rows post-aggregationUse join with original

Excellent — let’s walk through how to “carry forward” non-grouped columns like name, manager, region when you are doing aggregations using PySpark’s groupBy().


🎯 The Problem

You want to group by one column (say dept) and still include other columns that are not in groupBy() — which would normally cause an error.

In SQL:

SELECT dept, name, SUM(salary) FROM emp GROUP BY dept; -- ❌ ERROR: 'name' not grouped/aggregated

Same issue in PySpark:

df.groupBy("dept").agg(F.sum("salary"))  # you cannot directly access 'name'

✅ The Solution: Use Aggregates to “Preserve” Other Columns

🔹 Option 1: Use F.first() or F.max() (if 1 row per group is enough)

from pyspark.sql import functions as F

df.groupBy("dept").agg(
    F.sum("salary").alias("total_salary"),
    F.first("name").alias("any_name"),
    F.max("region").alias("max_region")
).show()
  • F.first("name") returns any (usually first) value in the group.
  • F.max("region") gives highest (lexicographically) value.
  • These are useful when you’re okay with 1 row per group.

🔹 Option 2: Use F.collect_list() or F.collect_set() (for full values)

df.groupBy("dept").agg(
    F.sum("salary").alias("total_salary"),
    F.collect_list("name").alias("all_names"),
    F.collect_set("region").alias("unique_regions")
).show(truncate=False)

This gives:

+------+-------------+------------------------+----------------------+
|dept  |total_salary |all_names               |unique_regions        |
+------+-------------+------------------------+----------------------+
|Sales | 3000        |[Alice, Bob, Charlie]   |[East, West]          |
+------+-------------+------------------------+----------------------+

🧠 When to Use Which?

GoalFunction to Use
1 arbitrary row per groupF.first(), F.max()
All values per groupF.collect_list()
Unique values onlyF.collect_set()
Need full original rowUse join after aggregation

🔁 BONUS: Join Back with Original

If you need full row context after aggregation, do this:

agg_df = df.groupBy(\"dept\").agg(F.sum(\"salary\").alias(\"total_salary\"))
result = df.join(agg_df, on=\"dept\", how=\"left\")

This way you keep original columns and bring in aggregate values.


✅ Summary

FunctionBehavior
F.first(col)Returns the first value in the group
F.max(col)Returns the max value
F.collect_listReturns all values (incl. duplicates)
F.collect_setAll unique values in group
Join after aggFull row preservation

Pages: 1 2 3 4 5 6 7