Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 116 additions & 15 deletions python/pyspark/testing/goldenutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,11 @@
# limitations under the License.
#

from typing import Any, Optional
from typing import Any, Callable, List, Optional
import inspect
import os
import time

import pandas as pd

try:
import numpy as np

Expand All @@ -29,6 +28,17 @@
have_numpy = False


# PyArrow uses internal names ("halffloat", "float", "double") that differ from
# the commonly used names ("float16", "float32", "float64"). This mapping
# normalises the str() representation of Arrow DataType so that repr_type()
# returns the more intuitive names.
_ARROW_FLOAT_ALIASES = {
"halffloat": "float16",
"float": "float32",
"double": "float64",
}


class GoldenFileTestMixin:
"""
Mixin class providing utilities for golden file based testing.
Expand Down Expand Up @@ -87,14 +97,21 @@ def tearDownClass(cls) -> None:
def setup_timezone(cls, tz: str = "America/Los_Angeles") -> None:
"""
Setup timezone for deterministic test results.
Synchronizes timezone between Python and Java.

Sets the OS-level TZ environment variable and, when a Spark session
is available, synchronises the timezone with the JVM and Spark config.
This allows the mixin to be used with both ReusedSQLTestCase (Spark)
and plain unittest.TestCase (no Spark).
"""
cls._tz_prev = os.environ.get("TZ", None)
os.environ["TZ"] = tz
time.tzset()

cls.sc.environment["TZ"] = tz
cls.spark.conf.set("spark.sql.session.timeZone", tz)
# Sync with Spark / Java if a session is available.
if hasattr(cls, "sc"):
cls.sc.environment["TZ"] = tz
if hasattr(cls, "spark"):
cls.spark.conf.set("spark.sql.session.timeZone", tz)

@classmethod
def teardown_timezone(cls) -> None:
Expand Down Expand Up @@ -128,6 +145,8 @@ def load_golden_csv(golden_csv: str, use_index: bool = True) -> "pd.DataFrame":
pd.DataFrame
The loaded golden data with string dtype.
"""
import pandas as pd

