post data-engineering · 2024-12-20 · 4 min read

PySpark skew detection, and the three fixes that actually work

#pyspark#data-engineering#spark#performance

Most slow Spark jobs are not slow for the reason you think. They’re skewed: 199 of your 200 partitions finish in 12 seconds, the 200th has half the data and runs for 40 minutes. Cluster sits idle, you wait, your cost-per-job blows up.

This post is the field guide I wish I had three years ago: how to actually detect skew, and the three fix patterns that resolve 90% of cases.

Detecting skew, accurately

The Spark UI’s “Stages” tab shows median + max task duration per stage. The naive read is “compare them, big gap = skew.” But the Spark UI reports task duration including shuffle reads, GC pauses, network blips, etc. A 3x median-to-max ratio is not always skew.

The real test: partition record count. Skew is when a small number of partitions hold a disproportionate share of the rows.

from pyspark.sql import functions as F
# How many records per partition does this DataFrame produce?
df.groupBy(F.spark_partition_id().alias("pid")) \
.count() \
.orderBy(F.desc("count")) \
.show(20)
+---+---------+
|pid| count|
+---+---------+
|142|2_847_193| ← skewed partition
| 17| 102_443|
| 3| 101_882|
| 88| 101_517|
| 11| 101_204|
+---+---------+

If the top partition is 5×+ the median, you have skew. Below that, look elsewhere (network, executor sizing, GC).

Knowing what’s skewed: a one-liner

For a df.join(other, "user_id") that’s slow, the question is: which user_id values are causing the skew?

df.groupBy("user_id") \
.count() \
.orderBy(F.desc("count")) \
.show(20)

If one user has 10 million rows and the next-biggest has 50k, that’s your skewed key. From here, the fix you pick depends on the kind of join.

Fix 1: Salt-and-aggregate (skewed group-by)

When: df.groupBy("k").agg(...) and k has a hot value.

The trick: replace the single hot key with N synthetic keys, aggregate locally, then re-aggregate without the salt.

# Before: one massive partition for user_id=42
slow = df.groupBy("user_id").agg(F.sum("amount").alias("total"))
# After: salt the key
import random
@F.udf("int")
def salt(_):
return random.randint(0, 19)
salted = (df
.withColumn("salt", salt(F.lit(0)))
.groupBy("user_id", "salt")
.agg(F.sum("amount").alias("partial"))
.groupBy("user_id")
.agg(F.sum("partial").alias("total")))

The first groupBy now has 20× more keys to distribute across executors. The second groupBy on user_id alone has tiny per-key data because each row is already a partial sum.

Salt cardinality (20 here) is the knob. Higher → more parallelism, more shuffle. Start at 16-32, tune from job timing.

Fix 2: Broadcast join (one side is small)

When: you’re joining a big table to a small one (lookups, dimensions, allow-lists).

Spark’s default is sort-merge join, which shuffles both sides. If one side fits in executor memory (rule of thumb: ≤ 100 MB serialised), broadcast it. The big side does not move, and there is no shuffle for the join.

from pyspark.sql.functions import broadcast
joined = big_df.join(
broadcast(small_df), # explicit hint
"user_id",
"left",
)

Spark’s AQE (spark.sql.adaptive.enabled) will sometimes pick this automatically when statistics are good. But hinting it explicitly is cheap insurance.

Watch out: broadcast on a too-big DataFrame OOMs every executor. There is no graceful failure mode. If you’re not sure whether the small side fits, sample its row count:

small_size_bytes = small_df.persist() \
.count() * estimated_row_size_bytes
print(f"~{small_size_bytes / 1e6:.1f} MB")

If it’s under 100 MB, broadcast. Above 100 MB, salt-and-aggregate the join key instead (see fix 3).

Fix 3: AQE skew-join optimisation (Spark 3.0+)

When: a sort-merge join is skewed but neither side is small enough to broadcast.

Spark’s adaptive query execution has built-in skew-join handling. Enable two flags:

spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")

What this does at runtime: AQE inspects partition sizes after the shuffle stage. Any partition that’s > 5× the median (and over a threshold, default 256 MB) gets split into smaller sub-partitions. The matching partition on the other side gets replicated across each split. The join now runs in parallel even on the hot key.

Two tuning knobs that matter:

# how skewed is "skewed enough to act on"
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
# minimum size before splitting; smaller partitions are not worth splitting
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256MB")

Default is 5× and 256 MB. For a job that’s known to skew on a particular key with very large hot values, drop the factor to 3× or 2× and the threshold to 64 MB.

AQE skew-join is the “easy button”. Try this first before reaching for manual salting. Manual salting only wins when AQE is failing to detect or split correctly (rare in 3.4+).

Diagnostic loop

When a job is slow, this is the loop I run:

1. partition-record-count query → is it skewed?
2. if yes: groupBy(key).count() → which key is hot?
3. join or aggregate?
├─ aggregate: fix 1 (salt + reaggregate)
├─ join, small side: fix 2 (broadcast)
└─ join, both big: fix 3 (AQE skew-join)
4. re-run, compare partition-record-count and stage timing
5. if still slow, look elsewhere (shuffle, GC, executor sizing)

The 80/20 rule: most slow Spark jobs that look like “the cluster is too small” are actually skew, and tightening the join or aggregate gets you a 5-10× speedup with zero extra spend.

What this is not

A few things to call out so this post does not get misread:

If your job is currently 40 minutes and 1 partition is doing 90% of the work, one of these three patterns will get you to 4 minutes by Friday.