Source code for sparkless.functions.base

"""
Base function classes for Sparkless.

This module provides base classes for all function types.
Most classes are imported from core/ modules to avoid duplication.
"""

from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
from sparkless.spark_types import DataType, StringType, get_row_value

# Import core classes from their canonical locations
from .core.column import Column, ColumnOperation
from .core.literals import Literal
from .core.lambda_parser import (
    MockLambdaExpression,
    LambdaParser,
    LambdaTranslationError,
)

if TYPE_CHECKING:
    from .window_execution import WindowFunction

# Re-export for backward compatibility
__all__ = [
    "Column",
    "ColumnOperation",
    "Literal",
    "AggregateFunction",
    "MockLambdaExpression",
    "LambdaParser",
    "LambdaTranslationError",
]


[docs] class AggregateFunction: """Base class for aggregate functions. This class provides the base functionality for all aggregate functions including count, sum, avg, max, min, etc. """
[docs] def __init__( self, column: Union[Column, ColumnOperation, str, None], function_name: str, data_type: Optional[DataType] = None, ignorenulls: Optional[bool] = None, ): """Initialize AggregateFunction. Args: column: The column to aggregate (None for count(*)). function_name: Name of the aggregate function. data_type: Optional return data type. ignorenulls: Optional flag to ignore nulls (for first/last functions). """ self.column = column self.function_name = function_name self.data_type = self._configure_data_type(data_type) self.name = self._generate_name() # Optional attributes for specific functions self.ord_column: Optional[Union[Column, str]] = None # For max_by, min_by self.ignorenulls: Optional[bool] = ignorenulls # For first/last functions self.rsd: Optional[float] = ( None # For approx_count_distinct (relative standard deviation) ) self.percentage: Optional[float] = None # For percentile function
def _configure_data_type(self, data_type: Optional[DataType]) -> DataType: """Configure data type with appropriate nullability based on function type.""" if not data_type: return StringType() # Functions that always return non-nullable results in PySpark non_nullable_functions = { "count", "countDistinct", "row_number", "rank", "dense_rank", "isNull", "isnan", "coalesce", } if self.function_name in non_nullable_functions: data_type.nullable = False return data_type @property def column_name(self) -> str: """Get the column name for compatibility.""" if self.column is None: return "*" elif isinstance(self.column, str): return self.column else: return str(self.column.name) def _generate_name(self) -> str: """Generate a name for this aggregate function.""" # PySpark uses 'avg' as column name even when using mean() # PySpark uses 'stddev_samp' as column name for stddev() # PySpark uses 'var_samp' as column name for variance() if self.function_name == "mean": display_name = "avg" elif self.function_name == "stddev": display_name = "stddev_samp" elif self.function_name == "variance": display_name = "var_samp" else: display_name = self.function_name if self.column is None: # For count(*), PySpark agg(F.count("*")) uses "count(1)". # GroupedData.count() shorthand uses alias("count") for cleaner output. if self.function_name == "count": return "count(1)" else: return f"{display_name}(*)" elif isinstance(self.column, str): # For count("*"), match PySpark agg: use "count(1)". if self.function_name == "count" and self.column == "*": return "count(1)" elif self.function_name == "countDistinct": # PySpark uses "count(column)" not "count(DISTINCT column)" for column names return f"count({self.column})" else: return f"{display_name}({self.column})" else: if self.function_name == "countDistinct": # PySpark uses "count(column)" not "count(DISTINCT column)" for column names return f"count({self.column.name})" elif self.function_name == "approx_count_distinct": # PySpark doesn't include rsd in column name, just use the base name return f"{display_name}({self.column.name})" else: return f"{display_name}({self.column.name})"
[docs] def evaluate(self, data: List[Dict[str, Any]]) -> Any: """Evaluate the aggregate function on the given data. Args: data: List of data rows to aggregate. Returns: The aggregated result. """ if self.function_name == "count": return self._evaluate_count(data) elif self.function_name == "sum": return self._evaluate_sum(data) elif self.function_name == "avg": return self._evaluate_avg(data) elif self.function_name == "max": return self._evaluate_max(data) elif self.function_name == "min": return self._evaluate_min(data) else: return None
def _evaluate_count(self, data: List[Dict[str, Any]]) -> int: """Evaluate count function.""" if self.column is None: return len(data) else: column_name = ( self.column if isinstance(self.column, str) else self.column.name ) return sum(1 for row in data if get_row_value(row, column_name) is not None) def _evaluate_sum(self, data: List[Dict[str, Any]]) -> Any: """Evaluate sum function.""" if self.column is None: return 0 column_name = self.column if isinstance(self.column, str) else self.column.name total = 0 for row in data: value = get_row_value(row, column_name) if value is not None: total += value return total def _evaluate_avg(self, data: List[Dict[str, Any]]) -> Any: """Evaluate average function.""" if self.column is None: return 0.0 column_name = self.column if isinstance(self.column, str) else self.column.name values = [ get_row_value(row, column_name) for row in data if get_row_value(row, column_name) is not None ] numeric_values = [v for v in values if isinstance(v, (int, float))] if numeric_values: return sum(numeric_values) / len(numeric_values) else: return None def _evaluate_max(self, data: List[Dict[str, Any]]) -> Any: """Evaluate max function.""" if self.column is None: return None column_name = self.column if isinstance(self.column, str) else self.column.name values = [ get_row_value(row, column_name) for row in data if get_row_value(row, column_name) is not None ] if values: return max(values) else: return None def _evaluate_min(self, data: List[Dict[str, Any]]) -> Any: """Evaluate min function.""" if self.column is None: return None column_name = self.column if isinstance(self.column, str) else self.column.name values = [ get_row_value(row, column_name) for row in data if get_row_value(row, column_name) is not None ] if values: return min(values) else: return None
[docs] def over(self, window_spec: Any) -> "WindowFunction": """Apply window function over window specification.""" from .window_execution import WindowFunction return WindowFunction(self, window_spec)
[docs] def alias(self, name: str) -> "AggregateFunction": """Create an alias for this aggregate function. Args: name: The alias name. Returns: Self for method chaining. """ self.name = name return self
[docs] def cast(self, data_type: Union[DataType, str]) -> "ColumnOperation": """Cast the aggregate function result to a different data type. Args: data_type: The target data type (DataType instance or string type name). Returns: ColumnOperation representing the cast operation. Example: >>> F.mean(F.col("value")).cast("string") """ return ColumnOperation(self, "cast", data_type)
def _create_operation(self, operation: str, other: Any) -> "ColumnOperation": """Create a ColumnOperation with the given operation and other operand. Args: operation: The operation to perform (e.g., "+", "-", etc.) other: The other operand Returns: ColumnOperation instance """ return ColumnOperation(self, operation, other)
[docs] def __add__(self, other: Any) -> "ColumnOperation": """Addition operation (PySpark-compatible).""" return self._create_operation("+", other)
[docs] def __sub__(self, other: Any) -> "ColumnOperation": """Subtraction operation (PySpark-compatible).""" return self._create_operation("-", other)
[docs] def __mul__(self, other: Any) -> "ColumnOperation": """Multiplication operation (PySpark-compatible).""" return self._create_operation("*", other)
[docs] def __truediv__(self, other: Any) -> "ColumnOperation": """Division operation (PySpark-compatible).""" return self._create_operation("/", other)
[docs] def __mod__(self, other: Any) -> "ColumnOperation": """Modulo operation (PySpark-compatible).""" return self._create_operation("%", other)
[docs] def __radd__(self, other: Any) -> "ColumnOperation": """Reverse addition operation (for `2 + agg_func`).""" # For commutative operations, we can just swap operands return self._create_operation("+", other)
[docs] def __rsub__(self, other: Any) -> "ColumnOperation": """Reverse subtraction operation (for `2 - agg_func`).""" # For non-commutative operations, create ColumnOperation with literal as left operand return ColumnOperation(other, "-", self)
[docs] def __rmul__(self, other: Any) -> "ColumnOperation": """Reverse multiplication operation (for `2 * agg_func`).""" # For commutative operations, we can just swap operands return self._create_operation("*", other)
[docs] def __rtruediv__(self, other: Any) -> "ColumnOperation": """Reverse division operation (for `2 / agg_func`).""" # For non-commutative operations, create ColumnOperation with literal as left operand return ColumnOperation(other, "/", self)
[docs] def __rmod__(self, other: Any) -> "ColumnOperation": """Reverse modulo operation (for `2 % agg_func`).""" # For non-commutative operations, create ColumnOperation with literal as left operand return ColumnOperation(other, "%", self)