from __future__ import annotations
# DataFrame operation executor for Polars.
# Provides execution of DataFrame operations (filter, select, join, etc.)
# using the Polars DataFrame API.
import json
import logging
from typing import Any, Dict, List, Optional, Set, TYPE_CHECKING, Tuple, Union, cast
import polars as pl
from .window_handler import PolarsWindowHandler
from sparkless import config
from sparkless.functions import Column, ColumnOperation
from sparkless.functions.window_execution import WindowFunction
from sparkless.spark_types import StructType, get_row_value
from sparkless.core.ddl_adapter import parse_ddl_schema
from sparkless.utils.profiling import profiled
if TYPE_CHECKING:
from .expression_translator import PolarsExpressionTranslator
from sparkless.dataframe.evaluation.expression_evaluator import (
ExpressionEvaluator,
)
logger = logging.getLogger(__name__)
[docs]
class PolarsOperationExecutor:
"""Executes DataFrame operations using Polars DataFrame API."""
[docs]
def __init__(self, expression_translator: PolarsExpressionTranslator):
"""Initialize operation executor.
Args:
expression_translator: Polars expression translator instance
"""
self.translator = expression_translator
self.window_handler = PolarsWindowHandler()
self._shortcuts_enabled = config.is_feature_enabled(
"enable_polars_vectorized_shortcuts"
)
self._struct_field_cache: Dict[Tuple[str, str], List[str]] = {}
def _extract_window_function_with_arithmetic(
self, col: Any
) -> Tuple[Optional[WindowFunction], List[Tuple[str, Any, bool]]]:
"""Recursively extract WindowFunction and all arithmetic operations applied to it.
Args:
col: Column, WindowFunction, or ColumnOperation
Returns:
Tuple of (WindowFunction or None, list of (operation, value, is_reverse) tuples)
Operations are in order from innermost to outermost.
"""
from sparkless.functions.core.literals import Literal
if isinstance(col, WindowFunction):
return (col, [])
elif isinstance(col, ColumnOperation):
# Skip alias-only wrappers. .alias() sets _alias_name on the same op;
# it does not create operation="alias". Callers use getattr(col, "_alias_name", None).
if col.operation == "alias":
return self._extract_window_function_with_arithmetic(col.column)
# Check if this is a reverse operation (Literal - WindowFunction or Literal / WindowFunction)
if isinstance(col.column, Literal) and isinstance(
col.value, WindowFunction
):
# Reverse operation: Literal op WindowFunction
window_func, inner_ops = self._extract_window_function_with_arithmetic(
col.value
)
if window_func:
# Add reverse operation at the end
inner_ops.append(
(cast("str", col.operation), col.column.value, True)
)
return (window_func, inner_ops)
# Check if column is WindowFunction or contains one
elif isinstance(col.column, WindowFunction):
# Direct: WindowFunction op value
window_func, inner_ops = self._extract_window_function_with_arithmetic(
col.column
)
if window_func:
inner_ops.append((cast("str", col.operation), col.value, False))
return (window_func, inner_ops)
# Recursively check nested ColumnOperation
elif isinstance(col.column, ColumnOperation):
window_func, inner_ops = self._extract_window_function_with_arithmetic(
col.column
)
if window_func:
# Add this operation to the list (operations are in order from innermost to outermost)
inner_ops.append((cast("str", col.operation), col.value, False))
return (window_func, inner_ops)
# Also check if value is a ColumnOperation (for cases like (A * B) where both are operations)
elif isinstance(col.value, ColumnOperation):
# This handles cases where both operands are ColumnOperations
# For now, we only handle WindowFunction on the left side
pass
# Handle Column op WindowFunction (e.g., F.col("value") - F.lag("value", 1).over(window))
elif isinstance(col.value, WindowFunction):
# Left-side operation: Column op WindowFunction
window_func, inner_ops = self._extract_window_function_with_arithmetic(
col.value
)
if window_func:
# Record operation with is_reverse=False (it's left op window, not reverse)
# For subtraction, we'll need special handling in arithmetic application
inner_ops.append((cast("str", col.operation), col.column, False))
return (window_func, inner_ops)
return (None, [])
def _arith_operand_to_polars(self, val: Any, df: pl.DataFrame) -> Any:
"""Convert an arithmetic operand to a form Polars accepts (scalar or pl.Expr).
Scalars are returned as-is (Polars accepts them in binary ops). Column/ColumnOperation
are translated to pl.Expr for e.g. window_expr - pl.col("x").
"""
if isinstance(val, (int, float, bool, str, type(None))):
return val
if isinstance(val, (Column, ColumnOperation)):
case_sensitive = self._get_case_sensitive()
return self.translator.translate(
val,
available_columns=list(df.columns),
case_sensitive=case_sensitive,
)
return val
def _get_case_sensitive(self) -> bool:
"""Get case sensitivity setting from active session.
Returns:
True if case-sensitive mode is enabled, False otherwise.
Defaults to False (case-insensitive) to match PySpark behavior.
"""
try:
from sparkless.session.core.session import SparkSession
active_sessions = getattr(SparkSession, "_active_sessions", [])
if active_sessions:
session = active_sessions[-1]
if hasattr(session, "conf"):
return bool(session.conf.is_case_sensitive())
except (AttributeError, TypeError):
pass
return False # Default to case-insensitive (matching PySpark)
def _find_column(
self, df: pl.DataFrame, column_name: str, case_sensitive: bool = False
) -> Optional[str]:
"""Find column name in Polars DataFrame using ColumnResolver.
Args:
df: Polars DataFrame to search in.
column_name: Column name to find.
case_sensitive: Whether to use case-sensitive matching.
Returns:
Actual column name if found, None otherwise.
"""
from sparkless.core.column_resolver import ColumnResolver
available_columns = list(df.columns)
result = ColumnResolver.resolve_column_name(
column_name, available_columns, case_sensitive
)
return str(result) if result else None
def _resolve_window_col_name(
self, df: pl.DataFrame, col_name: str, case_sensitive: bool
) -> str:
"""Resolve window partition/order column name against DataFrame schema."""
resolved = self._find_column(df, col_name, case_sensitive)
return resolved if resolved is not None else col_name
[docs]
@profiled("polars.apply_filter", category="polars")
def apply_filter(self, df: pl.DataFrame, condition: Any) -> pl.DataFrame:
"""Apply a filter operation.
Args:
df: Source Polars DataFrame
condition: Filter condition (ColumnOperation or expression)
Returns:
Filtered Polars DataFrame
"""
# Check if condition is a WindowFunction comparison or isnull/isnotnull
is_window_function_comparison = (
isinstance(condition, ColumnOperation)
and condition.operation in [">", "<", ">=", "<=", "==", "!=", "eqNullSafe"]
and isinstance(condition.column, WindowFunction)
)
is_window_function_isnull = (
isinstance(condition, ColumnOperation)
and condition.operation in ["isnull", "isnotnull"]
and isinstance(condition.column, WindowFunction)
)
if is_window_function_comparison or is_window_function_isnull:
# Handle comparison operations or isnull/isnotnull operations with WindowFunction
# First, apply the window function to get a temporary column
window_func = condition.column
operation = condition.operation
operation_value = condition.value
# Apply window function to get a temporary column
temp_col_name = "__window_func_temp_filter"
df_with_window = self.apply_with_column(
df, temp_col_name, window_func, expected_field=None
)
# Now create the operation expression: temp_col_name op value (or isnull/isnotnull)
# Note: Column is already imported at the top of the file
temp_col = Column(temp_col_name)
if operation == ">":
operation_expr = temp_col > operation_value
elif operation == "<":
operation_expr = temp_col < operation_value
elif operation == ">=":
operation_expr = temp_col >= operation_value
elif operation == "<=":
operation_expr = temp_col <= operation_value
elif operation == "==":
operation_expr = temp_col == operation_value
elif operation == "!=":
operation_expr = temp_col != operation_value
elif operation == "eqNullSafe":
operation_expr = temp_col.eqNullSafe(operation_value)
elif operation == "isnull":
operation_expr = temp_col.isnull()
elif operation == "isnotnull":
operation_expr = temp_col.isnotnull()
else:
operation_expr = ColumnOperation(temp_col, operation, operation_value)
# Translate and apply the filter
case_sensitive = self._get_case_sensitive()
filter_expr = self.translator.translate(
operation_expr,
available_columns=list(df_with_window.columns),
case_sensitive=case_sensitive,
)
# Apply filter and drop temporary column
result_df = df_with_window.filter(filter_expr).drop(temp_col_name)
return result_df
# Extract column dtype if condition is a ColumnOperation with isin
# This enables type coercion for mixed types
input_col_dtype = None
if (
isinstance(condition, ColumnOperation)
and condition.operation == "isin"
and hasattr(condition, "column")
and hasattr(condition.column, "name")
):
# Get the column name from the condition
col_name = condition.column.name
# Find column using ColumnResolver
case_sensitive = self._get_case_sensitive()
actual_col_name = self._find_column(df, col_name, case_sensitive)
if actual_col_name and actual_col_name in df.columns:
# Get the column's dtype
input_col_dtype = df[actual_col_name].dtype
case_sensitive = self._get_case_sensitive()
try:
# Check if condition contains struct field paths that need special handling
# (e.g., F.col("StructVal")["E1"] creates Column("StructVal.E1"))
needs_struct_handling = False
struct_field_col = None
if isinstance(condition, ColumnOperation):
# Check if the column part is a Column with struct field path
if (
isinstance(condition.column, Column)
and "." in condition.column.name
):
needs_struct_handling = True
struct_field_col = condition.column
elif isinstance(condition, Column) and "." in condition.name:
needs_struct_handling = True
struct_field_col = condition
if needs_struct_handling and struct_field_col:
# Handle struct field access in filter by creating temporary column
# Extract the struct field path and apply it as a withColumn first
temp_col_name = f"__struct_field_temp_{id(condition)}"
df_with_temp = self.apply_with_column(
df, temp_col_name, struct_field_col, expected_field=None
)
# Now create the filter condition using the temp column
if isinstance(condition, ColumnOperation):
# Recreate the operation with the temp column
# Note: Column is already imported at the top of the file
from sparkless.functions.core.column import Column as Col
temp_col = Col(temp_col_name)
if condition.operation == ">":
filter_condition = temp_col > condition.value
elif condition.operation == "<":
filter_condition = temp_col < condition.value
elif condition.operation == ">=":
filter_condition = temp_col >= condition.value
elif condition.operation == "<=":
filter_condition = temp_col <= condition.value
elif condition.operation == "==":
filter_condition = temp_col == condition.value
elif condition.operation == "!=":
filter_condition = temp_col != condition.value
else:
# For other operations, use the temp column directly
filter_condition = temp_col
else:
# Simple Column - just use the temp column
# Note: Column is already imported at the top of the file
from sparkless.functions.core.column import Column as Col
filter_condition = Col(temp_col_name)
filter_expr = self.translator.translate(
filter_condition,
available_columns=list(df_with_temp.columns),
case_sensitive=case_sensitive,
)
result_df = df_with_temp.filter(filter_expr).drop(temp_col_name)
return result_df
filter_expr = self.translator.translate(
condition,
input_col_dtype=input_col_dtype,
available_columns=list(df.columns),
case_sensitive=case_sensitive,
)
except ValueError as e:
# Check if this is a WindowFunction comparison that should be handled
error_msg = str(e)
if "WindowFunction comparison" in error_msg and isinstance(
condition, ColumnOperation
):
# Recursively handle it (should have been caught above, but handle as fallback)
return self.apply_filter(df, condition)
raise
return df.filter(filter_expr)
[docs]
@profiled("polars.apply_select", category="polars")
def apply_select(self, df: pl.DataFrame, columns: Tuple[Any, ...]) -> pl.DataFrame:
"""Apply a select operation.
Args:
df: Source Polars DataFrame
columns: Columns to select
Returns:
Selected Polars DataFrame
"""
select_exprs = []
select_names = []
map_op_indices = set() # Track which columns are map operations
posexplode_pending: List[Tuple[str, str, str]] = [] # (temp_name, name0, name1)
python_columns: List[Tuple[str, List[Any]]] = []
rows_cache: Optional[List[Dict[str, Any]]] = None
evaluator: Union[ExpressionEvaluator, None] = None
# First pass: handle map_keys, map_values, map_entries using struct operations
for i, col in enumerate(columns):
# Check if this is a map_keys, map_values, or map_entries operation
is_map_op = False
map_op_name = None
map_col_name = None
if hasattr(col, "operation"):
if col.operation in [
"map_keys",
"map_values",
"map_entries",
"map_concat",
]:
is_map_op = True
map_op_name = col.operation
if hasattr(col, "column") and hasattr(col.column, "name"):
map_col_name = col.column.name
elif hasattr(col, "function_name") and col.function_name in [
"map_keys",
"map_values",
"map_entries",
"map_concat",
]:
is_map_op = True
map_op_name = col.function_name
if hasattr(col, "column") and hasattr(col.column, "name"):
map_col_name = col.column.name
if is_map_op and map_col_name and map_col_name in df.columns:
# Get the struct dtype for this column
struct_dtype = df[map_col_name].dtype
if hasattr(struct_dtype, "fields") and struct_dtype.fields:
# Build expression using struct.field() checks
field_names = self._get_struct_field_names(
map_col_name, struct_dtype
)
alias_name = (
getattr(col, "name", None) or f"{map_op_name}({map_col_name})"
)
if map_op_name == "map_keys":
# Get only non-null field names
keys_expr = pl.concat_list(
[
pl.when(
pl.col(map_col_name)
.struct.field(fname)
.is_not_null()
)
.then(pl.lit(fname))
.otherwise(None)
for fname in field_names
]
).list.drop_nulls()
select_exprs.append(keys_expr.alias(alias_name))
select_names.append(alias_name)
map_op_indices.add(i)
elif map_op_name == "map_values":
# Get only non-null field values
values_expr = pl.concat_list(
[
pl.when(
pl.col(map_col_name)
.struct.field(fname)
.is_not_null()
)
.then(pl.col(map_col_name).struct.field(fname))
.otherwise(None)
for fname in field_names
]
).list.drop_nulls()
select_exprs.append(values_expr.alias(alias_name))
select_names.append(alias_name)
map_op_indices.add(i)
elif map_op_name == "map_entries":
# Create array of structs with key and value
entries_list = pl.concat_list(
[
pl.struct(
[
pl.lit(fname).cast(pl.Utf8).alias("key"),
pl.col(map_col_name)
.struct.field(fname)
.alias("value"),
]
)
for fname in field_names
]
).list.filter(pl.element().struct.field("value").is_not_null())
select_exprs.append(entries_list.alias(alias_name))
select_names.append(alias_name)
map_op_indices.add(i)
elif map_op_name == "map_concat":
# map_concat(*cols) - merge multiple maps
# col.value contains additional columns (first column is in col.column)
if (
hasattr(col, "value")
and col.value
and isinstance(col.value, (list, tuple))
):
# Get all map column names
map_cols = [map_col_name] # Start with first column
for other_col in col.value:
if isinstance(other_col, str):
map_cols.append(other_col)
elif hasattr(other_col, "name"):
map_cols.append(other_col.name)
elif hasattr(other_col, "column") and hasattr(
other_col.column, "name"
):
# Handle nested column references
map_cols.append(other_col.column.name)
else:
# Try to get name from string representation or other attributes
col_str = str(other_col)
if col_str in df.columns:
map_cols.append(col_str)
# Verify all map columns exist in DataFrame
available_map_cols = [
mc for mc in map_cols if mc in df.columns
]
if len(available_map_cols) < len(map_cols):
# Some columns missing - this shouldn't happen but handle gracefully
map_cols = available_map_cols
# Merge all struct columns - combine all fields from all maps
# Get all field names from all struct columns
all_field_names: Set[str] = set()
for map_col in map_cols:
if map_col in df.columns:
struct_dtype = df[map_col].dtype
if hasattr(struct_dtype, "fields"):
field_names = self._get_struct_field_names(
map_col, struct_dtype
)
all_field_names.update(field_names)
sorted_field_names = sorted(all_field_names)
# Build merged struct: for each field, take value from later maps first (they override)
# Later maps override earlier ones (PySpark behavior)
# Then filter out null fields per row (PySpark only includes non-null keys)
struct_field_exprs = []
for fname in sorted_field_names:
# Check each map column in reverse order (later maps override earlier)
value_exprs = []
for map_col in reversed(map_cols):
if map_col in df.columns:
struct_dtype = df[map_col].dtype
if hasattr(struct_dtype, "fields") and any(
f.name == fname for f in struct_dtype.fields
):
value_exprs.append(
pl.col(map_col).struct.field(fname)
)
if value_exprs:
# Use coalesce to take first non-null value (later maps first)
if len(value_exprs) == 1:
struct_field_exprs.append(
value_exprs[0].alias(fname)
)
else:
struct_field_exprs.append(
pl.coalesce(value_exprs).alias(fname)
)
# Create merged struct with all fields
merged_struct = pl.struct(struct_field_exprs)
# Filter out null fields per row using map_elements
# PySpark only includes keys that have non-null values
filtered_merged = merged_struct.map_elements(
lambda x: (
{k: v for k, v in x.items() if v is not None}
if isinstance(x, dict)
else (
{
k: getattr(x, k)
for k in dir(x)
if not k.startswith("_")
and getattr(x, k, None) is not None
}
if hasattr(x, "__dict__")
else None
)
if x is not None
else None
),
return_dtype=pl.Object,
)
select_exprs.append(filtered_merged.alias(alias_name))
select_names.append(alias_name)
map_op_indices.add(i)
# Second pass: handle all other columns (skip map operations already handled)
for i, col in enumerate(columns):
if i in map_op_indices:
continue # Skip map operations already handled
# Check posexplode first (PySpark: posexplode().alias("A","B") yields two columns)
_has_op = hasattr(col, "operation")
_has_col = hasattr(col, "column")
if (
_has_op
and _has_col
and (
col.operation in ("posexplode", "posexplode_outer")
or (
col.operation == "alias"
and (
getattr(col.column, "operation", None)
in ("posexplode", "posexplode_outer")
or (
getattr(col, "_alias_names", None)
and len(getattr(col, "_alias_names", ())) >= 2
and getattr(col.column, "name", None)
)
)
)
)
):
# posexplode produces two columns (pos, val); alias("Name1", "Name2") names them
if col.operation == "alias":
posexplode_col = (
col.column
if getattr(col.column, "operation", None)
in ("posexplode", "posexplode_outer")
else col
)
alias_names_tuple = getattr(col, "_alias_names", None)
else:
posexplode_col = col
alias_names_tuple = None
from sparkless.functions import Column as ColumnType
source_col = posexplode_col.column
if isinstance(source_col, ColumnType):
source_col_name = source_col.name
elif isinstance(source_col, str):
source_col_name = source_col
else:
source_col_name = str(getattr(source_col, "name", ""))
if not source_col_name:
raise ValueError(
f"Cannot determine source column for posexplode: {posexplode_col}"
)
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
resolved_col_name = ColumnResolver.resolve_column_name(
source_col_name, list(df.columns), case_sensitive
)
if resolved_col_name is None:
raise ValueError(
f"Column '{source_col_name}' not found for posexplode"
)
_alias_names = alias_names_tuple or getattr(
posexplode_col, "_alias_names", None
)
if _alias_names and len(_alias_names) >= 2:
alias_names = (_alias_names[0], _alias_names[1])
else:
_first = (
getattr(col, "_alias_name", None)
or getattr(posexplode_col, "_alias_name", None)
or "pos"
)
alias_names = (_first, "col")
temp_name = "__posexplode_struct"
def _posexplode_list(arr: Any) -> Any:
if arr is None:
return []
# Polars 1.x passes Series or other iterables to map_elements for List cols
if hasattr(arr, "to_list"):
arr = arr.to_list()
elif hasattr(arr, "tolist"):
arr = arr.tolist()
elif not isinstance(arr, (list, tuple)):
return []
return [{"pos": i, "val": v} for i, v in enumerate(arr)]
list_dtype = df[resolved_col_name].dtype
val_dtype = (
list_dtype.inner
if hasattr(list_dtype, "inner") and list_dtype.inner is not None
else pl.Int64
)
posexplode_dtype = pl.List(
pl.Struct([pl.Field("pos", pl.Int64), pl.Field("val", val_dtype)])
)
posexplode_expr = (
pl.col(resolved_col_name)
.map_elements(
_posexplode_list,
return_dtype=posexplode_dtype,
)
.alias(temp_name)
)
select_exprs.append(posexplode_expr)
select_names.append(temp_name)
posexplode_pending.append((temp_name, alias_names[0], alias_names[1]))
continue
if isinstance(col, str):
if col == "*":
# Select all columns - return original DataFrame
return df
elif "." in col:
# Try table-prefixed first (e.g. "t1.id" -> "id" or "t1_id")
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
resolved_table_prefixed = ColumnResolver.resolve_column_name(
col, list(df.columns), case_sensitive
)
if resolved_table_prefixed is not None:
# Table-prefixed column - use resolved name
select_exprs.append(pl.col(resolved_table_prefixed).alias(col))
select_names.append(col)
continue
# Handle nested struct field access, or right-alias.column after join (#380)
parts = col.split(".", 1)
struct_col = parts[0]
field_name = parts[1]
right_prefixed = f"_right_{field_name}"
if right_prefixed in df.columns:
select_exprs.append(pl.col(right_prefixed).alias(col))
select_names.append(col)
continue
if struct_col in df.columns:
# Get struct dtype
struct_dtype = df[struct_col].dtype
if hasattr(struct_dtype, "fields") and struct_dtype.fields:
# Resolve field name case-insensitively within struct
field_names = [f.name for f in struct_dtype.fields]
resolved_field = ColumnResolver.resolve_column_name(
field_name, field_names, case_sensitive
)
if resolved_field:
# Use Polars struct.field() syntax for nested access
select_exprs.append(
pl.col(struct_col)
.struct.field(resolved_field)
.alias(col)
)
select_names.append(col)
continue
# If nested access failed, fall through to error handling
raise ValueError(
f"Cannot access nested field '{col}' - struct column '{struct_col}' or field '{field_name}' not found"
)
else:
# Resolve column name to find actual column (case-insensitive matching)
# PySpark behavior:
# - If there's only one match: use the original column name
# - If there are multiple matches (different cases): use the requested column name
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
resolved_col_name = ColumnResolver.resolve_column_name(
col, list(df.columns), case_sensitive
)
if resolved_col_name is None:
raise ValueError(f"Column '{col}' not found in DataFrame")
# Check if there are multiple matches (different cases)
column_name_lower = col.lower()
matches = [c for c in df.columns if c.lower() == column_name_lower]
has_multiple_matches = len(matches) > 1
# Use resolved column name for lookup
# If multiple matches exist, alias with requested name (PySpark behavior for issue #297)
# If single match, use original column name (PySpark default behavior)
if has_multiple_matches:
# Multiple matches: use requested name as output (issue #297)
select_exprs.append(pl.col(resolved_col_name).alias(col))
select_names.append(col)
elif resolved_col_name == col:
# Exact match: no alias needed
select_exprs.append(pl.col(col))
select_names.append(col)
else:
# Case-insensitive match but single match: use original column name
select_exprs.append(pl.col(resolved_col_name))
select_names.append(resolved_col_name)
elif isinstance(col, ColumnOperation) and col.operation == "json_tuple":
# json_tuple(col, *fields) expands to multiple output columns (c0, c1, ...)
fields = list(col.value) if isinstance(col.value, (list, tuple)) else []
case_sensitive = self._get_case_sensitive()
json_expr = self.translator.translate(
col.column,
available_columns=list(df.columns),
case_sensitive=case_sensitive,
)
import json as _json
def _extract_field(val: Any, field: str) -> Any:
if val is None:
return None
if not isinstance(val, str):
val = str(val)
try:
obj = _json.loads(val)
except (ValueError, TypeError):
return None
if not isinstance(obj, dict):
return None
v = obj.get(field)
return None if v is None else str(v)
for j, f in enumerate(fields):
name = f"c{j}"
select_exprs.append(
json_expr.map_elements(
lambda x, field=f: _extract_field(x, field), # noqa: B023
return_dtype=pl.Utf8,
).alias(name)
)
select_names.append(name)
elif (
isinstance(col, ColumnOperation)
or (hasattr(col, "operation") and hasattr(col, "column"))
) and (
col.operation in ("posexplode", "posexplode_outer")
or (
col.operation == "alias"
and (
getattr(col.column, "operation", None)
in ("posexplode", "posexplode_outer")
or (
getattr(col, "_alias_names", None)
and len(getattr(col, "_alias_names", ())) >= 2
and getattr(col.column, "name", None)
)
)
)
):
# posexplode produces two columns (pos, val); alias("Name1", "Name2") names them
# Unwrap alias(posexplode(...)) so we use the inner posexplode and alias names
if col.operation == "alias":
posexplode_col = (
col.column
if getattr(col.column, "operation", None)
in ("posexplode", "posexplode_outer")
else col
)
alias_names_tuple = getattr(col, "_alias_names", None)
else:
posexplode_col = col
alias_names_tuple = None
from sparkless.functions import Column as ColumnType
source_col = posexplode_col.column
if isinstance(source_col, ColumnType):
source_col_name = source_col.name
elif isinstance(source_col, str):
source_col_name = source_col
else:
source_col_name = str(getattr(source_col, "name", ""))
if not source_col_name:
raise ValueError(
f"Cannot determine source column for posexplode: {posexplode_col}"
)
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
resolved_col_name = ColumnResolver.resolve_column_name(
source_col_name, list(df.columns), case_sensitive
)
if resolved_col_name is None:
raise ValueError(
f"Column '{source_col_name}' not found for posexplode"
)
# PySpark alias(*names) supports multiple names; posexplode output columns (pos, val)
_alias_names = alias_names_tuple or getattr(
posexplode_col, "_alias_names", None
)
if _alias_names and len(_alias_names) >= 2:
alias_names = (_alias_names[0], _alias_names[1])
else:
_first = (
getattr(col, "_alias_name", None)
or getattr(posexplode_col, "_alias_name", None)
or "pos"
)
alias_names = (_first, "col")
temp_name = "__posexplode_struct"
def _posexplode_list(arr: Any) -> Any:
if arr is None:
return []
# Polars 1.x passes Series or other iterables to map_elements for List cols
if hasattr(arr, "to_list"):
arr = arr.to_list()
elif hasattr(arr, "tolist"):
arr = arr.tolist()
elif not isinstance(arr, (list, tuple)):
return []
return [{"pos": i, "val": v} for i, v in enumerate(arr)]
# Polars requires return_dtype for map_elements; use List(Struct) for pos+val
list_dtype = df[resolved_col_name].dtype
val_dtype = (
list_dtype.inner
if hasattr(list_dtype, "inner") and list_dtype.inner is not None
else pl.Int64
)
posexplode_dtype = pl.List(
pl.Struct([pl.Field("pos", pl.Int64), pl.Field("val", val_dtype)])
)
posexplode_expr = (
pl.col(resolved_col_name)
.map_elements(
_posexplode_list,
return_dtype=posexplode_dtype,
)
.alias(temp_name)
)
select_exprs.append(posexplode_expr)
select_names.append(temp_name)
posexplode_pending.append((temp_name, alias_names[0], alias_names[1]))
continue
elif isinstance(col, WindowFunction) or (
isinstance(col, ColumnOperation)
and (
col.operation == "cast"
or col.operation == "alias"
or col.operation in ["*", "+", "-", "/", "**"]
)
):
# Handle window functions or ColumnOperation wrapping WindowFunction
# Use helper to recursively extract WindowFunction and all arithmetic operations
window_func, arithmetic_ops = (
self._extract_window_function_with_arithmetic(col)
)
# Also check for cast operation
cast_type = None
if isinstance(col, ColumnOperation) and col.operation == "cast":
if isinstance(col.column, WindowFunction):
cast_type = col.value
elif isinstance(col.column, ColumnOperation):
# Cast might be applied to arithmetic result
# Extract window function from nested operation
nested_window_func, _ = (
self._extract_window_function_with_arithmetic(col.column)
)
if nested_window_func:
window_func = nested_window_func
cast_type = col.value
# Re-extract arithmetic ops from the nested operation
_, nested_ops = (
self._extract_window_function_with_arithmetic(
col.column
)
)
arithmetic_ops = nested_ops
if window_func is None:
# Not a window function (e.g. 2 * F.col("x")) - use regular handling
alias_name = getattr(col, "name", None) or getattr(
col, "_alias_name", None
)
# Check for struct field access - need to check original column name if aliased
# When F.col("StructValue.E1").alias("E1-Extract") is used, col.name is the alias,
# but we need to check col._original_column._name or col.column.name for the struct path
struct_field_path = None
if hasattr(col, "name") and "." in col.name:
struct_field_path = col.name
elif hasattr(col, "_original_column") and hasattr(
col._original_column, "_name"
):
# Check original column name for struct field path
original_col = col._original_column
if (
original_col is not None
and hasattr(original_col, "_name")
and "." in original_col._name
):
struct_field_path = original_col._name
elif hasattr(col, "column") and hasattr(col.column, "name"):
# For ColumnOperation, check the column attribute
col_attr = col.column
if (
col_attr is not None
and hasattr(col_attr, "name")
and "." in col_attr.name
):
struct_field_path = col_attr.name
elif isinstance(col, ColumnOperation) and hasattr(col, "column"):
# For ColumnOperation, check if column is a Column with struct field
col_attr = col.column
if col_attr is not None:
if hasattr(col_attr, "_name") and "." in col_attr._name:
struct_field_path = col_attr._name
elif hasattr(col_attr, "name") and "." in col_attr.name:
struct_field_path = col_attr.name
if struct_field_path:
parts = struct_field_path.split(".", 1)
struct_col, field_name = parts[0], parts[1]
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
# After join with right prefix, "right_alias.column" -> "_right_column" (#380)
right_prefixed = f"_right_{field_name}"
if right_prefixed in df.columns:
expr = pl.col(right_prefixed)
if alias_name:
expr = expr.alias(alias_name)
select_exprs.append(expr)
select_names.append(alias_name or struct_field_path)
continue
resolved_struct_col = ColumnResolver.resolve_column_name(
struct_col, list(df.columns), case_sensitive
)
if resolved_struct_col and resolved_struct_col in df.columns:
struct_dtype = df[resolved_struct_col].dtype
if hasattr(struct_dtype, "fields") and struct_dtype.fields:
field_names = [f.name for f in struct_dtype.fields]
resolved_field = ColumnResolver.resolve_column_name(
field_name,
field_names,
case_sensitive,
)
if resolved_field:
expr = pl.col(resolved_struct_col).struct.field(
resolved_field
)
if alias_name:
expr = expr.alias(alias_name)
select_exprs.append(expr)
select_names.append(alias_name or struct_field_path)
continue
try:
case_sensitive = self._get_case_sensitive()
# Get schema: DataFrame has .schema, LazyFrame has .collect_schema()
if hasattr(df, "schema") and df.schema:
column_dtypes = dict(df.schema)
elif hasattr(df, "collect_schema"):
column_dtypes = dict(df.collect_schema())
else:
column_dtypes = {}
expr = self.translator.translate(
col,
available_columns=list(df.columns),
case_sensitive=case_sensitive,
column_dtypes=column_dtypes,
)
if alias_name:
expr = expr.alias(alias_name)
select_exprs.append(expr)
select_names.append(alias_name)
else:
select_exprs.append(expr)
col_name = getattr(col, "name", None)
select_names.append(
col_name
if col_name is not None
else f"col_{len(select_exprs)}"
)
except ValueError:
if rows_cache is None:
rows_cache = df.to_dicts()
if evaluator is None:
from sparkless.dataframe.evaluation.expression_evaluator import (
ExpressionEvaluator,
)
evaluator = ExpressionEvaluator()
values = [
self._evaluate_python_expression(row, col, evaluator)
for row in rows_cache
]
column_name_candidate = alias_name or getattr(col, "name", None)
if not column_name_candidate:
column_name_candidate = (
f"col_{len(select_exprs) + len(python_columns) + 1}"
)
column_name = str(column_name_candidate)
if isinstance(col, ColumnOperation) and col.operation in {
"to_json",
"to_csv",
}:
struct_alias = self._format_struct_alias(col.column)
column_name = f"{col.operation}({struct_alias})"
python_columns.append((column_name, values))
select_names.append(column_name)
continue
else:
# We found a WindowFunction - process it
# Save the original col for alias extraction later
original_col_for_alias = col
# Ensure function_name is set correctly
if (
not hasattr(window_func, "function_name")
or not window_func.function_name
or window_func.function_name == "window_function"
) and hasattr(window_func, "function"):
function_name_from_func = getattr(
window_func.function, "function_name", None
)
if function_name_from_func:
window_func.function_name = function_name_from_func
window_spec = window_func.window_spec
function_name = getattr(window_func, "function_name", "").upper()
case_sensitive = self._get_case_sensitive()
# Build sort columns from partition_by and order_by
# Window functions need the DataFrame sorted before evaluation
sort_cols = []
has_order_by = bool(
hasattr(window_spec, "_order_by") and window_spec._order_by
)
has_partition_by = bool(
hasattr(window_spec, "_partition_by")
and window_spec._partition_by
)
# Sort DataFrame for this window function before building expression
if has_order_by:
# Add partition_by columns first
if has_partition_by:
for col in window_spec._partition_by:
if isinstance(col, str):
name = col
elif hasattr(col, "name"):
name = col.name
else:
name = None
if name:
resolved = self._resolve_window_col_name(
df, name, case_sensitive
)
sort_cols.append(resolved)
# Add order_by columns
# Build sort parameters for window functions
window_sort_cols = []
window_desc_flags = []
window_nulls_last_flags = []
for col in window_spec._order_by:
col_name = None
is_desc = False
nulls_last = None # None means default
if isinstance(col, str):
col_name = col
is_desc = False # Default ascending
nulls_last = None
elif hasattr(col, "operation"):
operation = col.operation
if hasattr(col, "column") and hasattr(
col.column, "name"
):
col_name = col.column.name
elif hasattr(col, "name"):
col_name = col.name
# Handle nulls variant operations
if operation == "desc_nulls_last":
is_desc = True
nulls_last = True
elif operation == "desc_nulls_first":
is_desc = True
nulls_last = False
elif operation == "asc_nulls_last":
is_desc = False
nulls_last = True
elif operation == "asc_nulls_first":
is_desc = False
nulls_last = False
elif operation == "desc":
is_desc = True
nulls_last = (
True # PySpark default: nulls last for desc()
)
elif operation == "asc":
is_desc = False
nulls_last = (
True # PySpark default: nulls last for asc()
)
elif hasattr(col, "name"):
col_name = col.name
is_desc = False
nulls_last = True # PySpark default: nulls last
if col_name:
resolved_name = self._resolve_window_col_name(
df, col_name, case_sensitive
)
window_sort_cols.append(resolved_name)
window_desc_flags.append(is_desc)
window_nulls_last_flags.append(nulls_last)
# Also add to sort_cols for compatibility with existing code
sort_cols.append(resolved_name)
# Sort window function data with proper nulls handling
if window_sort_cols:
has_nulls_spec = any(
n is not None for n in window_nulls_last_flags
)
# Polars sort: for single column, descending should be bool, not list
if len(window_sort_cols) == 1:
descending_single: bool = (
window_desc_flags[0] if window_desc_flags else False
)
nulls_last_single: Optional[bool] = (
window_nulls_last_flags[0]
if window_nulls_last_flags
and window_nulls_last_flags[0] is not None
else None
)
if has_nulls_spec and nulls_last_single is not None:
df = df.sort(
window_sort_cols[0],
descending=descending_single,
nulls_last=nulls_last_single,
)
else:
df = df.sort(
window_sort_cols[0],
descending=descending_single,
)
else:
descending_list: List[bool] = window_desc_flags
nulls_last_list: Optional[List[Optional[bool]]] = (
window_nulls_last_flags if has_nulls_spec else None
)
if has_nulls_spec and nulls_last_list is not None:
df = df.sort(
window_sort_cols,
descending=descending_list,
nulls_last=nulls_last_list,
)
else:
df = df.sort(
window_sort_cols, descending=descending_list
)
# Sort if we have string column names (and haven't already sorted with expressions)
# CRITICAL: For lag/lead functions, we MUST sort before applying the window function
if function_name in ("LAG", "LEAD") and has_order_by:
# Rebuild sort_cols if needed to ensure we sort
if not sort_cols or not all(
isinstance(c, str) for c in sort_cols
):
sort_cols = []
if has_partition_by:
for col in window_spec._partition_by:
if isinstance(col, str):
name = col
elif hasattr(col, "name"):
name = col.name
else:
name = None
if name:
sort_cols.append(
self._resolve_window_col_name(
df, name, case_sensitive
)
)
for col in window_spec._order_by:
if isinstance(col, str):
name = col
elif hasattr(col, "name"):
name = col.name
else:
name = None
if name:
sort_cols.append(
self._resolve_window_col_name(
df, name, case_sensitive
)
)
if sort_cols and all(isinstance(c, str) for c in sort_cols):
# Use the same descending flags as window_sort_cols
if len(sort_cols) == len(window_desc_flags):
if len(sort_cols) == 1:
df = df.sort(
sort_cols[0], descending=window_desc_flags[0]
)
else:
df = df.sort(
sort_cols, descending=window_desc_flags
)
else:
df = df.sort(sort_cols)
elif sort_cols and all(isinstance(c, str) for c in sort_cols):
# Use the same descending flags as window_sort_cols if available
if window_sort_cols and len(sort_cols) == len(
window_desc_flags
):
if len(sort_cols) == 1:
df = df.sort(
sort_cols[0], descending=window_desc_flags[0]
)
else:
df = df.sort(sort_cols, descending=window_desc_flags)
else:
df = df.sort(sort_cols)
# Special handling for rank()/dense_rank() without column_expr
# Polars ranks by value, but PySpark ranks by position in ordered window
# Solution: use row_number() to get position, then min(row_number) over value
# Check if column_expr would be None (no column argument to rank/dense_rank)
column_expr_would_be_none = (
not hasattr(window_func, "column_name")
or window_func.column_name is None
or (
hasattr(window_func, "column_name")
and window_func.column_name
in {
"__rank__",
"__dense_rank__",
"__row_number__",
"__cume_dist__",
"__percent_rank__",
"__ntile__",
}
)
)
# For rank()/dense_rank() without a column, PySpark ranks by position
# Polars ranks by value, so we need position-based ranking
needs_position_based_rank = (
function_name in ("RANK", "DENSE_RANK")
and column_expr_would_be_none
and has_order_by
)
try:
if needs_position_based_rank:
# Add row_number column to DataFrame (since it's already sorted)
# IMPORTANT: The DataFrame should already be sorted by the window's orderBy
# at this point, so row_numbers reflect the position in the ordered window
row_num_col = "__row_number_for_rank__"
row_num_expr = pl.int_range(pl.len()) + 1
df = df.with_columns(row_num_expr.alias(row_num_col))
# Debug: verify row_num_col is in df
if row_num_col not in df.columns:
import warnings
warnings.warn(
f"row_num_col {row_num_col} not in df.columns: {df.columns}"
)
# Get the order column name for grouping
order_col_name = None
if window_spec._order_by:
first_order_col = window_spec._order_by[0]
if isinstance(first_order_col, str):
order_col_name = self._resolve_window_col_name(
df, first_order_col, case_sensitive
)
elif hasattr(first_order_col, "column") and hasattr(
first_order_col.column, "name"
):
order_col_name = self._resolve_window_col_name(
df, first_order_col.column.name, case_sensitive
)
elif hasattr(first_order_col, "name"):
order_col_name = self._resolve_window_col_name(
df, first_order_col.name, case_sensitive
)
if order_col_name:
# Use min(row_number) over order column to get position-based rank
# For dense_rank, use rank(method="dense") on row_numbers
# Debug: verify order_col_name is in df
if order_col_name not in df.columns:
import warnings
warnings.warn(
f"order_col_name {order_col_name} not in df.columns: {df.columns}"
)
if function_name == "DENSE_RANK":
# For dense_rank: use min(row_number) over order to get position,
# then add as a column and rank(method='dense') on that
# We can't nest window expressions, so we need to do it in two steps
if has_partition_by:
partition_cols = []
for p_col in window_spec._partition_by:
if isinstance(p_col, str):
p_name = self._resolve_window_col_name(
df, p_col, case_sensitive
)
elif hasattr(p_col, "name"):
p_name = self._resolve_window_col_name(
df, p_col.name, case_sensitive
)
else:
continue
if p_name:
partition_cols.append(p_name)
# Compute min(row_number) over partition+order
min_rn_col = f"__min_rn_for_dense_rank_{len(select_exprs)}__"
min_expr = (
pl.col(row_num_col)
.min()
.over(partition_cols + [order_col_name])
)
df = df.with_columns(min_expr.alias(min_rn_col))
# Then rank(method='dense') on that column
window_expr = pl.col(min_rn_col).rank(
method="dense"
)
else:
# Compute min(row_number) over order
min_rn_col = f"__min_rn_for_dense_rank_{len(select_exprs)}__"
min_expr = (
pl.col(row_num_col)
.min()
.over([order_col_name])
)
df = df.with_columns(min_expr.alias(min_rn_col))
# Then rank(method='dense') on that column
# Note: rank(method='dense') ranks globally, which is what we want
# because min_rn already groups by order_col_name
window_expr = pl.col(min_rn_col).rank(
method="dense"
)
else:
# For RANK, use min(row_number) over order column
if has_partition_by:
partition_cols = []
for p_col in window_spec._partition_by:
if isinstance(p_col, str):
p_name = self._resolve_window_col_name(
df, p_col, case_sensitive
)
elif hasattr(p_col, "name"):
p_name = self._resolve_window_col_name(
df, p_col.name, case_sensitive
)
else:
continue
if p_name:
partition_cols.append(p_name)
window_expr = (
pl.col(row_num_col)
.min()
.over(partition_cols + [order_col_name])
)
else:
window_expr = (
pl.col(row_num_col)
.min()
.over([order_col_name])
)
else:
# Fallback to regular rank if we can't extract order column
window_expr = (
self.window_handler.translate_window_function(
window_func, df, case_sensitive=case_sensitive
)
)
else:
window_expr = self.window_handler.translate_window_function(
window_func, df, case_sensitive=case_sensitive
)
# Apply all arithmetic operations in order (innermost to outermost)
# (applies to both position-based and regular window functions)
for op, val, is_reverse in arithmetic_ops:
right = self._arith_operand_to_polars(val, df)
if op == "*":
window_expr = window_expr * right
elif op == "+":
window_expr = window_expr + right
elif op == "-":
if is_reverse:
# Literal - WindowFunction
window_expr = pl.lit(val) - window_expr
elif isinstance(val, (Column, ColumnOperation)):
# Column - WindowFunction (left - window)
window_expr = right - window_expr
else:
# WindowFunction - value
window_expr = window_expr - right
elif op == "/":
if is_reverse:
window_expr = pl.lit(val) / window_expr
else:
window_expr = window_expr / right
elif op == "**":
window_expr = window_expr.pow(
val if isinstance(val, (int, float)) else right
)
# Apply cast if this is a cast operation
if cast_type is not None:
from .type_mapper import mock_type_to_polars_dtype
from sparkless.spark_types import (
StringType,
IntegerType,
LongType,
DoubleType,
FloatType,
BooleanType,
DateType,
TimestampType,
ShortType,
ByteType,
)
# Handle string type names
if isinstance(cast_type, str):
type_name_map = {
"string": StringType(),
"str": StringType(),
"int": IntegerType(),
"integer": IntegerType(),
"long": LongType(),
"bigint": LongType(),
"double": DoubleType(),
"float": FloatType(),
"boolean": BooleanType(),
"bool": BooleanType(),
"date": DateType(),
"timestamp": TimestampType(),
"short": ShortType(),
"byte": ByteType(),
}
cast_type = type_name_map.get(cast_type.lower())
if cast_type is not None:
polars_dtype = mock_type_to_polars_dtype(cast_type)
window_expr = window_expr.cast(
polars_dtype, strict=False
)
# Extract alias from the original col (before any processing)
# Check _alias_name first (from .alias() call), then name, then fallback to default
alias_name = (
getattr(original_col_for_alias, "_alias_name", None)
or getattr(original_col_for_alias, "name", None)
or (
f"{window_func.function_name.lower()}_window"
if hasattr(window_func, "function_name")
else "window_result"
)
)
select_exprs.append(window_expr.alias(alias_name))
select_names.append(alias_name)
except ValueError:
# Fallback to Python evaluation for unsupported window functions
# This path needs to handle arithmetic operations too
# IMPORTANT: rows_cache must be built from the sorted DataFrame
# Rebuild rows_cache from the current (sorted) df to ensure correct order
rows_cache = df.to_dicts()
if rows_cache is None:
rows_cache = []
if not hasattr(self, "_python_window_functions"):
self._python_window_functions: List[Any] = []
# Evaluate window function first
# rows_cache should be built from sorted DataFrame (done above)
results = window_func.evaluate(rows_cache)
# Apply arithmetic operations
for op, val, is_reverse in arithmetic_ops:
if op == "*":
results = [
r * val if r is not None else None for r in results
]
elif op == "+":
results = [
r + val if r is not None else None for r in results
]
elif op == "-":
if is_reverse:
results = [
val - r if r is not None else None
for r in results
]
else:
results = [
r - val if r is not None else None
for r in results
]
elif op == "/":
if is_reverse:
results = [
val / r if r is not None and r != 0 else None
for r in results
]
else:
results = [
r / val if r is not None else None
for r in results
]
elif op == "**":
results = [
r**val if r is not None else None for r in results
]
# Apply cast if needed
if cast_type is not None:
from sparkless.dataframe.casting.type_converter import (
TypeConverter,
)
from sparkless.spark_types import (
StringType,
IntegerType,
LongType,
DoubleType,
FloatType,
BooleanType,
DateType,
TimestampType,
ShortType,
ByteType,
)
# Handle string type names
if isinstance(cast_type, str):
type_name_map = {
"string": StringType(),
"str": StringType(),
"int": IntegerType(),
"integer": IntegerType(),
"long": LongType(),
"bigint": LongType(),
"double": DoubleType(),
"float": FloatType(),
"boolean": BooleanType(),
"bool": BooleanType(),
"date": DateType(),
"timestamp": TimestampType(),
"short": ShortType(),
"byte": ByteType(),
}
cast_type = type_name_map.get(cast_type.lower())
if cast_type is not None:
results = [
TypeConverter.cast_to_type(r, cast_type)
if r is not None
else None
for r in results
]
# Extract alias from the original col (before any processing)
# Check _alias_name first (from .alias() call), then name, then fallback to default
alias_name = (
getattr(original_col_for_alias, "_alias_name", None)
or getattr(original_col_for_alias, "name", None)
or (
f"{window_func.function_name.lower()}_window"
if hasattr(window_func, "function_name")
else "window_result"
)
)
# Store for later addition to result (will be added after select_exprs are processed)
# Format: (alias_name, results_list, arithmetic_ops, cast_type)
if not hasattr(self, "_python_window_functions"):
self._python_window_functions = []
self._python_window_functions.append(
(alias_name, results, arithmetic_ops, cast_type)
)
select_names.append(alias_name)
continue
else:
alias_name = getattr(col, "name", None) or getattr(
col, "_alias_name", None
)
# Handle nested struct field access for Column objects (e.g., F.col("Person.name"))
# When F.col("StructValue.E1").alias("E1-Extract") is used, col.name is the alias,
# but we need to check col._original_column._name for the struct field path
struct_field_path = None
if hasattr(col, "name") and "." in col.name:
struct_field_path = col.name
elif (
hasattr(col, "_original_column")
and hasattr(col._original_column, "_name")
and "." in col._original_column._name
):
# Check original column name for struct field path (for aliased columns)
struct_field_path = col._original_column._name
if struct_field_path:
# Split into struct column and field name
parts = struct_field_path.split(".", 1)
struct_col = parts[0]
field_name = parts[1]
# Resolve struct column name case-insensitively
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
resolved_struct_col = ColumnResolver.resolve_column_name(
struct_col, list(df.columns), case_sensitive
)
if resolved_struct_col and resolved_struct_col in df.columns:
# Get struct dtype
struct_dtype = df[resolved_struct_col].dtype
if hasattr(struct_dtype, "fields") and struct_dtype.fields:
# Resolve field name case-insensitively within struct
field_names = [f.name for f in struct_dtype.fields]
resolved_field = ColumnResolver.resolve_column_name(
field_name, field_names, case_sensitive
)
if resolved_field:
# Use Polars struct.field() syntax for nested access
expr = pl.col(resolved_struct_col).struct.field(
resolved_field
)
if alias_name:
expr = expr.alias(alias_name)
select_exprs.append(expr)
select_names.append(alias_name or struct_field_path)
continue
try:
# Pass available columns and case sensitivity for column resolution
case_sensitive = self._get_case_sensitive()
expr = self.translator.translate(
col,
available_columns=list(df.columns),
case_sensitive=case_sensitive,
)
if alias_name:
expr = expr.alias(alias_name)
select_exprs.append(expr)
select_names.append(alias_name)
else:
select_exprs.append(expr)
if hasattr(col, "name"):
select_names.append(col.name)
elif isinstance(col, str):
select_names.append(col)
else:
select_names.append(f"col_{len(select_exprs)}")
except ValueError:
# Fallback to Python evaluation for unsupported expressions
if rows_cache is None:
rows_cache = df.to_dicts()
if evaluator is None:
from sparkless.dataframe.evaluation.expression_evaluator import (
ExpressionEvaluator,
)
evaluator = ExpressionEvaluator()
values = [
self._evaluate_python_expression(row, col, evaluator)
for row in rows_cache
]
column_name_candidate = alias_name or getattr(col, "name", None)
if not column_name_candidate:
column_name_candidate = (
f"col_{len(select_exprs) + len(python_columns) + 1}"
)
column_name = str(column_name_candidate)
if isinstance(col, ColumnOperation) and col.operation in {
"to_json",
"to_csv",
}:
struct_alias = self._format_struct_alias(col.column)
column_name = f"{col.operation}({struct_alias})"
python_columns.append((column_name, values))
select_names.append(column_name)
continue
if not select_exprs and not python_columns:
return df
# Check if any column uses explode or explode_outer operation
has_explode = False
has_explode_outer = False
explode_index = None
explode_outer_index = None
for i, col in enumerate(columns):
col_operation = getattr(col, "operation", None) or getattr(
col, "function_name", None
)
if col_operation == "explode":
has_explode = True
explode_index = i
elif col_operation == "explode_outer":
has_explode_outer = True
explode_outer_index = i
if select_exprs:
try:
if has_explode or has_explode_outer:
result = df.select(select_exprs)
exploded_col_name = None
if (
has_explode
and explode_index is not None
and explode_index < len(select_names)
):
exploded_col_name = select_names[explode_index]
elif (
has_explode_outer
and explode_outer_index is not None
and explode_outer_index < len(select_names)
):
exploded_col_name = select_names[explode_outer_index]
if exploded_col_name:
result = result.explode(exploded_col_name)
elif posexplode_pending:
result = df.select(select_exprs)
try:
for temp_name, name0, name1 in posexplode_pending:
if temp_name not in result.columns:
raise ValueError(
f"posexplode temp column '{temp_name}' not in result; "
f"columns={result.columns}"
)
temp_dtype = result[temp_name].dtype
# map_elements may return List(Struct) or Struct; only explode if list
if hasattr(temp_dtype, "inner") or str(
temp_dtype
).startswith("List"):
result = result.explode(temp_name)
result = result.unnest(temp_name)
result = result.rename({"pos": name0, "val": name1})
except Exception as _posex_err:
raise
# Update select_names so downstream reorder logic sees final column names
for temp_name, name0, name1 in posexplode_pending:
if temp_name in select_names:
idx = select_names.index(temp_name)
select_names[idx : idx + 1] = [name0, name1]
else:
result = df.select(select_exprs)
except Exception as e:
# Catch Polars InvalidOperationError for unsupported casts
# (e.g., string to boolean which Polars doesn't support directly)
import polars.exceptions
error_msg = str(e)
error_type = type(e).__name__
is_invalid_cast = (
isinstance(e, polars.exceptions.InvalidOperationError)
or "InvalidOperationError" in error_type
or (
"not supported" in error_msg.lower()
and "casting" in error_msg.lower()
)
)
if is_invalid_cast:
# Fallback to Python evaluation for all columns when Polars cast fails
# This handles cases like string to boolean where Polars doesn't support the cast
if rows_cache is None:
rows_cache = df.to_dicts()
if evaluator is None:
from sparkless.dataframe.evaluation.expression_evaluator import (
ExpressionEvaluator,
)
evaluator = ExpressionEvaluator()
# Evaluate all columns in Python (both translated and Python ones)
all_python_columns: List[Tuple[str, List[Any]]] = []
all_python_names: List[str] = []
# Process all columns (including ones that were in select_exprs)
for col in columns:
alias_name = None
if isinstance(col, tuple):
col, alias_name = col
values = [
self._evaluate_python_expression(row, col, evaluator)
for row in rows_cache
]
column_name_candidate = alias_name or getattr(col, "name", None)
if not column_name_candidate:
column_name_candidate = f"col_{len(all_python_columns) + 1}"
column_name = str(column_name_candidate)
if isinstance(col, ColumnOperation) and col.operation in {
"to_json",
"to_csv",
}:
struct_alias = self._format_struct_alias(col.column)
column_name = f"{col.operation}({struct_alias})"
all_python_columns.append((column_name, values))
all_python_names.append(column_name)
# Build result from Python-evaluated columns
if all_python_columns:
data_dict = dict(all_python_columns)
result = pl.DataFrame(data_dict)
else:
return df
else:
# Re-raise if it's not a cast-related error
raise
elif python_columns:
# Only Python-evaluated columns; build DataFrame from values
data_dict = dict(python_columns)
result = pl.DataFrame(data_dict)
else:
return df
# Special handling: if we're selecting only literals (no column references),
# Polars returns 1 row by default. We need to ensure the literal broadcasts
# to all rows in the source DataFrame.
# Check if result has fewer rows than source and we're selecting expressions
# (not string column names)
if select_exprs and len(result) == 1 and len(df) > 1:
# Check if all selected items are expressions (not string column names)
# If all are expressions and none reference columns from df, they're literals
has_column_reference = False
for col in columns:
if isinstance(col, str):
# String column name - definitely references a column
has_column_reference = True
break
# Check if expression references columns from the DataFrame
# We can't easily inspect Polars expressions, so we use a heuristic:
# If the result has 1 row and source has >1 rows, and we're not selecting
# string column names, it's likely all literals
# If no column references and result is shorter, replicate
if not has_column_reference and len(result) < len(df):
# Replicate the single row to match DataFrame length
result = pl.concat([result] * len(df))
# Append Python-evaluated columns
for name, values in python_columns:
result = result.with_columns(pl.Series(name, values))
# Evaluate window functions that require Python evaluation
# These need to be evaluated across all rows, not row-by-row
# Initialize had_python_window_functions before the if block to avoid UnboundLocalError
had_python_window_functions = hasattr(
self, "_python_window_functions"
) and bool(getattr(self, "_python_window_functions", None))
if hasattr(self, "_python_window_functions") and self._python_window_functions:
# Check if we have pre-computed results (with arithmetic) or need to evaluate
first_item = self._python_window_functions[0]
if len(first_item) == 4:
# New format: (alias_name, results, arithmetic_ops, cast_type)
# Results are already computed with arithmetic applied
for (
alias_name,
results,
arithmetic_ops,
cast_type,
) in self._python_window_functions:
# Apply cast if needed
if cast_type is not None:
from sparkless.dataframe.casting.type_converter import (
TypeConverter,
)
from sparkless.spark_types import (
StringType,
IntegerType,
LongType,
DoubleType,
FloatType,
BooleanType,
DateType,
TimestampType,
ShortType,
ByteType,
)
# Handle string type names
if isinstance(cast_type, str):
type_name_map = {
"string": StringType(),
"str": StringType(),
"int": IntegerType(),
"integer": IntegerType(),
"long": LongType(),
"bigint": LongType(),
"double": DoubleType(),
"float": FloatType(),
"boolean": BooleanType(),
"bool": BooleanType(),
"date": DateType(),
"timestamp": TimestampType(),
"short": ShortType(),
"byte": ByteType(),
}
cast_type = type_name_map.get(cast_type.lower())
if cast_type is not None:
results = [
TypeConverter.cast_to_type(r, cast_type)
if r is not None
else None
for r in results
]
result = result.with_columns(pl.Series(alias_name, results))
else:
# Old format: (alias_name, window_func, _)
from sparkless.dataframe.window_handler import WindowFunctionHandler
from sparkless.dataframe import DataFrame
# Use the cached rows for window function evaluation
data_rows = rows_cache if rows_cache else result.to_dicts()
# Create a temporary DataFrame for window function evaluation
from sparkless.spark_types import StructType
temp_df = DataFrame(data_rows, StructType([]), None)
window_handler = WindowFunctionHandler(temp_df)
# Evaluate all window functions
for alias_name, window_func, _ in self._python_window_functions:
# Evaluate window function across all rows
window_handler.evaluate_window_functions(
data_rows, [(alias_name, window_func)]
)
# Extract values from evaluated data
values = [get_row_value(row, alias_name) for row in data_rows]
result = result.with_columns(pl.Series(alias_name, values))
# Clean up
if had_python_window_functions:
delattr(self, "_python_window_functions")
# Only reorder if we have python_columns AND the order doesn't match
# This ensures we preserve all columns while matching the requested order
# Note: When aliases are applied, result.columns should already match select_names,
# so reordering should preserve the aliased column names
if select_names and (python_columns or had_python_window_functions):
existing_cols = list(result.columns)
# Check if reordering is needed and safe
# select_names contains the requested column names (with aliases applied)
# existing_cols should match select_names if aliases were applied correctly
if existing_cols != select_names and all(
name in existing_cols for name in select_names
):
# Reorder using select_names - these should already be in the DataFrame from aliases
result = result.select(select_names)
return result
def _evaluate_python_expression(
self,
row: Dict[str, Any],
expression: Any,
evaluator: ExpressionEvaluator,
) -> Any:
"""Evaluate expressions that require Python fallbacks."""
if isinstance(expression, ColumnOperation):
op_name = expression.operation
if op_name == "from_json":
return self._python_from_json(row, expression)
if op_name == "to_json":
return self._python_to_json(row, expression)
if op_name == "to_csv":
return self._python_to_csv(row, expression)
return evaluator.evaluate_expression(row, expression)
def _get_struct_field_names(self, column_name: str, struct_dtype: Any) -> List[str]:
"""Return struct field names, caching results when shortcuts are enabled."""
if not hasattr(struct_dtype, "fields") or not struct_dtype.fields:
return []
cache_key = (column_name, repr(struct_dtype))
if self._shortcuts_enabled:
cached = self._struct_field_cache.get(cache_key)
if cached is not None:
return cached
field_names = [field.name for field in struct_dtype.fields]
if self._shortcuts_enabled:
# Store a shallow copy in case downstream users mutate the list.
self._struct_field_cache[cache_key] = list(field_names)
return field_names
def _python_from_json(
self, row: Dict[str, Any], expression: ColumnOperation
) -> Any:
column_name = self._extract_column_name(expression.column)
if not column_name:
return None
raw_value = get_row_value(row, column_name)
if raw_value is None:
return None
schema_spec, _ = self._unpack_schema_and_options(expression)
try:
parsed = json.loads(raw_value)
except json.JSONDecodeError:
return None
schema = self._resolve_struct_schema(schema_spec)
if schema is None:
return parsed
if not isinstance(parsed, dict):
return None
return {field.name: parsed.get(field.name) for field in schema.fields}
def _python_to_json(
self, row: Dict[str, Any], expression: ColumnOperation
) -> Union[str, None]:
field_names = self._extract_struct_field_names(expression.column)
if not field_names:
return None
struct_dict = {name: get_row_value(row, name) for name in field_names}
return json.dumps(struct_dict, ensure_ascii=False, separators=(",", ":"))
def _python_to_csv(
self, row: Dict[str, Any], expression: ColumnOperation
) -> Union[str, None]:
field_names = self._extract_struct_field_names(expression.column)
if not field_names:
return None
values = []
for name in field_names:
val = get_row_value(row, name)
values.append("" if val is None else str(val))
return ",".join(values)
def _extract_column_name(self, expr: Any) -> Optional[str]:
if isinstance(expr, Column):
name = expr.name
return str(name) if name is not None else None
if isinstance(expr, ColumnOperation) and hasattr(expr, "name"):
name_attr = getattr(expr, "name", None)
if name_attr is not None:
return str(name_attr)
return None
if isinstance(expr, str):
return expr
name_attr = getattr(expr, "name", None)
if name_attr is not None:
return str(name_attr)
return None
def _extract_struct_field_names(self, expr: Any) -> List[str]:
names: List[str] = []
if isinstance(expr, ColumnOperation) and expr.operation == "struct":
first = self._extract_column_name(expr.column)
if first:
names.append(first)
additional = expr.value
if isinstance(additional, tuple):
for item in additional:
name = self._extract_column_name(item)
if name:
names.append(name)
else:
name = self._extract_column_name(expr)
if name:
names.append(name)
return names
def _format_struct_alias(self, expr: Any) -> str:
names = self._extract_struct_field_names(expr)
if names:
return f"struct({', '.join(names)})"
return "struct(...)"
def _unpack_schema_and_options(
self, expression: ColumnOperation
) -> Tuple[Any, Dict[str, Any]]:
schema_spec: Any = None
options: Dict[str, Any] = {}
raw_value = getattr(expression, "value", None)
if isinstance(raw_value, tuple):
if len(raw_value) >= 1:
schema_spec = raw_value[0]
if len(raw_value) >= 2 and isinstance(raw_value[1], dict):
options = dict(raw_value[1])
elif isinstance(raw_value, dict):
options = dict(raw_value)
return schema_spec, options
def _resolve_struct_schema(self, schema_spec: Any) -> Union[StructType, None]:
if schema_spec is None:
return None
if isinstance(schema_spec, StructType):
return schema_spec
if hasattr(schema_spec, "value"):
return self._resolve_struct_schema(schema_spec.value)
if isinstance(schema_spec, str):
try:
return parse_ddl_schema(schema_spec)
except Exception:
logger.debug(
"parse_ddl_schema failed, returning empty StructType",
exc_info=True,
)
return StructType([])
return None
[docs]
@profiled("polars.apply_with_column", category="polars")
def apply_with_column(
self,
df: pl.DataFrame,
column_name: str,
expression: Any,
expected_field: Any = None,
) -> pl.DataFrame:
"""Apply a withColumn operation.
Args:
df: Source Polars DataFrame
column_name: Name of new/updated column
expression: Expression for the column
Returns:
DataFrame with new column
"""
# Check if expression is a ColumnOperation containing window function arithmetic
# This must be checked before the simple window function check
window_func, arithmetic_ops = self._extract_window_function_with_arithmetic(
expression
)
cast_type = None
# Check for cast operation on window function arithmetic
if isinstance(expression, ColumnOperation) and expression.operation == "cast":
if window_func:
cast_type = expression.value
elif isinstance(expression.column, ColumnOperation):
# Cast might be applied to arithmetic result
nested_window_func, nested_ops = (
self._extract_window_function_with_arithmetic(expression.column)
)
if nested_window_func:
window_func = nested_window_func
arithmetic_ops = nested_ops
cast_type = expression.value
# Check if expression is a WindowFunction or a ColumnOperation wrapping a WindowFunction (e.g., WindowFunction.cast())
is_window_function = isinstance(expression, WindowFunction)
is_window_function_cast = (
isinstance(expression, ColumnOperation)
and expression.operation == "cast"
and isinstance(expression.column, WindowFunction)
)
if window_func is not None or is_window_function or is_window_function_cast:
# Window functions need special handling
# For window functions with order_by, we need to sort the DataFrame first
# to ensure correct window function evaluation
# Extract the WindowFunction if it's wrapped in a ColumnOperation
# Use extracted window_func if available, otherwise use existing logic
if window_func is None:
window_func = (
expression
if isinstance(expression, WindowFunction)
else expression.column
)
cast_type = expression.value if is_window_function_cast else None
# If window_func was extracted from arithmetic, arithmetic_ops and cast_type are already set
# Ensure function_name is set correctly - extract from function if needed
if (
not hasattr(window_func, "function_name")
or not window_func.function_name
or window_func.function_name == "window_function"
) and hasattr(window_func, "function"):
function_name_from_func = getattr(
window_func.function, "function_name", None
)
if function_name_from_func:
window_func.function_name = function_name_from_func
window_spec = window_func.window_spec
function_name = getattr(window_func, "function_name", "").upper()
case_sensitive = self._get_case_sensitive()
# Build sort columns from partition_by and order_by
sort_cols = []
has_order_by = hasattr(window_spec, "_order_by") and window_spec._order_by
has_partition_by = (
hasattr(window_spec, "_partition_by") and window_spec._partition_by
)
if has_order_by:
# Add partition_by columns first
if has_partition_by:
for col in window_spec._partition_by:
if isinstance(col, str):
name = col
elif hasattr(col, "name"):
name = col.name
else:
name = None
if name:
resolved = self._resolve_window_col_name(
df, name, case_sensitive
)
sort_cols.append(resolved)
# Add order_by columns
# Build sort parameters for window functions
window_sort_cols = []
window_desc_flags = []
window_nulls_last_flags = []
for col in window_spec._order_by:
col_name = None
is_desc = False
nulls_last = None # None means default
if isinstance(col, str):
col_name = col
is_desc = False # Default ascending
nulls_last = None
elif hasattr(col, "operation"):
operation = col.operation
if hasattr(col, "column") and hasattr(col.column, "name"):
col_name = col.column.name
elif hasattr(col, "name"):
col_name = col.name
# Handle nulls variant operations
if operation == "desc_nulls_last":
is_desc = True
nulls_last = True
elif operation == "desc_nulls_first":
is_desc = True
nulls_last = False
elif operation == "asc_nulls_last":
is_desc = False
nulls_last = True
elif operation == "asc_nulls_first":
is_desc = False
nulls_last = False
elif operation == "desc":
is_desc = True
nulls_last = True # PySpark default: nulls last for desc()
elif operation == "asc":
is_desc = False
nulls_last = True # PySpark default: nulls last for asc()
elif hasattr(col, "name"):
col_name = col.name
is_desc = False
nulls_last = True # PySpark default: nulls last
if col_name:
resolved_name = self._resolve_window_col_name(
df, col_name, case_sensitive
)
window_sort_cols.append(resolved_name)
window_desc_flags.append(is_desc)
window_nulls_last_flags.append(nulls_last)
# Also add to sort_cols for compatibility with existing code
sort_cols.append(resolved_name)
# Sort window function data with proper nulls handling
if window_sort_cols:
has_nulls_spec = any(n is not None for n in window_nulls_last_flags)
if has_nulls_spec:
df = df.sort(
window_sort_cols,
descending=window_desc_flags,
nulls_last=window_nulls_last_flags,
)
else:
df = df.sort(window_sort_cols, descending=window_desc_flags)
# Sort if we have string column names (and haven't already sorted with expressions)
# CRITICAL: For lag/lead functions, we MUST sort before applying the window function
if function_name in ("LAG", "LEAD") and has_order_by:
# Rebuild sort_cols if needed to ensure we sort
if not sort_cols or not all(isinstance(c, str) for c in sort_cols):
sort_cols = []
if has_partition_by:
for col in window_spec._partition_by:
if isinstance(col, str):
name = col
elif hasattr(col, "name"):
name = col.name
else:
name = None
if name:
sort_cols.append(
self._resolve_window_col_name(
df, name, case_sensitive
)
)
for col in window_spec._order_by:
if isinstance(col, str):
name = col
elif hasattr(col, "name"):
name = col.name
else:
name = None
if name:
sort_cols.append(
self._resolve_window_col_name(df, name, case_sensitive)
)
if sort_cols and all(isinstance(c, str) for c in sort_cols):
df = df.sort(sort_cols)
elif sort_cols and all(isinstance(c, str) for c in sort_cols):
df = df.sort(sort_cols)
try:
window_expr = self.window_handler.translate_window_function(
window_func, df, case_sensitive=case_sensitive
)
# Apply all arithmetic operations in order (innermost to outermost)
for op, val, is_reverse in arithmetic_ops:
right = self._arith_operand_to_polars(val, df)
if op == "*":
window_expr = window_expr * right
elif op == "+":
window_expr = window_expr + right
elif op == "-":
if is_reverse:
# Literal - WindowFunction
window_expr = pl.lit(val) - window_expr
elif isinstance(val, ColumnOperation) or (
hasattr(val, "name") and hasattr(val, "column_type")
):
# Column - WindowFunction (left - window)
# Check for Column by checking for Column-like attributes
window_expr = right - window_expr
else:
# WindowFunction - value
window_expr = window_expr - right
elif op == "/":
if is_reverse:
window_expr = pl.lit(val) / window_expr
else:
window_expr = window_expr / right
elif op == "**":
window_expr = window_expr.pow(
val if isinstance(val, (int, float)) else right
)
# Apply cast if this is a cast operation
if (
is_window_function_cast or cast_type is not None
) and cast_type is not None:
from .type_mapper import mock_type_to_polars_dtype
from sparkless.spark_types import (
StringType,
IntegerType,
LongType,
DoubleType,
FloatType,
BooleanType,
DateType,
TimestampType,
ShortType,
ByteType,
)
# Handle string type names
if isinstance(cast_type, str):
type_name_map = {
"string": StringType(),
"str": StringType(),
"int": IntegerType(),
"integer": IntegerType(),
"long": LongType(),
"bigint": LongType(),
"double": DoubleType(),
"float": FloatType(),
"boolean": BooleanType(),
"bool": BooleanType(),
"date": DateType(),
"timestamp": TimestampType(),
"short": ShortType(),
"byte": ByteType(),
}
cast_type = type_name_map.get(cast_type.lower())
if cast_type is not None:
polars_dtype = mock_type_to_polars_dtype(cast_type)
window_expr = window_expr.cast(polars_dtype, strict=False)
result = df.with_columns(window_expr.alias(column_name))
except ValueError:
# Fallback to Python evaluation for unsupported window functions
# Ensure data is sorted before evaluation (df should already be sorted above)
# Convert Polars DataFrame to list of dicts for Python evaluation
data = df.to_dicts()
# Evaluate window function using Python implementation
# WindowFunction.evaluate() expects sorted data for correct results
results = window_func.evaluate(data)
# Apply arithmetic operations
for op, val, is_reverse in arithmetic_ops:
if op == "*":
results = [r * val if r is not None else None for r in results]
elif op == "+":
results = [r + val if r is not None else None for r in results]
elif op == "-":
if is_reverse:
results = [
val - r if r is not None else None for r in results
]
else:
results = [
r - val if r is not None else None for r in results
]
elif op == "/":
if is_reverse:
results = [
val / r if r is not None and r != 0 else None
for r in results
]
else:
results = [
r / val if r is not None else None for r in results
]
elif op == "**":
results = [r**val if r is not None else None for r in results]
# Apply cast if this is a cast operation
if (
is_window_function_cast or cast_type is not None
) and cast_type is not None:
from sparkless.dataframe.casting.type_converter import TypeConverter
from sparkless.spark_types import (
StringType,
IntegerType,
LongType,
DoubleType,
FloatType,
BooleanType,
DateType,
TimestampType,
ShortType,
ByteType,
)
# Handle string type names
if isinstance(cast_type, str):
type_name_map = {
"string": StringType(),
"str": StringType(),
"int": IntegerType(),
"integer": IntegerType(),
"long": LongType(),
"bigint": LongType(),
"double": DoubleType(),
"float": FloatType(),
"boolean": BooleanType(),
"bool": BooleanType(),
"date": DateType(),
"timestamp": TimestampType(),
"short": ShortType(),
"byte": ByteType(),
}
cast_type = type_name_map.get(cast_type.lower())
if cast_type is not None:
results = [
TypeConverter.cast_to_type(r, cast_type)
if r is not None
else None
for r in results
]
# Add results as new column
result = df.with_columns(pl.Series(column_name, results))
# For lag/lead, ensure result maintains sort order
# CRITICAL: Must sort result to preserve correct window function values
if function_name in ("LAG", "LEAD"):
# Rebuild sort_cols if needed
if not sort_cols or not all(isinstance(c, str) for c in sort_cols):
sort_cols = []
if has_partition_by:
for col in window_spec._partition_by:
if isinstance(col, str):
sort_cols.append(col)
elif hasattr(col, "name"):
sort_cols.append(col.name)
if has_order_by:
for col in window_spec._order_by:
if isinstance(col, str):
sort_cols.append(col)
elif hasattr(col, "name"):
sort_cols.append(col.name)
if sort_cols and all(isinstance(c, str) for c in sort_cols):
result = result.sort(sort_cols)
return result
else:
# Check if this is an explode or explode_outer operation
# These operations need special handling: they explode arrays into multiple rows
is_explode = (
isinstance(expression, ColumnOperation)
and expression.operation == "explode"
)
is_explode_outer = (
isinstance(expression, ColumnOperation)
and expression.operation == "explode_outer"
)
if is_explode or is_explode_outer:
# For explode operations, we need to:
# 1. Get the source column (the array column to explode)
# 2. Add it as a new column with the array values
# 3. Explode the DataFrame on that new column (this creates multiple rows)
from sparkless.functions import Column as ColumnType
# Get the source column name
source_col = expression.column
if isinstance(source_col, ColumnType):
source_col_name = source_col.name
elif isinstance(source_col, str):
source_col_name = source_col
else:
# Fallback: try to get name from column attribute
source_col_name_attr = getattr(source_col, "name", None)
if source_col_name_attr is None:
raise ValueError(
f"Cannot determine source column for explode operation: {expression}"
)
source_col_name = str(source_col_name_attr)
# Resolve column name (case-insensitive)
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
resolved_col_name = ColumnResolver.resolve_column_name(
source_col_name, list(df.columns), case_sensitive
)
if resolved_col_name is None:
raise ValueError(
f"Column '{source_col_name}' not found in DataFrame for explode operation"
)
# Check if the column exists
if resolved_col_name not in df.columns:
raise ValueError(
f"Column '{resolved_col_name}' not found in DataFrame"
)
# Add the new column with array values from the source column
# Then explode the DataFrame on this new column
# This will create multiple rows, one for each element in the array
result = df.with_columns(pl.col(resolved_col_name).alias(column_name))
# Explode the DataFrame on the new column
result = result.explode(column_name)
# For regular explode, filter out rows where the exploded value is None
# (PySpark drops rows with null/empty arrays)
# For explode_outer, keep all rows (including those with None)
if not is_explode_outer:
# Filter out rows where ExplodedValue is None
result = result.filter(pl.col(column_name).is_not_null())
return result
# Check if this is a to_timestamp operation and if the input column is a string
# This helps us choose the right method (str.strptime vs map_elements)
input_col_dtype = None
# ColumnOperation is already imported at the top of the file
if (
isinstance(expression, ColumnOperation)
and expression.operation == "to_timestamp"
):
# Check if the input column is a simple Column (direct column reference)
from sparkless.functions import Column as ColumnType
if isinstance(expression.column, ColumnType) and not isinstance(
expression.column, ColumnOperation
):
# Check the dtype of the input column in the DataFrame
col_name = expression.column.name
if col_name in df.columns:
input_col_dtype = df[col_name].dtype
elif isinstance(expression.column, ColumnOperation):
# For ColumnOperation chains, check if the result is a string type
# This handles cases like regexp_replace().cast("string")
col_op = expression.column
# Check if it's a cast to string
if col_op.operation == "cast":
cast_target = col_op.value
if isinstance(cast_target, str) and cast_target.lower() in [
"string",
"varchar",
]:
input_col_dtype = pl.Utf8
# Check if it's a string operation (regexp_replace, substring, etc.)
elif col_op.operation in [
"regexp_replace",
"substring",
"concat",
"upper",
"lower",
"trim",
"ltrim",
"rtrim",
]:
input_col_dtype = pl.Utf8
# For nested ColumnOperations, check recursively
elif isinstance(col_op.column, ColumnOperation):
# Recursively check the inner operation
inner_op = col_op.column
if inner_op.operation == "cast":
cast_target = inner_op.value
if isinstance(cast_target, str) and cast_target.lower() in [
"string",
"varchar",
]:
input_col_dtype = pl.Utf8
elif inner_op.operation in [
"regexp_replace",
"substring",
"concat",
"upper",
"lower",
"trim",
"ltrim",
"rtrim",
]:
input_col_dtype = pl.Utf8
# Check if expression is a comparison operation with WindowFunction
# (e.g., F.row_number().over(w) > 0)
# Also check for isnull/isnotnull operations on WindowFunction
is_window_function_comparison = (
isinstance(expression, ColumnOperation)
and expression.operation
in [">", "<", ">=", "<=", "==", "!=", "eqNullSafe"]
and isinstance(expression.column, WindowFunction)
)
is_window_function_isnull = (
isinstance(expression, ColumnOperation)
and expression.operation in ["isnull", "isnotnull"]
and isinstance(expression.column, WindowFunction)
)
if is_window_function_comparison or is_window_function_isnull:
# Handle comparison operations or isnull/isnotnull operations with WindowFunction
# First, apply the window function to get a temporary column
window_func = expression.column
operation = expression.operation
operation_value = expression.value
# Apply window function to get a temporary column
temp_col_name = f"__window_func_temp_{column_name}"
df_with_window = self.apply_with_column(
df, temp_col_name, window_func, expected_field=None
)
# Now create the operation expression: temp_col_name op value (or isnull/isnotnull)
from sparkless.functions.core.column import Column
temp_col = Column(temp_col_name)
if operation == ">":
operation_expr = temp_col > operation_value
elif operation == "<":
operation_expr = temp_col < operation_value
elif operation == ">=":
operation_expr = temp_col >= operation_value
elif operation == "<=":
operation_expr = temp_col <= operation_value
elif operation == "==":
operation_expr = temp_col == operation_value
elif operation == "!=":
operation_expr = temp_col != operation_value
elif operation == "eqNullSafe":
operation_expr = temp_col.eqNullSafe(operation_value)
elif operation == "isnull":
operation_expr = temp_col.isnull()
elif operation == "isnotnull":
operation_expr = temp_col.isnotnull()
else:
operation_expr = ColumnOperation(
temp_col, operation, operation_value
)
# Translate and apply the operation
case_sensitive = self._get_case_sensitive()
operation_pl_expr = self.translator.translate(
operation_expr,
available_columns=list(df_with_window.columns),
case_sensitive=case_sensitive,
)
# Apply the operation and drop the temporary column
result_df = (
df_with_window.with_columns(
pl.col(temp_col_name).alias(column_name)
)
.with_columns(operation_pl_expr.alias(column_name))
.drop(temp_col_name)
)
return result_df
# Check if expression is a CaseWhen that contains WindowFunction comparisons
from sparkless.functions.conditional import CaseWhen
if isinstance(expression, CaseWhen):
# Check if any condition contains a WindowFunction comparison or isnull/isnotnull
has_window_function_comparison = False
for condition, _ in expression.conditions:
if (
isinstance(condition, ColumnOperation)
and isinstance(condition.column, WindowFunction)
and condition.operation
in [
">",
"<",
">=",
"<=",
"==",
"!=",
"eqNullSafe",
"isnull",
"isnotnull",
]
):
has_window_function_comparison = True
break
if has_window_function_comparison:
# Handle CaseWhen with WindowFunction comparisons
# First, apply all window functions to get temporary columns
temp_cols: Dict[int, str] = {}
current_df = df
for condition, _ in expression.conditions:
if isinstance(condition, ColumnOperation) and isinstance(
condition.column, WindowFunction
):
operation = condition.operation
if operation in [
">",
"<",
">=",
"<=",
"==",
"!=",
"eqNullSafe",
"isnull",
"isnotnull",
]:
window_func = condition.column
# Create a unique temp column name
temp_col_name = f"__window_func_temp_{len(temp_cols)}"
temp_cols[id(window_func)] = temp_col_name
# Apply window function
current_df = self.apply_with_column(
current_df,
temp_col_name,
window_func,
expected_field=None,
)
# Now replace WindowFunction comparisons with column comparisons
from sparkless.functions.core.column import Column
new_conditions = []
for condition, value in expression.conditions:
if isinstance(condition, ColumnOperation) and isinstance(
condition.column, WindowFunction
):
operation = condition.operation
if operation in [
">",
"<",
">=",
"<=",
"==",
"!=",
"eqNullSafe",
"isnull",
"isnotnull",
]:
window_func = condition.column
temp_col_name = temp_cols[id(window_func)]
temp_col = Column(temp_col_name)
operation_value = condition.value
# Create new operation expression
if operation == ">":
new_condition = temp_col > operation_value
elif operation == "<":
new_condition = temp_col < operation_value
elif operation == ">=":
new_condition = temp_col >= operation_value
elif operation == "<=":
new_condition = temp_col <= operation_value
elif operation == "==":
new_condition = temp_col == operation_value
elif operation == "!=":
new_condition = temp_col != operation_value
elif operation == "eqNullSafe":
new_condition = temp_col.eqNullSafe(operation_value)
elif operation == "isnull":
new_condition = temp_col.isnull()
elif operation == "isnotnull":
new_condition = temp_col.isnotnull()
else:
# This branch should never be reached due to the if check above
# But include it for completeness
# For other operations, create ColumnOperation
# operation_value can be None for unary operations
# ColumnOperation accepts Any for value parameter (including None)
op_str: str = operation # type: ignore[assignment]
new_condition = ColumnOperation(
temp_col, op_str, None, name=None
)
new_conditions.append((new_condition, value))
else:
new_conditions.append((condition, value))
else:
new_conditions.append((condition, value))
# Create new CaseWhen with replaced conditions
new_case_when = CaseWhen()
new_case_when.conditions = new_conditions
new_case_when.default_value = expression.default_value
# Translate the new CaseWhen (pass column_dtypes for nested between - Issue #445)
case_sensitive = self._get_case_sensitive()
column_dtypes = (
dict(current_df.schema)
if hasattr(current_df, "schema") and current_df.schema
else {}
)
expr = self.translator.translate(
new_case_when,
available_columns=list(current_df.columns),
case_sensitive=case_sensitive,
column_dtypes=column_dtypes,
)
# Apply the expression and drop temporary columns
result_df = current_df.with_columns(expr.alias(column_name))
if temp_cols:
result_df = result_df.drop(list(temp_cols.values()))
return result_df
# Check if expression is a Column with struct field path (e.g., "StructVal.E1")
# This needs special handling to use Polars struct.field() syntax
from sparkless.functions.core.column import Column
if isinstance(expression, Column) and "." in expression.name:
# Handle struct field access
struct_field_path = expression.name
parts = struct_field_path.split(".", 1)
struct_col = parts[0]
field_name = parts[1]
# Resolve struct column name case-insensitively
from ...core.column_resolver import ColumnResolver
case_sensitive = self._get_case_sensitive()
resolved_struct_col = ColumnResolver.resolve_column_name(
struct_col, list(df.columns), case_sensitive
)
if resolved_struct_col and resolved_struct_col in df.columns:
# Get struct dtype
struct_dtype = df[resolved_struct_col].dtype
if hasattr(struct_dtype, "fields") and struct_dtype.fields:
# Resolve field name case-insensitively within struct
field_names = [f.name for f in struct_dtype.fields]
resolved_field = ColumnResolver.resolve_column_name(
field_name, field_names, case_sensitive
)
if resolved_field:
# Use Polars struct.field() syntax for nested access
expr = pl.col(resolved_struct_col).struct.field(
resolved_field
)
return df.with_columns(expr.alias(column_name))
try:
# Pass case sensitivity for column resolution
# Build column_dtypes from df schema for between/isin type coercion (Issue #445)
case_sensitive = self._get_case_sensitive()
column_dtypes = (
dict(df.schema) if hasattr(df, "schema") and df.schema else {}
)
expr = self.translator.translate(
expression,
input_col_dtype=input_col_dtype,
available_columns=list(df.columns),
case_sensitive=case_sensitive,
column_dtypes=column_dtypes,
)
except ValueError as e:
# Fallback to Python evaluation for unsupported operations (e.g., withField, + with strings, WindowFunction)
error_msg = str(e)
# Check if this is a WindowFunction comparison that should be handled
is_window_function_comparison_fallback = (
isinstance(expression, ColumnOperation)
and expression.operation
in [">", "<", ">=", "<=", "==", "!=", "eqNullSafe"]
and isinstance(expression.column, WindowFunction)
) or (
"WindowFunction comparison" in error_msg
and isinstance(expression, ColumnOperation)
and expression.operation
in [">", "<", ">=", "<=", "==", "!=", "eqNullSafe"]
)
if is_window_function_comparison_fallback:
# Handle comparison operations with WindowFunction
# First, apply the window function to get a temporary column
window_func = expression.column
comparison_op = expression.operation
comparison_value = expression.value
# Apply window function to get a temporary column
temp_col_name = f"__window_func_temp_{column_name}"
df_with_window = self.apply_with_column(
df, temp_col_name, window_func, expected_field=None
)
# Now create a comparison expression: temp_col_name op value
from sparkless.functions.core.column import Column
temp_col = Column(temp_col_name)
if comparison_op == ">":
comparison_expr = temp_col > comparison_value
elif comparison_op == "<":
comparison_expr = temp_col < comparison_value
elif comparison_op == ">=":
comparison_expr = temp_col >= comparison_value
elif comparison_op == "<=":
comparison_expr = temp_col <= comparison_value
elif comparison_op == "==":
comparison_expr = temp_col == comparison_value
elif comparison_op == "!=":
comparison_expr = temp_col != comparison_value
elif comparison_op == "eqNullSafe":
comparison_expr = temp_col.eqNullSafe(comparison_value)
else:
comparison_expr = ColumnOperation(
temp_col, comparison_op, comparison_value
)
# Translate and apply the comparison
case_sensitive = self._get_case_sensitive()
comparison_pl_expr = self.translator.translate(
comparison_expr,
available_columns=list(df_with_window.columns),
case_sensitive=case_sensitive,
)
# Apply the comparison and drop the temporary column
result_df = (
df_with_window.with_columns(
pl.col(temp_col_name).alias(column_name)
)
.with_columns(comparison_pl_expr.alias(column_name))
.drop(temp_col_name)
)
return result_df
# Check if this is a WindowFunction cast that should be handled above
# Check both the expression structure and the error message
is_window_function_cast_fallback = (
isinstance(expression, ColumnOperation)
and expression.operation == "cast"
and isinstance(expression.column, WindowFunction)
) or (
"WindowFunction" in error_msg
and isinstance(expression, ColumnOperation)
and expression.operation == "cast"
)
if is_window_function_cast_fallback:
# This should have been caught above, but handle it here as fallback
# This is the same logic as above, but as a safety net
# Extract WindowFunction and cast type
window_func = expression.column
cast_type = expression.value
# Window functions need sorting - reuse the same logic from above
window_spec = window_func.window_spec
sort_cols = []
has_order_by = (
hasattr(window_spec, "_order_by") and window_spec._order_by
)
has_partition_by = (
hasattr(window_spec, "_partition_by")
and window_spec._partition_by
)
# Build sort columns
if has_order_by:
if has_partition_by:
for col in window_spec._partition_by:
if isinstance(col, str):
sort_cols.append(col)
elif hasattr(col, "name"):
sort_cols.append(col.name)
for col in window_spec._order_by:
if isinstance(col, str):
sort_cols.append(col)
elif hasattr(col, "name"):
sort_cols.append(col.name)
# Sort DataFrame if needed for window function evaluation
# Window functions need sorted data for correct evaluation
df_sorted = (
df.sort(sort_cols)
if sort_cols and all(isinstance(c, str) for c in sort_cols)
else df
)
# Evaluate window function on sorted data
# WindowFunction.evaluate() handles sorting internally, but we need to sort first for correctness
data = df_sorted.to_dicts()
results = window_func.evaluate(data)
# Apply cast
if cast_type is not None:
from sparkless.dataframe.casting.type_converter import (
TypeConverter,
)
from sparkless.spark_types import (
StringType,
IntegerType,
LongType,
DoubleType,
FloatType,
BooleanType,
DateType,
TimestampType,
ShortType,
ByteType,
)
# Handle string type names
if isinstance(cast_type, str):
type_name_map = {
"string": StringType(),
"str": StringType(),
"int": IntegerType(),
"integer": IntegerType(),
"long": LongType(),
"bigint": LongType(),
"double": DoubleType(),
"float": FloatType(),
"boolean": BooleanType(),
"bool": BooleanType(),
"date": DateType(),
"timestamp": TimestampType(),
"short": ShortType(),
"byte": ByteType(),
}
cast_type = type_name_map.get(cast_type.lower())
if cast_type is not None:
results = [
TypeConverter.cast_to_type(r, cast_type)
if r is not None
else None
for r in results
]
# Add results as new column to sorted DataFrame
# Results are in the same order as df_sorted
result = df_sorted.with_columns(pl.Series(column_name, results))
# If we sorted, we need to restore original order
# Create a row identifier to match back to original DataFrame
if df_sorted is not df and sort_cols:
# Use all columns as join keys to match rows
join_cols = list(df.columns)
# Add row number to both DataFrames for matching
original_with_row_num = df.with_row_count("__original_index__")
result_with_row_num = result.with_row_count("__sorted_index__")
# Join to restore original order
result_joined = (
original_with_row_num.join(
result_with_row_num.select(
join_cols + [column_name, "__sorted_index__"]
),
on=join_cols,
how="left",
)
.select(join_cols + [column_name])
.drop("__original_index__")
)
return result_joined
return result
elif (
"withField" in error_msg
or (
isinstance(expression, ColumnOperation)
and expression.operation == "withField"
)
or "+ operation requires Python evaluation" in error_msg
or "format_string operation requires Python evaluation" in error_msg
or "array function requires Python evaluation" in error_msg
or (
"getItem" in error_msg
and "requires Python evaluation" in error_msg
)
):
# Convert Polars DataFrame to list of dicts for Python evaluation
data = df.to_dicts()
# Evaluate using ExpressionEvaluator (Issue #398: pass full_data
# and row_index for WindowFunction inside withField)
from sparkless.dataframe.evaluation.expression_evaluator import (
ExpressionEvaluator,
)
# Get dataframe context from expression if available (for case sensitivity)
dataframe_context = None
if hasattr(expression, "_dataframe_context"):
dataframe_context = expression._dataframe_context
elif hasattr(expression, "column") and hasattr(
expression.column, "_dataframe_context"
):
dataframe_context = expression.column._dataframe_context
evaluator = ExpressionEvaluator(
dataframe_context=dataframe_context,
full_data=data,
)
results = []
for i, row in enumerate(data):
evaluator._current_row_index = i
results.append(
evaluator.evaluate_expression(row, expression, row_index=i)
)
# Add results as new column
# Polars will automatically infer struct type from dict values
# Create Series - use strict=False if available to handle mixed types
# This is needed for conditional expressions and other complex cases
# For array() function, results are lists which may contain mixed types
# Use Object dtype for array() to handle mixed types correctly
is_array_function = (
isinstance(expression, ColumnOperation)
and expression.operation == "array"
)
try:
if is_array_function:
# Array function may have mixed types - use Object dtype
result_series = pl.Series(
column_name, results, dtype=pl.Object
)
else:
# Try with strict=False first (Polars 0.19+)
result_series = pl.Series(
column_name, results, strict=False
)
except TypeError:
# Fallback: try without strict parameter or use Object dtype
try:
result_series = pl.Series(column_name, results)
except TypeError:
# Last resort: use Object dtype explicitly
result_series = pl.Series(
column_name, results, dtype=pl.Object
)
# Replace the column (if it exists) or add new column
result = df.with_columns(result_series)
return result
else:
# Re-raise if it's not a withField error
raise
# If expected_field is provided, use it to explicitly cast the result
# This fixes issue #151 where Polars was expecting String but got datetime
# for to_timestamp() operations
if expected_field is not None:
from sparkless.spark_types import TimestampType
from .type_mapper import mock_type_to_polars_dtype
# Check if the expected type is TimestampType
if isinstance(expected_field.dataType, TimestampType):
# Explicitly cast to pl.Datetime to ensure Polars recognizes the correct type
# This is critical for to_timestamp operations to avoid schema validation errors
polars_dtype = mock_type_to_polars_dtype(expected_field.dataType)
# Cast immediately to ensure type is correct before any operations
expr = expr.cast(polars_dtype)
# Apply with_columns - with schema inference fix, this should work correctly
# The expression translator already handles cast operations correctly
# For to_timestamp operations with TimestampType expected_field, evaluate eagerly
# and use hstack to add column without creating lazy frame
if expected_field is not None:
from sparkless.spark_types import TimestampType
if isinstance(expected_field.dataType, TimestampType):
# Evaluate the expression eagerly and add as Series to avoid lazy validation
# This avoids Polars' lazy frame schema validation that checks input types
try:
# For to_timestamp, use with_columns directly with explicit cast
# The cast ensures Polars recognizes the output type before validation
from .type_mapper import mock_type_to_polars_dtype
polars_dtype = mock_type_to_polars_dtype(
expected_field.dataType
)
# Cast the expression to the expected type before using with_columns
# Use strict=False to handle edge cases gracefully
cast_expr = expr.cast(polars_dtype, strict=False)
# Use with_columns - the cast should prevent validation errors
result = df.with_columns([cast_expr.alias(column_name)])
return result
except Exception:
logger.debug(
"Cast with_columns path failed, falling through",
exc_info=True,
)
# For to_date() operations on datetime columns, use .dt.date() directly
# This avoids schema validation issues that map_elements can cause
if (
isinstance(expression, ColumnOperation)
and expression.operation == "to_date"
):
from sparkless.functions.core.column import Column
# Check if the input column is a simple Column reference (not a ColumnOperation)
input_col = expression.column
if isinstance(input_col, Column) and not isinstance(
input_col, ColumnOperation
):
# Simple column reference - check if it's a datetime type in the DataFrame
col_name = input_col.name
if col_name in df.columns:
# Check the actual Polars dtype
col_dtype = df[col_name].dtype
is_datetime = (
isinstance(col_dtype, pl.Datetime)
or str(col_dtype).startswith("Datetime")
or (hasattr(pl, "Datetime") and col_dtype == pl.Datetime)
)
is_date = col_dtype == pl.Date
if is_datetime or is_date:
# For datetime/date columns, use .dt.date() directly
# This avoids schema validation issues
try:
if expression.value is None:
# No format - use .dt.date() for datetime/date columns
date_expr = pl.col(col_name).dt.date()
result = df.with_columns(
date_expr.alias(column_name)
)
return result
else:
# With format - still need to use map_elements for string parsing
# But try select to avoid validation
all_exprs = [pl.col(c) for c in df.columns] + [
expr.alias(column_name)
]
result = df.select(all_exprs)
return result
except Exception:
logger.debug(
"to_date .dt.date() path failed, falling back",
exc_info=True,
)
# For complex expressions or string columns, try using select to avoid validation
try:
# Use select to avoid schema validation issues
# This works for both StringType and TimestampType inputs
all_exprs = [pl.col(c) for c in df.columns] + [
expr.alias(column_name)
]
result = df.select(all_exprs)
return result
except Exception:
logger.debug(
"to_date select path failed, falling through to with_columns",
exc_info=True,
)
# Try to execute with Polars, but catch ColumnNotFoundError for Python evaluation fallback
try:
result = df.with_columns(expr.alias(column_name))
return result
except pl.exceptions.ColumnNotFoundError as e:
# Polars couldn't find a column - this might be a case sensitivity issue
# Fall back to Python evaluation
error_msg = str(e)
# Check if this looks like a case sensitivity issue
# (column name exists but with different case)
missing_col_match = None
for col in df.columns:
# Check if the missing column name matches any existing column case-insensitively
if col.lower() in error_msg.lower() or any(
missing.lower() == col.lower()
for missing in error_msg.split('"')
if missing and missing.strip()
):
missing_col_match = col
break
if missing_col_match:
# Column exists but with different case - use Python evaluation
data = df.to_dicts()
from sparkless.dataframe.evaluation.expression_evaluator import (
ExpressionEvaluator,
)
# Get dataframe context from expression if available (for case sensitivity)
dataframe_context = None
if hasattr(expression, "_dataframe_context"):
dataframe_context = expression._dataframe_context
elif hasattr(expression, "column") and hasattr(
expression.column, "_dataframe_context"
):
dataframe_context = expression.column._dataframe_context
evaluator = ExpressionEvaluator(dataframe_context=dataframe_context)
results = [
evaluator.evaluate_expression(row, expression) for row in data
]
try:
result_series = pl.Series(column_name, results, strict=False)
except TypeError:
try:
result_series = pl.Series(column_name, results)
except TypeError:
result_series = pl.Series(
column_name, results, dtype=pl.Object
)
result = df.with_columns(result_series)
return result
else:
# Re-raise if it's not a case sensitivity issue
raise
def _add_missing_join_key_columns(
self,
joined: pl.DataFrame,
df1: pl.DataFrame,
df2: pl.DataFrame,
resolved_left_on: List[str],
resolved_right_on: List[str],
polars_how: str,
) -> pl.DataFrame:
"""Add join key columns dropped by Polars to match PySpark semantics.
Polars drops right_on for left/inner join, drops left_on for right join.
PySpark keeps both. For unmatched rows, dropped keys should be null.
"""
right_non_key_cols = [c for c in df2.columns if c not in resolved_right_on]
left_non_key_cols = [c for c in df1.columns if c not in resolved_left_on]
# Match indicator: for left join use right-side cols; for right join use left-side
if polars_how == "left":
indicator_cols = [c for c in right_non_key_cols if c in joined.columns]
else:
indicator_cols = [c for c in left_non_key_cols if c in joined.columns]
has_match_indicator = bool(indicator_cols)
if has_match_indicator:
is_matched = pl.col(indicator_cols[0]).is_not_null()
for c in indicator_cols[1:]:
is_matched = is_matched | pl.col(c).is_not_null()
# Left/inner join: Polars drops right_on
if polars_how in ("left", "inner"):
for right_col in resolved_right_on:
if right_col not in joined.columns:
left_col = resolved_left_on[resolved_right_on.index(right_col)]
if has_match_indicator and polars_how == "left":
left_dtype = joined[left_col].dtype
joined = joined.with_columns(
pl.when(is_matched)
.then(pl.col(left_col))
.otherwise(pl.lit(None).cast(left_dtype))
.alias(right_col)
)
else:
# Inner: all matched, or no indicator
joined = joined.with_columns(pl.col(left_col).alias(right_col))
# Right join: Polars drops left_on
elif polars_how == "right":
for left_col in resolved_left_on:
if left_col not in joined.columns:
right_col = resolved_right_on[resolved_left_on.index(left_col)]
if has_match_indicator:
right_dtype = joined[right_col].dtype
joined = joined.with_columns(
pl.when(is_matched)
.then(pl.col(right_col))
.otherwise(pl.lit(None).cast(right_dtype))
.alias(left_col)
)
else:
joined = joined.with_columns(pl.col(right_col).alias(left_col))
return joined
def _coerce_join_key_types(
self,
df1: pl.DataFrame,
df2: pl.DataFrame,
join_keys: Optional[List[str]] = None,
left_on: Optional[List[str]] = None,
right_on: Optional[List[str]] = None,
) -> Tuple[
pl.DataFrame,
pl.DataFrame,
Optional[List[str]],
Optional[List[str]],
Optional[List[str]],
]:
"""Coerce join key types to match if needed (numeric vs string).
PySpark allows joining on columns with different types (e.g., i64 vs str),
automatically casting string keys to numeric. This method replicates that behavior.
Args:
df1: Left DataFrame
df2: Right DataFrame
join_keys: List of column names for on= joins (both DataFrames use same column names)
left_on: List of column names from df1 for left_on/right_on joins
right_on: List of column names from df2 for left_on/right_on joins
Returns:
Tuple of (casted_df1, casted_df2, updated_join_keys, updated_left_on, updated_right_on)
"""
# Define numeric types
numeric_types = (
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Float32,
pl.Float64,
)
string_type = pl.Utf8
cast_exprs_df1 = []
cast_exprs_df2 = []
result_df1 = df1
result_df2 = df2
# Handle on= joins (same column names in both DataFrames)
if join_keys is not None:
for key in join_keys:
if key not in df1.columns or key not in df2.columns:
continue # Skip if column doesn't exist (will error later)
dtype1 = df1[key].dtype
dtype2 = df2[key].dtype
# If types match, no coercion needed
if dtype1 == dtype2:
continue
# Check if one is numeric and one is string
is_numeric1 = dtype1 in numeric_types
is_numeric2 = dtype2 in numeric_types
is_string1 = dtype1 == string_type
is_string2 = dtype2 == string_type
if (is_numeric1 and is_string2) or (is_string1 and is_numeric2):
# One is numeric, one is string - cast string to numeric
if is_numeric1 and is_string2:
# df2[key] is string, cast to df1[key]'s numeric type
cast_exprs_df2.append(pl.col(key).cast(dtype1, strict=False))
else:
# df1[key] is string, cast to df2[key]'s numeric type
cast_exprs_df1.append(pl.col(key).cast(dtype2, strict=False))
elif is_numeric1 and is_numeric2:
# Both numeric but different types - prefer larger type
# Int8 < Int16 < Int32 < Int64, Float32 < Float64
# Prefer Int64 over Int32, Float64 over Float32
# Prefer Float over Int
if isinstance(dtype1, (pl.Float32, pl.Float64)) or isinstance(
dtype2, (pl.Float32, pl.Float64)
):
# At least one is float, use Float64
target_dtype = pl.Float64
else:
# Both integers, use Int64
target_dtype = pl.Int64
if dtype1 != target_dtype:
cast_exprs_df1.append(
pl.col(key).cast(target_dtype, strict=False)
)
if dtype2 != target_dtype:
cast_exprs_df2.append(
pl.col(key).cast(target_dtype, strict=False)
)
else:
# Types can't be coerced (e.g., boolean vs string, date vs string)
raise ValueError(
f"Cannot join on column '{key}' with incompatible types: "
f"left={dtype1}, right={dtype2}. "
f"PySpark only supports joining numeric types with string types."
)
# Handle left_on/right_on joins (different column names)
elif left_on is not None and right_on is not None:
if len(left_on) != len(right_on):
raise ValueError(
f"left_on and right_on must have the same length: "
f"left_on={left_on}, right_on={right_on}"
)
for i, (left_key, right_key) in enumerate(zip(left_on, right_on)):
if left_key not in df1.columns or right_key not in df2.columns:
continue # Skip if column doesn't exist (will error later)
dtype1 = df1[left_key].dtype
dtype2 = df2[right_key].dtype
# If types match, no coercion needed
if dtype1 == dtype2:
continue
# Check if one is numeric and one is string
is_numeric1 = dtype1 in numeric_types
is_numeric2 = dtype2 in numeric_types
is_string1 = dtype1 == string_type
is_string2 = dtype2 == string_type
if (is_numeric1 and is_string2) or (is_string1 and is_numeric2):
# One is numeric, one is string - cast string to numeric
if is_numeric1 and is_string2:
# df2[right_key] is string, cast to df1[left_key]'s numeric type
cast_exprs_df2.append(
pl.col(right_key).cast(dtype1, strict=False)
)
else:
# df1[left_key] is string, cast to df2[right_key]'s numeric type
cast_exprs_df1.append(
pl.col(left_key).cast(dtype2, strict=False)
)
elif is_numeric1 and is_numeric2:
# Both numeric but different types - prefer larger type
if isinstance(dtype1, (pl.Float32, pl.Float64)) or isinstance(
dtype2, (pl.Float32, pl.Float64)
):
target_dtype = pl.Float64
else:
target_dtype = pl.Int64
if dtype1 != target_dtype:
cast_exprs_df1.append(
pl.col(left_key).cast(target_dtype, strict=False)
)
if dtype2 != target_dtype:
cast_exprs_df2.append(
pl.col(right_key).cast(target_dtype, strict=False)
)
else:
# Types can't be coerced
raise ValueError(
f"Cannot join on columns '{left_key}' (left) and '{right_key}' (right) "
f"with incompatible types: left={dtype1}, right={dtype2}. "
f"PySpark only supports joining numeric types with string types."
)
# Apply casts if any
if cast_exprs_df1:
result_df1 = df1.with_columns(cast_exprs_df1)
if cast_exprs_df2:
result_df2 = df2.with_columns(cast_exprs_df2)
return result_df1, result_df2, join_keys, left_on, right_on
[docs]
@profiled("polars.apply_join", category="polars")
def apply_join(
self,
df1: pl.DataFrame,
df2: pl.DataFrame,
on: Optional[Union[str, List[str], ColumnOperation]] = None,
how: str = "inner",
right_alias: Optional[str] = None,
) -> pl.DataFrame:
"""Apply a join operation.
Args:
df1: Left DataFrame
df2: Right DataFrame
on: Join key(s) - column name(s), list of column names, or ColumnOperation with ==
how: Join type ("inner", "left", "right", "outer", "cross", "semi", "anti")
Returns:
Joined DataFrame
"""
# Extract column names from join condition if it's a ColumnOperation
join_keys: Optional[List[str]] = None
left_on: Optional[List[str]] = None
right_on: Optional[List[str]] = None
expression_condition: Optional[ColumnOperation] = None
if isinstance(on, ColumnOperation):
operation = getattr(on, "operation", None)
if operation in ("==", "eqNullSafe"):
# Equality-based join - extract column names
if not hasattr(on, "column") or not hasattr(on, "value"):
raise ValueError(
"Join condition must have column and value attributes"
)
left_col = (
on.column.name if hasattr(on.column, "name") else str(on.column)
)
right_col = (
on.value.name if hasattr(on.value, "name") else str(on.value)
)
left_col_str = str(left_col)
right_col_str = str(right_col)
case_sensitive_early = self._get_case_sensitive()
col_in_df1_left = self._find_column(
df1, left_col_str, case_sensitive_early
)
col_in_df2_left = self._find_column(
df2, left_col_str, case_sensitive_early
)
col_in_df1_right = self._find_column(
df1, right_col_str, case_sensitive_early
)
col_in_df2_right = self._find_column(
df2, right_col_str, case_sensitive_early
)
# Same column name - check if it exists in both DataFrames
if left_col_str == right_col_str:
if col_in_df1_left and col_in_df2_left:
join_keys = [col_in_df1_left]
else:
left_on = [col_in_df1_left or left_col_str]
right_on = [col_in_df2_left or right_col_str]
else:
# Different column names - assign each to the DataFrame that has it (Issue #421)
# F.col("Key") == F.col("Name"): Key may be in df2, Name in df1
if col_in_df1_left and col_in_df2_right:
left_on = [col_in_df1_left]
right_on = [col_in_df2_right]
elif col_in_df2_left and col_in_df1_right:
left_on = [col_in_df1_right]
right_on = [col_in_df2_left]
else:
# Fallback: assume left side of == is from df1, right from df2
left_on = [left_col_str]
right_on = [right_col_str]
else:
# Expression-based join (e.g., array_contains, other expressions)
# Store the expression for evaluation after cross join
expression_condition = on
elif on is None:
common_cols = set(df1.columns) & set(df2.columns)
if not common_cols:
raise ValueError("No common columns found for join")
join_keys = list(common_cols)
elif isinstance(on, str):
join_keys = [on]
elif isinstance(on, list):
join_keys = list(on)
else:
raise ValueError("Join keys must be column name(s) or a ColumnOperation")
# Map join types (Polars 0.20.29+ uses "full" instead of deprecated "outer")
join_type_map = {
"inner": "inner",
"left": "left",
"right": "right",
"outer": "full",
"full": "full",
"full_outer": "full",
"cross": "cross",
}
polars_how = join_type_map.get(how.lower(), "inner")
# Handle expression-based joins (e.g., array_contains)
if expression_condition is not None:
return self._apply_expression_join(
df1, df2, expression_condition, polars_how, how.lower(), right_alias
)
# Resolve join_keys using ColumnResolver if they are strings
case_sensitive = self._get_case_sensitive()
resolved_join_keys = None
left_on_keys = []
right_on_keys = []
use_left_right_on = False
if join_keys is not None:
resolved_join_keys = []
for col in join_keys:
if isinstance(col, str):
# Resolve column name using ColumnResolver in both DataFrames
actual_col_df1 = self._find_column(df1, col, case_sensitive)
actual_col_df2 = self._find_column(df2, col, case_sensitive)
if actual_col_df1 is None or actual_col_df2 is None:
# Will be handled in the else branch below
resolved_join_keys.append(col)
elif actual_col_df1 != actual_col_df2:
# Column names differ - need to use left_on/right_on
use_left_right_on = True
left_on_keys.append(actual_col_df1)
right_on_keys.append(actual_col_df2)
# Don't add to resolved_join_keys when using left_on/right_on
else:
# Same column name, can use on=
resolved_join_keys.append(actual_col_df1)
# Note: join_keys is List[str], so all items are strings
# Non-string items would be handled above
# Coerce join key types if needed (e.g., numeric vs string)
# PySpark allows joining on columns with different types, casting string to numeric
# Coerce types after resolving column names (case-insensitively)
if use_left_right_on and left_on_keys and right_on_keys:
# Use left_on/right_on for coercion
df1, df2, _, coerced_left_on, coerced_right_on = (
self._coerce_join_key_types(df1, df2, None, left_on_keys, right_on_keys)
)
left_on_keys = (
coerced_left_on if coerced_left_on is not None else left_on_keys
)
right_on_keys = (
coerced_right_on if coerced_right_on is not None else right_on_keys
)
elif resolved_join_keys:
# Use resolved_join_keys for coercion
df1, df2, coerced_join_keys, _, _ = self._coerce_join_key_types(
df1, df2, resolved_join_keys, None, None
)
resolved_join_keys = (
coerced_join_keys
if coerced_join_keys is not None
else resolved_join_keys
)
elif join_keys:
# Fallback to original join_keys
df1, df2, coerced_join_keys, _, _ = self._coerce_join_key_types(
df1, df2, join_keys, None, None
)
if coerced_join_keys is not None:
resolved_join_keys = coerced_join_keys
# Handle semi and anti joins (Polars doesn't support natively)
if how.lower() in ("semi", "left_semi", "leftsemi"):
# Semi join: return rows from left where match exists in right
# Do inner join, then select only left columns and distinct
if use_left_right_on and left_on_keys and right_on_keys:
joined = df1.join(
df2, left_on=left_on_keys, right_on=right_on_keys, how="inner"
)
elif resolved_join_keys:
joined = df1.join(df2, on=resolved_join_keys, how="inner")
else:
joined = df1.join(df2, on=join_keys, how="inner")
# Select only columns from df1 (preserve original column order)
left_cols = [col for col in df1.columns if col in joined.columns]
return joined.select(left_cols).unique()
elif how.lower() in ("anti", "left_anti", "leftanti"):
# Anti join: return rows from left where no match exists in right
# Do left join, then filter where right columns are null
if use_left_right_on and left_on_keys and right_on_keys:
joined = df1.join(
df2, left_on=left_on_keys, right_on=right_on_keys, how="left"
)
elif resolved_join_keys:
joined = df1.join(df2, on=resolved_join_keys, how="left")
else:
joined = df1.join(df2, on=join_keys, how="left")
# Find right-side columns (columns in df2 but not in df1)
right_cols = [col for col in df2.columns if col not in df1.columns]
if right_cols:
# Filter where any right column is null
filter_expr = pl.col(right_cols[0]).is_null()
for col in right_cols[1:]:
filter_expr = filter_expr | pl.col(col).is_null()
joined = joined.filter(filter_expr)
else:
# If no right columns (all match left), check if join key exists
# This case shouldn't happen, but handle it
keys_to_use = resolved_join_keys if resolved_join_keys else join_keys
if keys_to_use is not None and len(keys_to_use) > 0:
joined = joined.filter(pl.col(keys_to_use[0]).is_null())
# Select only columns from df1
left_cols = [col for col in joined.columns if col in df1.columns]
return joined.select(left_cols)
elif polars_how == "cross":
return df1.join(df2, how="cross")
else:
# Handle different column names with left_on/right_on
if left_on is not None and right_on is not None:
# Validate right_on is not empty
if not right_on:
raise ValueError(
f"Join right_on is empty. left_on={left_on}, right_on={right_on}, "
f"df1.columns={df1.columns}, df2.columns={df2.columns}"
)
# Verify columns exist using ColumnResolver
resolved_left_on = []
for col in left_on:
actual_col = self._find_column(df1, col, case_sensitive)
if actual_col is None:
raise ValueError(
f"Join column '{col}' not found in left DataFrame. Available columns: {df1.columns}"
)
resolved_left_on.append(actual_col)
resolved_right_on = []
for col in right_on:
actual_col = self._find_column(df2, col, case_sensitive)
if actual_col is None:
raise ValueError(
f"Join column '{col}' not found in right DataFrame. Available columns: {df2.columns}"
)
resolved_right_on.append(actual_col)
# Coerce join key types if needed (for left_on/right_on case)
df1, df2, _, coerced_left_on, coerced_right_on = (
self._coerce_join_key_types(
df1, df2, None, resolved_left_on, resolved_right_on
)
)
if coerced_left_on is not None:
resolved_left_on = coerced_left_on
if coerced_right_on is not None:
resolved_right_on = coerced_right_on
# Polars join with left_on/right_on doesn't include right_on column
# But PySpark includes both columns, so we need to add it back
joined = df1.join(
df2,
left_on=resolved_left_on,
right_on=resolved_right_on,
how=polars_how,
)
joined = self._add_missing_join_key_columns(
joined, df1, df2, resolved_left_on, resolved_right_on, polars_how
)
return joined
else:
# Verify columns exist in both DataFrames using ColumnResolver
# Check if we already resolved with left_on/right_on in the earlier block
if use_left_right_on and left_on_keys and right_on_keys:
# Use left_on/right_on when column names differ (from earlier resolution)
joined = df1.join(
df2,
left_on=left_on_keys,
right_on=right_on_keys,
how=polars_how,
)
joined = self._add_missing_join_key_columns(
joined, df1, df2, left_on_keys, right_on_keys, polars_how
)
return joined
elif resolved_join_keys and len(resolved_join_keys) > 0:
# Use resolved_join_keys from earlier resolution
return df1.join(df2, on=resolved_join_keys, how=polars_how)
elif resolved_join_keys is None:
# Need to resolve join_keys
# If resolved_join_keys is None, join_keys must have been None
# (since we set resolved_join_keys = [] when join_keys is not None)
raise ValueError("Join keys must be specified")
else:
# resolved_join_keys is empty list, fallback to original join_keys
return df1.join(df2, on=join_keys, how=polars_how)
def _apply_expression_join(
self,
df1: pl.DataFrame,
df2: pl.DataFrame,
condition: ColumnOperation,
polars_how: str,
how: str,
right_alias: Optional[str] = None,
) -> pl.DataFrame:
"""Apply join with expression-based condition (e.g., array_contains).
For expression-based joins, we:
1. Do a cross join to get all row combinations
2. Evaluate the expression condition for each row pair
3. Filter based on the expression result
4. Handle join types appropriately
Args:
df1: Left DataFrame
df2: Right DataFrame
condition: ColumnOperation expression to evaluate (e.g., array_contains)
polars_how: Polars join type string
how: Original join type string
Returns:
Joined DataFrame
"""
# Compound condition (A == B) & (C > D): join on equality then filter (#380)
if isinstance(condition, ColumnOperation) and condition.operation == "&":
eq_part = None
filter_part = None
for attr in ("column", "value"):
part = getattr(condition, attr, None)
if (
isinstance(part, ColumnOperation)
and getattr(part, "operation", None) == "=="
):
eq_part = part
filter_part = getattr(
condition, "value" if attr == "column" else "column"
)
break
if eq_part is not None and filter_part is not None:
# Resolve equality columns: left_col from df1, right_col from df2 (prefixed)
left_col = getattr(eq_part.column, "name", None) or str(eq_part.column)
right_col = getattr(eq_part.value, "name", None) or str(eq_part.value)
if (
isinstance(left_col, str)
and isinstance(right_col, str)
and "." in left_col
and "." in right_col
):
left_base = left_col.split(".", 1)[1]
right_base = right_col.split(".", 1)[1]
df2_prefixed = df2.select(
[pl.col(c).alias(f"_right_{c}") for c in df2.columns]
)
right_join_col = (
f"_right_{right_base}"
if right_base in df2.columns
else right_base
)
if (
left_base in df1.columns
and right_join_col in df2_prefixed.columns
):
joined = df1.join(
df2_prefixed,
left_on=[left_base],
right_on=[right_join_col],
how=polars_how,
)
# Apply filter on joined result
case_sensitive = self._get_case_sensitive()
filter_expr = self.translator.translate(
filter_part,
available_columns=joined.columns,
case_sensitive=case_sensitive,
)
return joined.filter(filter_expr)
# Step 1: Do cross join to get all row combinations
# Check for column name conflicts and prefix df2 columns if needed
df1_cols = set(df1.columns)
df2_cols = set(df2.columns)
has_conflicts = bool(df1_cols & df2_cols)
if has_conflicts:
# Prefix right DataFrame columns to avoid conflicts
df2_prefixed = df2.select(
[pl.col(col).alias(f"_right_{col}") for col in df2.columns]
)
cross_joined = df1.join(df2_prefixed, how="cross")
available_columns = list(df1.columns) + list(df2_prefixed.columns)
else:
# No conflicts - can use original column names
cross_joined = df1.join(df2, how="cross")
available_columns = list(df1.columns) + list(df2.columns)
# Step 2: Translate and evaluate the condition expression
case_sensitive = self._get_case_sensitive()
# If we have conflicts but no right_alias, infer from condition.
# Prefer the alias on the *value* side of equality (e.g. a.id == b.id -> b is right).
if has_conflicts and not right_alias:
def _infer_right_alias(expr: Any) -> Optional[str]:
from sparkless.functions.core.column import Column
if isinstance(expr, Column) and "." in getattr(expr, "name", ""):
alias_part, col_part = expr.name.split(".", 1)
if col_part in df2.columns:
return alias_part
elif isinstance(expr, ColumnOperation):
op = getattr(expr, "operation", None)
if op == "==":
# In left == right, right is typically .value; infer from value only
return _infer_right_alias(expr.value)
a = _infer_right_alias(expr.column)
if a is not None:
return a
return _infer_right_alias(expr.value)
return None
right_alias = _infer_right_alias(condition)
# If there are column name conflicts, we need to map df2 column references
# to prefixed names. Recursively rewrite Column("right_alias.col") -> Column("_right_col")
def _rewrite_right_cols(expr: Any) -> Any:
"""Replace Column('alias.col') with Column('_right_col') when alias is right."""
from sparkless.functions.core.column import Column
if isinstance(expr, Column) and "." in getattr(expr, "name", ""):
name = expr.name
if right_alias and name.startswith(f"{right_alias}."):
col = name.split(".", 1)[1]
if col in df2.columns:
return Column(f"_right_{col}")
elif isinstance(expr, ColumnOperation):
col_rewritten = _rewrite_right_cols(expr.column)
val_rewritten = _rewrite_right_cols(expr.value)
if col_rewritten is not expr.column or val_rewritten is not expr.value:
return ColumnOperation(
col_rewritten, expr.operation or "", val_rewritten
)
return expr
condition_to_translate = (
_rewrite_right_cols(condition)
if (has_conflicts and right_alias)
else condition
)
if has_conflicts and isinstance(condition_to_translate, ColumnOperation):
# For array_contains(df1.IDs, df2.ID), we need to replace df2.ID with _right_ID
from sparkless.functions.core.column import Column
if condition_to_translate.operation == "array_contains":
# array_contains(column, value)
# column is from df1 (keep as is), value might be from df2 (needs prefix)
original_column = condition_to_translate.column
original_value = condition_to_translate.value
# Check if value is a Column from df2 that needs prefixing
modified_value = original_value
if isinstance(original_value, Column):
col_name = original_value.name
if col_name in df2.columns:
# This is a df2 column - use prefixed name
prefixed_name = f"_right_{col_name}"
modified_value = Column(prefixed_name)
elif (
isinstance(original_value, ColumnOperation)
and hasattr(original_value, "column")
and isinstance(original_value.column, Column)
):
# Nested ColumnOperation - check if column needs prefixing
# operation is always a string for ColumnOperation instances
op = getattr(original_value, "operation", None)
if op is not None:
col_name = original_value.column.name
if col_name in df2.columns:
prefixed_name = f"_right_{col_name}"
# Create new ColumnOperation with prefixed column
modified_value = ColumnOperation(
Column(prefixed_name),
op,
original_value.value,
)
if modified_value != original_value:
# Create new condition with modified value
op = getattr(condition_to_translate, "operation", None)
if op is not None:
condition_to_translate = ColumnOperation(
original_column,
op,
modified_value,
)
try:
# Try to translate the condition
filter_expr = self.translator.translate(
condition_to_translate,
available_columns=available_columns,
case_sensitive=case_sensitive,
)
# Filter based on expression result
filtered = cross_joined.filter(filter_expr)
# Handle join types
if how in ("left", "outer", "full", "full_outer"):
# For left/outer joins, include unmatched rows from left DataFrame
if filtered.height == 0:
# No matches - return all left rows with null right columns
result_df = df1
for col in df2.columns:
prefixed_col = f"_right_{col}" if has_conflicts else col
result_df = result_df.with_columns(
pl.lit(None).alias(prefixed_col)
)
# Remove prefix if needed
if has_conflicts:
rename_dict = {f"_right_{col}": col for col in df2.columns}
result_df = result_df.rename(rename_dict)
return result_df
elif len(df1.columns) > 0:
# Get matched left rows using all columns as composite key
# Create a hash or use all columns to identify matched rows
matched_left_cols = df1.columns
matched_left = filtered.select(matched_left_cols).unique()
# Get all left rows
all_left = df1.select(matched_left_cols)
# Find unmatched left rows using anti join on all columns
unmatched_left = all_left.join(
matched_left, on=matched_left_cols, how="anti"
)
if unmatched_left.height > 0:
# Add unmatched left rows with null right columns
unmatched_with_nulls = unmatched_left
for col in df2.columns:
prefixed_col = f"_right_{col}" if has_conflicts else col
unmatched_with_nulls = unmatched_with_nulls.with_columns(
pl.lit(None).alias(prefixed_col)
)
# Combine matched and unmatched rows
filtered = pl.concat([filtered, unmatched_with_nulls])
if how in ("outer", "full", "full_outer"):
# For outer joins, also include unmatched right rows
# This is more complex and may not be needed for basic cases
pass
except (ValueError, AttributeError, TypeError):
# If direct translation fails, use Python fallback
from sparkless.dataframe.evaluation.expression_evaluator import (
ExpressionEvaluator,
)
evaluator = ExpressionEvaluator()
rows = cross_joined.to_dicts()
# Evaluate condition for each row
filtered_rows = []
for row in rows:
# Create evaluation row - map prefixed columns back if needed
eval_row = row.copy()
if has_conflicts:
# Map prefixed columns back to original names for evaluation
for col in df2.columns:
prefixed_col = f"_right_{col}"
if prefixed_col in eval_row:
eval_row[col] = eval_row[prefixed_col]
try:
result = evaluator.evaluate_expression(eval_row, condition)
if result is True:
filtered_rows.append(row)
except Exception:
logger.debug(
"Row evaluation failed for join condition, skipping row",
exc_info=True,
)
if not filtered_rows:
# No matches
if how in ("left", "outer", "full", "full_outer"):
# For left/outer joins, return all left rows with null right columns
result_df = df1
for col in df2.columns:
result_df = result_df.with_columns(pl.lit(None).alias(col))
return result_df
else:
# For inner/right joins with no matches, return empty with correct schema
empty_schema = {col: df1[col].dtype for col in df1.columns}
for col in df2.columns:
empty_schema[col] = df2[col].dtype
return pl.DataFrame(schema=empty_schema)
# Convert back to Polars DataFrame
filtered = pl.DataFrame(filtered_rows)
# Step 3: Handle join types
# For now, we return the filtered result (inner join behavior)
# Full left/outer join support would require tracking unmatched rows
# This is a simplified implementation that works for inner joins
# Step 4: Remove prefix from column names if we prefixed df2 columns
# Only rename if the original column name doesn't already exist (from df1)
if has_conflicts:
rename_dict = {}
for col in df2.columns:
prefixed_col = f"_right_{col}"
# Only rename if the original name doesn't exist (to avoid duplicates)
# If it exists, keep the prefixed name (PySpark keeps both with same name)
if prefixed_col in filtered.columns and col not in filtered.columns:
rename_dict[prefixed_col] = col
if rename_dict:
filtered = filtered.rename(rename_dict)
return filtered
def _coerce_union_types(
self, df1: pl.DataFrame, df2: pl.DataFrame
) -> Tuple[pl.DataFrame, pl.DataFrame]:
"""Coerce union column types to match if needed (numeric vs string).
PySpark allows unioning DataFrames with different types (e.g., i64 vs str),
automatically normalizing to string when mixing numeric and string types.
Args:
df1: First DataFrame
df2: Second DataFrame
Returns:
Tuple of (coerced_df1, coerced_df2)
"""
# Define numeric types
numeric_types = (
pl.Int8,
pl.Int16,
pl.Int32,
pl.Int64,
pl.UInt8,
pl.UInt16,
pl.UInt32,
pl.UInt64,
pl.Float32,
pl.Float64,
)
string_type = pl.Utf8
cast_exprs_df1 = []
cast_exprs_df2 = []
result_df1 = df1
result_df2 = df2
# Ensure both DataFrames have the same columns
df1_cols = set(df1.columns)
df2_cols = set(df2.columns)
all_cols = sorted(df1_cols | df2_cols)
# Process each column that exists in both DataFrames
for col in all_cols:
if col not in df1.columns or col not in df2.columns:
continue
dtype1 = df1[col].dtype
dtype2 = df2[col].dtype
# If types match, no coercion needed
if dtype1 == dtype2:
continue
# Check if one is numeric and one is string
is_numeric1 = dtype1 in numeric_types
is_numeric2 = dtype2 in numeric_types
is_string1 = dtype1 == string_type
is_string2 = dtype2 == string_type
if (is_numeric1 and is_string2) or (is_string1 and is_numeric2):
# One is numeric, one is string - normalize to string (PySpark behavior)
# Issue #242: LongType + StringType -> StringType
if is_numeric1 and is_string2:
# df1[col] is numeric, cast to string
cast_exprs_df1.append(pl.col(col).cast(string_type, strict=False))
else:
# df2[col] is numeric, cast to string
cast_exprs_df2.append(pl.col(col).cast(string_type, strict=False))
elif is_numeric1 and is_numeric2:
# Both numeric but different types - promote to wider type
if isinstance(dtype1, (pl.Float32, pl.Float64)) or isinstance(
dtype2, (pl.Float32, pl.Float64)
):
target_dtype = pl.Float64
else:
target_dtype = pl.Int64
if dtype1 != target_dtype:
cast_exprs_df1.append(pl.col(col).cast(target_dtype, strict=False))
if dtype2 != target_dtype:
cast_exprs_df2.append(pl.col(col).cast(target_dtype, strict=False))
# For other type combinations, don't coerce (will be handled by validation)
# Apply casts if any
if cast_exprs_df1:
result_df1 = df1.with_columns(cast_exprs_df1)
if cast_exprs_df2:
result_df2 = df2.with_columns(cast_exprs_df2)
return result_df1, result_df2
[docs]
def apply_union(self, df1: pl.DataFrame, df2: pl.DataFrame) -> pl.DataFrame:
"""Apply a union operation.
Args:
df1: First DataFrame
df2: Second DataFrame
Returns:
Unioned DataFrame
"""
# Coerce types before union (PySpark behavior: normalize numeric+string to string)
df1, df2 = self._coerce_union_types(df1, df2)
# Ensure schemas match
df1_cols = set(df1.columns)
df2_cols = set(df2.columns)
# Add missing columns with correct types
for col in df1_cols - df2_cols:
# Use the type from df1's column
col_type = df1[col].dtype
df2 = df2.with_columns(pl.lit(None, dtype=col_type).alias(col))
for col in df2_cols - df1_cols:
# Use the type from df2's column
col_type = df2[col].dtype
df1 = df1.with_columns(pl.lit(None, dtype=col_type).alias(col))
# Ensure column order matches
column_order = df1.columns
df2 = df2.select(column_order)
return pl.concat([df1, df2])
[docs]
def apply_order_by(
self, df: pl.DataFrame, columns: List[Any], ascending: bool = True
) -> pl.DataFrame:
"""Apply an orderBy operation.
Args:
df: Source Polars DataFrame
columns: Columns to sort by
ascending: Sort direction
Returns:
Sorted DataFrame
"""
sort_by = []
descending_flags = []
nulls_last_flags = []
for col in columns:
is_desc = False
nulls_last = None # None means default behavior
col_name = None
if isinstance(col, str):
col_name = col
is_desc = not ascending
nulls_last = True # PySpark default: nulls last
elif hasattr(col, "operation"):
operation = col.operation
col_name = col.column.name if hasattr(col, "column") else col.name
# Handle nulls variant operations
if operation == "desc_nulls_last":
is_desc = True
nulls_last = True
elif operation == "desc_nulls_first":
is_desc = True
nulls_last = False
elif operation == "asc_nulls_last":
is_desc = False
nulls_last = True
elif operation == "asc_nulls_first":
is_desc = False
nulls_last = False
elif operation == "desc":
is_desc = True
nulls_last = True # PySpark default: nulls last for desc()
elif operation == "asc":
is_desc = False
nulls_last = True # PySpark default: nulls last for asc()
else:
# Fallback for other operations
is_desc = not ascending
nulls_last = True # PySpark default: nulls last
else:
col_name = col.name if hasattr(col, "name") else str(col)
is_desc = not ascending
nulls_last = True # PySpark default: nulls last
if col_name:
sort_by.append(col_name)
descending_flags.append(is_desc)
nulls_last_flags.append(nulls_last)
if not sort_by:
return df
# Use sort() with by, descending, and nulls_last parameters
has_nulls_specification = any(n is not None for n in nulls_last_flags)
if has_nulls_specification:
return df.sort(
sort_by, descending=descending_flags, nulls_last=nulls_last_flags
)
else:
# No nulls specification, use default
return df.sort(sort_by, descending=descending_flags)
[docs]
def apply_limit(self, df: pl.DataFrame, n: int) -> pl.DataFrame:
"""Apply a limit operation.
Args:
df: Source Polars DataFrame
n: Number of rows to return
Returns:
Limited DataFrame
"""
return df.head(n)
[docs]
def apply_offset(self, df: pl.DataFrame, n: int) -> pl.DataFrame:
"""Apply an offset operation (skip first n rows).
Args:
df: Source Polars DataFrame
n: Number of rows to skip
Returns:
DataFrame with first n rows skipped
"""
return df.slice(n)
[docs]
@profiled("polars.apply_group_by_agg", category="polars")
def apply_group_by_agg(
self, df: pl.DataFrame, group_by: List[Any], aggs: List[Any]
) -> pl.DataFrame:
"""Apply a groupBy().agg() operation.
Args:
df: Source Polars DataFrame
group_by: Columns to group by
aggs: Aggregation expressions
Returns:
Aggregated DataFrame
"""
# Translate group by columns
group_by_cols = []
for col in group_by:
if isinstance(col, str):
group_by_cols.append(col)
elif hasattr(col, "name"):
group_by_cols.append(col.name)
else:
raise ValueError(f"Cannot determine column name for group by: {col}")
# Translate aggregation expressions
agg_exprs = []
for agg in aggs:
expr = self.translator.translate(agg)
# Get alias if available
alias_name = getattr(agg, "name", None) or getattr(agg, "_alias_name", None)
if alias_name:
expr = expr.alias(alias_name)
agg_exprs.append(expr)
if not group_by_cols:
# Global aggregation
return df.select(agg_exprs)
else:
return df.group_by(group_by_cols).agg(agg_exprs)
[docs]
def apply_distinct(self, df: pl.DataFrame) -> pl.DataFrame:
"""Apply a distinct operation.
Args:
df: Source Polars DataFrame
Returns:
DataFrame with distinct rows
"""
return df.unique()
[docs]
def apply_drop(self, df: pl.DataFrame, columns: List[str]) -> pl.DataFrame:
"""Apply a drop operation.
Args:
df: Source Polars DataFrame
columns: Columns to drop
Returns:
DataFrame with columns dropped
"""
return df.drop(columns)
[docs]
def apply_with_column_renamed(
self, df: pl.DataFrame, old_name: str, new_name: str
) -> pl.DataFrame:
"""Apply a withColumnRenamed operation.
Args:
df: Source Polars DataFrame
old_name: Old column name
new_name: New column name
Returns:
DataFrame with renamed column
"""
return df.rename({old_name: new_name})