187 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
			
		
		
	
	
			187 lines
		
	
	
		
			7.0 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
| from __future__ import annotations
 | |
| 
 | |
| from typing import TYPE_CHECKING, Any, Generic, TypeVar
 | |
| 
 | |
| from narwhals._expression_parsing import all_exprs_are_scalar_like
 | |
| from narwhals._utils import flatten, tupleify
 | |
| from narwhals.exceptions import InvalidOperationError
 | |
| from narwhals.typing import DataFrameT
 | |
| 
 | |
| if TYPE_CHECKING:
 | |
|     from collections.abc import Iterable, Iterator, Sequence
 | |
| 
 | |
|     from narwhals._compliant.typing import CompliantExprAny
 | |
|     from narwhals.dataframe import LazyFrame
 | |
|     from narwhals.expr import Expr
 | |
| 
 | |
| LazyFrameT = TypeVar("LazyFrameT", bound="LazyFrame[Any]")
 | |
| 
 | |
| 
 | |
| class GroupBy(Generic[DataFrameT]):
 | |
|     def __init__(
 | |
|         self,
 | |
|         df: DataFrameT,
 | |
|         keys: Sequence[str] | Sequence[CompliantExprAny],
 | |
|         /,
 | |
|         *,
 | |
|         drop_null_keys: bool,
 | |
|     ) -> None:
 | |
|         self._df: DataFrameT = df
 | |
|         self._keys = keys
 | |
|         self._grouped = self._df._compliant_frame.group_by(
 | |
|             self._keys, drop_null_keys=drop_null_keys
 | |
|         )
 | |
| 
 | |
|     def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> DataFrameT:
 | |
|         """Compute aggregations for each group of a group by operation.
 | |
| 
 | |
|         Arguments:
 | |
|             aggs: Aggregations to compute for each group of the group by operation,
 | |
|                 specified as positional arguments.
 | |
|             named_aggs: Additional aggregations, specified as keyword arguments.
 | |
| 
 | |
|         Examples:
 | |
|             Group by one column or by multiple columns and call `agg` to compute
 | |
|             the grouped sum of another column.
 | |
| 
 | |
|             >>> import pandas as pd
 | |
|             >>> import narwhals as nw
 | |
|             >>> df_native = pd.DataFrame(
 | |
|             ...     {
 | |
|             ...         "a": ["a", "b", "a", "b", "c"],
 | |
|             ...         "b": [1, 2, 1, 3, 3],
 | |
|             ...         "c": [5, 4, 3, 2, 1],
 | |
|             ...     }
 | |
|             ... )
 | |
|             >>> df = nw.from_native(df_native)
 | |
|             >>>
 | |
|             >>> df.group_by("a").agg(nw.col("b").sum()).sort("a")
 | |
|             ┌──────────────────┐
 | |
|             |Narwhals DataFrame|
 | |
|             |------------------|
 | |
|             |        a  b      |
 | |
|             |     0  a  2      |
 | |
|             |     1  b  5      |
 | |
|             |     2  c  3      |
 | |
|             └──────────────────┘
 | |
|             >>>
 | |
|             >>> df.group_by("a", "b").agg(nw.col("c").sum()).sort("a", "b").to_native()
 | |
|                a  b  c
 | |
|             0  a  1  8
 | |
|             1  b  2  4
 | |
|             2  b  3  2
 | |
|             3  c  3  1
 | |
|         """
 | |
|         flat_aggs = tuple(flatten(aggs))
 | |
|         if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
 | |
|             msg = (
 | |
|                 "Found expression which does not aggregate.\n\n"
 | |
|                 "All expressions passed to GroupBy.agg must aggregate.\n"
 | |
|                 "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
 | |
|                 "but `df.group_by('a').agg(nw.col('b'))` is not."
 | |
|             )
 | |
|             raise InvalidOperationError(msg)
 | |
|         plx = self._df.__narwhals_namespace__()
 | |
|         compliant_aggs = (
 | |
|             *(x._to_compliant_expr(plx) for x in flat_aggs),
 | |
|             *(
 | |
|                 value.alias(key)._to_compliant_expr(plx)
 | |
|                 for key, value in named_aggs.items()
 | |
|             ),
 | |
|         )
 | |
|         return self._df._with_compliant(self._grouped.agg(*compliant_aggs))
 | |
| 
 | |
|     def __iter__(self) -> Iterator[tuple[Any, DataFrameT]]:
 | |
|         yield from (
 | |
|             (tupleify(key), self._df._with_compliant(df))
 | |
|             for (key, df) in self._grouped.__iter__()
 | |
|         )
 | |
