Introduction
One of the features that I love about PySpark is the data frame abstraction is agnostic about data sources and destinations. This allows unit testing data transformations without connecting to a database or file system. As a bonus data frames are immutable and each transformation returns a new one. This means we can test transformations on the data in bite-sized chunks.
All we need to do is create an in-memory data frame, perform a set of transformations, then check the data by executing an action. As an added bonus, the transformations can be tested without data. Given an input schema, the resulting output schema can be confirmed alone. No data needed. This allows us to make sure that our PySpark script fulfills its contract. The unit tests can verify preconditions and postconditions by comparing schemas.
Example
Let’s start with a simple example. Our dataset is a simple list of musicians with their bands and roles. It’s just enough to keep explanations clear.
data = [
("Rob", "Halford", "Judas Priest", "Singer"),
("Alice", "Cooper", "Hollywood Vampires", "Singer"),
("Steve", "Harris", "Iron Maiden", "Bassist"),
("James", "Hetfield", "Metallica", "Singer"),
("Bernie", "Worrell", "Parliament", "Keyboardist"),
]
schema = sqlt.StructType(
[
sqlt.StructField("first_name", sqlt.StringType(), True),
sqlt.StructField("last_name", sqlt.StringType(), True),
sqlt.StructField("band", sqlt.StringType(), True),
sqlt.StructField("role", sqlt.StringType(), True),
]
)
df = spark.createDataFrame(data=data, schema=schema)
df.printSchema()
df.show(truncate=False)
root
|-- first_name: string (nullable = true)
|-- last_name: string (nullable = true)
|-- band: string (nullable = true)
|-- role: string (nullable = true)
+----------+---------+------------------+-----------+
|first_name|last_name|band |role |
+----------+---------+------------------+-----------+
|Rob |Halford |Judas Priest |Singer |
|Alice |Cooper |Hollywood Vampires|Singer |
|Steve |Harris |Iron Maiden |Bassist |
|James |Hetfield |Metallica |Singer |
|Bernie |Worrell |Parliament |Keyboardist|
+----------+---------+------------------+-----------+
Implement Transformations
Let’s create some transformation functions that we can use to test our data.
def drop_unnecessary(df: DataFrame) -> DataFrame:
"""
Removed the role column
"""
return df.drop("role")
def only_singers(df: DataFrame) -> DataFrame:
"""
Filter out only the singers
"""
return df.where(sqlf.col("role") == "Singer")
def combine_names(first: Column, last: Column) -> Column:
"""
Take the first name and last name columns and create a single name structure
"""
return sqlf.struct(first.alias("first"), last.alias("last"))
def fix_name(df: DataFrame) -> DataFrame:
"""
Fix names from two columns to one
"""
return df.select(
"band",
combine_names(sqlf.col("first_name"), sqlf.col("last_name")).alias("name"),
)
Run Transformations
Let’s run the transformations that we just implemented. The “transform” method on DataFrame allows us to call the transformations in sequence without having to create temporary variables. It also makes the code easier to read with less noise.
"""
Since data frames are lazy, we can filter out the singers after dropping the role column. Notice that
the code below runs and returns the correct answer.
"""
new_df = df.transform(drop_unnecessary).transform(only_singers).transform(fix_name)
new_df.show()
+------------------+-----------------+
| band| name|
+------------------+-----------------+
| Judas Priest| {Rob, Halford}|
|Hollywood Vampires| {Alice, Cooper}|
| Metallica|{James, Hetfield}|
+------------------+-----------------+
Tests
Now, let’s write some tests that verify the output schema for each transformation and that each transformation changes the input data as expected.
import pytest
import ipytest
ipytest.autoconfig()
ipytest.clean()
@pytest.fixture
def spark() -> SparkSession:
return SparkSession.builder.getOrCreate()
@pytest.fixture
def schema() -> sqlt.StructType:
"""
Create the input schema to be used for tests
"""
return sqlt.StructType(
[
sqlt.StructField("first_name", sqlt.StringType(), True),
sqlt.StructField("last_name", sqlt.StringType(), True),
sqlt.StructField("band", sqlt.StringType(), True),
sqlt.StructField("role", sqlt.StringType(), True),
]
)
@pytest.fixture
def data(spark: SparkSession, schema: sqlt.StructType) -> DataFrame:
"""
Sample input data to use in tests
"""
data = [
("Rob", "Halford", "Judas Priest", "Singer"),
("Alice", "Cooper", "Hollywood Vampires", "Singer"),
("Steve", "Harris", "Iron Maiden", "Bassist"),
("James", "Hetfield", "Metallica", "Singer"),
("Bernie", "Worrell", "Parliament", "Keyboardist"),
]
return spark.createDataFrame(data=data, schema=schema)
# Schema tests
def test_drop_unnecessary_schema(spark: SparkSession, schema: sqlt.StructType):
"""
This test verifies only the column names
"""
empty = spark.createDataFrame(data=[], schema=schema)
result = empty.transform(drop_unnecessary)
assert ["first_name", "last_name", "band"] == result.columns
def test_only_singers_schema(spark: SparkSession, schema: sqlt.StructType):
"""
You can also verify the contract by comparing schema instances.
In this case, the schemas should be the same.
"""
empty = spark.createDataFrame(data=[], schema=schema)
result = empty.transform(only_singers)
assert result.schema == schema
def test_fix_name_schema(spark: SparkSession, schema: sqlt.StructType):
"""
Schema can also be verified with a simple string
"""
empty = spark.createDataFrame(data=[], schema=schema)
result = empty.transform(fix_name)
assert ["band", "name"] == result.columns
assert (
result.schema["name"].simpleString() == "name:struct<first:string,last:string>"
)
# Data test
def test_end_to_end(spark: SparkSession, data: DataFrame):
result = (
data.transform(drop_unnecessary).transform(only_singers).transform(fix_name)
)
# call an action
output = result.collect()
assert len(output) == 3
assert output[0]["band"] == "Judas Priest"
assert output[0]["name"]["first"] == "Rob"
assert output[0]["name"]["last"] == "Halford"
assert output[1]["band"] == "Hollywood Vampires"
assert output[2]["band"] == "Metallica"
ipytest.run();
.... [100%]
4 passed in 0.20s
Conclusion
I hope this gave you some ideas on how to test the transformations in your own pipelines. It’s helped me simplify my tests and verify edge cases more easily. It’s caught many errors sooner without doing a full run of the script or using all of the data. It also makes debugging quicker since the tests are smaller.