return pd.read_csv(
golden_csv,
sep="\t",
Expand Down Expand Up @@ -167,31 +186,39 @@ def save_golden(df: "pd.DataFrame", golden_csv: str, golden_md: Optional[str] =
@staticmethod
def repr_type(t: Any) -> str:
"""
Convert a type to string representation.
Convert a type to a readable string representation.

Handles different type representations:
- Spark DataType: uses simpleString() (e.g., "int", "string", "array<int>")
- Python type: uses __name__ (e.g., "int", "str", "list")
- Other: uses str()

- Spark DataType: uses simpleString()
(e.g. "int", "string", "array<int>")
- PyArrow DataType: uses str(t) with float-name normalisation
(e.g. "int8", "float32", "timestamp[s, tz=UTC]")
- Python type: uses __name__
(e.g. "int", "str", "list")
- Other: falls back to str(t)

Parameters
----------
t : Any
The type to represent. Can be Spark DataType or Python type.
The type to represent.

Returns
-------
str
String representation of the type.
Human-readable string representation of the type.
"""
# Check if it's a Spark DataType (has simpleString method)
# Spark DataType
if hasattr(t, "simpleString"):
return t.simpleString()
# Check if it's a Python type
# Python type (class)
elif isinstance(t, type):
return t.__name__
else:
return str(t)
s = str(t)
# Normalise PyArrow float type names to be more intuitive:
# "halffloat" -> "float16", "float" -> "float32", "double" -> "float64"
return _ARROW_FLOAT_ALIASES.get(s, s)

@classmethod
def repr_value(cls, value: Any, max_len: int = 32) -> str:
Expand Down Expand Up @@ -235,6 +262,8 @@ def repr_pandas_value(cls, value: Any, max_len: int = 32) -> str:
str
String representation in format "value@type[dtype]".
"""
import pandas as pd

if isinstance(value, pd.DataFrame):
v_str = value.to_json()
else:
Expand All @@ -252,3 +281,75 @@ def repr_pandas_value(cls, value: Any, max_len: int = 32) -> str:
def clean_result(result: str) -> str:
"""Clean result string by removing newlines and extra whitespace."""
return result.replace("\n", " ").replace("\r", " ").replace("\t", " ")

def compare_or_generate_golden_matrix(
self,
row_names: List[str],
col_names: List[str],
compute_cell: Callable[[str, str], str],
golden_file_prefix: str,
index_name: str = "source \\ target",
) -> None:
"""
Run a matrix of computations and compare against (or generate) a golden file.

This is the standard pattern for golden-file matrix tests:

1. If SPARK_GENERATE_GOLDEN_FILES=1, compute every cell, build a
DataFrame, and save it as the new golden CSV / Markdown file.
2. Otherwise, load the existing golden file and assert that every cell
matches the freshly computed value.

Parameters
----------
row_names : list[str]
Ordered row labels (becomes the DataFrame index).
col_names : list[str]
Ordered column labels.
compute_cell : (row_name, col_name) -> str
Function that computes the string result for one cell.
golden_file_prefix : str
Prefix for the golden CSV/MD files (without extension).
Files are placed in the same directory as the concrete test file.
index_name : str, default "source \\ target"
Name for the index column in the golden file.
"""
generating = self.is_generating_golden()

test_dir = os.path.dirname(inspect.getfile(type(self)))
golden_csv = os.path.join(test_dir, f"{golden_file_prefix}.csv")
golden_md = os.path.join(test_dir, f"{golden_file_prefix}.md")

golden = None
if not generating:
golden = self.load_golden_csv(golden_csv)

errors = []
results = {}

for row_name in row_names:
for col_name in col_names:
result = compute_cell(row_name, col_name)
results[(row_name, col_name)] = result

if not generating:
expected = golden.loc[row_name, col_name]
if expected != result:
errors.append(
f"{row_name} -> {col_name}: " f"expected '{expected}', got '{result}'"
)

if generating:
import pandas as pd

index = pd.Index(row_names, name=index_name)
df = pd.DataFrame(index=index)
for col_name in col_names:
df[col_name] = [results[(row, col_name)] for row in row_names]
self.save_golden(df, golden_csv, golden_md)
else:
self.assertEqual(
len(errors),
0,
f"\n{len(errors)} golden file mismatches:\n" + "\n".join(errors),
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
source \ target list<item: int32> list<item: int64> list<item: string> large_list<item: int32> large_list<item: int64> fixed_size_list<item: int32>[2] fixed_size_list<item: int32>[3] map<string, int32> map<string, int64> struct<x: int32, y: string> struct<x: int64, y: string> struct<y: string, x: int32> string int32
list<item: int32>:standard [[1, 2, 3], None]@list<item: int32> [[1, 2, 3], None]@list<item: int64> [['1', '2', '3'], None]@list<item: string> [[1, 2, 3], None]@large_list<item: int32> [[1, 2, 3], None]@large_list<item: int64> ERR@ArrowInvalid [[1, 2, 3], None]@fixed_size_list<item: int32>[3] ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
list<item: int32>:empty [[], None]@list<item: int32> [[], None]@list<item: int64> [[], None]@list<item: string> [[], None]@large_list<item: int32> [[], None]@large_list<item: int64> ERR@ArrowInvalid ERR@ArrowInvalid ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
list<item: int32>:null_elem [[None], [1, None, 3], None]@list<item: int32> [[None], [1, None, 3], None]@list<item: int64> [[None], ['1', None, '3'], None]@list<item: string> [[None], [1, None, 3], None]@large_list<item: int32> [[None], [1, None, 3], None]@large_list<item: int64> ERR@ArrowInvalid ERR@ArrowInvalid ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
list<item: struct<x: int32, y: string>>:standard ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowInvalid ERR@ArrowInvalid ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
list<item: struct<x: int32, y: string>>:null_fields ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowInvalid ERR@ArrowInvalid ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
large_list<item: int32>:standard [[1, 2, 3], None]@list<item: int32> [[1, 2, 3], None]@list<item: int64> [['1', '2', '3'], None]@list<item: string> [[1, 2, 3], None]@large_list<item: int32> [[1, 2, 3], None]@large_list<item: int64> ERR@ArrowInvalid [[1, 2, 3], None]@fixed_size_list<item: int32>[3] ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
large_list<item: int32>:empty [[], None]@list<item: int32> [[], None]@list<item: int64> [[], None]@list<item: string> [[], None]@large_list<item: int32> [[], None]@large_list<item: int64> ERR@ArrowInvalid ERR@ArrowInvalid ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
large_list<item: int32>:null_elem [[None], None]@list<item: int32> [[None], None]@list<item: int64> [[None], None]@list<item: string> [[None], None]@large_list<item: int32> [[None], None]@large_list<item: int64> ERR@ArrowInvalid ERR@ArrowInvalid ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
fixed_size_list<item: int32>[2]:standard [[1, 2], [3, 4], None]@list<item: int32> [[1, 2], [3, 4], None]@list<item: int64> [['1', '2'], ['3', '4'], None]@list<item: string> [[1, 2], [3, 4], None]@large_list<item: int32> [[1, 2], [3, 4], None]@large_list<item: int64> [[1, 2], [3, 4], None]@fixed_size_list<item: int32>[2] ERR@ArrowTypeError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
fixed_size_list<item: int32>[2]:null_elem [[None, None], None]@list<item: int32> [[None, None], None]@list<item: int64> [[None, None], None]@list<item: string> [[None, None], None]@large_list<item: int32> [[None, None], None]@large_list<item: int64> [[None, None], None]@fixed_size_list<item: int32>[2] ERR@ArrowTypeError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
map<string, int32>:standard ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowNotImplementedError ERR@ArrowInvalid [[('a', 1), ('b', 2)], None]@map<string, int32> [[('a', 1), ('b', 2)], None]@map<string, int64> ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
map<string, int32>:empty ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowTypeError ERR@ArrowInvalid ERR@ArrowInvalid [[], None]@map<string, int32> [[], None]@map<string, int64> ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
struct<x: int32, y: string>:standard ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError [{'x': 1, 'y': 'a'}, None]@struct<x: int32, y: string> [{'x': 1, 'y': 'a'}, None]@struct<x: int64, y: string> [{'y': 'a', 'x': 1}, None]@struct<y: string, x: int32> ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
struct<x: int32, y: string>:null_fields ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError [{'x': None, 'y': None}, None]@struct<x: int32, y: string> [{'x': None, 'y': None}, None]@struct<x: int64, y: string> [{'y': None, 'x': None}, None]@struct<y: string, x: int32> ERR@ArrowNotImplementedError ERR@ArrowNotImplementedError
Loading