| 
 | |
| 
 | |
| class LazyGroupBy(Generic[LazyFrameT]):
 | |
|     def __init__(
 | |
|         self,
 | |
|         df: LazyFrameT,
 | |
|         keys: Sequence[str] | Sequence[CompliantExprAny],
 | |
|         /,
 | |
|         *,
 | |
|         drop_null_keys: bool,
 | |
|     ) -> None:
 | |
|         self._df: LazyFrameT = df
 | |
|         self._keys = keys
 | |
|         self._grouped = self._df._compliant_frame.group_by(
 | |
|             self._keys, drop_null_keys=drop_null_keys
 | |
|         )
 | |
| 
 | |
|     def agg(self, *aggs: Expr | Iterable[Expr], **named_aggs: Expr) -> LazyFrameT:
 | |
|         """Compute aggregations for each group of a group by operation.
 | |
| 
 | |
|         Arguments:
 | |
|             aggs: Aggregations to compute for each group of the group by operation,
 | |
|                 specified as positional arguments.
 | |
|             named_aggs: Additional aggregations, specified as keyword arguments.
 | |
| 
 | |
|         Examples:
 | |
|             Group by one column or by multiple columns and call `agg` to compute
 | |
|             the grouped sum of another column.
 | |
| 
 | |
|             >>> import polars as pl
 | |
|             >>> import narwhals as nw
 | |
|             >>> from narwhals.typing import IntoFrameT
 | |
|             >>> lf_native = pl.LazyFrame(
 | |
|             ...     {
 | |
|             ...         "a": ["a", "b", "a", "b", "c"],
 | |
|             ...         "b": [1, 2, 1, 3, 3],
 | |
|             ...         "c": [5, 4, 3, 2, 1],
 | |
|             ...     }
 | |
|             ... )
 | |
|             >>> lf = nw.from_native(lf_native)
 | |
|             >>>
 | |
|             >>> nw.to_native(lf.group_by("a").agg(nw.col("b").sum()).sort("a")).collect()
 | |
|             shape: (3, 2)
 | |
|             ┌─────┬─────┐
 | |
|             │ a   ┆ b   │
 | |
|             │ --- ┆ --- │
 | |
|             │ str ┆ i64 │
 | |
|             ╞═════╪═════╡
 | |
|             │ a   ┆ 2   │
 | |
|             │ b   ┆ 5   │
 | |
|             │ c   ┆ 3   │
 | |
|             └─────┴─────┘
 | |
|             >>>
 | |
|             >>> lf.group_by("a", "b").agg(nw.sum("c")).sort("a", "b").collect()
 | |
|             ┌───────────────────┐
 | |
|             |Narwhals DataFrame |
 | |
|             |-------------------|
 | |
|             |shape: (4, 3)      |
 | |
|             |┌─────┬─────┬─────┐|
 | |
|             |│ a   ┆ b   ┆ c   │|
 | |
|             |│ --- ┆ --- ┆ --- │|
 | |
|             |│ str ┆ i64 ┆ i64 │|
 | |
|             |╞═════╪═════╪═════╡|
 | |
|             |│ a   ┆ 1   ┆ 8   │|
 | |
|             |│ b   ┆ 2   ┆ 4   │|
 | |
|             |│ b   ┆ 3   ┆ 2   │|
 | |
|             |│ c   ┆ 3   ┆ 1   │|
 | |
|             |└─────┴─────┴─────┘|
 | |
|             └───────────────────┘
 | |
|         """
 | |
|         flat_aggs = tuple(flatten(aggs))
 | |
|         if not all_exprs_are_scalar_like(*flat_aggs, **named_aggs):
 | |
|             msg = (
 | |
|                 "Found expression which does not aggregate.\n\n"
 | |
|                 "All expressions passed to GroupBy.agg must aggregate.\n"
 | |
|                 "For example, `df.group_by('a').agg(nw.col('b').sum())` is valid,\n"
 | |
|                 "but `df.group_by('a').agg(nw.col('b'))` is not."
 | |
|             )
 | |
|             raise InvalidOperationError(msg)
 | |
|         plx = self._df.__narwhals_namespace__()
 | |
|         compliant_aggs = (
 | |
|             *(x._to_compliant_expr(plx) for x in flat_aggs),
 | |
|             *(
 | |
|                 value.alias(key)._to_compliant_expr(plx)
 | |
|                 for key, value in named_aggs.items()
 | |
|             ),
 | |
|         )
 | |
|         return self._df._with_compliant(self._grouped.agg(*compliant_aggs))
 |