This commit is contained in:
2025-09-07 22:09:54 +02:00
parent e1b817252c
commit 2fc0d000b6
7796 changed files with 2159515 additions and 933 deletions

View File

@ -0,0 +1,185 @@
from __future__ import annotations
import typing as _t
from narwhals import dependencies, dtypes, exceptions, selectors
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
is_ordered_categorical,
maybe_align_index,
maybe_convert_dtypes,
maybe_get_index,
maybe_reset_index,
maybe_set_index,
)
from narwhals.dataframe import DataFrame, LazyFrame
from narwhals.dtypes import (
Array,
Binary,
Boolean,
Categorical,
Date,
Datetime,
Decimal,
Duration,
Enum,
Field,
Float32,
Float64,
Int8,
Int16,
Int32,
Int64,
Int128,
List,
Object,
String,
Struct,
Time,
UInt8,
UInt16,
UInt32,
UInt64,
UInt128,
Unknown,
)
from narwhals.expr import Expr
from narwhals.functions import (
all_ as all,
all_horizontal,
any_horizontal,
coalesce,
col,
concat,
concat_str,
exclude,
from_arrow,
from_dict,
from_numpy,
len_ as len,
lit,
max,
max_horizontal,
mean,
mean_horizontal,
median,
min,
min_horizontal,
new_series,
nth,
read_csv,
read_parquet,
scan_csv,
scan_parquet,
show_versions,
sum,
sum_horizontal,
when,
)
from narwhals.schema import Schema
from narwhals.series import Series
from narwhals.translate import (
from_native,
get_native_namespace,
narwhalify,
to_native,
to_py_scalar,
)
__version__: str
__all__ = [
"Array",
"Binary",
"Boolean",
"Categorical",
"DataFrame",
"Date",
"Datetime",
"Decimal",
"Duration",
"Enum",
"Expr",
"Field",
"Float32",
"Float64",
"Implementation",
"Int8",
"Int16",
"Int32",
"Int64",
"Int128",
"LazyFrame",
"List",
"Object",
"Schema",
"Series",
"String",
"Struct",
"Time",
"UInt8",
"UInt16",
"UInt32",
"UInt64",
"UInt128",
"Unknown",
"all",
"all_horizontal",
"any_horizontal",
"coalesce",
"col",
"concat",
"concat_str",
"dependencies",
"dtypes",
"exceptions",
"exclude",
"from_arrow",
"from_dict",
"from_native",
"from_numpy",
"generate_temporary_column_name",
"get_native_namespace",
"is_ordered_categorical",
"len",
"lit",
"max",
"max_horizontal",
"maybe_align_index",
"maybe_convert_dtypes",
"maybe_get_index",
"maybe_reset_index",
"maybe_set_index",
"mean",
"mean_horizontal",
"median",
"min",
"min_horizontal",
"narwhalify",
"new_series",
"nth",
"read_csv",
"read_parquet",
"scan_csv",
"scan_parquet",
"selectors",
"show_versions",
"sum",
"sum_horizontal",
"to_native",
"to_py_scalar",
"when",
]
def __getattr__(name: _t.Literal["__version__"]) -> str: # type: ignore[misc]
if name == "__version__":
global __version__ # noqa: PLW0603
from importlib import metadata
__version__ = metadata.version(__name__)
return __version__
msg = f"module {__name__!r} has no attribute {name!r}"
raise AttributeError(msg)

View File

@ -0,0 +1,792 @@
from __future__ import annotations
from collections.abc import Collection, Iterator, Mapping, Sequence
from functools import partial
from typing import TYPE_CHECKING, Any, Literal, cast, overload
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import native_to_narwhals_dtype
from narwhals._compliant import EagerDataFrame
from narwhals._expression_parsing import ExprKind
from narwhals._utils import (
Implementation,
Version,
check_column_names_are_unique,
convert_str_slice_to_int_slice,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
scale_bytes,
supports_arrow_c_stream,
zip_strict,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ShapeError
if TYPE_CHECKING:
from collections.abc import Iterable
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pandas as pd
import polars as pl
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.group_by import ArrowGroupBy
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
ChunkedArrayAny,
Mask,
Order,
)
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
from narwhals._spark_like.utils import SparkSession
from narwhals._translate import IntoArrowTable
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import (
IntoSchema,
JoinStrategy,
SizedMultiIndexSelector,
SizedMultiNameSelector,
SizeUnit,
UniqueKeepStrategy,
_1DArray,
_2DArray,
_SliceIndex,
_SliceName,
)
JoinType: TypeAlias = Literal[
"left semi",
"right semi",
"left anti",
"right anti",
"inner",
"left outer",
"right outer",
"full outer",
]
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
class ArrowDataFrame(
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
):
_implementation = Implementation.PYARROW
def __init__(
self,
native_dataframe: pa.Table,
*,
version: Version,
validate_column_names: bool,
validate_backend_version: bool = False,
) -> None:
if validate_column_names:
check_column_names_are_unique(native_dataframe.column_names)
if validate_backend_version:
self._validate_backend_version()
self._native_frame = native_dataframe
self._version = version
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
backend_version = context._implementation._backend_version()
if cls._is_native(data):
native = data
elif backend_version >= (14,) or isinstance(data, Collection):
native = pa.table(data)
elif supports_arrow_c_stream(data): # pragma: no cover
msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}."
raise ModuleNotFoundError(msg)
else: # pragma: no cover
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
raise TypeError(msg)
return cls.from_native(native, context=context)
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: IntoSchema | None,
) -> Self:
from narwhals.schema import Schema
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
if pa_schema and not data:
native = pa_schema.empty_table()
else:
native = pa.Table.from_pydict(data, schema=pa_schema)
return cls.from_native(native, context=context)
@staticmethod
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
return isinstance(obj, pa.Table)
@classmethod
def from_native(cls, data: pa.Table, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version, validate_column_names=True)
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext,
schema: IntoSchema | Sequence[str] | None,
) -> Self:
from narwhals.schema import Schema
arrays = [pa.array(val) for val in data.T]
if isinstance(schema, (Mapping, Schema)):
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
else:
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
return cls.from_native(native, context=context)
def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace
return ArrowNamespace(version=self._version)
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.PYARROW:
return self._implementation.to_native_namespace()
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def __narwhals_dataframe__(self) -> Self:
return self
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version, validate_column_names=False)
def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self:
return self.__class__(
df, version=self._version, validate_column_names=validate_column_names
)
@property
def shape(self) -> tuple[int, int]:
return self.native.shape
def __len__(self) -> int:
return len(self.native)
def row(self, index: int) -> tuple[Any, ...]:
return tuple(col[index] for col in self.native.itercolumns())
@overload
def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
@overload
def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ...
@overload
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
if not named:
return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value]
return self.native.to_pylist()
def iter_columns(self) -> Iterator[ArrowSeries]:
for name, series in zip_strict(self.columns, self.native.itercolumns()):
yield ArrowSeries.from_native(series, context=self, name=name)
_iter_columns = iter_columns
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
df = self.native
num_rows = df.num_rows
if not named:
for i in range(0, num_rows, buffer_size):
rows = df[i : i + buffer_size].to_pydict().values()
yield from zip_strict(*rows)
else:
for i in range(0, num_rows, buffer_size):
yield from df[i : i + buffer_size].to_pylist()
def get_column(self, name: str) -> ArrowSeries:
if not isinstance(name, str):
msg = f"Expected str, got: {type(name)}"
raise TypeError(msg)
return ArrowSeries.from_native(self.native[name], context=self, name=name)
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray:
return self.native.__array__(dtype, copy=copy)
def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self:
if len(rows) == 0:
return self._with_native(self.native.slice(0, 0))
if self._backend_version < (18,) and isinstance(rows, tuple):
rows = list(rows)
return self._with_native(self.native.take(rows))
def _gather_slice(self, rows: _SliceIndex | range) -> Self:
start = rows.start or 0
stop = rows.stop if rows.stop is not None else len(self.native)
if start < 0:
start = len(self.native) + start
if stop < 0:
stop = len(self.native) + stop
if rows.step is not None and rows.step != 1:
msg = "Slicing with step is not supported on PyArrow tables"
raise NotImplementedError(msg)
return self._with_native(self.native.slice(start, stop - start))
def _select_slice_name(self, columns: _SliceName) -> Self:
start, stop, step = convert_str_slice_to_int_slice(columns, self.columns)
return self._with_native(self.native.select(self.columns[start:stop:step]))
def _select_slice_index(self, columns: _SliceIndex | range) -> Self:
return self._with_native(
self.native.select(self.columns[columns.start : columns.stop : columns.step])
)
def _select_multi_index(
self, columns: SizedMultiIndexSelector[ChunkedArrayAny]
) -> Self:
selector: Sequence[int]
if isinstance(columns, pa.ChunkedArray):
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
selector = cast("Sequence[int]", columns.to_pylist())
# TODO @dangotbanned: Fix upstream, it is actually much narrower
# **Doesn't accept `ndarray`**
elif is_numpy_array_1d(columns):
selector = columns.tolist()
else:
selector = columns
return self._with_native(self.native.select(selector))
def _select_multi_name(
self, columns: SizedMultiNameSelector[ChunkedArrayAny]
) -> Self:
selector: Sequence[str] | _1DArray
if isinstance(columns, pa.ChunkedArray):
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
selector = cast("Sequence[str]", columns.to_pylist())
else:
selector = columns
# NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221
return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType]
@property
def schema(self) -> dict[str, DType]:
return {
field.name: native_to_narwhals_dtype(field.type, self._version)
for field in self.native.schema
}
def collect_schema(self) -> dict[str, DType]:
return self.schema
def estimated_size(self, unit: SizeUnit) -> int | float:
sz = self.native.nbytes
return scale_bytes(sz, unit)
explode = not_implemented()
@property
def columns(self) -> list[str]:
return self.native.column_names
def simple_select(self, *column_names: str) -> Self:
return self._with_native(
self.native.select(list(column_names)), validate_column_names=False
)
def select(self, *exprs: ArrowExpr) -> Self:
new_series = self._evaluate_into_exprs(*exprs)
if not new_series:
# return empty dataframe, like Polars does
return self._with_native(
self.native.__class__.from_arrays([]), validate_column_names=False
)
names = [s.name for s in new_series]
align = new_series[0]._align_full_broadcast
reshaped = align(*new_series)
df = pa.Table.from_arrays([s.native for s in reshaped], names=names)
return self._with_native(df, validate_column_names=True)
def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny:
length = len(self)
if not other._broadcast:
if (len_other := len(other)) != length:
msg = f"Expected object of length {length}, got: {len_other}."
raise ShapeError(msg)
return other.native
value = other.native[0]
return pa.chunked_array([pa.repeat(value, length)])
def with_columns(self, *exprs: ArrowExpr) -> Self:
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
# All `pyarrow` data is immutable, so this is fine
native_frame = self.native
new_columns = self._evaluate_into_exprs(*exprs)
columns = self.columns
for col_value in new_columns:
col_name = col_value.name
column = self._extract_comparand(col_value)
native_frame = (
native_frame.set_column(columns.index(col_name), col_name, column=column)
if col_name in columns
else native_frame.append_column(col_name, column=column)
)
return self._with_native(native_frame, validate_column_names=False)
def group_by(
self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool
) -> ArrowGroupBy:
from narwhals._arrow.group_by import ArrowGroupBy
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_to_join_map: dict[str, JoinType] = {
"anti": "left anti",
"semi": "left semi",
"inner": "inner",
"left": "left outer",
"full": "full outer",
}
if how == "cross":
plx = self.__narwhals_namespace__()
key_token = generate_temporary_column_name(
n_bytes=8, columns=[*self.columns, *other.columns]
)
return self._with_native(
self.with_columns(
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
)
.native.join(
other.with_columns(
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
).native,
keys=key_token,
right_keys=key_token,
join_type="inner",
right_suffix=suffix,
)
.drop([key_token])
)
coalesce_keys = how != "full" # polars full join does not coalesce keys
return self._with_native(
self.native.join(
other.native,
keys=left_on or [], # type: ignore[arg-type]
right_keys=right_on, # type: ignore[arg-type]
join_type=how_to_join_map[how],
right_suffix=suffix,
coalesce_keys=coalesce_keys,
)
)
join_asof = not_implemented()
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(to_drop), validate_column_names=False)
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
if subset is None:
return self._with_native(self.native.drop_null(), validate_column_names=False)
plx = self.__narwhals_namespace__()
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
return self.filter(mask)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
order: Order = "descending" if descending else "ascending"
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
else:
sorting = [
(key, "descending" if is_descending else "ascending")
for key, is_descending in zip_strict(by, descending)
]
null_placement = "at_end" if nulls_last else "at_start"
return self._with_native(
self.native.sort_by(sorting, null_placement=null_placement),
validate_column_names=False,
)
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
if isinstance(reverse, bool):
order: Order = "ascending" if reverse else "descending"
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
else:
sorting = [
(key, "ascending" if is_ascending else "descending")
for key, is_ascending in zip_strict(by, reverse)
]
return self._with_native(
self.native.take(pc.select_k_unstable(self.native, k, sorting)), # type: ignore[call-overload]
validate_column_names=False,
)
def to_pandas(self) -> pd.DataFrame:
return self.native.to_pandas()
def to_polars(self) -> pl.DataFrame:
import polars as pl # ignore-banned-import
return pl.from_arrow(self.native) # type: ignore[return-value]
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
import numpy as np # ignore-banned-import
arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns])
return arr
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
it = self.iter_columns()
if as_series:
return {ser.name: ser for ser in it}
return {ser.name: ser.to_list() for ser in it}
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
plx = self.__narwhals_namespace__()
if order_by is None:
import numpy as np # ignore-banned-import
data = pa.array(np.arange(len(self), dtype=np.int64))
row_index = plx._expr._from_series(
plx._series.from_iterable(data, context=self, name=name)
)
else:
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
return self.select(row_index, plx.all())
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
if isinstance(predicate, list):
mask_native: Mask | ChunkedArrayAny = predicate
else:
# `[0]` is safe as the predicate's expression only returns a single column
mask_native = self._evaluate_into_exprs(predicate)[0].native
return self._with_native(
self.native.filter(mask_native), validate_column_names=False
)
def head(self, n: int) -> Self:
df = self.native
if n >= 0:
return self._with_native(df.slice(0, n), validate_column_names=False)
num_rows = df.num_rows
return self._with_native(
df.slice(0, max(0, num_rows + n)), validate_column_names=False
)
def tail(self, n: int) -> Self:
df = self.native
if n >= 0:
num_rows = df.num_rows
return self._with_native(
df.slice(max(0, num_rows - n)), validate_column_names=False
)
return self._with_native(df.slice(abs(n)), validate_column_names=False)
def lazy(
self,
backend: _LazyAllowedImpl | None = None,
*,
session: SparkSession | None = None,
) -> CompliantLazyFrameAny:
if backend is None:
return self
if backend is Implementation.DUCKDB:
import duckdb # ignore-banned-import
from narwhals._duckdb.dataframe import DuckDBLazyFrame
_df = self.native
return DuckDBLazyFrame(
duckdb.table("_df"), validate_backend_version=True, version=self._version
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsLazyFrame
return PolarsLazyFrame(
cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(),
validate_backend_version=True,
version=self._version,
)
if backend is Implementation.DASK:
import dask.dataframe as dd # ignore-banned-import
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
dd.from_pandas(self.native.to_pandas()),
validate_backend_version=True,
version=self._version,
)
if backend is Implementation.IBIS:
import ibis # ignore-banned-import
from narwhals._ibis.dataframe import IbisLazyFrame
return IbisLazyFrame(
ibis.memtable(self.native, columns=self.columns),
validate_backend_version=True,
version=self._version,
)
if backend.is_spark_like():
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
if session is None:
msg = "Spark like backends require `session` to be not None."
raise ValueError(msg)
return SparkLikeLazyFrame._from_compliant_dataframe(
self, session=session, implementation=backend, version=self._version
)
raise AssertionError # pragma: no cover
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is Implementation.PYARROW or backend is None:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self.native, version=self._version, validate_column_names=False
)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.to_pandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
cast("pl.DataFrame", pl.from_arrow(self.native)),
validate_backend_version=True,
version=self._version,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise AssertionError(msg) # pragma: no cover
def clone(self) -> Self:
return self._with_native(self.native, validate_column_names=False)
def item(self, row: int | None, column: int | str | None) -> Any:
from narwhals._arrow.series import maybe_extract_py_scalar
if row is None and column is None:
if self.shape != (1, 1):
msg = (
"can only call `.item()` if the dataframe is of shape (1, 1),"
" or if explicit row/col values are provided;"
f" frame has shape {self.shape!r}"
)
raise ValueError(msg)
return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True)
if row is None or column is None:
msg = "cannot call `.item()` with only one of `row` or `column`"
raise ValueError(msg)
_col = self.columns.index(column) if isinstance(column, str) else column
return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True)
def rename(self, mapping: Mapping[str, str]) -> Self:
names: dict[str, str] | list[str]
if self._backend_version >= (17,):
names = cast("dict[str, str]", mapping)
else: # pragma: no cover
names = [mapping.get(c, c) for c in self.columns]
return self._with_native(self.native.rename_columns(names))
def write_parquet(self, file: str | Path | BytesIO) -> None:
import pyarrow.parquet as pp
pp.write_table(self.native, file)
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None:
import pyarrow.csv as pa_csv
if file is None:
csv_buffer = pa.BufferOutputStream()
pa_csv.write_csv(self.native, csv_buffer)
return csv_buffer.getvalue().to_pybytes().decode()
pa_csv.write_csv(self.native, file)
return None
def is_unique(self) -> ArrowSeries:
import numpy as np # ignore-banned-import
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
row_index = pa.array(np.arange(len(self)))
keep_idx = (
self.native.append_column(col_token, row_index)
.group_by(self.columns)
.aggregate([(col_token, "min"), (col_token, "max")])
)
native = pa.chunked_array(
pc.and_(
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
)
)
return ArrowSeries.from_native(native, context=self)
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self:
# The param `maintain_order` is only here for compatibility with the Polars API
# and has no effect on the output.
import numpy as np # ignore-banned-import
if subset and (error := self._check_columns_exist(subset)):
raise error
subset = list(subset or self.columns)
if keep in {"any", "first", "last"}:
from narwhals._arrow.group_by import ArrowGroupBy
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
keep_idx_native = (
self.native.append_column(col_token, pa.array(np.arange(len(self))))
.group_by(subset)
.aggregate([(col_token, agg_func)])
.column(f"{col_token}_{agg_func}")
)
return self._with_native(
self.native.take(keep_idx_native), validate_column_names=False
)
keep_idx = self.simple_select(*subset).is_unique()
plx = self.__narwhals_namespace__()
return self.filter(plx._expr._from_series(keep_idx))
def gather_every(self, n: int, offset: int) -> Self:
return self._with_native(self.native[offset::n], validate_column_names=False)
def to_arrow(self) -> pa.Table:
return self.native
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self:
import numpy as np # ignore-banned-import
num_rows = len(self)
if n is None and fraction is not None:
n = int(num_rows * fraction)
rng = np.random.default_rng(seed=seed)
idx = np.arange(num_rows)
mask = rng.choice(idx, size=n, replace=with_replacement)
return self._with_native(self.native.take(mask), validate_column_names=False)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
n_rows = len(self)
index_ = [] if index is None else index
on_ = [c for c in self.columns if c not in index_] if on is None else on
concat = (
partial(pa.concat_tables, promote_options="permissive")
if self._backend_version >= (14, 0, 0)
else pa.concat_tables
)
names = [*index_, variable_name, value_name]
return self._with_native(
concat(
[
pa.Table.from_arrays(
[
*(self.native.column(idx_col) for idx_col in index_),
cast(
"ChunkedArrayAny",
pa.array([on_col] * n_rows, pa.string()),
),
self.native.column(on_col),
],
names=names,
)
for on_col in on_
]
)
)
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
# upcast numeric to non-numeric (e.g. string) datatypes
pivot = not_implemented()

View File

@ -0,0 +1,170 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import pyarrow.compute as pc
from narwhals._arrow.series import ArrowSeries
from narwhals._compliant import EagerExpr
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
not_implemented,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
from narwhals._expression_parsing import ExprMetadata
from narwhals._utils import Version, _LimitedContext
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
_implementation: Implementation = Implementation.PYARROW
def __init__(
self,
call: EvalSeries[ArrowDataFrame, ArrowSeries],
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[ArrowDataFrame],
alias_output_names: AliasNames | None,
version: Version,
scalar_kwargs: ScalarKwargs | None = None,
implementation: Implementation | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._depth = depth
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._scalar_kwargs = scalar_kwargs or {}
self._metadata: ExprMetadata | None = None
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[ArrowDataFrame],
/,
*,
context: _LimitedContext,
function_name: str = "",
) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
try:
return [
ArrowSeries(
df.native[column_name], name=column_name, version=df._version
)
for column_name in evaluate_column_names(df)
]
except KeyError as e:
if error := df._check_columns_exist(evaluate_column_names(df)):
raise error from e
raise
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
tbl = df.native
cols = df.columns
return [
ArrowSeries.from_native(tbl[i], name=cols[i], context=df)
for i in column_indices
]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
def __narwhals_namespace__(self) -> ArrowNamespace:
from narwhals._arrow.namespace import ArrowNamespace
return ArrowNamespace(version=self._version)
def _reuse_series_extra_kwargs(
self, *, returns_scalar: bool = False
) -> dict[str, Any]:
return {"_return_py_scalar": False} if returns_scalar else {}
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
if (
partition_by
and self._metadata is not None
and not self._metadata.is_scalar_like
):
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
raise NotImplementedError(msg)
if not partition_by:
# e.g. `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
assert order_by # noqa: S101
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
token = generate_temporary_column_name(8, df.columns)
df = df.with_row_index(token, order_by=None).sort(
*order_by, descending=False, nulls_last=False
)
result = self(df.drop([token], strict=True))
# TODO(marco): is there a way to do this efficiently without
# doing 2 sorts? Here we're sorting the dataframe and then
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
sorting_indices = pc.sort_indices(df.get_column(token).native)
return [s._with_native(s.native.take(sorting_indices)) for s in result]
else:
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
if overlap := set(output_names).intersection(partition_by):
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
# we just don't support it yet.
msg = (
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
"This is not yet supported."
)
raise NotImplementedError(msg)
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
tmp = df.simple_select(*partition_by).join(
tmp,
how="left",
left_on=partition_by,
right_on=partition_by,
suffix="_right",
)
return [tmp.get_column(alias) for alias in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
ewm_mean = not_implemented()

View File

@ -0,0 +1,159 @@
from __future__ import annotations
import collections
from typing import TYPE_CHECKING, Any, ClassVar
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
from narwhals._compliant import EagerGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import generate_temporary_column_name
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
AggregateOptions,
Aggregation,
Incomplete,
)
from narwhals._compliant.typing import NarwhalsAggregation
from narwhals.typing import UniqueKeepStrategy
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
"sum": "sum",
"mean": "mean",
"median": "approximate_median",
"max": "max",
"min": "min",
"std": "stddev",
"var": "variance",
"len": "count",
"n_unique": "count_distinct",
"count": "count",
"all": "all",
"any": "any",
}
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
"any": "min",
"first": "min",
"last": "max",
}
def __init__(
self,
df: ArrowDataFrame,
keys: Sequence[ArrowExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._df = df
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
self._drop_null_keys = drop_null_keys
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
self._ensure_all_simple(exprs)
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
expected_pyarrow_column_names: list[str] = self._keys.copy()
new_column_names: list[str] = self._keys.copy()
exclude = (*self._keys, *self._output_key_names)
for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self.compliant, exclude
)
if expr._depth == 0:
# e.g. `agg(nw.len())`
if expr._function_name != "len": # pragma: no cover
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
raise AssertionError(msg)
new_column_names.append(aliases[0])
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
continue
function_name = self._leaf_name(expr)
if function_name in {"std", "var"}:
assert "ddof" in expr._scalar_kwargs # noqa: S101
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
elif function_name in {"len", "n_unique"}:
option = pc.CountOptions(mode="all")
elif function_name == "count":
option = pc.CountOptions(mode="only_valid")
elif function_name in {"all", "any"}:
option = pc.ScalarAggregateOptions(min_count=0)
else:
option = None
function_name = self._remap_expr_name(function_name)
new_column_names.extend(aliases)
expected_pyarrow_column_names.extend(
[f"{output_name}_{function_name}" for output_name in output_names]
)
aggs.extend(
[(output_name, function_name, option) for output_name in output_names]
)
result_simple = self._grouped.aggregate(aggs)
# Rename columns, being very careful
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
for idx, item in enumerate(expected_pyarrow_column_names):
expected_old_names_indices[item].append(idx)
if not (
set(result_simple.column_names) == set(expected_pyarrow_column_names)
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
): # pragma: no cover
msg = (
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
f"got {result_simple.column_names}, "
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
)
raise AssertionError(msg)
index_map: list[int] = [
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
]
new_column_names = [new_column_names[i] for i in index_map]
result_simple = result_simple.rename_columns(new_column_names)
return self.compliant._with_native(result_simple).rename(
dict(zip(self._keys, self._output_key_names))
)
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
col_token = generate_temporary_column_name(
n_bytes=8, columns=self.compliant.columns
)
null_token: str = "__null_token_value__" # noqa: S105
table = self.compliant.native
it, separator_scalar = cast_to_comparable_string_types(
*(table[key] for key in self._keys), separator=""
)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
key_values = concat_str(
*it, separator_scalar, null_handling="replace", null_replacement=null_token
)
table = table.add_column(i=0, field_=col_token, column=key_values)
for v in pc.unique(key_values):
t = self.compliant._with_native(
table.filter(pc.equal(table[col_token], v)).drop([col_token])
)
row = t.simple_select(*self._keys).row(0)
yield (
tuple(extract_py_scalar(el) for el in row),
t.simple_select(*self._df.columns),
)

View File

@ -0,0 +1,303 @@
from __future__ import annotations
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Literal
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.dataframe import ArrowDataFrame
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.selectors import ArrowSelectorNamespace
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.utils import cast_to_comparable_string_types
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
from narwhals._compliant.typing import ScalarKwargs
from narwhals._utils import Version
from narwhals.typing import IntoDType, NonNestedLiteral
class ArrowNamespace(
EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table, "ChunkedArrayAny"]
):
_implementation = Implementation.PYARROW
@property
def _dataframe(self) -> type[ArrowDataFrame]:
return ArrowDataFrame
@property
def _expr(self) -> type[ArrowExpr]:
return ArrowExpr
@property
def _series(self) -> type[ArrowSeries]:
return ArrowSeries
def __init__(self, *, version: Version) -> None:
self._version = version
def len(self) -> ArrowExpr:
# coverage bug? this is definitely hit
return self._expr( # pragma: no cover
lambda df: [
ArrowSeries.from_iterable([len(df.native)], name="len", context=self)
],
depth=0,
function_name="len",
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
version=self._version,
)
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr:
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
arrow_series = ArrowSeries.from_iterable(
data=[value], name="literal", context=self
)
if dtype:
return arrow_series.cast(dtype)
return arrow_series
return self._expr(
lambda df: [_lit_arrow_series(df)],
depth=0,
function_name="lit",
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
version=self._version,
)
def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
align = self._series._align_full_broadcast
if ignore_nulls:
series = (s.fill_null(True, None, None) for s in series)
return [reduce(operator.and_, align(*series))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
align = self._series._align_full_broadcast
if ignore_nulls:
series = (s.fill_null(False, None, None) for s in series)
return [reduce(operator.or_, align(*series))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="any_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
it = chain.from_iterable(expr(df) for expr in exprs)
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
align = self._series._align_full_broadcast
return [reduce(operator.add, align(*series))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="sum_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
int_64 = self._version.dtypes.Int64()
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
align = self._series._align_full_broadcast
series = align(
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
)
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="mean_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
init_series, *series = align(init_series, *series)
native_series = reduce(
pc.min_element_wise, [s.native for s in series], init_series.native
)
return [
ArrowSeries(native_series, name=init_series.name, version=self._version)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="min_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
init_series, *series = align(init_series, *series)
native_series = reduce(
pc.max_element_wise, [s.native for s in series], init_series.native
)
return [
ArrowSeries(native_series, name=init_series.name, version=self._version)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="max_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
if self._backend_version >= (14,):
return pa.concat_tables(dfs, promote_options="default")
return pa.concat_tables(dfs, promote=True) # pragma: no cover
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
names = list(chain.from_iterable(df.column_names for df in dfs))
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
return pa.Table.from_arrays(arrays, names=names)
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
cols_0 = dfs[0].column_names
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.column_names
if cols_current != cols_0:
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0}\n"
f" - dataframe {i}: {cols_current}\n"
)
raise TypeError(msg)
return pa.concat_tables(dfs)
@property
def selectors(self) -> ArrowSelectorNamespace:
return ArrowSelectorNamespace.from_namespace(self)
def when(self, predicate: ArrowExpr) -> ArrowWhen:
return ArrowWhen.from_expr(predicate, context=self)
def concat_str(
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
compliant_series_list = align(
*(chain.from_iterable(expr(df) for expr in exprs))
)
name = compliant_series_list[0].name
null_handling: Literal["skip", "emit_null"] = (
"skip" if ignore_nulls else "emit_null"
)
it, separator_scalar = cast_to_comparable_string_types(
*(s.native for s in compliant_series_list), separator=separator
)
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
# Reality: `str` is fine
concat_str: Incomplete = pc.binary_join_element_wise
compliant = self._series(
concat_str(*it, separator_scalar, null_handling=null_handling),
name=name,
version=self._version,
)
return [compliant]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="concat_str",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def coalesce(self, *exprs: ArrowExpr) -> ArrowExpr:
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
align = self._series._align_full_broadcast
init_series, *series = align(*chain.from_iterable(expr(df) for expr in exprs))
return [
ArrowSeries(
pc.coalesce(init_series.native, *(s.native for s in series)),
name=init_series.name,
version=self._version,
)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="coalesce",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]):
@property
def _then(self) -> type[ArrowThen]:
return ArrowThen
def _if_then_else(
self,
when: ChunkedArrayAny,
then: ChunkedArrayAny,
otherwise: ArrayOrScalar | NonNestedLiteral,
/,
) -> ChunkedArrayAny:
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
return pc.if_else(when, then, otherwise)
class ArrowThen(
CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr
):
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "whenthen"

View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._arrow.expr import ArrowExpr
from narwhals._compliant import CompliantSelector, EagerSelectorNamespace
if TYPE_CHECKING:
from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401
from narwhals._arrow.series import ArrowSeries # noqa: F401
from narwhals._compliant.typing import ScalarKwargs
class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]):
@property
def _selector(self) -> type[ArrowSelector]:
return ArrowSelector
class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc]
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "selector"
def _to_expr(self) -> ArrowExpr:
return ArrowExpr(
self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pyarrow as pa
from narwhals._arrow.utils import ArrowSeriesNamespace
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import Incomplete
class ArrowSeriesCatNamespace(ArrowSeriesNamespace):
def get_categories(self) -> ArrowSeries:
# NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes
chunks: Incomplete = self.native.chunks
return self.with_native(pa.concat_arrays(x.dictionary for x in chunks).unique())

View File

@ -0,0 +1,226 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import UNITS_DICT, ArrowSeriesNamespace, floordiv_compat, lit
from narwhals._constants import (
MS_PER_MINUTE,
MS_PER_SECOND,
NS_PER_MICROSECOND,
NS_PER_MILLISECOND,
NS_PER_MINUTE,
NS_PER_SECOND,
SECONDS_PER_DAY,
SECONDS_PER_MINUTE,
US_PER_MINUTE,
US_PER_SECOND,
)
from narwhals._duration import Interval
if TYPE_CHECKING:
from collections.abc import Mapping
from typing_extensions import TypeAlias
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny
from narwhals.dtypes import Datetime
from narwhals.typing import TimeUnit
UnitCurrent: TypeAlias = TimeUnit
UnitTarget: TypeAlias = TimeUnit
BinOpBroadcast: TypeAlias = Callable[[ChunkedArrayAny, ScalarAny], ChunkedArrayAny]
IntoRhs: TypeAlias = int
class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace):
_TIMESTAMP_DATE_FACTOR: ClassVar[Mapping[TimeUnit, int]] = {
"ns": NS_PER_SECOND,
"us": US_PER_SECOND,
"ms": MS_PER_SECOND,
"s": 1,
}
_TIMESTAMP_DATETIME_OP_FACTOR: ClassVar[
Mapping[tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]]
] = {
("ns", "us"): (floordiv_compat, 1_000),
("ns", "ms"): (floordiv_compat, 1_000_000),
("us", "ns"): (pc.multiply, NS_PER_MICROSECOND),
("us", "ms"): (floordiv_compat, 1_000),
("ms", "ns"): (pc.multiply, NS_PER_MILLISECOND),
("ms", "us"): (pc.multiply, 1_000),
("s", "ns"): (pc.multiply, NS_PER_SECOND),
("s", "us"): (pc.multiply, US_PER_SECOND),
("s", "ms"): (pc.multiply, MS_PER_SECOND),
}
@property
def unit(self) -> TimeUnit: # NOTE: Unsafe (native).
return cast("pa.TimestampType[TimeUnit, Any]", self.native.type).unit
@property
def time_zone(self) -> str | None: # NOTE: Unsafe (narwhals).
return cast("Datetime", self.compliant.dtype).time_zone
def to_string(self, format: str) -> ArrowSeries:
# PyArrow differs from other libraries in that %S also prints out
# the fractional part of the second...:'(
# https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
return self.with_native(pc.strftime(self.native, format))
def replace_time_zone(self, time_zone: str | None) -> ArrowSeries:
if time_zone is not None:
result = pc.assume_timezone(pc.local_timestamp(self.native), time_zone)
else:
result = pc.local_timestamp(self.native)
return self.with_native(result)
def convert_time_zone(self, time_zone: str) -> ArrowSeries:
ser = self.replace_time_zone("UTC") if self.time_zone is None else self.compliant
return self.with_native(ser.native.cast(pa.timestamp(self.unit, time_zone)))
def timestamp(self, time_unit: TimeUnit) -> ArrowSeries:
ser = self.compliant
dtypes = ser._version.dtypes
if isinstance(ser.dtype, dtypes.Datetime):
current = ser.dtype.time_unit
s_cast = self.native.cast(pa.int64())
if current == time_unit:
result = s_cast
elif item := self._TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
fn, factor = item
result = fn(s_cast, lit(factor))
else: # pragma: no cover
msg = f"unexpected time unit {current}, please report an issue at https://github.com/narwhals-dev/narwhals"
raise AssertionError(msg)
return self.with_native(result)
if isinstance(ser.dtype, dtypes.Date):
time_s = pc.multiply(self.native.cast(pa.int32()), lit(SECONDS_PER_DAY))
factor = self._TIMESTAMP_DATE_FACTOR[time_unit]
return self.with_native(pc.multiply(time_s, lit(factor)))
msg = "Input should be either of Date or Datetime type"
raise TypeError(msg)
def date(self) -> ArrowSeries:
return self.with_native(self.native.cast(pa.date32()))
def year(self) -> ArrowSeries:
return self.with_native(pc.year(self.native))
def month(self) -> ArrowSeries:
return self.with_native(pc.month(self.native))
def day(self) -> ArrowSeries:
return self.with_native(pc.day(self.native))
def hour(self) -> ArrowSeries:
return self.with_native(pc.hour(self.native))
def minute(self) -> ArrowSeries:
return self.with_native(pc.minute(self.native))
def second(self) -> ArrowSeries:
return self.with_native(pc.second(self.native))
def millisecond(self) -> ArrowSeries:
return self.with_native(pc.millisecond(self.native))
def microsecond(self) -> ArrowSeries:
arr = self.native
result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr))
return self.with_native(result)
def nanosecond(self) -> ArrowSeries:
result = pc.add(
pc.multiply(self.microsecond().native, lit(1000)), pc.nanosecond(self.native)
)
return self.with_native(result)
def ordinal_day(self) -> ArrowSeries:
return self.with_native(pc.day_of_year(self.native))
def weekday(self) -> ArrowSeries:
return self.with_native(pc.day_of_week(self.native, count_from_zero=False))
def total_minutes(self) -> ArrowSeries:
unit_to_minutes_factor = {
"s": SECONDS_PER_MINUTE,
"ms": MS_PER_MINUTE,
"us": US_PER_MINUTE,
"ns": NS_PER_MINUTE,
}
factor = lit(unit_to_minutes_factor[self.unit], type=pa.int64())
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_seconds(self) -> ArrowSeries:
unit_to_seconds_factor = {
"s": 1,
"ms": MS_PER_SECOND,
"us": US_PER_SECOND,
"ns": NS_PER_SECOND,
}
factor = lit(unit_to_seconds_factor[self.unit], type=pa.int64())
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_milliseconds(self) -> ArrowSeries:
unit_to_milli_factor = {
"s": 1e3, # seconds
"ms": 1, # milli
"us": 1e3, # micro
"ns": 1e6, # nano
}
factor = lit(unit_to_milli_factor[self.unit], type=pa.int64())
if self.unit == "s":
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_microseconds(self) -> ArrowSeries:
unit_to_micro_factor = {
"s": 1e6, # seconds
"ms": 1e3, # milli
"us": 1, # micro
"ns": 1e3, # nano
}
factor = lit(unit_to_micro_factor[self.unit], type=pa.int64())
if self.unit in {"s", "ms"}:
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
def total_nanoseconds(self) -> ArrowSeries:
unit_to_nano_factor = {
"s": NS_PER_SECOND,
"ms": NS_PER_MILLISECOND,
"us": NS_PER_MICROSECOND,
"ns": 1,
}
factor = lit(unit_to_nano_factor[self.unit], type=pa.int64())
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
def truncate(self, every: str) -> ArrowSeries:
interval = Interval.parse(every)
return self.with_native(
pc.floor_temporal(self.native, interval.multiple, UNITS_DICT[interval.unit])
)
def offset_by(self, by: str) -> ArrowSeries:
interval = Interval.parse_no_constraints(by)
native = self.native
if interval.unit in {"y", "q", "mo"}:
msg = f"Offsetting by {interval.unit} is not yet supported for pyarrow."
raise NotImplementedError(msg)
dtype = self.compliant.dtype
datetime_dtype = self.version.dtypes.Datetime
if interval.unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
offset: pa.DurationScalar[Any] = lit(interval.to_timedelta())
native_naive = pc.local_timestamp(native)
result = pc.assume_timezone(pc.add(native_naive, offset), dtype.time_zone)
return self.with_native(result)
if interval.unit == "ns": # pragma: no cover
offset = lit(interval.multiple, pa.duration("ns")) # type: ignore[assignment]
else:
offset = lit(interval.to_timedelta())
return self.with_native(pc.add(native, offset))

View File

@ -0,0 +1,24 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import ArrowSeriesNamespace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
class ArrowSeriesListNamespace(ArrowSeriesNamespace):
def len(self) -> ArrowSeries:
return self.with_native(pc.list_value_length(self.native).cast(pa.uint32()))
unique = not_implemented()
contains = not_implemented()
def get(self, index: int) -> ArrowSeries:
return self.with_native(pc.list_element(self.native, index))

View File

@ -0,0 +1,115 @@
from __future__ import annotations
import string
from typing import TYPE_CHECKING
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import Incomplete
class ArrowSeriesStringNamespace(ArrowSeriesNamespace):
def len_chars(self) -> ArrowSeries:
return self.with_native(pc.utf8_length(self.native))
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries:
fn = pc.replace_substring if literal else pc.replace_substring_regex
try:
arr = fn(self.native, pattern, replacement=value, max_replacements=n)
except TypeError as e:
if not isinstance(value, str):
msg = "PyArrow backed `.str.replace` only supports str replacement values"
raise TypeError(msg) from e
raise
return self.with_native(arr)
def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries:
try:
return self.replace(pattern, value, literal=literal, n=-1)
except TypeError as e:
if not isinstance(value, str):
msg = "PyArrow backed `.str.replace_all` only supports str replacement values."
raise TypeError(msg) from e
raise
def strip_chars(self, characters: str | None) -> ArrowSeries:
return self.with_native(
pc.utf8_trim(self.native, characters or string.whitespace)
)
def starts_with(self, prefix: str) -> ArrowSeries:
return self.with_native(pc.equal(self.slice(0, len(prefix)).native, lit(prefix)))
def ends_with(self, suffix: str) -> ArrowSeries:
return self.with_native(
pc.equal(self.slice(-len(suffix), None).native, lit(suffix))
)
def contains(self, pattern: str, *, literal: bool) -> ArrowSeries:
check_func = pc.match_substring if literal else pc.match_substring_regex
return self.with_native(check_func(self.native, pattern))
def slice(self, offset: int, length: int | None) -> ArrowSeries:
stop = offset + length if length is not None else None
return self.with_native(
pc.utf8_slice_codeunits(self.native, start=offset, stop=stop)
)
def split(self, by: str) -> ArrowSeries:
split_series = pc.split_pattern(self.native, by) # type: ignore[call-overload]
return self.with_native(split_series)
def to_datetime(self, format: str | None) -> ArrowSeries:
format = parse_datetime_format(self.native) if format is None else format
timestamp_array = pc.strptime(self.native, format=format, unit="us")
return self.with_native(timestamp_array)
def to_date(self, format: str | None) -> ArrowSeries:
return self.to_datetime(format=format).dt.date()
def to_uppercase(self) -> ArrowSeries:
return self.with_native(pc.utf8_upper(self.native))
def to_lowercase(self) -> ArrowSeries:
return self.with_native(pc.utf8_lower(self.native))
def zfill(self, width: int) -> ArrowSeries:
binary_join: Incomplete = pc.binary_join_element_wise
native = self.native
hyphen, plus = lit("-"), lit("+")
first_char, remaining_chars = (
self.slice(0, 1).native,
self.slice(1, None).native,
)
# Conditions
less_than_width = pc.less(pc.utf8_length(native), lit(width))
starts_with_hyphen = pc.equal(first_char, hyphen)
starts_with_plus = pc.equal(first_char, plus)
conditions = pc.make_struct(
pc.and_(starts_with_hyphen, less_than_width),
pc.and_(starts_with_plus, less_than_width),
less_than_width,
)
# Cases
padded_remaining_chars = pc.utf8_lpad(remaining_chars, width - 1, padding="0")
result = pc.case_when(
conditions,
binary_join(
pa.repeat(hyphen, len(native)), padded_remaining_chars, ""
), # starts with hyphen and less than width
binary_join(
pa.repeat(plus, len(native)), padded_remaining_chars, ""
), # starts with plus and less than width
pc.utf8_lpad(native, width=width, padding="0"), # less than width
native,
)
return self.with_native(result)

View File

@ -0,0 +1,15 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import pyarrow.compute as pc
from narwhals._arrow.utils import ArrowSeriesNamespace
if TYPE_CHECKING:
from narwhals._arrow.series import ArrowSeries
class ArrowSeriesStructNamespace(ArrowSeriesNamespace):
def field(self, name: str) -> ArrowSeries:
return self.with_native(pc.struct_field(self.native, name)).alias(name)

View File

@ -0,0 +1,72 @@
from __future__ import annotations # pragma: no cover
from typing import (
TYPE_CHECKING, # pragma: no cover
Any, # pragma: no cover
TypeVar, # pragma: no cover
)
if TYPE_CHECKING:
import sys
from typing import Generic, Literal
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
import pyarrow as pa
from pyarrow.__lib_pxi.table import (
AggregateOptions, # noqa: F401
Aggregation, # noqa: F401
)
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource]
Indices, # noqa: F401
Mask, # noqa: F401
Order, # noqa: F401
)
from narwhals._arrow.expr import ArrowExpr
from narwhals._arrow.series import ArrowSeries
IntoArrowExpr: TypeAlias = "ArrowExpr | ArrowSeries"
TieBreaker: TypeAlias = Literal["min", "max", "first", "dense"]
NullPlacement: TypeAlias = Literal["at_start", "at_end"]
NativeIntervalUnit: TypeAlias = Literal[
"year",
"quarter",
"month",
"week",
"day",
"hour",
"minute",
"second",
"millisecond",
"microsecond",
"nanosecond",
]
ChunkedArrayAny: TypeAlias = pa.ChunkedArray[Any]
ArrayAny: TypeAlias = pa.Array[Any]
ArrayOrChunkedArray: TypeAlias = "ArrayAny | ChunkedArrayAny"
ScalarAny: TypeAlias = pa.Scalar[Any]
ArrayOrScalar: TypeAlias = "ArrayOrChunkedArray | ScalarAny"
ArrayOrScalarT1 = TypeVar("ArrayOrScalarT1", ArrayAny, ChunkedArrayAny, ScalarAny)
ArrayOrScalarT2 = TypeVar("ArrayOrScalarT2", ArrayAny, ChunkedArrayAny, ScalarAny)
_AsPyType = TypeVar("_AsPyType")
class _BasicDataType(pa.DataType, Generic[_AsPyType]): ...
Incomplete: TypeAlias = Any # pragma: no cover
"""
Marker for working code that fails on the stubs.
Common issues:
- Annotated for `Array`, but not `ChunkedArray`
- Relies on typing information that the stubs don't provide statically
- Missing attributes
- Incorrect return types
- Inconsistent use of generic/concrete types
- `_clone_signature` used on signatures that are not identical
"""

View File

@ -0,0 +1,438 @@
from __future__ import annotations
from functools import lru_cache
from typing import TYPE_CHECKING, Any, cast
import pyarrow as pa
import pyarrow.compute as pc
from narwhals._compliant import EagerSeriesNamespace
from narwhals._utils import Version, isinstance_or_issubclass
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from typing_extensions import TypeAlias, TypeIs
from narwhals._arrow.series import ArrowSeries
from narwhals._arrow.typing import (
ArrayAny,
ArrayOrScalar,
ArrayOrScalarT1,
ArrayOrScalarT2,
ChunkedArrayAny,
NativeIntervalUnit,
ScalarAny,
)
from narwhals._duration import IntervalUnit
from narwhals.dtypes import DType
from narwhals.typing import IntoDType, PythonLiteral
# NOTE: stubs don't allow for `ChunkedArray[StructArray]`
# Intended to represent the `.chunks` property storing `list[pa.StructArray]`
ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny
def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ...
def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ...
def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ...
def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ...
def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ...
def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ...
def extract_regex(
strings: ChunkedArrayAny,
/,
pattern: str,
*,
options: Any = None,
memory_pool: Any = None,
) -> ChunkedArrayStructArray: ...
else:
from pyarrow.compute import extract_regex
from pyarrow.types import (
is_dictionary, # noqa: F401
is_duration,
is_fixed_size_list,
is_large_list,
is_list,
is_timestamp,
)
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
"y": "year",
"q": "quarter",
"mo": "month",
"d": "day",
"h": "hour",
"m": "minute",
"s": "second",
"ms": "millisecond",
"us": "microsecond",
"ns": "nanosecond",
}
lit = pa.scalar
"""Alias for `pyarrow.scalar`."""
def extract_py_scalar(value: Any, /) -> Any:
from narwhals._arrow.series import maybe_extract_py_scalar
return maybe_extract_py_scalar(value, return_py_scalar=True)
def is_array_or_scalar(obj: Any) -> TypeIs[ArrayOrScalar]:
"""Return True for any base `pyarrow` container."""
return isinstance(obj, (pa.ChunkedArray, pa.Array, pa.Scalar))
def chunked_array(
arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, /
) -> ChunkedArrayAny:
if isinstance(arr, pa.ChunkedArray):
return arr
if isinstance(arr, list):
return pa.chunked_array(arr, dtype)
return pa.chunked_array([arr], dtype)
def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
"""Create a strongly-typed Array instance with all elements null.
Uses the type of `series`, without upseting `mypy`.
"""
return pa.nulls(n, series.native.type)
def zeros(n: int, /) -> pa.Int64Array:
return pa.repeat(0, n)
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912
dtypes = version.dtypes
if pa.types.is_int64(dtype):
return dtypes.Int64()
if pa.types.is_int32(dtype):
return dtypes.Int32()
if pa.types.is_int16(dtype):
return dtypes.Int16()
if pa.types.is_int8(dtype):
return dtypes.Int8()
if pa.types.is_uint64(dtype):
return dtypes.UInt64()
if pa.types.is_uint32(dtype):
return dtypes.UInt32()
if pa.types.is_uint16(dtype):
return dtypes.UInt16()
if pa.types.is_uint8(dtype):
return dtypes.UInt8()
if pa.types.is_boolean(dtype):
return dtypes.Boolean()
if pa.types.is_float64(dtype):
return dtypes.Float64()
if pa.types.is_float32(dtype):
return dtypes.Float32()
# bug in coverage? it shows `31->exit` (where `31` is currently the line number of
# the next line), even though both when the if condition is true and false are covered
if ( # pragma: no cover
pa.types.is_string(dtype)
or pa.types.is_large_string(dtype)
or getattr(pa.types, "is_string_view", lambda _: False)(dtype)
):
return dtypes.String()
if pa.types.is_date32(dtype):
return dtypes.Date()
if is_timestamp(dtype):
return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz)
if is_duration(dtype):
return dtypes.Duration(time_unit=dtype.unit)
if pa.types.is_dictionary(dtype):
return dtypes.Categorical()
if pa.types.is_struct(dtype):
return dtypes.Struct(
[
dtypes.Field(
dtype.field(i).name,
native_to_narwhals_dtype(dtype.field(i).type, version),
)
for i in range(dtype.num_fields)
]
)
if is_list(dtype) or is_large_list(dtype):
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version))
if is_fixed_size_list(dtype):
return dtypes.Array(
native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size
)
if pa.types.is_decimal(dtype):
return dtypes.Decimal()
if pa.types.is_time32(dtype) or pa.types.is_time64(dtype):
return dtypes.Time()
if pa.types.is_binary(dtype):
return dtypes.Binary()
return dtypes.Unknown() # pragma: no cover
dtypes = Version.MAIN.dtypes
NW_TO_PA_DTYPES: Mapping[type[DType], pa.DataType] = {
dtypes.Float64: pa.float64(),
dtypes.Float32: pa.float32(),
dtypes.Binary: pa.binary(),
dtypes.String: pa.string(),
dtypes.Boolean: pa.bool_(),
dtypes.Categorical: pa.dictionary(pa.uint32(), pa.string()),
dtypes.Date: pa.date32(),
dtypes.Time: pa.time64("ns"),
dtypes.Int8: pa.int8(),
dtypes.Int16: pa.int16(),
dtypes.Int32: pa.int32(),
dtypes.Int64: pa.int64(),
dtypes.UInt8: pa.uint8(),
dtypes.UInt16: pa.uint16(),
dtypes.UInt32: pa.uint32(),
dtypes.UInt64: pa.uint64(),
}
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Object)
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType:
dtypes = version.dtypes
base_type = dtype.base_type()
if pa_type := NW_TO_PA_DTYPES.get(base_type):
return pa_type
if isinstance_or_issubclass(dtype, dtypes.Datetime):
unit = dtype.time_unit
return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pa.duration(dtype.time_unit)
if isinstance_or_issubclass(dtype, dtypes.List):
return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
return pa.struct(
[
(field.name, narwhals_to_native_dtype(field.dtype, version=version))
for field in dtype.fields
]
)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
inner = narwhals_to_native_dtype(dtype.inner, version=version)
list_size = dtype.size
return pa.list_(inner, list_size=list_size)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for PyArrow."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
def extract_native(
lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny
) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]:
"""Extract native objects in binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
"right-hand-side" operation (e.g. `__radd__`) can be tried.
If one of the two sides has a `_broadcast` flag, then extract the scalar
underneath it so that PyArrow can do its own broadcasting.
"""
from narwhals._arrow.series import ArrowSeries
if rhs is None: # pragma: no cover
return lhs.native, lit(None, type=lhs._type)
if isinstance(rhs, ArrowSeries):
if lhs._broadcast and not rhs._broadcast:
return lhs.native[0], rhs.native
if rhs._broadcast:
return lhs.native, rhs.native[0]
return lhs.native, rhs.native
if isinstance(rhs, list):
msg = "Expected Series or scalar, got list."
raise TypeError(msg)
return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs)
def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any:
# The following lines are adapted from pandas' pyarrow implementation.
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
divided = pc.divide_checked(left, right)
# TODO @dangotbanned: Use a `TypeVar` in guards
# Narrowing to a `Union` isn't interacting well with the rest of the stubs
# https://github.com/zen-xu/pyarrow-stubs/pull/215
if pa.types.is_signed_integer(divided.type):
div_type = cast("pa._lib.Int64Type", divided.type)
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
has_one_negative_operand = pc.less(
pc.bit_wise_xor(left, right), lit(0, div_type)
)
result = pc.if_else(
pc.and_(has_remainder, has_one_negative_operand),
pc.subtract(divided, lit(1, div_type)),
divided,
)
else:
result = divided # pragma: no cover
result = result.cast(left.type)
else:
divided = pc.divide(left, right)
result = pc.floor(divided)
return result
def cast_for_truediv(
arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2
) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]:
# Lifted from:
# https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
# Ensure int / int -> float mirroring Python/Numpy behavior
# as pc.divide_checked(int, int) -> int
if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type):
# GH: 56645. # noqa: ERA001
# https://github.com/apache/arrow/issues/35563
return arrow_array.cast(pa.float64(), safe=False), pa_object.cast(
pa.float64(), safe=False
)
return arrow_array, pa_object
# Regex for date, time, separator and timezone components
DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})"
SEP_RE = r"(?P<sep>\s|T)"
TIME_RE = r"(?P<time>\d{2}:\d{2}(?::\d{2})?|\d{6}?)" # \s*(?P<period>[AP]M)?)?
HMS_RE = r"^(?P<hms>\d{2}:\d{2}:\d{2})$"
HM_RE = r"^(?P<hm>\d{2}:\d{2})$"
HMS_RE_NO_SEP = r"^(?P<hms_no_sep>\d{6})$"
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"
# Separate regexes for different date formats
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
YMD_RE_NO_SEP = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<month>0[1-9]|1[0-2])(?P<day>0[1-9]|[12][0-9]|3[01])$"
DATE_FORMATS = (
(YMD_RE_NO_SEP, "%Y%m%d"),
(YMD_RE, "%Y-%m-%d"),
(DMY_RE, "%d-%m-%Y"),
(MDY_RE, "%m-%d-%Y"),
)
TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S"))
def _extract_regex_concat_arrays(
strings: ChunkedArrayAny,
/,
pattern: str,
*,
options: Any = None,
memory_pool: Any = None,
) -> pa.StructArray:
r = pa.concat_arrays(
extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks
)
return cast("pa.StructArray", r)
def parse_datetime_format(arr: ChunkedArrayAny) -> str:
"""Try to infer datetime format from StringArray."""
matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE)
if not pc.all(matches.is_valid()).as_py():
msg = (
"Unable to infer datetime format, provided format is not supported. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise NotImplementedError(msg)
separators = matches.field("sep")
tz = matches.field("tz")
# separators and time zones must be unique
if pc.count(pc.unique(separators)).as_py() > 1:
msg = "Found multiple separator values while inferring datetime format."
raise ValueError(msg)
if pc.count(pc.unique(tz)).as_py() > 1:
msg = "Found multiple timezone values while inferring datetime format."
raise ValueError(msg)
date_value = _parse_date_format(cast("pc.StringArray", matches.field("date")))
time_value = _parse_time_format(cast("pc.StringArray", matches.field("time")))
sep_value = separators[0].as_py()
tz_value = "%z" if tz[0].as_py() else ""
return f"{date_value}{sep_value}{time_value}{tz_value}"
def _parse_date_format(arr: pc.StringArray) -> str:
for date_rgx, date_fmt in DATE_FORMATS:
matches = pc.extract_regex(arr, pattern=date_rgx)
if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py():
return date_fmt
if (
pc.all(matches.is_valid()).as_py()
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
):
return date_fmt.replace("-", date_sep_value)
msg = (
"Unable to infer datetime format. "
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
)
raise ValueError(msg)
def _parse_time_format(arr: pc.StringArray) -> str:
for time_rgx, time_fmt in TIME_FORMATS:
matches = pc.extract_regex(arr, pattern=time_rgx)
if pc.all(matches.is_valid()).as_py():
return time_fmt
return ""
def pad_series(
series: ArrowSeries, *, window_size: int, center: bool
) -> tuple[ArrowSeries, int]:
"""Pad series with None values on the left and/or right side, depending on the specified parameters.
Arguments:
series: The input ArrowSeries to be padded.
window_size: The desired size of the window.
center: Specifies whether to center the padding or not.
Returns:
A tuple containing the padded ArrowSeries and the offset value.
"""
if not center:
return series, 0
offset_left = window_size // 2
# subtract one if window_size is even
offset_right = offset_left - (window_size % 2 == 0)
pad_left = pa.array([None] * offset_left, type=series._type)
pad_right = pa.array([None] * offset_right, type=series._type)
concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right])
return series._with_native(concat), offset_left + offset_right
def cast_to_comparable_string_types(
*chunked_arrays: ChunkedArrayAny, separator: str
) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
dtype = (
pa.string() # (PyArrow default)
if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays)
else pa.large_string()
)
return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ...

View File

@ -0,0 +1,103 @@
from __future__ import annotations
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
LazyExpr,
LazyExprNamespace,
)
from narwhals._compliant.group_by import (
CompliantGroupBy,
DepthTrackingGroupBy,
EagerGroupBy,
)
from narwhals._compliant.namespace import (
CompliantNamespace,
DepthTrackingNamespace,
EagerNamespace,
LazyNamespace,
)
from narwhals._compliant.selectors import (
CompliantSelector,
CompliantSelectorNamespace,
EagerSelectorNamespace,
LazySelectorNamespace,
)
from narwhals._compliant.series import (
CompliantSeries,
EagerSeries,
EagerSeriesCatNamespace,
EagerSeriesDateTimeNamespace,
EagerSeriesHist,
EagerSeriesListNamespace,
EagerSeriesNamespace,
EagerSeriesStringNamespace,
EagerSeriesStructNamespace,
)
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantSeriesOrNativeExprT_co,
CompliantSeriesT,
EagerDataFrameT,
EagerSeriesT,
EvalNames,
EvalSeries,
NativeFrameT_co,
NativeSeriesT_co,
)
from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen
from narwhals._compliant.window import WindowInputs
__all__ = [
"CompliantDataFrame",
"CompliantExpr",
"CompliantExprT",
"CompliantFrame",
"CompliantFrameT",
"CompliantGroupBy",
"CompliantLazyFrame",
"CompliantNamespace",
"CompliantSelector",
"CompliantSelectorNamespace",
"CompliantSeries",
"CompliantSeriesOrNativeExprT_co",
"CompliantSeriesT",
"CompliantThen",
"CompliantWhen",
"DepthTrackingExpr",
"DepthTrackingGroupBy",
"DepthTrackingNamespace",
"EagerDataFrame",
"EagerDataFrameT",
"EagerExpr",
"EagerGroupBy",
"EagerNamespace",
"EagerSelectorNamespace",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesHist",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
"EagerSeriesT",
"EagerWhen",
"EvalNames",
"EvalSeries",
"LazyExpr",
"LazyExprNamespace",
"LazyNamespace",
"LazySelectorNamespace",
"NativeFrameT_co",
"NativeSeriesT_co",
"WindowInputs",
]

View File

@ -0,0 +1,94 @@
"""`Expr` and `Series` namespace accessor protocols."""
from __future__ import annotations
from typing import TYPE_CHECKING, Protocol
from narwhals._utils import CompliantT_co, _StoresCompliant
if TYPE_CHECKING:
from typing import Callable
from narwhals.typing import NonNestedLiteral, TimeUnit
__all__ = [
"CatNamespace",
"DateTimeNamespace",
"ListNamespace",
"NameNamespace",
"StringNamespace",
"StructNamespace",
]
class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def get_categories(self) -> CompliantT_co: ...
class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def to_string(self, format: str) -> CompliantT_co: ...
def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ...
def convert_time_zone(self, time_zone: str) -> CompliantT_co: ...
def timestamp(self, time_unit: TimeUnit) -> CompliantT_co: ...
def date(self) -> CompliantT_co: ...
def year(self) -> CompliantT_co: ...
def month(self) -> CompliantT_co: ...
def day(self) -> CompliantT_co: ...
def hour(self) -> CompliantT_co: ...
def minute(self) -> CompliantT_co: ...
def second(self) -> CompliantT_co: ...
def millisecond(self) -> CompliantT_co: ...
def microsecond(self) -> CompliantT_co: ...
def nanosecond(self) -> CompliantT_co: ...
def ordinal_day(self) -> CompliantT_co: ...
def weekday(self) -> CompliantT_co: ...
def total_minutes(self) -> CompliantT_co: ...
def total_seconds(self) -> CompliantT_co: ...
def total_milliseconds(self) -> CompliantT_co: ...
def total_microseconds(self) -> CompliantT_co: ...
def total_nanoseconds(self) -> CompliantT_co: ...
def truncate(self, every: str) -> CompliantT_co: ...
def offset_by(self, by: str) -> CompliantT_co: ...
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def get(self, index: int) -> CompliantT_co: ...
def len(self) -> CompliantT_co: ...
def unique(self) -> CompliantT_co: ...
def contains(self, item: NonNestedLiteral) -> CompliantT_co: ...
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def keep(self) -> CompliantT_co: ...
def map(self, function: Callable[[str], str]) -> CompliantT_co: ...
def prefix(self, prefix: str) -> CompliantT_co: ...
def suffix(self, suffix: str) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def len_chars(self) -> CompliantT_co: ...
def replace(
self, pattern: str, value: str, *, literal: bool, n: int
) -> CompliantT_co: ...
def replace_all(
self, pattern: str, value: str, *, literal: bool
) -> CompliantT_co: ...
def strip_chars(self, characters: str | None) -> CompliantT_co: ...
def starts_with(self, prefix: str) -> CompliantT_co: ...
def ends_with(self, suffix: str) -> CompliantT_co: ...
def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ...
def slice(self, offset: int, length: int | None) -> CompliantT_co: ...
def split(self, by: str) -> CompliantT_co: ...
def to_datetime(self, format: str | None) -> CompliantT_co: ...
def to_date(self, format: str | None) -> CompliantT_co: ...
def to_lowercase(self) -> CompliantT_co: ...
def to_uppercase(self) -> CompliantT_co: ...
def zfill(self, width: int) -> CompliantT_co: ...
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
def field(self, name: str) -> CompliantT_co: ...

View File

@ -0,0 +1,213 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing_extensions import Self
from narwhals._compliant.any_namespace import (
CatNamespace,
DateTimeNamespace,
ListNamespace,
StringNamespace,
StructNamespace,
)
from narwhals._compliant.namespace import CompliantNamespace
from narwhals._utils import Version
from narwhals.typing import (
ClosedInterval,
FillNullStrategy,
IntoDType,
ModeKeepStrategy,
NonNestedLiteral,
NumericLiteral,
RankMethod,
TemporalLiteral,
)
__all__ = ["CompliantColumn"]
class CompliantColumn(Protocol):
"""Common parts of `Expr`, `Series`."""
_version: Version
def __add__(self, other: Any) -> Self: ...
def __and__(self, other: Any) -> Self: ...
def __eq__(self, other: object) -> Self: ... # type: ignore[override]
def __floordiv__(self, other: Any) -> Self: ...
def __ge__(self, other: Any) -> Self: ...
def __gt__(self, other: Any) -> Self: ...
def __invert__(self) -> Self: ...
def __le__(self, other: Any) -> Self: ...
def __lt__(self, other: Any) -> Self: ...
def __mod__(self, other: Any) -> Self: ...
def __mul__(self, other: Any) -> Self: ...
def __ne__(self, other: object) -> Self: ... # type: ignore[override]
def __or__(self, other: Any) -> Self: ...
def __pow__(self, other: Any) -> Self: ...
def __rfloordiv__(self, other: Any) -> Self: ...
def __rmod__(self, other: Any) -> Self: ...
def __rpow__(self, other: Any) -> Self: ...
def __rsub__(self, other: Any) -> Self: ...
def __rtruediv__(self, other: Any) -> Self: ...
def __sub__(self, other: Any) -> Self: ...
def __truediv__(self, other: Any) -> Self: ...
def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ...
def abs(self) -> Self: ...
def alias(self, name: str) -> Self: ...
def cast(self, dtype: IntoDType) -> Self: ...
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self: ...
def cum_count(self, *, reverse: bool) -> Self: ...
def cum_max(self, *, reverse: bool) -> Self: ...
def cum_min(self, *, reverse: bool) -> Self: ...
def cum_prod(self, *, reverse: bool) -> Self: ...
def cum_sum(self, *, reverse: bool) -> Self: ...
def diff(self) -> Self: ...
def drop_nulls(self) -> Self: ...
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self: ...
def exp(self) -> Self: ...
def sqrt(self) -> Self: ...
def fill_nan(self, value: float | None) -> Self: ...
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self: ...
def is_between(
self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval
) -> Self:
if closed == "left":
return (self >= lower_bound) & (self < upper_bound)
if closed == "right":
return (self > lower_bound) & (self <= upper_bound)
if closed == "none":
return (self > lower_bound) & (self < upper_bound)
return (self >= lower_bound) & (self <= upper_bound)
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
from decimal import Decimal
other_abs: Self | NumericLiteral
other_is_nan: Self | bool
other_is_inf: Self | bool
other_is_not_inf: Self | bool
if isinstance(other, (float, int, Decimal)):
from math import isinf, isnan
# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
other_abs = other.__abs__()
other_is_nan = isnan(other)
other_is_inf = isinf(other)
# Define the other_is_not_inf variable to prevent triggering the following warning:
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
# > removed in Python 3.16.
other_is_not_inf = not other_is_inf
else:
other_abs, other_is_nan = other.abs(), other.is_nan()
other_is_not_inf = other.is_finite() | other_is_nan
other_is_inf = ~other_is_not_inf
rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)
self_is_nan = self.is_nan()
self_is_not_inf = self.is_finite() | self_is_nan
# Values are close if abs_diff <= tolerance, and both finite
is_close = (
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
)
# Handle infinity cases: infinities are close/equal if they have the same sign
self_sign, other_sign = self > 0, other > 0
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)
# Handle nan cases:
# * If any value is NaN, then False (via `& ~either_nan`)
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
either_nan = self_is_nan | other_is_nan
result = (is_close | is_same_inf) & ~either_nan
if nans_equal:
both_nan = self_is_nan & other_is_nan
result = result | both_nan
return result
def is_duplicated(self) -> Self:
return ~self.is_unique()
def is_finite(self) -> Self: ...
def is_first_distinct(self) -> Self: ...
def is_in(self, other: Any) -> Self: ...
def is_last_distinct(self) -> Self: ...
def is_nan(self) -> Self: ...
def is_null(self) -> Self: ...
def is_unique(self) -> Self: ...
def log(self, base: float) -> Self: ...
def mode(self, *, keep: ModeKeepStrategy) -> Self: ...
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self: ...
def rolling_mean(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def rolling_sum(
self, window_size: int, *, min_samples: int, center: bool
) -> Self: ...
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self: ...
def round(self, decimals: int) -> Self: ...
def shift(self, n: int) -> Self: ...
def unique(self) -> Self: ...
@property
def str(self) -> StringNamespace[Self]: ...
@property
def dt(self) -> DateTimeNamespace[Self]: ...
@property
def cat(self) -> CatNamespace[Self]: ...
@property
def list(self) -> ListNamespace[Self]: ...
@property
def struct(self) -> StructNamespace[Self]: ...

View File

@ -0,0 +1,426 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping, Sequence, Sized
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprT_contra,
CompliantLazyFrameAny,
CompliantSeriesT,
EagerExprT,
EagerSeriesT,
NativeDataFrameT,
NativeLazyFrameT,
NativeSeriesT,
)
from narwhals._translate import (
ArrowConvertible,
DictConvertible,
FromNative,
NumpyConvertible,
ToNarwhals,
ToNarwhalsT_co,
)
from narwhals._typing_compat import assert_never
from narwhals._utils import (
ValidateBackendVersion,
Version,
_StoresNative,
check_columns_exist,
is_compliant_series,
is_index_selector,
is_range,
is_sequence_like,
is_sized_multi_index_selector,
is_slice_index,
is_slice_none,
)
if TYPE_CHECKING:
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import Self, TypeAlias
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
from narwhals._compliant.namespace import EagerNamespace
from narwhals._spark_like.utils import SparkSession
from narwhals._translate import IntoArrowTable
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Implementation, _LimitedContext
from narwhals.dataframe import DataFrame
from narwhals.dtypes import DType
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import (
AsofJoinStrategy,
IntoSchema,
JoinStrategy,
LazyUniqueKeepStrategy,
MultiColSelector,
MultiIndexSelector,
PivotAgg,
SingleIndexSelector,
SizedMultiIndexSelector,
SizedMultiNameSelector,
SizeUnit,
UniqueKeepStrategy,
_2DArray,
_SliceIndex,
_SliceName,
)
Incomplete: TypeAlias = Any
__all__ = ["CompliantDataFrame", "CompliantFrame", "CompliantLazyFrame", "EagerDataFrame"]
T = TypeVar("T")
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
_NativeFrameT = TypeVar("_NativeFrameT")
class CompliantFrame(
_StoresNative[_NativeFrameT],
FromNative[_NativeFrameT],
ToNarwhals[ToNarwhalsT_co],
Protocol[CompliantExprT_contra, _NativeFrameT, ToNarwhalsT_co],
):
"""Common parts of `DataFrame`, `LazyFrame`."""
_native_frame: _NativeFrameT
_implementation: Implementation
_version: Version
def __native_namespace__(self) -> ModuleType: ...
def __narwhals_namespace__(self) -> Any: ...
def _with_version(self, version: Version) -> Self: ...
@classmethod
def from_native(cls, data: _NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
@property
def columns(self) -> Sequence[str]: ...
@property
def native(self) -> _NativeFrameT:
return self._native_frame
@property
def schema(self) -> Mapping[str, DType]: ...
def collect_schema(self) -> Mapping[str, DType]: ...
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
def explode(self, columns: Sequence[str]) -> Self: ...
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
def head(self, n: int) -> Self: ...
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self: ...
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self: ...
def rename(self, mapping: Mapping[str, str]) -> Self: ...
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
def simple_select(self, *column_names: str) -> Self:
"""`select` where all args are column names."""
...
def sort(
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
) -> Self: ...
def tail(self, n: int) -> Self: ...
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self: ...
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self: ...
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
class CompliantDataFrame(
NumpyConvertible["_2DArray", "_2DArray"],
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
ArrowConvertible["pa.Table", "IntoArrowTable"],
Sized,
CompliantFrame[CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
):
def __narwhals_dataframe__(self) -> Self: ...
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: IntoSchema | None,
) -> Self: ...
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext,
schema: IntoSchema | Sequence[str] | None,
) -> Self: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
def __getitem__(
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
MultiColSelector[CompliantSeriesT],
],
) -> Self: ...
@property
def shape(self) -> tuple[int, int]: ...
def clone(self) -> Self: ...
def estimated_size(self, unit: SizeUnit) -> int | float: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def get_column(self, name: str) -> CompliantSeriesT: ...
def group_by(
self,
keys: Sequence[str] | Sequence[CompliantExprT_contra],
*,
drop_null_keys: bool,
) -> DataFrameGroupBy[Self, Any]: ...
def item(self, row: int | None, column: int | str | None) -> Any: ...
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
def iter_rows(
self, *, named: bool, buffer_size: int
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
def is_unique(self) -> CompliantSeriesT: ...
def lazy(
self, backend: _LazyAllowedImpl | None, *, session: SparkSession | None
) -> CompliantLazyFrameAny: ...
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self: ...
def row(self, index: int) -> tuple[Any, ...]: ...
def rows(
self, *, named: bool
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def to_arrow(self) -> pa.Table: ...
def to_pandas(self) -> pd.DataFrame: ...
def to_polars(self) -> pl.DataFrame: ...
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
def unique(
self,
subset: Sequence[str] | None,
*,
keep: UniqueKeepStrategy,
maintain_order: bool | None = None,
) -> Self: ...
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
@overload
def write_csv(self, file: None) -> str: ...
@overload
def write_csv(self, file: str | Path | BytesIO) -> None: ...
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
class CompliantLazyFrame(
CompliantFrame[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
Protocol[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
):
def __narwhals_lazyframe__(self) -> Self: ...
# `LazySelectorNamespace._iter_columns` depends
def _iter_columns(self) -> Iterator[Any]: ...
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
"""`select` where all args are aggregations or literals.
(so, no broadcasting is necessary).
"""
...
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny: ...
def sink_parquet(self, file: str | Path | BytesIO) -> None: ...
class EagerDataFrame(
CompliantDataFrame[
EagerSeriesT, EagerExprT, NativeDataFrameT, "DataFrame[NativeDataFrameT]"
],
CompliantLazyFrame[EagerExprT, "Incomplete", "DataFrame[NativeDataFrameT]"],
ValidateBackendVersion,
Protocol[EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __narwhals_namespace__(
self,
) -> EagerNamespace[
Self, EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT
]: ...
def to_narwhals(self) -> DataFrame[NativeDataFrameT]:
return self._version.dataframe(self, level="full")
def aggregate(self, *exprs: EagerExprT) -> Self:
# NOTE: Ignore intermittent [False Negative]
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "exprs" of type "EagerExprT@EagerDataFrame" in function "select"
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
return self.select(*exprs) # pyright: ignore[reportArgumentType]
def _with_native(
self, df: NativeDataFrameT, *, validate_column_names: bool = True
) -> Self: ...
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
"""Evaluate `expr` and ensure it has a **single** output."""
result: Sequence[EagerSeriesT] = expr(self)
assert len(result) == 1 # debug assertion # noqa: S101
return result[0]
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
# NOTE: Ignore intermittent [False Negative]
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr"
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
"""Return list of raw columns.
For eager backends we alias operations at each step.
As a safety precaution, here we can check that the expected result names match those
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
"""
aliases = expr._evaluate_aliases(self)
result = expr(self)
if list(aliases) != (
result_aliases := [s.name for s in result]
): # pragma: no cover
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
raise AssertionError(msg)
return result
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
"""Extract native Series, broadcasting to `len(self)` if necessary."""
...
@staticmethod
def _numpy_column_names(
data: _2DArray, columns: Sequence[str] | None, /
) -> list[str]:
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def _select_multi_index(
self, columns: SizedMultiIndexSelector[NativeSeriesT]
) -> Self: ...
def _select_multi_name(
self, columns: SizedMultiNameSelector[NativeSeriesT]
) -> Self: ...
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
def _select_slice_name(self, columns: _SliceName) -> Self: ...
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
MultiColSelector[EagerSeriesT],
],
) -> Self:
rows, columns = item
compliant = self
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return compliant.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
compliant = compliant._select_slice_index(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_index(columns.native)
else:
compliant = compliant._select_multi_index(columns)
elif isinstance(columns, slice):
compliant = compliant._select_slice_name(columns)
elif is_compliant_series(columns):
compliant = self._select_multi_name(columns.native)
elif is_sequence_like(columns):
compliant = self._select_multi_name(columns)
else:
assert_never(columns)
if not is_slice_none(rows):
if isinstance(rows, int):
compliant = compliant._gather([rows])
elif isinstance(rows, (slice, range)):
compliant = compliant._gather_slice(rows)
elif is_compliant_series(rows):
compliant = compliant._gather(rows.native)
elif is_sized_multi_index_selector(rows):
compliant = compliant._gather(rows)
else:
assert_never(rows)
return compliant
def sink_parquet(self, file: str | Path | BytesIO) -> None:
return self.write_parquet(file)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,180 @@
from __future__ import annotations
import re
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar
from narwhals._compliant.typing import (
CompliantDataFrameT,
CompliantDataFrameT_co,
CompliantExprT_contra,
CompliantFrameT,
CompliantFrameT_co,
DepthTrackingExprAny,
DepthTrackingExprT_contra,
EagerExprT_contra,
ImplExprT_contra,
NarwhalsAggregation,
)
from narwhals._utils import is_sequence_of, zip_strict
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from narwhals._compliant.expr import ImplExpr
__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"]
NativeAggregationT_co = TypeVar(
"NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True
)
_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
def _evaluate_aliases(
frame: CompliantFrameT, exprs: Iterable[ImplExpr[CompliantFrameT, Any]], /
) -> list[str]:
it = (expr._evaluate_aliases(frame) for expr in exprs)
return list(chain.from_iterable(it))
class CompliantGroupBy(Protocol[CompliantFrameT_co, CompliantExprT_contra]):
_compliant_frame: Any
@property
def compliant(self) -> CompliantFrameT_co:
return self._compliant_frame # type: ignore[no-any-return]
def __init__(
self,
compliant_frame: CompliantFrameT_co,
keys: Sequence[CompliantExprT_contra] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None: ...
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
class DataFrameGroupBy(
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
Protocol[CompliantDataFrameT_co, CompliantExprT_contra],
):
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
class ParseKeysGroupBy(
CompliantGroupBy[CompliantFrameT, ImplExprT_contra],
Protocol[CompliantFrameT, ImplExprT_contra],
):
def _parse_keys(
self,
compliant_frame: CompliantFrameT,
keys: Sequence[ImplExprT_contra] | Sequence[str],
) -> tuple[CompliantFrameT, list[str], list[str]]:
if is_sequence_of(keys, str):
keys_str = list(keys)
return compliant_frame, keys_str, keys_str.copy()
return self._parse_expr_keys(compliant_frame, keys=keys)
@staticmethod
def _parse_expr_keys(
compliant_frame: CompliantFrameT, keys: Sequence[ImplExprT_contra]
) -> tuple[CompliantFrameT, list[str], list[str]]:
"""Parses key expressions to set up `.agg` operation with correct information.
Since keys are expressions, it's possible to alias any such key to match
other dataframe column names.
In order to match polars behavior and not overwrite columns when evaluating keys:
- We evaluate what the output key names should be, in order to remap temporary column
names to the expected ones, and to exclude those from unnamed expressions in
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
- Create temporary names for evaluated key expressions that are guaranteed to have
no overlap with any existing column name.
- Add these temporary columns to the compliant dataframe.
"""
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
def _temporary_name(key: str) -> str:
# 5 is the length of `__tmp`
key_str = str(key) # pandas allows non-string column names :sob:
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
keys_aliases = [expr._evaluate_aliases(compliant_frame) for expr in keys]
safe_keys = [
# multi-output expression cannot have duplicate names, hence it's safe to suffix
key.name.map(_temporary_name)
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
# otherwise it's single named and we can use Expr.alias
else key.alias(_temporary_name(new_names[0]))
for key, new_names in zip_strict(keys, keys_aliases)
]
return (
compliant_frame.with_columns(*safe_keys),
_evaluate_aliases(compliant_frame, safe_keys),
list(chain.from_iterable(keys_aliases)),
)
class DepthTrackingGroupBy(
ParseKeysGroupBy[CompliantFrameT, DepthTrackingExprT_contra],
Protocol[CompliantFrameT, DepthTrackingExprT_contra, NativeAggregationT_co],
):
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]]
"""Mapping from `narwhals` to native representation.
Note:
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
"""
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
for expr in exprs:
if not self._is_simple(expr):
name = self.compliant._implementation.name.lower()
msg = (
f"Non-trivial complex aggregation found.\n\n"
f"Hint: you were probably trying to apply a non-elementary aggregation with a"
f"{name!r} table.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
raise ValueError(msg)
@classmethod
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS
@classmethod
def _remap_expr_name(
cls, name: NarwhalsAggregation | Any, /
) -> NativeAggregationT_co:
"""Replace `name`, with some native representation.
Arguments:
name: Name of a `nw.Expr` aggregation method.
"""
return cls._REMAP_AGGS.get(name, name)
@classmethod
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
"""Return the last function name in the chain defined by `expr`."""
return _RE_LEAF_NAME.sub("", expr._function_name)
class EagerGroupBy(
DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
DataFrameGroupBy[CompliantDataFrameT, EagerExprT_contra],
Protocol[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
): ...

View File

@ -0,0 +1,238 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Protocol, overload
from narwhals._compliant.typing import (
CompliantExprT,
CompliantFrameT,
CompliantLazyFrameT,
DepthTrackingExprT,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprT,
NativeFrameT,
NativeFrameT_co,
NativeSeriesT,
)
from narwhals._expression_parsing import is_expr, is_series
from narwhals._utils import (
exclude_column_names,
get_column_names,
passthrough_column_names,
)
from narwhals.dependencies import is_numpy_array, is_numpy_array_2d
if TYPE_CHECKING:
from collections.abc import Container, Iterable, Sequence
from typing_extensions import TypeAlias
from narwhals._compliant.selectors import CompliantSelectorNamespace
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
from narwhals._utils import Implementation, Version
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import (
ConcatMethod,
Into1DArray,
IntoDType,
IntoSchema,
NonNestedLiteral,
_1DArray,
_2DArray,
)
Incomplete: TypeAlias = Any
__all__ = [
"CompliantNamespace",
"DepthTrackingNamespace",
"EagerNamespace",
"LazyNamespace",
]
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
# NOTE: `narwhals`
_implementation: Implementation
_version: Version
@property
def _expr(self) -> type[CompliantExprT]: ...
def parse_into_expr(
self, data: Expr | NonNestedLiteral | Any, /, *, str_as_lit: bool
) -> CompliantExprT | NonNestedLiteral:
if is_expr(data):
expr = data._to_compliant_expr(self)
assert isinstance(expr, self._expr) # noqa: S101
return expr
if isinstance(data, str) and not str_as_lit:
return self.col(data)
return data
# NOTE: `polars`
def all(self) -> CompliantExprT:
return self._expr.from_column_names(get_column_names, context=self)
def col(self, *column_names: str) -> CompliantExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), context=self
)
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names), context=self
)
def nth(self, *column_indices: int) -> CompliantExprT:
return self._expr.from_column_indices(*column_indices, context=self)
def len(self) -> CompliantExprT: ...
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
def all_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def any_horizontal(
self, *exprs: CompliantExprT, ignore_nulls: bool
) -> CompliantExprT: ...
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
def concat(
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
) -> CompliantFrameT: ...
def when(
self, predicate: CompliantExprT
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
def concat_str(
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
) -> CompliantExprT: ...
@property
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
class DepthTrackingNamespace(
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
Protocol[CompliantFrameT, DepthTrackingExprT],
):
def all(self) -> DepthTrackingExprT:
return self._expr.from_column_names(
get_column_names, function_name="all", context=self
)
def col(self, *column_names: str) -> DepthTrackingExprT:
return self._expr.from_column_names(
passthrough_column_names(column_names), function_name="col", context=self
)
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
return self._expr.from_column_names(
partial(exclude_column_names, names=excluded_names),
function_name="exclude",
context=self,
)
class LazyNamespace(
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
if self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
raise TypeError(msg)
class EagerNamespace(
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
):
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def _dataframe(self) -> type[EagerDataFrameT]: ...
@property
def _series(self) -> type[EagerSeriesT]: ...
def when(
self, predicate: EagerExprT
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
@overload
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
@overload
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
def from_native(
self, data: NativeFrameT | NativeSeriesT | Any, /
) -> EagerDataFrameT | EagerSeriesT:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
if self._series._is_native(data):
return self._series.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}"
raise TypeError(msg)
def parse_into_expr(
self,
data: Expr | Series[NativeSeriesT] | _1DArray | NonNestedLiteral,
/,
*,
str_as_lit: bool,
) -> EagerExprT | NonNestedLiteral:
if not (is_series(data) or is_numpy_array(data)):
return super().parse_into_expr(data, str_as_lit=str_as_lit)
return self._expr._from_series(
data._compliant_series
if is_series(data)
else self._series.from_numpy(data, context=self)
)
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
@overload
def from_numpy(
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
) -> EagerDataFrameT: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: IntoSchema | Sequence[str] | None = None,
) -> EagerDataFrameT | EagerSeriesT:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self)
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def _concat_horizontal(
self, dfs: Sequence[NativeFrameT | Any], /
) -> NativeFrameT: ...
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
def concat(
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
) -> EagerDataFrameT:
dfs = [item.native for item in items]
if how == "horizontal":
native = self._concat_horizontal(dfs)
elif how == "vertical":
native = self._concat_vertical(dfs)
elif how == "diagonal":
native = self._concat_diagonal(dfs)
else: # pragma: no cover
raise NotImplementedError
return self._dataframe.from_native(native, context=self)

View File

@ -0,0 +1,318 @@
"""Almost entirely complete, generic `selectors` implementation."""
from __future__ import annotations
import re
from functools import partial
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
from narwhals._compliant.expr import CompliantExpr
from narwhals._utils import (
_parse_time_unit_and_time_zone,
dtype_matches_time_unit_and_time_zone,
get_column_names,
is_compliant_dataframe,
zip_strict,
)
if TYPE_CHECKING:
from collections.abc import Collection, Iterable, Iterator, Sequence
from datetime import timezone
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.expr import NativeExpr
from narwhals._compliant.typing import (
CompliantDataFrameAny,
CompliantExprAny,
CompliantFrameAny,
CompliantLazyFrameAny,
CompliantSeriesAny,
CompliantSeriesOrNativeExprAny,
EvalNames,
EvalSeries,
)
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.typing import TimeUnit
__all__ = [
"CompliantSelector",
"CompliantSelectorNamespace",
"EagerSelectorNamespace",
"LazySelectorNamespace",
]
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
ExprT = TypeVar("ExprT", bound="NativeExpr")
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
SelectorOrExpr: TypeAlias = (
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
)
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
# NOTE: `narwhals`
_implementation: Implementation
_version: Version
@property
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
@classmethod
def from_namespace(cls, context: _LimitedContext, /) -> Self:
obj = cls.__new__(cls)
obj._implementation = context._implementation
obj._version = context._version
return obj
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
def _iter_columns_dtypes(
self, df: FrameT, /
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
yield from zip_strict(self._iter_columns(df), df.columns)
def _is_dtype(
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
return self._selector.from_callables(series, names, context=self)
# NOTE: `polars`
def by_dtype(
self, dtypes: Collection[DType | type[DType]]
) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
return self._selector.from_callables(series, names, context=self)
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
p = re.compile(pattern)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
if (
is_compliant_dataframe(df)
and not self._implementation.is_duckdb()
and not self._implementation.is_ibis()
):
return [df.get_column(col) for col in df.columns if p.search(col)]
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
def names(df: FrameT) -> Sequence[str]:
return [col for col in df.columns if p.search(col)]
return self._selector.from_callables(series, names, context=self)
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
return self._selector.from_callables(series, names, context=self)
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Categorical)
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.String)
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self._is_dtype(self._version.dtypes.Boolean)
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return list(self._iter_columns(df))
return self._selector.from_callables(series, get_column_names, context=self)
def datetime(
self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> CompliantSelector[FrameT, SeriesOrExprT]:
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
matches = partial(
dtype_matches_time_unit_and_time_zone,
dtypes=self._version.dtypes,
time_units=time_units,
time_zones=time_zones,
)
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
def names(df: FrameT) -> Sequence[str]:
return [name for name, tp in self._iter_schema(df) if matches(tp)]
return self._selector.from_callables(series, names, context=self)
class EagerSelectorNamespace(
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
):
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
for ser in self._iter_columns(df):
yield ser.name, ser.dtype
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
yield from df.iter_columns()
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
for ser in self._iter_columns(df):
yield ser, ser.dtype
class LazySelectorNamespace(
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
):
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
yield from df.schema.items()
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
yield from df._iter_columns()
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
yield from zip_strict(self._iter_columns(df), df.schema.values())
class CompliantSelector(
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
):
_call: EvalSeries[FrameT, SeriesOrExprT]
_function_name: str
_implementation: Implementation
_version: Version
@classmethod
def from_callables(
cls,
call: EvalSeries[FrameT, SeriesOrExprT],
evaluate_output_names: EvalNames[FrameT],
*,
context: _LimitedContext,
) -> Self:
obj = cls.__new__(cls)
obj._call = call
obj._evaluate_output_names = evaluate_output_names
obj._alias_output_names = None
obj._implementation = context._implementation
obj._version = context._version
return obj
@property
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
return self.__narwhals_namespace__().selectors
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def _is_selector(
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
return isinstance(other, type(self))
@overload
def __sub__(self, other: Self) -> Self: ...
@overload
def __sub__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __sub__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x
for x, name in zip_strict(self(df), lhs_names)
if name not in rhs_names
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x not in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() - other
@overload
def __or__(self, other: Self) -> Self: ...
@overload
def __or__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __or__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
*(
x
for x, name in zip_strict(self(df), lhs_names)
if name not in rhs_names
),
*other(df),
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() | other
@overload
def __and__(self, other: Self) -> Self: ...
@overload
def __and__(
self, other: CompliantExpr[FrameT, SeriesOrExprT]
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
def __and__(
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
if self._is_selector(other):
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [
x for x, name in zip_strict(self(df), lhs_names) if name in rhs_names
]
def names(df: FrameT) -> Sequence[str]:
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
return [x for x in lhs_names if x in rhs_names]
return self.selectors._selector.from_callables(series, names, context=self)
return self._to_expr() & other
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
return self.selectors.all() - self
def _eval_lhs_rhs(
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
) -> tuple[Sequence[str], Sequence[str]]:
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)

View File

@ -0,0 +1,411 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol
from narwhals._compliant.any_namespace import (
CatNamespace,
DateTimeNamespace,
ListNamespace,
StringNamespace,
StructNamespace,
)
from narwhals._compliant.column import CompliantColumn
from narwhals._compliant.typing import (
CompliantSeriesT_co,
EagerDataFrameAny,
EagerSeriesT_co,
NativeSeriesT,
NativeSeriesT_co,
)
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
from narwhals._typing_compat import TypeVar, assert_never
from narwhals._utils import (
_StoresCompliant,
_StoresNative,
is_compliant_series,
is_sized_multi_index_selector,
unstable,
)
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
from types import ModuleType
import pandas as pd
import polars as pl
import pyarrow as pa
from typing_extensions import NotRequired, Self, TypedDict
from narwhals._compliant.dataframe import CompliantDataFrame
from narwhals._compliant.namespace import EagerNamespace
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
Into1DArray,
IntoDType,
MultiIndexSelector,
RollingInterpolationMethod,
SizedMultiIndexSelector,
_1DArray,
_SliceIndex,
)
class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
breakpoint: NotRequired[list[float] | _1DArray | list[Any]]
count: NativeSeriesT | _1DArray | _CountsT_co | list[Any]
_CountsT_co = TypeVar("_CountsT_co", bound="Iterable[Any]", covariant=True)
__all__ = [
"CompliantSeries",
"EagerSeries",
"EagerSeriesCatNamespace",
"EagerSeriesDateTimeNamespace",
"EagerSeriesHist",
"EagerSeriesListNamespace",
"EagerSeriesNamespace",
"EagerSeriesStringNamespace",
"EagerSeriesStructNamespace",
]
class CompliantSeries(
NumpyConvertible["_1DArray", "Into1DArray"],
FromIterable,
FromNative[NativeSeriesT],
ToNarwhals["Series[NativeSeriesT]"],
CompliantColumn,
Protocol[NativeSeriesT],
):
# NOTE: `narwhals`
_implementation: Implementation
@property
def native(self) -> NativeSeriesT: ...
def __narwhals_series__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType: ...
@classmethod
def from_native(cls, data: NativeSeriesT, /, *, context: _LimitedContext) -> Self: ...
def to_narwhals(self) -> Series[NativeSeriesT]:
return self._version.series(self, level="full")
def _with_native(self, series: Any) -> Self: ...
def _with_version(self, version: Version) -> Self: ...
# NOTE: `polars`
@property
def dtype(self) -> DType: ...
@property
def name(self) -> str: ...
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ...
def __contains__(self, other: Any) -> bool: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ...
def __iter__(self) -> Iterator[Any]: ...
def __len__(self) -> int:
return len(self.native)
@classmethod
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self: ...
@classmethod
def from_iterable(
cls,
data: Iterable[Any],
/,
*,
context: _LimitedContext,
name: str = "",
dtype: IntoDType | None = None,
) -> Self: ...
def __radd__(self, other: Any) -> Self: ...
def __rand__(self, other: Any) -> Self: ...
def __rmul__(self, other: Any) -> Self: ...
def __ror__(self, other: Any) -> Self: ...
def all(self) -> bool: ...
def any(self) -> bool: ...
def arg_max(self) -> int: ...
def arg_min(self) -> int: ...
def arg_true(self) -> Self: ...
def count(self) -> int: ...
def filter(self, predicate: Any) -> Self: ...
def gather_every(self, n: int, offset: int) -> Self: ...
def head(self, n: int) -> Self: ...
def is_empty(self) -> bool:
return self.len() == 0
def is_sorted(self, *, descending: bool) -> bool: ...
def item(self, index: int | None) -> Any: ...
def kurtosis(self) -> float | None: ...
def len(self) -> int: ...
def max(self) -> Any: ...
def mean(self) -> float: ...
def median(self) -> float: ...
def min(self) -> Any: ...
def n_unique(self) -> int: ...
def null_count(self) -> int: ...
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> float: ...
def sample(
self,
n: int | None,
*,
fraction: float | None,
with_replacement: bool,
seed: int | None,
) -> Self: ...
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
def shift(self, n: int) -> Self: ...
def skew(self) -> float | None: ...
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
def std(self, *, ddof: int) -> float: ...
def sum(self) -> float: ...
def tail(self, n: int) -> Self: ...
def to_arrow(self) -> pa.Array[Any]: ...
def to_dummies(
self, *, separator: str, drop_first: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_frame(self) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def to_list(self) -> list[Any]: ...
def to_pandas(self) -> pd.Series[Any]: ...
def to_polars(self) -> pl.Series: ...
def unique(self, *, maintain_order: bool = False) -> Self: ...
def value_counts(
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
def var(self, *, ddof: int) -> float: ...
def zip_with(self, mask: Any, other: Any) -> Self: ...
# NOTE: *Technically* `polars`
@unstable
def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
) -> CompliantDataFrame[Self, Any, Any, Any]:
"""`Series.hist(bins=..., bin_count=None)`."""
...
@unstable
def hist_from_bin_count(
self, bin_count: int, *, include_breakpoint: bool
) -> CompliantDataFrame[Self, Any, Any, Any]:
"""`Series.hist(bins=None, bin_count=...)`."""
...
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):
_native_series: Any
_implementation: Implementation
_version: Version
_broadcast: bool
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@classmethod
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
"""Ensure all of `series` have the same length (and index if `pandas`).
Scalars get broadcasted to the full length of the longest Series.
This is useful when you need to construct a full Series anyway, such as:
DataFrame.select(...)
It should not be used in binary operations, such as:
nw.col("a") - nw.col("a").mean()
because then it's more efficient to extract the right-hand-side's single element as a scalar.
"""
...
def _from_scalar(self, value: Any) -> Self:
return self.from_iterable([value], name=self.name, context=self)
def _with_native(
self, series: NativeSeriesT, *, preserve_broadcast: bool = False
) -> Self:
"""Return a new `CompliantSeries`, wrapping the native `series`.
In cases when operations are known to not affect whether a result should
be broadcast, we can pass `preserve_broadcast=True`.
Set this with care - it should only be set for unary expressions which don't
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
set it, you probably don't need it.
"""
...
def __narwhals_namespace__(
self,
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
if isinstance(item, (slice, range)):
return self._gather_slice(item)
if is_compliant_series(item):
return self._gather(item.native)
elif is_sized_multi_index_selector(item): # noqa: RET505
return self._gather(item)
assert_never(item)
@property
def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ...
@property
def dt(self) -> EagerSeriesDateTimeNamespace[Self, NativeSeriesT]: ...
@property
def cat(self) -> EagerSeriesCatNamespace[Self, NativeSeriesT]: ...
@property
def list(self) -> EagerSeriesListNamespace[Self, NativeSeriesT]: ...
@property
def struct(self) -> EagerSeriesStructNamespace[Self, NativeSeriesT]: ...
class _SeriesNamespace( # type: ignore[misc]
_StoresCompliant[CompliantSeriesT_co],
_StoresNative[NativeSeriesT_co],
Protocol[CompliantSeriesT_co, NativeSeriesT_co],
):
_compliant_series: CompliantSeriesT_co
@property
def compliant(self) -> CompliantSeriesT_co:
return self._compliant_series
@property
def implementation(self) -> Implementation:
return self.compliant._implementation
@property
def backend_version(self) -> tuple[int, ...]:
return self.implementation._backend_version()
@property
def version(self) -> Version:
return self.compliant._version
@property
def native(self) -> NativeSeriesT_co:
return self._compliant_series.native # type: ignore[no-any-return]
def with_native(self, series: Any, /) -> CompliantSeriesT_co:
return self.compliant._with_native(series)
class EagerSeriesNamespace(
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
Generic[EagerSeriesT_co, NativeSeriesT_co],
):
_compliant_series: EagerSeriesT_co
def __init__(self, series: EagerSeriesT_co, /) -> None:
self._compliant_series = series
class EagerSeriesCatNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
CatNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesDateTimeNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
DateTimeNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesListNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
ListNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStringNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StringNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesStructNamespace( # type: ignore[misc]
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
StructNamespace[EagerSeriesT_co],
Protocol[EagerSeriesT_co, NativeSeriesT_co],
): ...
class EagerSeriesHist(Protocol[NativeSeriesT, _CountsT_co]):
_series: EagerSeries[NativeSeriesT]
_breakpoint: bool
_data: HistData[NativeSeriesT, _CountsT_co]
@property
def native(self) -> NativeSeriesT:
return self._series.native
@classmethod
def from_series(
cls, series: EagerSeries[NativeSeriesT], *, include_breakpoint: bool
) -> Self:
obj = cls.__new__(cls)
obj._series = series
obj._breakpoint = include_breakpoint
return obj
def to_frame(self) -> EagerDataFrameAny: ...
def _linear_space( # NOTE: Roughly `pl.linear_space`
self,
start: float,
end: float,
num_samples: int,
*,
closed: Literal["both", "none"] = "both",
) -> _1DArray: ...
# NOTE: *Could* be handled at narwhals-level
def is_empty_series(self) -> bool: ...
# NOTE: **Should** be handled at narwhals-level
def data_empty(self) -> HistData[NativeSeriesT, _CountsT_co]:
return {"breakpoint": [], "count": []} if self._breakpoint else {"count": []}
# NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
# See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
def series_empty(
self, arg: int | list[float], /
) -> HistData[NativeSeriesT, _CountsT_co]: ...
def with_bins(self, bins: list[float], /) -> Self:
if len(bins) <= 1:
self._data = self.data_empty()
elif self.is_empty_series():
self._data = self.series_empty(bins)
else:
self._data = self._calculate_hist(bins)
return self
def with_bin_count(self, bin_count: int, /) -> Self:
if bin_count == 0:
self._data = self.data_empty()
elif self.is_empty_series():
self._data = self.series_empty(bin_count)
else:
self._data = self._calculate_hist(self._calculate_bins(bin_count))
return self
def _calculate_breakpoint(self, arg: int | list[float], /) -> list[float] | _1DArray:
bins = self._linear_space(0, 1, arg + 1) if isinstance(arg, int) else arg
return bins[1:]
def _calculate_bins(self, bin_count: int) -> _1DArray: ...
def _calculate_hist(
self, bins: list[float] | _1DArray
) -> HistData[NativeSeriesT, _CountsT_co]: ...

View File

@ -0,0 +1,206 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar
if TYPE_CHECKING:
from typing_extensions import TypeAlias
from narwhals._compliant.dataframe import (
CompliantDataFrame,
CompliantFrame,
CompliantLazyFrame,
EagerDataFrame,
)
from narwhals._compliant.expr import (
CompliantExpr,
DepthTrackingExpr,
EagerExpr,
ImplExpr,
LazyExpr,
NativeExpr,
)
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
from narwhals._compliant.series import CompliantSeries, EagerSeries
from narwhals._compliant.window import WindowInputs
from narwhals.typing import (
FillNullStrategy,
IntoLazyFrame,
ModeKeepStrategy,
NativeDataFrame,
NativeFrame,
NativeSeries,
RankMethod,
RollingInterpolationMethod,
)
class ScalarKwargs(TypedDict, total=False):
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
adjust: bool
alpha: float | None
center: int
com: float | None
ddof: int
descending: bool
half_life: float | None
ignore_nulls: bool
interpolation: RollingInterpolationMethod
keep: ModeKeepStrategy
limit: int | None
method: RankMethod
min_samples: int
n: int
quantile: float
reverse: bool
span: float | None
strategy: FillNullStrategy | None
window_size: int
__all__ = [
"AliasName",
"AliasNames",
"CompliantDataFrameT",
"CompliantFrameT",
"CompliantLazyFrameT",
"CompliantSeriesT",
"EvalNames",
"EvalSeries",
"NarwhalsAggregation",
"NativeFrameT_co",
"NativeSeriesT_co",
]
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
CompliantFrameAny: TypeAlias = "CompliantFrame[Any, Any, Any]"
CompliantNamespaceAny: TypeAlias = "CompliantNamespace[Any, Any]"
ImplExprAny: TypeAlias = "ImplExpr[Any, Any]"
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"
NativeExprT = TypeVar("NativeExprT", bound="NativeExpr")
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries")
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
NativeSeriesT_contra = TypeVar(
"NativeSeriesT_contra", bound="NativeSeries", contravariant=True
)
NativeDataFrameT = TypeVar("NativeDataFrameT", bound="NativeDataFrame")
NativeLazyFrameT = TypeVar("NativeLazyFrameT", bound="IntoLazyFrame")
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame")
NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True)
NativeFrameT_contra = TypeVar(
"NativeFrameT_contra", bound="NativeFrame", contravariant=True
)
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny)
CompliantExprT_co = TypeVar("CompliantExprT_co", bound=CompliantExprAny, covariant=True)
CompliantExprT_contra = TypeVar(
"CompliantExprT_contra", bound=CompliantExprAny, contravariant=True
)
CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny)
CompliantSeriesT_co = TypeVar(
"CompliantSeriesT_co", bound=CompliantSeriesAny, covariant=True
)
CompliantSeriesOrNativeExprT = TypeVar(
"CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny
)
CompliantSeriesOrNativeExprT_co = TypeVar(
"CompliantSeriesOrNativeExprT_co",
bound=CompliantSeriesOrNativeExprAny,
covariant=True,
)
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
CompliantFrameT_co = TypeVar(
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
)
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny)
CompliantDataFrameT_co = TypeVar(
"CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True
)
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny)
CompliantLazyFrameT_co = TypeVar(
"CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True
)
CompliantNamespaceT = TypeVar("CompliantNamespaceT", bound=CompliantNamespaceAny)
CompliantNamespaceT_co = TypeVar(
"CompliantNamespaceT_co", bound=CompliantNamespaceAny, covariant=True
)
ImplExprT_contra = TypeVar("ImplExprT_contra", bound=ImplExprAny, contravariant=True)
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)
DepthTrackingExprT_contra = TypeVar(
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
)
EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny)
EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True)
EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny)
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)
# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
"""A function aliasing a *sequence* of column names."""
AliasName: TypeAlias = Callable[[str], str]
"""A function aliasing a *single* column name."""
EvalSeries: TypeAlias = Callable[
[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]
]
"""A function from a `Frame` to a sequence of `Series`*.
See [underwater unicorn magic](https://narwhals-dev.github.io/narwhals/how_it_works/).
"""
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
WindowFunction: TypeAlias = (
"Callable[[CompliantFrameT, WindowInputs[NativeExprT]], Sequence[NativeExprT]]"
)
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
NarwhalsAggregation: TypeAlias = Literal[
"sum",
"mean",
"median",
"max",
"min",
"mode",
"std",
"var",
"len",
"n_unique",
"count",
"quantile",
"all",
"any",
]
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
Be sure to update me if you're working on one of these:
- https://github.com/narwhals-dev/narwhals/issues/981
- https://github.com/narwhals-dev/narwhals/issues/2385
- https://github.com/narwhals-dev/narwhals/issues/2484
- https://github.com/narwhals-dev/narwhals/issues/2526
- https://github.com/narwhals-dev/narwhals/issues/2660
"""

View File

@ -0,0 +1,130 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
from narwhals._compliant.expr import CompliantExpr
from narwhals._compliant.typing import (
CompliantExprAny,
CompliantFrameAny,
CompliantSeriesOrNativeExprAny,
EagerDataFrameT,
EagerExprT,
EagerSeriesT,
LazyExprAny,
NativeSeriesT,
)
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self, TypeAlias
from narwhals._compliant.typing import EvalSeries
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.typing import NonNestedLiteral
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"]
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
Scalar: TypeAlias = Any
"""A native literal value."""
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
"""Anything that is convertible into a `CompliantExpr`."""
class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]):
_condition: ExprT
_then_value: IntoExpr[SeriesT, ExprT]
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
_implementation: Implementation
_version: Version
@property
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ...
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
def then(
self, value: IntoExpr[SeriesT, ExprT], /
) -> CompliantThen[FrameT, SeriesT, ExprT, Self]:
return self._then.from_when(self, value)
@classmethod
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
obj = cls.__new__(cls)
obj._condition = condition
obj._then_value = None
obj._otherwise_value = None
obj._implementation = context._implementation
obj._version = context._version
return obj
WhenT_contra = TypeVar(
"WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True
)
class CompliantThen(
CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra]
):
_call: EvalSeries[FrameT, SeriesT]
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
_implementation: Implementation
_version: Version
@classmethod
def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self:
when._then_value = then
obj = cls.__new__(cls)
obj._call = when
obj._when_value = when
obj._evaluate_output_names = getattr(
then, "_evaluate_output_names", lambda _df: ["literal"]
)
obj._alias_output_names = getattr(then, "_alias_output_names", None)
obj._implementation = when._implementation
obj._version = when._version
return obj
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
self._when_value._otherwise_value = otherwise
return cast("ExprT", self)
class EagerWhen(
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
):
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
/,
) -> NativeSeriesT: ...
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
is_expr = self._condition._is_expr
when: EagerSeriesT = self._condition(df)[0]
then: EagerSeriesT
align = when._align_full_broadcast
if is_expr(self._then_value):
then = self._then_value(df)[0]
else:
then = when.alias("literal")._from_scalar(self._then_value)
then._broadcast = True
if is_expr(self._otherwise_value):
otherwise = self._otherwise_value(df)[0]
when, then, otherwise = align(when, then, otherwise)
result = self._if_then_else(when.native, then.native, otherwise.native)
else:
when, then = align(when, then)
result = self._if_then_else(when.native, then.native, self._otherwise_value)
return [then._with_native(result)]

View File

@ -0,0 +1,20 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Generic
from narwhals._compliant.typing import NativeExprT_co
if TYPE_CHECKING:
from collections.abc import Sequence
__all__ = ["WindowInputs"]
class WindowInputs(Generic[NativeExprT_co]):
__slots__ = ("order_by", "partition_by")
def __init__(
self, partition_by: Sequence[str | NativeExprT_co], order_by: Sequence[str]
) -> None:
self.partition_by = partition_by
self.order_by = order_by

View File

@ -0,0 +1,30 @@
from __future__ import annotations
import datetime as dt
# Temporal (from `polars._utils.constants`)
SECONDS_PER_DAY = 86_400
SECONDS_PER_MINUTE = 60
NS_PER_MINUTE = 60_000_000_000
"""Nanoseconds (`[ns]`) per minute."""
US_PER_MINUTE = 60_000_000
"""Microseconds (`[μs]`) per minute."""
MS_PER_MINUTE = 60_000
"""Milliseconds (`[ms]`) per minute."""
NS_PER_SECOND = 1_000_000_000
"""Nanoseconds (`[ns]`) per second (`[s]`)."""
US_PER_SECOND = 1_000_000
"""Microseconds (`[μs]`) per second (`[s]`)."""
MS_PER_SECOND = 1_000
"""Milliseconds (`[ms]`) per second (`[s]`)."""
NS_PER_MICROSECOND = 1_000
"""Nanoseconds (`[ns]`) per microsecond (`[μs]`)."""
NS_PER_MILLISECOND = 1_000_000
"""Nanoseconds (`[ns]`) per millisecond (`[ms]`).
From [polars](https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-time/src/chunkedarray/duration.rs#L7).
"""
EPOCH_YEAR = 1970
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""
EPOCH = dt.datetime(EPOCH_YEAR, 1, 1).replace(tzinfo=None)
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""

View File

@ -0,0 +1,502 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
import dask.dataframe as dd
from narwhals._dask.utils import add_row_index, evaluate_exprs
from narwhals._expression_parsing import ExprKind
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
from narwhals._typing_compat import assert_never
from narwhals._utils import (
Implementation,
ValidateBackendVersion,
_remap_full_join_keys,
check_column_names_are_unique,
check_columns_exist,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
zip_strict,
)
from narwhals.typing import CompliantLazyFrame
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
import dask.dataframe.dask_expr as dx
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._dask.expr import DaskExpr
from narwhals._dask.group_by import DaskLazyGroupBy
from narwhals._dask.namespace import DaskNamespace
from narwhals._typing import _EagerAllowedImpl
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.exceptions import ColumnNotFoundError
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
Incomplete: TypeAlias = "Any"
"""Using `_pandas_like` utils with `_dask`.
Typing this correctly will complicate the `_pandas_like`-side.
Very low priority until `dask` adds typing.
"""
class DaskLazyFrame(
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
ValidateBackendVersion,
):
_implementation = Implementation.DASK
def __init__(
self,
native_dataframe: dd.DataFrame,
*,
version: Version,
validate_backend_version: bool = False,
) -> None:
self._native_frame: dd.DataFrame = native_dataframe
self._version = version
self._cached_schema: dict[str, DType] | None = None
self._cached_columns: list[str] | None = None
if validate_backend_version:
self._validate_backend_version()
@staticmethod
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
return isinstance(obj, dd.DataFrame)
@classmethod
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
return self._version.lazyframe(self, level="lazy")
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.DASK:
return self._implementation.to_native_namespace()
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def __narwhals_namespace__(self) -> DaskNamespace:
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(version=self._version)
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
def _with_native(self, df: Any) -> Self:
return self.__class__(df, version=self._version)
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
return check_columns_exist(subset, available=self.columns)
def _iter_columns(self) -> Iterator[dx.Series]:
for _col, ser in self.native.items(): # noqa: PERF102
yield ser
def with_columns(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
return self._with_native(self.native.assign(**dict(new_series)))
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
result = self.native.compute(**kwargs)
if backend is None or backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
result,
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
pl.from_pandas(result),
validate_backend_version=True,
version=self._version,
)
if backend is Implementation.PYARROW:
import pyarrow as pa # ignore-banned-import
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
pa.Table.from_pandas(result),
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_schema is not None
else self.native.columns.tolist()
)
return self._cached_columns
def filter(self, predicate: DaskExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = predicate(self)[0]
return self._with_native(self.native.loc[mask])
def simple_select(self, *column_names: str) -> Self:
df: Incomplete = self.native
native = select_columns_by_name(df, list(column_names), self._implementation)
return self._with_native(native)
def aggregate(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
return self._with_native(df)
def select(self, *exprs: DaskExpr) -> Self:
new_series = evaluate_exprs(self, *exprs)
df: Incomplete = self.native
df = select_columns_by_name(
df.assign(**dict(new_series)),
[s[0] for s in new_series],
self._implementation,
)
return self._with_native(df)
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
if subset is None:
return self._with_native(self.native.dropna())
plx = self.__narwhals_namespace__()
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
return self.filter(mask)
@property
def schema(self) -> dict[str, DType]:
if self._cached_schema is None:
native_dtypes = self.native.dtypes
self._cached_schema = {
col: native_to_narwhals_dtype(
native_dtypes[col], self._version, self._implementation
)
for col in self.native.columns
}
return self._cached_schema
def collect_schema(self) -> dict[str, DType]:
return self.schema
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(columns=to_drop))
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
# Implementation is based on the following StackOverflow reply:
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
if order_by is None:
return self._with_native(add_row_index(self.native, name))
plx = self.__narwhals_namespace__()
columns = self.columns
const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
row_index_expr = (
plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by)
- 1
)
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
def rename(self, mapping: Mapping[str, str]) -> Self:
return self._with_native(self.native.rename(columns=mapping))
def head(self, n: int) -> Self:
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset and (error := self._check_columns_exist(subset)):
raise error
if keep == "none":
subset = subset or self.columns
token = generate_temporary_column_name(n_bytes=8, columns=subset)
ser = self.native.groupby(subset).size().rename(token)
ser = ser[ser == 1]
unique = ser.reset_index().drop(columns=token)
result = self.native.merge(unique, on=subset, how="inner")
else:
mapped_keep = {"any": "first"}.get(keep, keep)
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
return self._with_native(result)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
ascending: bool | list[bool] = not descending
else:
ascending = [not d for d in descending]
position = "last" if nulls_last else "first"
return self._with_native(
self.native.sort_values(list(by), ascending=ascending, na_position=position)
)
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
df = self.native
schema = self.schema
by = list(by)
if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by):
if reverse:
return self._with_native(df.nsmallest(k, by))
return self._with_native(df.nlargest(k, by))
if isinstance(reverse, bool):
reverse = [reverse] * len(by)
return self._with_native(
df.sort_values(by, ascending=list(reverse)).head(
n=k, compute=False, npartitions=-1
)
)
def _join_inner(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
return self.native.merge(
other.native,
left_on=left_on,
right_on=right_on,
how="inner",
suffixes=("", suffix),
)
def _join_left(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
result_native = self.native.merge(
other.native,
how="left",
left_on=left_on,
right_on=right_on,
suffixes=("", suffix),
)
extra = [
right_key if right_key not in self.columns else f"{right_key}{suffix}"
for left_key, right_key in zip_strict(left_on, right_on)
if right_key != left_key
]
return result_native.drop(columns=extra)
def _join_full(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
) -> dd.DataFrame:
# dask does not retain keys post-join
# we must append the suffix to each key before-hand
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
other_native = other.native.rename(columns=right_on_mapper)
check_column_names_are_unique(other_native.columns)
right_suffixed = list(right_on_mapper.values())
return self.native.merge(
other_native,
left_on=left_on,
right_on=right_suffixed,
how="outer",
suffixes=("", suffix),
)
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
key_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
)
return (
self.native.assign(**{key_token: 0})
.merge(
other.native.assign(**{key_token: 0}),
how="inner",
left_on=key_token,
right_on=key_token,
suffixes=("", suffix),
)
.drop(columns=key_token)
)
def _join_semi(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
) -> dd.DataFrame:
other_native = self._join_filter_rename(
other=other,
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
return self.native.merge(
other_native, how="inner", left_on=left_on, right_on=left_on
)
def _join_anti(
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
) -> dd.DataFrame:
indicator_token = generate_temporary_column_name(
n_bytes=8, columns=(*self.columns, *other.columns)
)
other_native = self._join_filter_rename(
other=other,
columns_to_select=list(right_on),
columns_mapping=dict(zip(right_on, left_on)),
)
df = self.native.merge(
other_native,
how="left",
indicator=indicator_token, # pyright: ignore[reportArgumentType]
left_on=left_on,
right_on=left_on,
)
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
def _join_filter_rename(
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
) -> dd.DataFrame:
"""Helper function to avoid creating extra columns and row duplication.
Used in `"anti"` and `"semi`" join's.
Notice that a native object is returned.
"""
other_native: Incomplete = other.native
# rename to avoid creating extra columns in join
return (
select_columns_by_name(other_native, columns_to_select, self._implementation)
.rename(columns=columns_mapping)
.drop_duplicates()
)
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
if how == "cross":
result = self._join_cross(other=other, suffix=suffix)
elif left_on is None or right_on is None: # pragma: no cover
raise ValueError(left_on, right_on)
elif how == "inner":
result = self._join_inner(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
elif how == "anti":
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
elif how == "semi":
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
elif how == "left":
result = self._join_left(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
elif how == "full":
result = self._join_full(
other=other, left_on=left_on, right_on=right_on, suffix=suffix
)
else:
assert_never(how)
return self._with_native(result)
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self:
plx = self.__native_namespace__()
return self._with_native(
plx.merge_asof(
self.native,
other.native,
left_on=left_on,
right_on=right_on,
left_by=by_left,
right_by=by_right,
direction=strategy,
suffixes=("", suffix),
)
)
def group_by(
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
) -> DaskLazyGroupBy:
from narwhals._dask.group_by import DaskLazyGroupBy
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def tail(self, n: int) -> Self: # pragma: no cover
native_frame = self.native
n_partitions = native_frame.npartitions
if n_partitions == 1:
return self._with_native(self.native.tail(n=n, compute=False))
msg = (
"`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
)
raise NotImplementedError(msg)
def gather_every(self, n: int, offset: int) -> Self:
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
plx = self.__narwhals_namespace__()
return (
self.with_row_index(row_index_token, order_by=None)
.filter(
(plx.col(row_index_token) >= offset)
& ((plx.col(row_index_token) - offset) % n == 0)
)
.drop([row_index_token], strict=False)
)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
var_name=variable_name,
value_name=value_name,
)
)
def sink_parquet(self, file: str | Path | BytesIO) -> None:
self.native.to_parquet(file)
explode = not_implemented()

View File

@ -0,0 +1,701 @@
from __future__ import annotations
import warnings
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
import pandas as pd
from narwhals._compliant import DepthTrackingExpr, LazyExpr
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
from narwhals._dask.expr_str import DaskExprStringNamespace
from narwhals._dask.utils import (
add_row_index,
maybe_evaluate_expr,
narwhals_to_native_dtype,
)
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
from narwhals._utils import (
Implementation,
generate_temporary_column_name,
not_implemented,
)
from narwhals.exceptions import InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Sequence
import dask.dataframe.dask_expr as dx
from typing_extensions import Self
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.namespace import DaskNamespace
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._utils import Version, _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
ModeKeepStrategy,
NonNestedLiteral,
NumericLiteral,
RollingInterpolationMethod,
TemporalLiteral,
)
class DaskExpr(
LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
):
_implementation: Implementation = Implementation.DASK
def __init__(
self,
call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm]
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[DaskLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
scalar_kwargs: ScalarKwargs | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._scalar_kwargs = scalar_kwargs or {}
self._metadata: ExprMetadata | None = None
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
return self._call(df)
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
from narwhals._dask.namespace import DaskNamespace
return DaskNamespace(version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
# that raised a KeyError for result[0] during collection.
return [result.loc[0][0] for result in self(df)]
return self.__class__(
func,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[DaskLazyFrame],
/,
*,
context: _LimitedContext,
function_name: str = "",
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
try:
return [
df._native_frame[column_name]
for column_name in evaluate_column_names(df)
]
except KeyError as e:
if error := df._check_columns_exist(evaluate_column_names(df)):
raise error from e
raise
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
return [df.native.iloc[:, i] for i in column_indices]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
def _with_callable(
self,
# First argument to `call` should be `dx.Series`
call: Callable[..., dx.Series],
/,
expr_name: str = "",
scalar_kwargs: ScalarKwargs | None = None,
**expressifiable_args: Self | Any,
) -> Self:
def func(df: DaskLazyFrame) -> list[dx.Series]:
native_results: list[dx.Series] = []
native_series_list = self._call(df)
other_native_series = {
key: maybe_evaluate_expr(df, value)
for key, value in expressifiable_args.items()
}
for native_series in native_series_list:
result_native = call(native_series, **other_native_series)
native_results.append(result_native)
return native_results
return self.__class__(
func,
depth=self._depth + 1,
function_name=f"{self._function_name}->{expr_name}",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
scalar_kwargs=scalar_kwargs,
)
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
current_alias_output_names = self._alias_output_names
alias_output_names = (
None
if func is None
else func
if current_alias_output_names is None
else lambda output_names: func(current_alias_output_names(output_names))
)
return type(self)(
call=self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=alias_output_names,
version=self._version,
scalar_kwargs=self._scalar_kwargs,
)
def _with_binary(
self,
call: Callable[[dx.Series, Any], dx.Series],
name: str,
other: Any,
*,
reverse: bool = False,
) -> Self:
result = self._with_callable(
lambda expr, other: call(expr, other), name, other=other
)
if reverse:
result = result.alias("literal")
return result
def _binary_op(self, op_name: str, other: Any) -> Self:
return self._with_binary(
lambda expr, other: getattr(expr, op_name)(other), op_name, other
)
def _reverse_binary_op(
self, op_name: str, operator_func: Callable[..., dx.Series], other: Any
) -> Self:
return self._with_binary(
lambda expr, other: operator_func(other, expr), op_name, other, reverse=True
)
def __add__(self, other: Any) -> Self:
return self._binary_op("__add__", other)
def __sub__(self, other: Any) -> Self:
return self._binary_op("__sub__", other)
def __mul__(self, other: Any) -> Self:
return self._binary_op("__mul__", other)
def __truediv__(self, other: Any) -> Self:
return self._binary_op("__truediv__", other)
def __floordiv__(self, other: Any) -> Self:
return self._binary_op("__floordiv__", other)
def __pow__(self, other: Any) -> Self:
return self._binary_op("__pow__", other)
def __mod__(self, other: Any) -> Self:
return self._binary_op("__mod__", other)
def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._binary_op("__eq__", other)
def __ne__(self, other: object) -> Self: # type: ignore[override]
return self._binary_op("__ne__", other)
def __ge__(self, other: Any) -> Self:
return self._binary_op("__ge__", other)
def __gt__(self, other: Any) -> Self:
return self._binary_op("__gt__", other)
def __le__(self, other: Any) -> Self:
return self._binary_op("__le__", other)
def __lt__(self, other: Any) -> Self:
return self._binary_op("__lt__", other)
def __and__(self, other: Any) -> Self:
return self._binary_op("__and__", other)
def __or__(self, other: Any) -> Self:
return self._binary_op("__or__", other)
def __rsub__(self, other: Any) -> Self:
return self._reverse_binary_op("__rsub__", lambda a, b: a - b, other)
def __rtruediv__(self, other: Any) -> Self:
return self._reverse_binary_op("__rtruediv__", lambda a, b: a / b, other)
def __rfloordiv__(self, other: Any) -> Self:
return self._reverse_binary_op("__rfloordiv__", lambda a, b: a // b, other)
def __rpow__(self, other: Any) -> Self:
return self._reverse_binary_op("__rpow__", lambda a, b: a**b, other)
def __rmod__(self, other: Any) -> Self:
return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other)
def __invert__(self) -> Self:
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
def mean(self) -> Self:
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
def median(self) -> Self:
from narwhals.exceptions import InvalidOperationError
def func(s: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
if not dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)
return s.median_approximate().to_series()
return self._with_callable(func, "median")
def min(self) -> Self:
return self._with_callable(lambda expr: expr.min().to_series(), "min")
def max(self) -> Self:
return self._with_callable(lambda expr: expr.max().to_series(), "max")
def std(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.std(ddof=ddof).to_series(),
"std",
scalar_kwargs={"ddof": ddof},
)
def var(self, ddof: int) -> Self:
return self._with_callable(
lambda expr: expr.var(ddof=ddof).to_series(),
"var",
scalar_kwargs={"ddof": ddof},
)
def skew(self) -> Self:
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
def kurtosis(self) -> Self:
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
def shift(self, n: int) -> Self:
return self._with_callable(lambda expr: expr.shift(n), "shift")
def cum_sum(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
# https://github.com/dask/dask/issues/11802
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
def cum_count(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
)
def cum_min(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
def cum_max(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
def cum_prod(self, *, reverse: bool) -> Self:
if reverse: # pragma: no cover
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).sum(),
"rolling_sum",
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).mean(),
"rolling_mean",
)
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).var(),
"rolling_var",
)
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
raise NotImplementedError(msg)
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
if ddof == 1:
return self._with_callable(
lambda expr: expr.rolling(
window=window_size, min_periods=min_samples, center=center
).std(),
"rolling_std",
)
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
raise NotImplementedError(msg)
def sum(self) -> Self:
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
def count(self) -> Self:
return self._with_callable(lambda expr: expr.count().to_series(), "count")
def round(self, decimals: int) -> Self:
return self._with_callable(lambda expr: expr.round(decimals), "round")
def unique(self) -> Self:
return self._with_callable(lambda expr: expr.unique(), "unique")
def drop_nulls(self) -> Self:
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
def abs(self) -> Self:
return self._with_callable(lambda expr: expr.abs(), "abs")
def all(self) -> Self:
return self._with_callable(
lambda expr: expr.all(
axis=None, skipna=True, split_every=False, out=None
).to_series(),
"all",
)
def any(self) -> Self:
return self._with_callable(
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
"any",
)
def fill_nan(self, value: float | None) -> Self:
value_nullable = pd.NA if value is None else value
value_numpy = float("nan") if value is None else value
def func(expr: dx.Series) -> dx.Series:
# If/when pandas exposes an API which distinguishes NaN vs null, use that.
mask = cast("dx.Series", expr != expr) # noqa: PLR0124
mask = mask.fillna(False)
fill = (
value_nullable
if get_dtype_backend(expr.dtype, self._implementation)
else value_numpy
)
return expr.mask(mask, fill) # pyright: ignore[reportArgumentType]
return self._with_callable(func, "fill_nan")
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
def func(expr: dx.Series) -> dx.Series:
if value is not None:
res_ser = expr.fillna(value)
else:
res_ser = (
expr.ffill(limit=limit)
if strategy == "forward"
else expr.bfill(limit=limit)
)
return res_ser
return self._with_callable(func, "fill_null")
def clip(
self,
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
) -> Self:
return self._with_callable(
lambda expr, lower_bound, upper_bound: expr.clip(
lower=lower_bound, upper=upper_bound
),
"clip",
lower_bound=lower_bound,
upper_bound=upper_bound,
)
def diff(self) -> Self:
return self._with_callable(lambda expr: expr.diff(), "diff")
def n_unique(self) -> Self:
return self._with_callable(
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
)
def is_null(self) -> Self:
return self._with_callable(lambda expr: expr.isna(), "is_null")
def is_nan(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
dtype = native_to_narwhals_dtype(
expr.dtype, self._version, self._implementation
)
if dtype.is_numeric():
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
raise InvalidOperationError(msg)
return self._with_callable(func, "is_null")
def len(self) -> Self:
return self._with_callable(lambda expr: expr.size.to_series(), "len")
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
if interpolation == "linear":
def func(expr: dx.Series, quantile: float) -> dx.Series:
if expr.npartitions > 1:
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
raise NotImplementedError(msg)
return expr.quantile(
q=quantile, method="dask"
).to_series() # pragma: no cover
return self._with_callable(func, "quantile", quantile=quantile)
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
raise NotImplementedError(msg)
def is_first_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
return frame[col_token].isin(first_distinct_index)
return self._with_callable(func, "is_first_distinct")
def is_last_distinct(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
frame = add_row_index(expr.to_frame(), col_token)
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
return frame[col_token].isin(last_distinct_index)
return self._with_callable(func, "is_last_distinct")
def is_unique(self) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
return (
expr.to_frame()
.groupby(_name, dropna=False)
.transform("size", meta=(_name, int))
== 1
)
return self._with_callable(func, "is_unique")
def is_in(self, other: Any) -> Self:
return self._with_callable(lambda expr: expr.isin(other), "is_in")
def null_count(self) -> Self:
return self._with_callable(
lambda expr: expr.isna().sum().to_series(), "null_count"
)
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
# pandas is a required dependency of dask so it's safe to import this
from narwhals._pandas_like.group_by import PandasLikeGroupBy
if not partition_by:
assert order_by # noqa: S101
# This is something like `nw.col('a').cum_sum().order_by(key)`
# which we can always easily support, as it doesn't require grouping.
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
return self(df.sort(*order_by, descending=False, nulls_last=False))
elif not self._is_elementary(): # pragma: no cover
msg = (
"Only elementary expressions are supported for `.over` in dask.\n\n"
"Please see: "
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
)
raise NotImplementedError(msg)
elif order_by:
# Wrong results https://github.com/dask/dask/issues/11806.
msg = "`over` with `order_by` is not yet supported in Dask."
raise NotImplementedError(msg)
else:
function_name = PandasLikeGroupBy._leaf_name(self)
try:
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
except KeyError:
# window functions are unsupported: https://github.com/dask/dask/issues/11806
msg = (
f"Unsupported function: {function_name} in `over` context.\n\n"
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
)
raise NotImplementedError(msg) from None
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
with warnings.catch_warnings():
# https://github.com/dask/dask/issues/11804
warnings.filterwarnings(
"ignore",
message=".*`meta` is not specified",
category=UserWarning,
)
grouped = df.native.groupby(partition_by)
if dask_function_name == "size":
if len(output_names) != 1: # pragma: no cover
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform(
dask_function_name, **self._scalar_kwargs
).to_frame(output_names[0])
else:
res_native = grouped[list(output_names)].transform(
dask_function_name, **self._scalar_kwargs
)
result_frame = df._with_native(
res_native.rename(columns=dict(zip(output_names, aliases)))
).native
return [result_frame[name] for name in aliases]
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
def cast(self, dtype: IntoDType) -> Self:
def func(expr: dx.Series) -> dx.Series:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
return expr.astype(native_dtype)
return self._with_callable(func, "cast")
def is_finite(self) -> Self:
import dask.array as da
return self._with_callable(da.isfinite, "is_finite")
def log(self, base: float) -> Self:
import dask.array as da
def _log(expr: dx.Series) -> dx.Series:
return da.log(expr) / da.log(base)
return self._with_callable(_log, "log")
def exp(self) -> Self:
import dask.array as da
return self._with_callable(da.exp, "exp")
def sqrt(self) -> Self:
import dask.array as da
return self._with_callable(da.sqrt, "sqrt")
def mode(self, *, keep: ModeKeepStrategy) -> Self:
def func(expr: dx.Series) -> dx.Series:
_name = expr.name
result = expr.to_frame().mode()[_name]
return result.head(1) if keep == "any" else result
return self._with_callable(func, "mode", scalar_kwargs={"keep": keep})
@property
def str(self) -> DaskExprStringNamespace:
return DaskExprStringNamespace(self)
@property
def dt(self) -> DaskExprDateTimeNamespace:
return DaskExprDateTimeNamespace(self)
arg_max: not_implemented = not_implemented()
arg_min: not_implemented = not_implemented()
arg_true: not_implemented = not_implemented()
ewm_mean: not_implemented = not_implemented()
gather_every: not_implemented = not_implemented()
head: not_implemented = not_implemented()
map_batches: not_implemented = not_implemented()
sample: not_implemented = not_implemented()
rank: not_implemented = not_implemented()
replace_strict: not_implemented = not_implemented()
sort: not_implemented = not_implemented()
tail: not_implemented = not_implemented()
# namespaces
list: not_implemented = not_implemented() # type: ignore[assignment]
cat: not_implemented = not_implemented() # type: ignore[assignment]
struct: not_implemented = not_implemented() # type: ignore[assignment]

View File

@ -0,0 +1,175 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import DateTimeNamespace
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
from narwhals._duration import Interval
from narwhals._pandas_like.utils import (
ALIAS_DICT,
calculate_timestamp_date,
calculate_timestamp_datetime,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from narwhals._dask.expr import DaskExpr
from narwhals.typing import TimeUnit
class DaskExprDateTimeNamespace(
LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"]
):
def date(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.date, "date")
def year(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.year, "year")
def month(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.month, "month")
def day(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.day, "day")
def hour(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour")
def minute(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute")
def second(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.dt.second, "second")
def millisecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond // 1000, "millisecond"
)
def microsecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond, "microsecond"
)
def nanosecond(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
)
def ordinal_day(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.dayofyear, "ordinal_day"
)
def weekday(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.weekday + 1, # Dask is 0-6
"weekday",
)
def to_string(self, format: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
"strftime",
format=format,
)
def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
if time_zone is not None
else expr.dt.tz_localize(None),
"tz_localize",
time_zone=time_zone,
)
def convert_time_zone(self, time_zone: str) -> DaskExpr:
def func(s: dx.Series, time_zone: str) -> dx.Series:
dtype = native_to_narwhals_dtype(
s.dtype, self.compliant._version, Implementation.DASK
)
if dtype.time_zone is None: # type: ignore[attr-defined]
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone)
# ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808.
def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover
def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
dtype = native_to_narwhals_dtype(
s.dtype, self.compliant._version, Implementation.DASK
)
is_pyarrow_dtype = "pyarrow" in str(dtype)
mask_na = s.isna()
dtypes = self.compliant._version.dtypes
if dtype == dtypes.Date:
# Date is only supported in pandas dtypes if pyarrow-backed
s_cast = s.astype("Int32[pyarrow]")
result = calculate_timestamp_date(s_cast, time_unit)
elif isinstance(dtype, dtypes.Datetime):
original_time_unit = dtype.time_unit
s_cast = (
s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
)
result = calculate_timestamp_datetime(
s_cast, original_time_unit, time_unit
)
else:
msg = "Input should be either of Date or Datetime type"
raise TypeError(msg)
return result.where(~mask_na) # pyright: ignore[reportReturnType]
return self.compliant._with_callable(func, "datetime", time_unit=time_unit)
def total_minutes(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
)
def total_seconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
)
def total_milliseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1,
"total_milliseconds",
)
def total_microseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1,
"total_microseconds",
)
def total_nanoseconds(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds"
)
def truncate(self, every: str) -> DaskExpr:
interval = Interval.parse(every)
unit = interval.unit
if unit in {"mo", "q", "y"}:
msg = f"Truncating to {unit} is not yet supported for dask."
raise NotImplementedError(msg)
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
def offset_by(self, by: str) -> DaskExpr:
def func(s: dx.Series, by: str) -> dx.Series:
interval = Interval.parse_no_constraints(by)
unit = interval.unit
if unit in {"y", "q", "mo", "d", "ns"}:
msg = f"Offsetting by {unit} is not yet supported for dask."
raise NotImplementedError(msg)
offset = interval.to_timedelta()
return s.add(offset)
return self.compliant._with_callable(func, "offset_by", by=by)

View File

@ -0,0 +1,121 @@
from __future__ import annotations
from typing import TYPE_CHECKING
import dask.dataframe as dd
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import StringNamespace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx
from narwhals._dask.expr import DaskExpr
class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]):
def len_chars(self) -> DaskExpr:
return self.compliant._with_callable(lambda expr: expr.str.len(), "len")
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr:
def _replace(
expr: dx.Series, pattern: str, value: str, *, literal: bool, n: int
) -> dx.Series:
try:
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
pattern, value, regex=not literal, n=n
)
except TypeError as e:
if not isinstance(value, str):
msg = "dask backed `Expr.str.replace` only supports str replacement values"
raise TypeError(msg) from e
raise
return self.compliant._with_callable(
_replace, "replace", pattern=pattern, value=value, literal=literal, n=n
)
def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr:
def _replace_all(
expr: dx.Series, pattern: str, value: str, *, literal: bool
) -> dx.Series:
try:
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
pattern, value, regex=not literal, n=-1
)
except TypeError as e:
if not isinstance(value, str):
msg = "dask backed `Expr.str.replace_all` only supports str replacement values."
raise TypeError(msg) from e
raise
return self.compliant._with_callable(
_replace_all, "replace", pattern=pattern, value=value, literal=literal
)
def strip_chars(self, characters: str | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, characters: expr.str.strip(characters),
"strip",
characters=characters,
)
def starts_with(self, prefix: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix
)
def ends_with(self, suffix: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix
)
def contains(self, pattern: str, *, literal: bool) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, pattern, literal: expr.str.contains(
pat=pattern, regex=not literal
),
"contains",
pattern=pattern,
literal=literal,
)
def slice(self, offset: int, length: int | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, offset, length: expr.str.slice(
start=offset, stop=offset + length if length else None
),
"slice",
offset=offset,
length=length,
)
def split(self, by: str) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, by: expr.str.split(pat=by), "split", by=by
)
def to_datetime(self, format: str | None) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, format: dd.to_datetime(expr, format=format),
"to_datetime",
format=format,
)
def to_uppercase(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.str.upper(), "to_uppercase"
)
def to_lowercase(self) -> DaskExpr:
return self.compliant._with_callable(
lambda expr: expr.str.lower(), "to_lowercase"
)
def zfill(self, width: int) -> DaskExpr:
return self.compliant._with_callable(
lambda expr, width: expr.str.zfill(width), "zfill", width=width
)
to_date = not_implemented()

View File

@ -0,0 +1,147 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, ClassVar
import dask.dataframe as dd
from narwhals._compliant import DepthTrackingGroupBy
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import zip_strict
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
import pandas as pd
from dask.dataframe.api import GroupBy as _DaskGroupBy
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
from typing_extensions import TypeAlias
from narwhals._compliant.typing import NarwhalsAggregation
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
_AggFn: TypeAlias = Callable[..., Any]
else:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx
_DaskGroupBy = dx._groupby.GroupBy
Aggregation: TypeAlias = "str | _AggFn"
"""The name of an aggregation function, or the function itself."""
def n_unique() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.nunique(dropna=False)
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.sum()
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)
def _all() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.all(skipna=True)
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.all(skipna=True)
return dd.Aggregation(name="all", chunk=chunk, agg=agg)
def _any() -> dd.Aggregation:
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
return s.any(skipna=True)
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
return s0.any(skipna=True)
return dd.Aggregation(name="any", chunk=chunk, agg=agg)
def var(ddof: int) -> _AggFn:
return partial(_DaskGroupBy.var, ddof=ddof)
def std(ddof: int) -> _AggFn:
return partial(_DaskGroupBy.std, ddof=ddof)
class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"std": std,
"var": var,
"len": "size",
"n_unique": n_unique,
"count": "count",
"quantile": "quantile",
"all": _all,
"any": _any,
}
def __init__(
self,
df: DaskLazyFrame,
keys: Sequence[DaskExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
df, keys=keys
)
self._grouped = self.compliant.native.groupby(
self._keys, dropna=drop_null_keys, observed=True
)
def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
from narwhals._dask.dataframe import DaskLazyFrame
if not exprs:
# No aggregation provided
return (
self.compliant.simple_select(*self._keys)
.unique(self._keys, keep="any")
.rename(dict(zip(self._keys, self._output_key_names)))
)
self._ensure_all_simple(exprs)
# This should be the fastpath, but cuDF is too far behind to use it.
# - https://github.com/rapidsai/cudf/issues/15118
# - https://github.com/rapidsai/cudf/issues/15084
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
exclude = (*self._keys, *self._output_key_names)
for expr in exprs:
output_names, aliases = evaluate_output_names_and_aliases(
expr, self.compliant, exclude
)
if expr._depth == 0:
# e.g. `agg(nw.len())`
column = self._keys[0]
agg_fn = self._remap_expr_name(expr._function_name)
simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn)))
continue
# e.g. `agg(nw.mean('a'))`
agg_fn = self._remap_expr_name(self._leaf_name(expr))
# deal with n_unique case in a "lazy" mode to not depend on dask globally
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
simple_aggregations.update(
(alias, (output_name, agg_fn))
for alias, output_name in zip_strict(aliases, output_names)
)
return DaskLazyFrame(
self._grouped.agg(**simple_aggregations).reset_index(),
version=self.compliant._version,
).rename(dict(zip(self._keys, self._output_key_names)))

View File

@ -0,0 +1,338 @@
from __future__ import annotations
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, cast
import dask.dataframe as dd
import pandas as pd
from narwhals._compliant import (
CompliantThen,
CompliantWhen,
DepthTrackingNamespace,
LazyNamespace,
)
from narwhals._dask.dataframe import DaskLazyFrame
from narwhals._dask.expr import DaskExpr
from narwhals._dask.selectors import DaskSelectorNamespace
from narwhals._dask.utils import (
align_series_full_broadcast,
narwhals_to_native_dtype,
validate_comparand,
)
from narwhals._expression_parsing import (
ExprKind,
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._utils import Implementation, zip_strict
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Sequence
import dask.dataframe.dask_expr as dx
from narwhals._compliant.typing import ScalarKwargs
from narwhals._utils import Version
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
class DaskNamespace(
LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
):
_implementation: Implementation = Implementation.DASK
@property
def selectors(self) -> DaskSelectorNamespace:
return DaskSelectorNamespace.from_namespace(self)
@property
def _expr(self) -> type[DaskExpr]:
return DaskExpr
@property
def _lazyframe(self) -> type[DaskLazyFrame]:
return DaskLazyFrame
def __init__(self, *, version: Version) -> None:
self._version = version
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
if dtype is not None:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
else:
native_pd_series = pd.Series([value], name="literal")
npartitions = df._native_frame.npartitions
dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
return [dask_series[0].to_series()]
return self._expr(
func,
depth=0,
function_name="lit",
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
version=self._version,
)
def len(self) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
# We don't allow dataframes with 0 columns, so `[0]` is safe.
return [df._native_frame[df.columns[0]].size.to_series()]
return self._expr(
func,
depth=0,
function_name="len",
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
version=self._version,
)
def all_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
# Note on `ignore_nulls`: Dask doesn't support storing arbitrary Python
# objects in `object` dtype, so we don't need the same check we have for pandas-like.
if ignore_nulls:
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
series = (s if s.dtype == "bool" else s.fillna(True) for s in series)
return [reduce(operator.and_, align_series_full_broadcast(df, *series))]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def any_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
if ignore_nulls:
series = (s if s.dtype == "bool" else s.fillna(False) for s in series)
return [reduce(operator.or_, align_series_full_broadcast(df, *series))]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="any_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).sum(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="sum_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def concat(
self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
) -> DaskLazyFrame:
if not items:
msg = "No items to concatenate" # pragma: no cover
raise AssertionError(msg)
dfs = [i._native_frame for i in items]
cols_0 = dfs[0].columns
if how == "vertical":
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not (
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0.to_list()}\n"
f" - dataframe {i}: {cols_current.to_list()}\n"
)
raise TypeError(msg)
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="inner"), version=self._version
)
if how == "diagonal":
return DaskLazyFrame(
dd.concat(dfs, axis=0, join="outer"), version=self._version
)
raise NotImplementedError
def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
non_na = align_series_full_broadcast(
df, *(1 - s.isna() for s in expr_results)
)
num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="mean_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).min(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="min_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [dd.concat(series, axis=1).max(axis=1)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="max_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def when(self, predicate: DaskExpr) -> DaskWhen:
return DaskWhen.from_expr(predicate, context=self)
def concat_str(
self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
expr_results = [s for _expr in exprs for s in _expr(df)]
series = (
s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
)
null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]
if not ignore_nulls:
null_mask_result = reduce(operator.or_, null_mask)
result = reduce(lambda x, y: x + separator + y, series).where(
~null_mask_result, None
)
else:
init_value, *values = [
s.where(~nm, "") for s, nm in zip_strict(series, null_mask)
]
separators = (
nm.map({True: "", False: separator}, meta=str)
for nm in null_mask[:-1]
)
result = reduce(
operator.add,
(s + v for s, v in zip_strict(separators, values)),
init_value,
)
return [result]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="concat_str",
evaluate_output_names=getattr(
exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
),
alias_output_names=getattr(exprs[0], "_alias_output_names", None),
version=self._version,
)
def coalesce(self, *exprs: DaskExpr) -> DaskExpr:
def func(df: DaskLazyFrame) -> list[dx.Series]:
series = align_series_full_broadcast(
df, *(s for _expr in exprs for s in _expr(df))
)
return [reduce(lambda x, y: x.fillna(y), series)]
return self._expr(
call=func,
depth=max(x._depth for x in exprs) + 1,
function_name="coalesce",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments]
@property
def _then(self) -> type[DaskThen]:
return DaskThen
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
then_value = (
self._then_value(df)[0]
if isinstance(self._then_value, DaskExpr)
else self._then_value
)
otherwise_value = (
self._otherwise_value(df)[0]
if isinstance(self._otherwise_value, DaskExpr)
else self._otherwise_value
)
condition = self._condition(df)[0]
# re-evaluate DataFrame if the condition aggregates to force
# then/otherwise to be evaluated against the aggregated frame
assert self._condition._metadata is not None # noqa: S101
if self._condition._metadata.is_scalar_like:
new_df = df._with_native(condition.to_frame())
condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0]
df = new_df
if self._otherwise_value is None:
(condition, then_series) = align_series_full_broadcast(
df, condition, then_value
)
validate_comparand(condition, then_series)
return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
(condition, then_series, otherwise_series) = align_series_full_broadcast(
df, condition, then_value, otherwise_value
)
validate_comparand(condition, then_series)
validate_comparand(condition, otherwise_series)
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments]
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "whenthen"

View File

@ -0,0 +1,34 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
from narwhals._dask.expr import DaskExpr
if TYPE_CHECKING:
import dask.dataframe.dask_expr as dx # noqa: F401
from narwhals._compliant.typing import ScalarKwargs
from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments]
@property
def _selector(self) -> type[DaskSelector]:
return DaskSelector
class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc]
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "selector"
def _to_expr(self) -> DaskExpr:
return DaskExpr(
self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

View File

@ -0,0 +1,139 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from narwhals._pandas_like.utils import select_columns_by_name
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
from narwhals.dependencies import get_pyarrow
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
import dask.dataframe as dd
import dask.dataframe.dask_expr as dx
from narwhals._dask.dataframe import DaskLazyFrame, Incomplete
from narwhals._dask.expr import DaskExpr
from narwhals.dtypes import DType
from narwhals.typing import IntoDType
else:
try:
import dask.dataframe.dask_expr as dx
except ModuleNotFoundError: # pragma: no cover
import dask_expr as dx
def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object:
from narwhals._dask.expr import DaskExpr
if isinstance(obj, DaskExpr):
results = obj._call(df)
assert len(results) == 1 # debug assertion # noqa: S101
return results[0]
return obj
def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]:
native_results: list[tuple[str, dx.Series]] = []
for expr in exprs:
native_series_list = expr(df)
aliases = expr._evaluate_aliases(df)
if len(aliases) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.extend(zip(aliases, native_series_list))
return native_results
def align_series_full_broadcast(
df: DaskLazyFrame, *series: dx.Series | object
) -> Sequence[dx.Series]:
return [
s if isinstance(s, dx.Series) else df._native_frame.assign(_tmp=s)["_tmp"]
for s in series
] # pyright: ignore[reportReturnType]
def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame:
original_cols = frame.columns
df: Incomplete = frame.assign(**{name: 1})
return select_columns_by_name(
df.assign(**{name: df[name].cumsum(method="blelloch") - 1}),
[name, *original_cols],
Implementation.DASK,
)
def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
# are_co_aligned is a method which cheaply checks if two Dask expressions
# have the same index, and therefore don't require index alignment.
# If someone only operates on a Dask DataFrame via expressions, then this
# should always be the case: expression outputs (by definition) all come from the
# same input dataframe, and Dask Series does not have any operations which
# change the index. Nonetheless, we perform this safety check anyway.
# However, we still need to carefully vet which methods we support for Dask, to
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
# https://github.com/dask/dask-expr/issues/1112.
msg = "Objects are not co-aligned, so this operation is not supported for Dask backend"
raise RuntimeError(msg)
dtypes = Version.MAIN.dtypes
dtypes_v1 = Version.V1.dtypes
NW_TO_DASK_DTYPES: Mapping[type[DType], str] = {
dtypes.Float64: "float64",
dtypes.Float32: "float32",
dtypes.Boolean: "bool",
dtypes.Categorical: "category",
dtypes.Date: "date32[day][pyarrow]",
dtypes.Int8: "int8",
dtypes.Int16: "int16",
dtypes.Int32: "int32",
dtypes.Int64: "int64",
dtypes.UInt8: "uint8",
dtypes.UInt16: "uint16",
dtypes.UInt32: "uint32",
dtypes.UInt64: "uint64",
dtypes.Datetime: "datetime64[us]",
dtypes.Duration: "timedelta64[ns]",
dtypes_v1.Datetime: "datetime64[us]",
dtypes_v1.Duration: "timedelta64[ns]",
}
UNSUPPORTED_DTYPES = (
dtypes.List,
dtypes.Struct,
dtypes.Array,
dtypes.Time,
dtypes.Binary,
)
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any:
dtypes = version.dtypes
base_type = dtype.base_type()
if dask_type := NW_TO_DASK_DTYPES.get(base_type):
return dask_type
if isinstance_or_issubclass(dtype, dtypes.String):
if Implementation.PANDAS._backend_version() >= (2, 0, 0):
return "string[pyarrow]" if get_pyarrow() else "string[python]"
return "object" # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
import pandas as pd
# NOTE: `pandas-stubs.core.dtypes.dtypes.CategoricalDtype.categories` is too narrow
# Should be one of the `ListLike*` types
# https://github.com/pandas-dev/pandas-stubs/blob/8434bde95460b996323cc8c0fea7b0a8bb00ea26/pandas-stubs/_typing.pyi#L497-L505
return pd.CategoricalDtype(dtype.categories, ordered=True) # type: ignore[arg-type]
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
msg = f"Converting to {base_type.__name__} dtype is not supported for Dask."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)

View File

@ -0,0 +1,542 @@
from __future__ import annotations
from functools import reduce
from operator import and_
from typing import TYPE_CHECKING, Any
import duckdb
from duckdb import StarExpression
from narwhals._duckdb.utils import (
DeferredTimeZone,
F,
catch_duckdb_exception,
col,
evaluate_exprs,
join_column_names,
lit,
native_to_narwhals_dtype,
window_expression,
)
from narwhals._sql.dataframe import SQLLazyFrame
from narwhals._utils import (
Implementation,
ValidateBackendVersion,
Version,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
requires,
zip_strict,
)
from narwhals.dependencies import get_duckdb
from narwhals.exceptions import InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pandas as pd
import pyarrow as pa
from duckdb import Expression
from duckdb.typing import DuckDBPyType
from typing_extensions import Self, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.group_by import DuckDBGroupBy
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._duckdb.series import DuckDBInterchangeSeries
from narwhals._typing import _EagerAllowedImpl
from narwhals._utils import _LimitedContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.stable.v1 import DataFrame as DataFrameV1
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
class DuckDBLazyFrame(
SQLLazyFrame[
"DuckDBExpr",
"duckdb.DuckDBPyRelation",
"LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]",
],
ValidateBackendVersion,
):
_implementation = Implementation.DUCKDB
def __init__(
self,
df: duckdb.DuckDBPyRelation,
*,
version: Version,
validate_backend_version: bool = False,
) -> None:
self._native_frame: duckdb.DuckDBPyRelation = df
self._version = version
self._cached_native_schema: dict[str, DuckDBPyType] | None = None
self._cached_columns: list[str] | None = None
if validate_backend_version:
self._validate_backend_version()
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@staticmethod
def _is_native(obj: duckdb.DuckDBPyRelation | Any) -> TypeIs[duckdb.DuckDBPyRelation]:
return isinstance(obj, duckdb.DuckDBPyRelation)
@classmethod
def from_native(
cls, data: duckdb.DuckDBPyRelation, /, *, context: _LimitedContext
) -> Self:
return cls(data, version=context._version)
def to_narwhals(
self, *args: Any, **kwds: Any
) -> LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]:
if self._version is Version.V1:
from narwhals.stable.v1 import DataFrame as DataFrameV1
return DataFrameV1(self, level="interchange") # type: ignore[no-any-return]
return self._version.lazyframe(self, level="lazy")
def __narwhals_dataframe__(self) -> Self: # pragma: no cover
# Keep around for backcompat.
if self._version is not Version.V1:
msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame"
raise AttributeError(msg)
return self
def __narwhals_lazyframe__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType:
return get_duckdb() # type: ignore[no-any-return]
def __narwhals_namespace__(self) -> DuckDBNamespace:
from narwhals._duckdb.namespace import DuckDBNamespace
return DuckDBNamespace(version=self._version)
def get_column(self, name: str) -> DuckDBInterchangeSeries:
from narwhals._duckdb.series import DuckDBInterchangeSeries
return DuckDBInterchangeSeries(self.native.select(name), version=self._version)
def _iter_columns(self) -> Iterator[Expression]:
for name in self.columns:
yield col(name)
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is None or backend is Implementation.PYARROW:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self.native.arrow(),
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.df(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
self.native.pl(), validate_backend_version=True, version=self._version
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def head(self, n: int) -> Self:
return self._with_native(self.native.limit(n))
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: DuckDBExpr) -> Self:
selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)]
try:
return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type]
except Exception as e: # noqa: BLE001
raise catch_duckdb_exception(e, self) from None
def select(self, *exprs: DuckDBExpr) -> Self:
selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs))
try:
return self._with_native(self.native.select(*selection))
except Exception as e: # noqa: BLE001
raise catch_duckdb_exception(e, self) from None
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
selection = (name for name in self.columns if name not in columns_to_drop)
return self._with_native(self.native.select(*selection))
def lazy(self, backend: None = None, **_: None) -> Self:
# The `backend`` argument has no effect but we keep it here for
# backwards compatibility because in `narwhals.stable.v1`
# function `.from_native()` will return a DataFrame for DuckDB.
if backend is not None: # pragma: no cover
msg = "`backend` argument is not supported for DuckDB"
raise ValueError(msg)
return self
def with_columns(self, *exprs: DuckDBExpr) -> Self:
new_columns_map = dict(evaluate_exprs(self, *exprs))
result = [
new_columns_map.pop(name).alias(name)
if name in new_columns_map
else col(name)
for name in self.columns
]
result.extend(value.alias(name) for name, value in new_columns_map.items())
try:
return self._with_native(self.native.select(*result))
except Exception as e: # noqa: BLE001
raise catch_duckdb_exception(e, self) from None
def filter(self, predicate: DuckDBExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = predicate(self)[0]
try:
return self._with_native(self.native.filter(mask))
except Exception as e: # noqa: BLE001
raise catch_duckdb_exception(e, self) from None
@property
def schema(self) -> dict[str, DType]:
if self._cached_native_schema is None:
# Note: prefer `self._cached_native_schema` over `functools.cached_property`
# due to Python3.13 failures.
self._cached_native_schema = dict(zip(self.columns, self.native.types))
deferred_time_zone = DeferredTimeZone(self.native)
return {
column_name: native_to_narwhals_dtype(
duckdb_dtype, self._version, deferred_time_zone
)
for column_name, duckdb_dtype in zip_strict(
self.native.columns, self.native.types
)
}
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_native_schema is not None
else self.native.columns
)
return self._cached_columns
def to_pandas(self) -> pd.DataFrame:
# only if version is v1, keep around for backcompat
return self.native.df()
def to_arrow(self) -> pa.Table:
# only if version is v1, keep around for backcompat
return self.native.arrow()
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self:
return self.__class__(df, version=self._version)
def group_by(
self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool
) -> DuckDBGroupBy:
from narwhals._duckdb.group_by import DuckDBGroupBy
return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys)
def rename(self, mapping: Mapping[str, str]) -> Self:
df = self.native
selection = (
col(name).alias(mapping[name]) if name in mapping else col(name)
for name in df.columns
)
return self._with_native(self.native.select(*selection))
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
native_how = "outer" if how == "full" else how
if native_how == "cross":
if self._backend_version < (1, 1, 4):
msg = f"'duckdb>=1.1.4' is required for cross-join, found version: {self._backend_version}"
raise NotImplementedError(msg)
rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs"))
else:
# help mypy
assert left_on is not None # noqa: S101
assert right_on is not None # noqa: S101
it = (
col(f'lhs."{left}"') == col(f'rhs."{right}"')
for left, right in zip_strict(left_on, right_on)
)
condition: Expression = reduce(and_, it)
rel = self.native.set_alias("lhs").join(
other.native.set_alias("rhs"),
# NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933
condition=condition, # type: ignore[arg-type, unused-ignore]
how=native_how,
)
if native_how in {"inner", "left", "cross", "outer"}:
select = [col(f'lhs."{x}"') for x in self.columns]
for name in other.columns:
col_in_lhs: bool = name in self.columns
if native_how == "outer" and not col_in_lhs:
select.append(col(f'rhs."{name}"'))
elif (native_how == "outer") or (
col_in_lhs and (right_on is None or name not in right_on)
):
select.append(col(f'rhs."{name}"').alias(f"{name}{suffix}"))
elif right_on is None or name not in right_on:
select.append(col(name))
res = rel.select(*select).set_alias(self.native.alias)
else: # semi, anti
res = rel.select("lhs.*").set_alias(self.native.alias)
return self._with_native(res)
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self:
lhs = self.native
rhs = other.native
conditions: list[Expression] = []
if by_left is not None and by_right is not None:
conditions.extend(
col(f'lhs."{left}"') == col(f'rhs."{right}"')
for left, right in zip_strict(by_left, by_right)
)
else:
by_left = by_right = []
if strategy == "backward":
conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"'))
elif strategy == "forward":
conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"'))
else:
msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB"
raise NotImplementedError(msg)
condition: Expression = reduce(and_, conditions)
select = ["lhs.*"]
for name in rhs.columns:
if name in lhs.columns and (
right_on is None or name not in {right_on, *by_right}
):
select.append(f'rhs."{name}" as "{name}{suffix}"')
elif right_on is None or name not in {right_on, *by_right}:
select.append(str(col(name)))
# Replace with Python API call once
# https://github.com/duckdb/duckdb/discussions/16947 is addressed.
query = f"""
SELECT {",".join(select)}
FROM lhs
ASOF LEFT JOIN rhs
ON {condition}
""" # noqa: S608
return self._with_native(duckdb.sql(query))
def collect_schema(self) -> dict[str, DType]:
return self.schema
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset_ := subset if keep == "any" else (subset or self.columns):
# Sanitise input
if error := self._check_columns_exist(subset_):
raise error
idx_name = generate_temporary_column_name(8, self.columns)
count_name = generate_temporary_column_name(8, [*self.columns, idx_name])
name = count_name if keep == "none" else idx_name
idx_expr = window_expression(F("row_number"), subset_).alias(idx_name)
count_expr = window_expression(
F("count", StarExpression()), subset_, ()
).alias(count_name)
return self._with_native(
self.native.select(StarExpression(), idx_expr, count_expr)
.filter(col(name) == lit(1))
.select(StarExpression(exclude=[count_name, idx_name]))
)
return self._with_native(self.native.unique(join_column_names(*self.columns)))
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
descending = [descending] * len(by)
if nulls_last:
it = (
col(name).nulls_last() if not desc else col(name).desc().nulls_last()
for name, desc in zip_strict(by, descending)
)
else:
it = (
col(name).nulls_first() if not desc else col(name).desc().nulls_first()
for name, desc in zip_strict(by, descending)
)
return self._with_native(self.native.sort(*it))
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
_df = self.native
by = list(by)
if isinstance(reverse, bool):
descending = [not reverse] * len(by)
else:
descending = [not rev for rev in reverse]
expr = window_expression(
F("row_number"),
order_by=by,
descending=descending,
nulls_last=[True] * len(by),
)
condition = expr <= lit(k)
query = f"""
SELECT *
FROM _df
QUALIFY {condition}
""" # noqa: S608
return self._with_native(duckdb.sql(query))
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset_ = subset if subset is not None else self.columns
keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_))
return self._with_native(self.native.filter(keep_condition))
def explode(self, columns: Sequence[str]) -> Self:
dtypes = self._version.dtypes
schema = self.collect_schema()
for name in columns:
dtype = schema[name]
if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)
if len(columns) != 1:
msg = (
"Exploding on multiple columns is not supported with DuckDB backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)
col_to_explode = col(columns[0])
rel = self.native
original_columns = self.columns
not_null_condition = col_to_explode.isnotnull() & F("len", col_to_explode) > lit(
0
)
non_null_rel = rel.filter(not_null_condition).select(
*(
F("unnest", col_to_explode).alias(name) if name in columns else name
for name in original_columns
)
)
null_rel = rel.filter(~not_null_condition).select(
*(
lit(None).alias(name) if name in columns else name
for name in original_columns
)
)
return self._with_native(non_null_rel.union(null_rel))
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
index_ = [] if index is None else index
on_ = [c for c in self.columns if c not in index_] if on is None else on
if variable_name == "":
msg = "`variable_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)
if value_name == "":
msg = "`value_name` cannot be empty string for duckdb backend."
raise NotImplementedError(msg)
unpivot_on = join_column_names(*on_)
rel = self.native # noqa: F841
# Replace with Python API once
# https://github.com/duckdb/duckdb/discussions/16980 is addressed.
query = f"""
unpivot rel
on {unpivot_on}
into
name {col(variable_name)}
value {col(value_name)}
"""
return self._with_native(
duckdb.sql(query).select(*[*index_, variable_name, value_name])
)
@requires.backend_version((1, 3))
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
if order_by is None:
msg = "Cannot pass `order_by` to `with_row_index` for DuckDB"
raise TypeError(msg)
expr = (window_expression(F("row_number"), order_by=order_by) - lit(1)).alias(
name
)
return self._with_native(self.native.select(expr, StarExpression()))
def sink_parquet(self, file: str | Path | BytesIO) -> None:
df = self.native # noqa: F841
query = f"""
COPY (SELECT * FROM df)
TO '{file}'
(FORMAT parquet)
""" # noqa: S608
duckdb.sql(query)
gather_every = not_implemented.deprecated(
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
tail = not_implemented.deprecated(
"`LazyFrame.tail` is deprecated and will be removed in a future version."
)

View File

@ -0,0 +1,303 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
from duckdb import CoalesceOperator, StarExpression
from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
from narwhals._duckdb.utils import (
DeferredTimeZone,
F,
col,
lit,
narwhals_to_native_dtype,
when,
window_expression,
)
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version
if TYPE_CHECKING:
from collections.abc import Sequence
from duckdb import Expression
from typing_extensions import Self
from narwhals._compliant import WindowInputs
from narwhals._compliant.typing import (
AliasNames,
EvalNames,
EvalSeries,
WindowFunction,
)
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._utils import _LimitedContext
from narwhals.typing import (
FillNullStrategy,
IntoDType,
NonNestedLiteral,
RollingInterpolationMethod,
)
DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression]
DuckDBWindowInputs = WindowInputs[Expression]
class DuckDBExpr(SQLExpr["DuckDBLazyFrame", "Expression"]):
_implementation = Implementation.DUCKDB
def __init__(
self,
call: EvalSeries[DuckDBLazyFrame, Expression],
window_function: DuckDBWindowFunction | None = None,
*,
evaluate_output_names: EvalNames[DuckDBLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
implementation: Implementation = Implementation.DUCKDB,
) -> None:
self._call = call
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._metadata: ExprMetadata | None = None
self._window_function: DuckDBWindowFunction | None = window_function
def _count_star(self) -> Expression:
return F("count", StarExpression())
def _window_expression(
self,
expr: Expression,
partition_by: Sequence[str | Expression] = (),
order_by: Sequence[str | Expression] = (),
rows_start: int | None = None,
rows_end: int | None = None,
*,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> Expression:
return window_expression(
expr,
partition_by,
order_by,
rows_start,
rows_end,
descending=descending,
nulls_last=nulls_last,
)
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
from narwhals._duckdb.namespace import DuckDBNamespace
return DuckDBNamespace(version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
if kind is ExprKind.LITERAL:
return self
if self._backend_version < (1, 3):
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
raise NotImplementedError(msg)
return self.over([lit(1)], [])
@classmethod
def from_column_names(
cls,
evaluate_column_names: EvalNames[DuckDBLazyFrame],
/,
*,
context: _LimitedContext,
) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
return [col(name) for name in evaluate_column_names(df)]
return cls(
func,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
columns = df.columns
return [col(columns[i]) for i in column_indices]
return cls(
func,
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
@classmethod
def _alias_native(cls, expr: Expression, name: str) -> Expression:
return expr.alias(name)
def __invert__(self) -> Self:
invert = cast("Callable[..., Expression]", operator.invert)
return self._with_elementwise(invert)
def skew(self) -> Self:
def func(expr: Expression) -> Expression:
count = F("count", expr)
# Adjust population skewness by correction factor to get sample skewness
sample_skewness = (
F("skewness", expr)
* (count - lit(2))
/ F("sqrt", count * (count - lit(1)))
)
return when(count == lit(0), lit(None)).otherwise(
when(count == lit(1), lit(float("nan"))).otherwise(
when(count == lit(2), lit(0.0)).otherwise(sample_skewness)
)
)
return self._with_callable(func)
def kurtosis(self) -> Self:
return self._with_callable(lambda expr: F("kurtosis_pop", expr))
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
def func(expr: Expression) -> Expression:
if interpolation == "linear":
return F("quantile_cont", expr, lit(quantile))
msg = "Only linear interpolation methods are supported for DuckDB quantile."
raise NotImplementedError(msg)
return self._with_callable(func)
def n_unique(self) -> Self:
def func(expr: Expression) -> Expression:
# https://stackoverflow.com/a/79338887/4451315
return F("array_unique", F("array_agg", expr)) + F(
"max", when(expr.isnotnull(), lit(0)).otherwise(lit(1))
)
return self._with_callable(func)
def len(self) -> Self:
return self._with_callable(lambda _expr: F("count"))
def std(self, ddof: int) -> Self:
if ddof == 0:
return self._with_callable(lambda expr: F("stddev_pop", expr))
if ddof == 1:
return self._with_callable(lambda expr: F("stddev_samp", expr))
def _std(expr: Expression) -> Expression:
n_samples = F("count", expr)
return (
F("stddev_pop", expr)
* F("sqrt", n_samples)
/ (F("sqrt", (n_samples - lit(ddof))))
)
return self._with_callable(_std)
def var(self, ddof: int) -> Self:
if ddof == 0:
return self._with_callable(lambda expr: F("var_pop", expr))
if ddof == 1:
return self._with_callable(lambda expr: F("var_samp", expr))
def _var(expr: Expression) -> Expression:
n_samples = F("count", expr)
return F("var_pop", expr) * n_samples / (n_samples - lit(ddof))
return self._with_callable(_var)
def null_count(self) -> Self:
return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int")))
def is_nan(self) -> Self:
return self._with_elementwise(lambda expr: F("isnan", expr))
def is_finite(self) -> Self:
return self._with_elementwise(lambda expr: F("isfinite", expr))
def is_in(self, other: Sequence[Any]) -> Self:
return self._with_elementwise(lambda expr: F("contains", lit(other), expr))
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
if strategy is not None:
if self._backend_version < (1, 3): # pragma: no cover
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
raise NotImplementedError(msg)
def _fill_with_strategy(
df: DuckDBLazyFrame, inputs: DuckDBWindowInputs
) -> Sequence[Expression]:
fill_func = "last_value" if strategy == "forward" else "first_value"
rows_start, rows_end = (
(-limit if limit is not None else None, 0)
if strategy == "forward"
else (0, limit)
)
return [
window_expression(
F(fill_func, expr),
inputs.partition_by,
inputs.order_by,
rows_start=rows_start,
rows_end=rows_end,
ignore_nulls=True,
)
for expr in self(df)
]
return self._with_window_function(_fill_with_strategy)
def _fill_constant(expr: Expression, value: Any) -> Expression:
return CoalesceOperator(expr, value)
return self._with_elementwise(_fill_constant, value=value)
def cast(self, dtype: IntoDType) -> Self:
def func(df: DuckDBLazyFrame) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return [expr.cast(native_dtype) for expr in self(df)]
def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
tz = DeferredTimeZone(df.native)
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]
return self.__class__(
func,
window_f,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
@property
def str(self) -> DuckDBExprStringNamespace:
return DuckDBExprStringNamespace(self)
@property
def dt(self) -> DuckDBExprDateTimeNamespace:
return DuckDBExprDateTimeNamespace(self)
@property
def list(self) -> DuckDBExprListNamespace:
return DuckDBExprListNamespace(self)
@property
def struct(self) -> DuckDBExprStructNamespace:
return DuckDBExprStructNamespace(self)

View File

@ -0,0 +1,132 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._constants import (
MS_PER_MINUTE,
MS_PER_SECOND,
NS_PER_SECOND,
SECONDS_PER_MINUTE,
US_PER_MINUTE,
US_PER_SECOND,
)
from narwhals._duckdb.utils import UNITS_DICT, F, fetch_rel_time_zone, lit
from narwhals._duration import Interval
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
from collections.abc import Sequence
from duckdb import Expression
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.expr import DuckDBExpr
class DuckDBExprDateTimeNamespace(SQLExprDateTimeNamesSpace["DuckDBExpr"]):
def millisecond(self) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("millisecond", expr) - F("second", expr) * lit(MS_PER_SECOND)
)
def microsecond(self) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("microsecond", expr) - F("second", expr) * lit(US_PER_SECOND)
)
def nanosecond(self) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("nanosecond", expr) - F("second", expr) * lit(NS_PER_SECOND)
)
def to_string(self, format: str) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("strftime", expr, lit(format))
)
def weekday(self) -> DuckDBExpr:
return self.compliant._with_elementwise(lambda expr: F("isodow", expr))
def date(self) -> DuckDBExpr:
return self.compliant._with_elementwise(lambda expr: expr.cast("date"))
def total_minutes(self) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("datepart", lit("minute"), expr)
)
def total_seconds(self) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: lit(SECONDS_PER_MINUTE) * F("datepart", lit("minute"), expr)
+ F("datepart", lit("second"), expr)
)
def total_milliseconds(self) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: lit(MS_PER_MINUTE) * F("datepart", lit("minute"), expr)
+ F("datepart", lit("millisecond"), expr)
)
def total_microseconds(self) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: lit(US_PER_MINUTE) * F("datepart", lit("minute"), expr)
+ F("datepart", lit("microsecond"), expr)
)
def truncate(self, every: str) -> DuckDBExpr:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if multiple != 1:
# https://github.com/duckdb/duckdb/issues/17554
msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}."
raise ValueError(msg)
if unit == "ns":
msg = "Truncating to nanoseconds is not yet supported for DuckDB."
raise NotImplementedError(msg)
format = lit(UNITS_DICT[unit])
def _truncate(expr: Expression) -> Expression:
return F("date_trunc", format, expr)
return self.compliant._with_elementwise(_truncate)
def offset_by(self, by: str) -> DuckDBExpr:
interval = Interval.parse_no_constraints(by)
format = lit(f"{interval.multiple!s} {UNITS_DICT[interval.unit]}")
def _offset_by(expr: Expression) -> Expression:
return F("date_add", format, expr)
return self.compliant._with_callable(_offset_by)
def _no_op_time_zone(self, time_zone: str) -> DuckDBExpr:
def func(df: DuckDBLazyFrame) -> Sequence[Expression]:
native_series_list = self.compliant(df)
conn_time_zone = fetch_rel_time_zone(df.native)
if conn_time_zone != time_zone:
msg = (
"DuckDB stores the time zone in the connection, rather than in the "
f"data type, so changing the timezone to anything other than {conn_time_zone} "
" (the current connection time zone) is not supported."
)
raise NotImplementedError(msg)
return native_series_list
return self.compliant.__class__(
func,
evaluate_output_names=self.compliant._evaluate_output_names,
alias_output_names=self.compliant._alias_output_names,
version=self.compliant._version,
)
def convert_time_zone(self, time_zone: str) -> DuckDBExpr:
return self._no_op_time_zone(time_zone)
def replace_time_zone(self, time_zone: str | None) -> DuckDBExpr:
if time_zone is None:
return self.compliant._with_elementwise(lambda expr: expr.cast("timestamp"))
return self._no_op_time_zone(time_zone)
total_nanoseconds = not_implemented()
timestamp = not_implemented()

View File

@ -0,0 +1,40 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import ListNamespace
from narwhals._duckdb.utils import F, lit, when
if TYPE_CHECKING:
from duckdb import Expression
from narwhals._duckdb.expr import DuckDBExpr
from narwhals.typing import NonNestedLiteral
class DuckDBExprListNamespace(
LazyExprNamespace["DuckDBExpr"], ListNamespace["DuckDBExpr"]
):
def len(self) -> DuckDBExpr:
return self.compliant._with_elementwise(lambda expr: F("len", expr))
def unique(self) -> DuckDBExpr:
def func(expr: Expression) -> Expression:
expr_distinct = F("list_distinct", expr)
return when(
F("array_position", expr, lit(None)).isnotnull(),
F("list_append", expr_distinct, lit(None)),
).otherwise(expr_distinct)
return self.compliant._with_callable(func)
def contains(self, item: NonNestedLiteral) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("list_contains", expr, lit(item))
)
def get(self, index: int) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("list_extract", expr, lit(index + 1))
)

View File

@ -0,0 +1,30 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._duckdb.utils import F, lit
from narwhals._sql.expr_str import SQLExprStringNamespace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
from narwhals._duckdb.expr import DuckDBExpr
class DuckDBExprStringNamespace(SQLExprStringNamespace["DuckDBExpr"]):
def to_datetime(self, format: str | None) -> DuckDBExpr:
if format is None:
msg = "Cannot infer format with DuckDB backend, please specify `format` explicitly."
raise NotImplementedError(msg)
return self.compliant._with_elementwise(
lambda expr: F("strptime", expr, lit(format))
)
def to_date(self, format: str | None) -> DuckDBExpr:
if format is not None:
return self.to_datetime(format=format).dt.date()
compliant_expr = self.compliant
return compliant_expr.cast(compliant_expr._version.dtypes.Date())
replace = not_implemented()

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import StructNamespace
from narwhals._duckdb.utils import F, lit
if TYPE_CHECKING:
from narwhals._duckdb.expr import DuckDBExpr
class DuckDBExprStructNamespace(
LazyExprNamespace["DuckDBExpr"], StructNamespace["DuckDBExpr"]
):
def field(self, name: str) -> DuckDBExpr:
return self.compliant._with_elementwise(
lambda expr: F("struct_extract", expr, lit(name))
).alias(name)

View File

@ -0,0 +1,33 @@
from __future__ import annotations
from itertools import chain
from typing import TYPE_CHECKING
from narwhals._sql.group_by import SQLGroupBy
if TYPE_CHECKING:
from collections.abc import Sequence
from duckdb import Expression # noqa: F401
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.expr import DuckDBExpr
class DuckDBGroupBy(SQLGroupBy["DuckDBLazyFrame", "DuckDBExpr", "Expression"]):
def __init__(
self,
df: DuckDBLazyFrame,
keys: Sequence[DuckDBExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
def agg(self, *exprs: DuckDBExpr) -> DuckDBLazyFrame:
agg_columns = list(chain(self._keys, self._evaluate_exprs(exprs)))
return self.compliant._with_native(
self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type]
).rename(dict(zip(self._keys, self._output_key_names)))

View File

@ -0,0 +1,164 @@
from __future__ import annotations
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any
from duckdb import CoalesceOperator, Expression
from duckdb.typing import BIGINT, VARCHAR
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.expr import DuckDBExpr
from narwhals._duckdb.selectors import DuckDBSelectorNamespace
from narwhals._duckdb.utils import (
DeferredTimeZone,
F,
concat_str,
function,
lit,
narwhals_to_native_dtype,
when,
)
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._sql.namespace import SQLNamespace
from narwhals._sql.when_then import SQLThen, SQLWhen
from narwhals._utils import Implementation
if TYPE_CHECKING:
from collections.abc import Iterable
from duckdb import DuckDBPyRelation # noqa: F401
from narwhals._utils import Version
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
class DuckDBNamespace(
SQLNamespace[DuckDBLazyFrame, DuckDBExpr, "DuckDBPyRelation", Expression]
):
_implementation: Implementation = Implementation.DUCKDB
def __init__(self, *, version: Version) -> None:
self._version = version
@property
def selectors(self) -> DuckDBSelectorNamespace:
return DuckDBSelectorNamespace.from_namespace(self)
@property
def _expr(self) -> type[DuckDBExpr]:
return DuckDBExpr
@property
def _lazyframe(self) -> type[DuckDBLazyFrame]:
return DuckDBLazyFrame
def _function(self, name: str, *args: Expression) -> Expression: # type: ignore[override]
return function(name, *args)
def _lit(self, value: Any) -> Expression:
return lit(value)
def _when(
self,
condition: Expression,
value: Expression,
otherwise: Expression | None = None,
) -> Expression:
if otherwise is None:
return when(condition, value)
return when(condition, value).otherwise(otherwise)
def _coalesce(self, *exprs: Expression) -> Expression:
return CoalesceOperator(*exprs)
def concat(
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
) -> DuckDBLazyFrame:
native_items = [item._native_frame for item in items]
items = list(items)
first = items[0]
schema = first.schema
if how == "vertical" and not all(x.schema == schema for x in items[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg)
res = reduce(lambda x, y: x.union(y), native_items)
return first._with_native(res)
def concat_str(
self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool
) -> DuckDBExpr:
def func(df: DuckDBLazyFrame) -> list[Expression]:
cols = list(chain.from_iterable(expr(df) for expr in exprs))
if not ignore_nulls:
null_mask_result = reduce(operator.or_, (s.isnull() for s in cols))
cols_separated = [
y
for x in [
(col.cast(VARCHAR),)
if i == len(cols) - 1
else (col.cast(VARCHAR), lit(separator))
for i, col in enumerate(cols)
]
for y in x
]
return [when(~null_mask_result, concat_str(*cols_separated))]
return [concat_str(*cols, separator=separator)]
return self._expr(
call=func,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
def func(cols: Iterable[Expression]) -> Expression:
cols = list(cols)
return reduce(
operator.add, (CoalesceOperator(col, lit(0)) for col in cols)
) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols))
return self._expr._from_elementwise_horizontal_op(func, *exprs)
def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
return DuckDBWhen.from_expr(predicate, context=self)
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr:
def func(df: DuckDBLazyFrame) -> list[Expression]:
tz = DeferredTimeZone(df.native)
if dtype is not None:
target = narwhals_to_native_dtype(dtype, self._version, tz)
return [lit(value).cast(target)]
return [lit(value)]
return self._expr(
func,
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
version=self._version,
)
def len(self) -> DuckDBExpr:
def func(_df: DuckDBLazyFrame) -> list[Expression]:
return [F("count")]
return self._expr(
call=func,
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
version=self._version,
)
class DuckDBWhen(SQLWhen["DuckDBLazyFrame", Expression, DuckDBExpr]):
@property
def _then(self) -> type[DuckDBThen]:
return DuckDBThen
class DuckDBThen(SQLThen["DuckDBLazyFrame", Expression, DuckDBExpr], DuckDBExpr): ...

View File

@ -0,0 +1,33 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
from narwhals._duckdb.expr import DuckDBExpr
if TYPE_CHECKING:
from duckdb import Expression # noqa: F401
from narwhals._duckdb.dataframe import DuckDBLazyFrame # noqa: F401
from narwhals._duckdb.expr import DuckDBWindowFunction
class DuckDBSelectorNamespace(LazySelectorNamespace["DuckDBLazyFrame", "Expression"]):
@property
def _selector(self) -> type[DuckDBSelector]:
return DuckDBSelector
class DuckDBSelector( # type: ignore[misc]
CompliantSelector["DuckDBLazyFrame", "Expression"], DuckDBExpr
):
_window_function: DuckDBWindowFunction | None = None
def _to_expr(self) -> DuckDBExpr:
return DuckDBExpr(
self._call,
self._window_function,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

View File

@ -0,0 +1,44 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._duckdb.utils import DeferredTimeZone, native_to_narwhals_dtype
from narwhals.dependencies import get_duckdb
if TYPE_CHECKING:
from types import ModuleType
import duckdb
from typing_extensions import Never, Self
from narwhals._utils import Version
from narwhals.dtypes import DType
class DuckDBInterchangeSeries:
def __init__(self, df: duckdb.DuckDBPyRelation, version: Version) -> None:
self._native_series = df
self._version = version
def __narwhals_series__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType:
return get_duckdb() # type: ignore[no-any-return]
@property
def dtype(self) -> DType:
return native_to_narwhals_dtype(
self._native_series.types[0],
self._version,
DeferredTimeZone(self._native_series),
)
def __getattr__(self, attr: str) -> Never:
msg = ( # pragma: no cover
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg) # pragma: no cover

View File

@ -0,0 +1,18 @@
from __future__ import annotations
from typing import TYPE_CHECKING, TypedDict
if TYPE_CHECKING:
from collections.abc import Sequence
from duckdb import Expression
class WindowExpressionKwargs(TypedDict, total=False):
partition_by: Sequence[str | Expression]
order_by: Sequence[str | Expression]
rows_start: int | None
rows_end: int | None
descending: Sequence[bool]
nulls_last: Sequence[bool]
ignore_nulls: bool

View File

@ -0,0 +1,370 @@
from __future__ import annotations
from functools import lru_cache
from typing import TYPE_CHECKING
import duckdb
import duckdb.typing as duckdb_dtypes
from duckdb.typing import DuckDBPyType
from narwhals._utils import Version, isinstance_or_issubclass, zip_strict
from narwhals.exceptions import ColumnNotFoundError
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from duckdb import DuckDBPyRelation, Expression
from narwhals._compliant.typing import CompliantLazyFrameAny
from narwhals._duckdb.dataframe import DuckDBLazyFrame
from narwhals._duckdb.expr import DuckDBExpr
from narwhals.dtypes import DType
from narwhals.typing import IntoDType, TimeUnit
UNITS_DICT = {
"y": "year",
"q": "quarter",
"mo": "month",
"d": "day",
"h": "hour",
"m": "minute",
"s": "second",
"ms": "millisecond",
"us": "microsecond",
"ns": "nanosecond",
}
DESCENDING_TO_ORDER = {True: "desc", False: "asc"}
NULLS_LAST_TO_NULLS_POS = {True: "nulls last", False: "nulls first"}
col = duckdb.ColumnExpression
"""Alias for `duckdb.ColumnExpression`."""
lit = duckdb.ConstantExpression
"""Alias for `duckdb.ConstantExpression`."""
when = duckdb.CaseExpression
"""Alias for `duckdb.CaseExpression`."""
F = duckdb.FunctionExpression
"""Alias for `duckdb.FunctionExpression`."""
def concat_str(*exprs: Expression, separator: str = "") -> Expression:
"""Concatenate many strings, NULL inputs are skipped.
Wraps [concat] and [concat_ws] `FunctionExpression`(s).
Arguments:
exprs: Native columns.
separator: String that will be used to separate the values of each column.
Returns:
A new native expression.
[concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring-
[concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string-
"""
return F("concat_ws", lit(separator), *exprs) if separator else F("concat", *exprs)
def evaluate_exprs(
df: DuckDBLazyFrame, /, *exprs: DuckDBExpr
) -> list[tuple[str, Expression]]:
native_results: list[tuple[str, Expression]] = []
for expr in exprs:
native_series_list = expr._call(df)
output_names = expr._evaluate_output_names(df)
if expr._alias_output_names is not None:
output_names = expr._alias_output_names(output_names)
if len(output_names) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.extend(zip(output_names, native_series_list))
return native_results
class DeferredTimeZone:
"""Object which gets passed between `native_to_narwhals_dtype` calls.
DuckDB stores the time zone in the connection, rather than in the dtypes, so
this ensures that when calculating the schema of a dataframe with multiple
timezone-aware columns, that the connection's time zone is only fetched once.
Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because
the time zone can be modified after `DuckDBLazyFrame` creation:
```python
df = nw.from_native(rel)
print(df.collect_schema())
rel.query("set timezone = 'Asia/Kolkata'")
print(df.collect_schema()) # should change to reflect new time zone
```
"""
_cached_time_zone: str | None = None
def __init__(self, rel: DuckDBPyRelation) -> None:
self._rel = rel
@property
def time_zone(self) -> str:
"""Fetch relation time zone (if it wasn't calculated already)."""
if self._cached_time_zone is None:
self._cached_time_zone = fetch_rel_time_zone(self._rel)
return self._cached_time_zone
def native_to_narwhals_dtype(
duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone
) -> DType:
duckdb_dtype_id = duckdb_dtype.id
dtypes = version.dtypes
# Handle nested data types first
if duckdb_dtype_id == "list":
return dtypes.List(
native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone)
)
if duckdb_dtype_id == "struct":
children = duckdb_dtype.children
return dtypes.Struct(
[
dtypes.Field(
name=child[0],
dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone),
)
for child in children
]
)
if duckdb_dtype_id == "array":
child, size = duckdb_dtype.children
shape: list[int] = [size[1]]
while child[1].id == "array":
child, size = child[1].children
shape.insert(0, size[1])
inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone)
return dtypes.Array(inner=inner, shape=tuple(shape))
if duckdb_dtype_id == "enum":
if version is Version.V1:
return dtypes.Enum() # type: ignore[call-arg]
categories = duckdb_dtype.children[0][1]
return dtypes.Enum(categories=categories)
if duckdb_dtype_id == "timestamp with time zone":
return dtypes.Datetime(time_zone=deferred_time_zone.time_zone)
return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version)
def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str:
result = rel.query(
"duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'"
).fetchone()
assert result is not None # noqa: S101
return result[0] # type: ignore[no-any-return]
@lru_cache(maxsize=16)
def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType:
dtypes = version.dtypes
return {
"hugeint": dtypes.Int128(),
"bigint": dtypes.Int64(),
"integer": dtypes.Int32(),
"smallint": dtypes.Int16(),
"tinyint": dtypes.Int8(),
"uhugeint": dtypes.UInt128(),
"ubigint": dtypes.UInt64(),
"uinteger": dtypes.UInt32(),
"usmallint": dtypes.UInt16(),
"utinyint": dtypes.UInt8(),
"double": dtypes.Float64(),
"float": dtypes.Float32(),
"varchar": dtypes.String(),
"date": dtypes.Date(),
"timestamp_s": dtypes.Datetime("s"),
"timestamp_ms": dtypes.Datetime("ms"),
"timestamp": dtypes.Datetime(),
"timestamp_ns": dtypes.Datetime("ns"),
"boolean": dtypes.Boolean(),
"interval": dtypes.Duration(),
"decimal": dtypes.Decimal(),
"time": dtypes.Time(),
"blob": dtypes.Binary(),
}.get(duckdb_dtype_id, dtypes.Unknown())
dtypes = Version.MAIN.dtypes
NW_TO_DUCKDB_DTYPES: Mapping[type[DType], DuckDBPyType] = {
dtypes.Float64: duckdb_dtypes.DOUBLE,
dtypes.Float32: duckdb_dtypes.FLOAT,
dtypes.Binary: duckdb_dtypes.BLOB,
dtypes.String: duckdb_dtypes.VARCHAR,
dtypes.Boolean: duckdb_dtypes.BOOLEAN,
dtypes.Date: duckdb_dtypes.DATE,
dtypes.Time: duckdb_dtypes.TIME,
dtypes.Int8: duckdb_dtypes.TINYINT,
dtypes.Int16: duckdb_dtypes.SMALLINT,
dtypes.Int32: duckdb_dtypes.INTEGER,
dtypes.Int64: duckdb_dtypes.BIGINT,
dtypes.Int128: DuckDBPyType("INT128"),
dtypes.UInt8: duckdb_dtypes.UTINYINT,
dtypes.UInt16: duckdb_dtypes.USMALLINT,
dtypes.UInt32: duckdb_dtypes.UINTEGER,
dtypes.UInt64: duckdb_dtypes.UBIGINT,
dtypes.UInt128: DuckDBPyType("UINT128"),
}
TIME_UNIT_TO_TIMESTAMP: Mapping[TimeUnit, DuckDBPyType] = {
"s": duckdb_dtypes.TIMESTAMP_S,
"ms": duckdb_dtypes.TIMESTAMP_MS,
"us": duckdb_dtypes.TIMESTAMP,
"ns": duckdb_dtypes.TIMESTAMP_NS,
}
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Categorical)
def narwhals_to_native_dtype( # noqa: PLR0912, C901
dtype: IntoDType, version: Version, deferred_time_zone: DeferredTimeZone
) -> DuckDBPyType:
dtypes = version.dtypes
base_type = dtype.base_type()
if duckdb_type := NW_TO_DUCKDB_DTYPES.get(base_type):
return duckdb_type
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
return DuckDBPyType(f"ENUM{dtype.categories!r}")
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if isinstance_or_issubclass(dtype, dtypes.Datetime):
tu = dtype.time_unit
tz = dtype.time_zone
if not tz:
return TIME_UNIT_TO_TIMESTAMP[tu]
if tu != "us":
msg = f"Only microsecond precision is supported for timezone-aware `Datetime` in DuckDB, got {tu} precision"
raise ValueError(msg)
if tz != (rel_tz := deferred_time_zone.time_zone): # pragma: no cover
msg = f"Only the connection time zone {rel_tz} is supported, got: {tz}."
raise ValueError(msg)
# TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed
return duckdb_dtypes.TIMESTAMP_TZ # pragma: no cover
if isinstance_or_issubclass(dtype, dtypes.Duration):
if (tu := dtype.time_unit) != "us": # pragma: no cover
msg = f"Only microsecond-precision Duration is supported, got {tu} precision"
return duckdb_dtypes.INTERVAL
if isinstance_or_issubclass(dtype, dtypes.List):
inner = narwhals_to_native_dtype(dtype.inner, version, deferred_time_zone)
return duckdb.list_type(inner)
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = {
field.name: narwhals_to_native_dtype(field.dtype, version, deferred_time_zone)
for field in dtype.fields
}
return duckdb.struct_type(fields)
if isinstance(dtype, dtypes.Array):
nw_inner: IntoDType = dtype
while isinstance(nw_inner, dtypes.Array):
nw_inner = nw_inner.inner
duckdb_inner = narwhals_to_native_dtype(nw_inner, version, deferred_time_zone)
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape)
return DuckDBPyType(f"{duckdb_inner}{duckdb_shape_fmt}")
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for DuckDB."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
def parse_into_expression(into_expression: str | Expression) -> Expression:
return col(into_expression) if isinstance(into_expression, str) else into_expression
def generate_partition_by_sql(*partition_by: str | Expression) -> str:
if not partition_by:
return ""
by_sql = ", ".join([f"{parse_into_expression(x)}" for x in partition_by])
return f"partition by {by_sql}"
def join_column_names(*names: str) -> str:
return ", ".join(str(col(name)) for name in names)
def generate_order_by_sql(
*order_by: str | Expression, descending: Sequence[bool], nulls_last: Sequence[bool]
) -> str:
if not order_by:
return ""
by_sql = ",".join(
f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}"
for x, _descending, _nulls_last in zip_strict(order_by, descending, nulls_last)
)
return f"order by {by_sql}"
def window_expression(
expr: Expression,
partition_by: Sequence[str | Expression] = (),
order_by: Sequence[str | Expression] = (),
rows_start: int | None = None,
rows_end: int | None = None,
*,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
ignore_nulls: bool = False,
) -> Expression:
# TODO(unassigned): Replace with `duckdb.WindowExpression` when they release it.
# https://github.com/duckdb/duckdb/discussions/14725#discussioncomment-11200348
try:
from duckdb import SQLExpression
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"DuckDB>=1.3.0 is required for this operation. Found: DuckDB {duckdb.__version__}"
raise NotImplementedError(msg) from exc
pb = generate_partition_by_sql(*partition_by)
descending = descending or [False] * len(order_by)
nulls_last = nulls_last or [False] * len(order_by)
ob = generate_order_by_sql(*order_by, descending=descending, nulls_last=nulls_last)
if rows_start is not None and rows_end is not None:
rows = f"rows between {-rows_start} preceding and {rows_end} following"
elif rows_end is not None:
rows = f"rows between unbounded preceding and {rows_end} following"
elif rows_start is not None:
rows = f"rows between {-rows_start} preceding and unbounded following"
else:
rows = ""
func = f"{str(expr).removesuffix(')')} ignore nulls)" if ignore_nulls else str(expr)
return SQLExpression(f"{func} over ({pb} {ob} {rows})")
def catch_duckdb_exception(
exception: Exception, frame: CompliantLazyFrameAny, /
) -> ColumnNotFoundError | Exception:
if isinstance(exception, duckdb.BinderException) and any(
msg in str(exception)
for msg in (
"not found in FROM clause",
"this column cannot be referenced before it is defined",
)
):
return ColumnNotFoundError.from_available_column_names(
available_columns=frame.columns
)
# Just return exception as-is.
return exception
def function(name: str, *args: Expression) -> Expression:
if name == "isnull":
return args[0].isnull()
return F(name, *args)

View File

@ -0,0 +1,94 @@
"""Tools for working with the Polars duration string language."""
from __future__ import annotations
import datetime as dt
import re
from typing import TYPE_CHECKING, Literal, cast, get_args
if TYPE_CHECKING:
from collections.abc import Container, Mapping
from typing_extensions import TypeAlias
__all__ = ["IntervalUnit"]
IntervalUnit: TypeAlias = Literal["ns", "us", "ms", "s", "m", "h", "d", "mo", "q", "y"]
"""A Polars duration string interval unit.
- 'ns': nanosecond.
- 'us': microsecond.
- 'ms': millisecond.
- 's': second.
- 'm': minute.
- 'h': hour.
- 'd': day.
- 'mo': month.
- 'q': quarter.
- 'y': year.
"""
TimedeltaKwd: TypeAlias = Literal[
"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"
]
PATTERN_INTERVAL: re.Pattern[str] = re.compile(
r"^(?P<multiple>-?\d+)(?P<unit>ns|us|ms|mo|m|s|h|d|q|y)\Z"
)
MONTH_MULTIPLES = frozenset([1, 2, 3, 4, 6, 12])
QUARTER_MULTIPLES = frozenset([1, 2, 4])
UNIT_TO_TIMEDELTA: Mapping[IntervalUnit, TimedeltaKwd] = {
"d": "days",
"h": "hours",
"m": "minutes",
"s": "seconds",
"ms": "milliseconds",
"us": "microseconds",
}
class Interval:
def __init__(self, multiple: int, unit: IntervalUnit, /) -> None:
self.multiple: int = multiple
self.unit: IntervalUnit = unit
def to_timedelta(
self, *, unsupported: Container[IntervalUnit] = frozenset(("ns", "mo", "q", "y"))
) -> dt.timedelta:
if self.unit in unsupported: # pragma: no cover
msg = f"Creating timedelta with {self.unit} unit is not supported."
raise NotImplementedError(msg)
kwd = UNIT_TO_TIMEDELTA[self.unit]
# error: Keywords must be strings (bad mypy)
return dt.timedelta(**{kwd: self.multiple}) # type: ignore[misc]
@classmethod
def parse(cls, every: str) -> Interval:
multiple, unit = cls._parse(every)
if unit == "mo" and multiple not in MONTH_MULTIPLES:
msg = f"Only the following multiples are supported for 'mo' unit: {MONTH_MULTIPLES}.\nGot: {multiple}."
raise ValueError(msg)
if unit == "q" and multiple not in QUARTER_MULTIPLES:
msg = f"Only the following multiples are supported for 'q' unit: {QUARTER_MULTIPLES}.\nGot: {multiple}."
raise ValueError(msg)
if unit == "y" and multiple != 1:
msg = (
f"Only multiple 1 is currently supported for 'y' unit.\nGot: {multiple}."
)
raise ValueError(msg)
return cls(multiple, unit)
@classmethod
def parse_no_constraints(cls, every: str) -> Interval:
return cls(*cls._parse(every))
@staticmethod
def _parse(every: str) -> tuple[int, IntervalUnit]:
if match := PATTERN_INTERVAL.match(every):
multiple = int(match["multiple"])
unit = cast("IntervalUnit", match["unit"])
return multiple, unit
msg = (
f"Invalid `every` string: {every}. Expected string of kind <number><unit>, "
f"where 'unit' is one of: {get_args(IntervalUnit)}."
)
raise ValueError(msg)

View File

@ -0,0 +1,42 @@
from __future__ import annotations
# ruff: noqa: ARG004
from enum import Enum
from typing import Any
class NoAutoEnum(Enum):
"""Enum base class that prohibits the use of enum.auto() for value assignment.
This behavior is achieved by overriding the value generation mechanism.
Examples:
>>> from enum import auto
>>> from narwhals._enum import NoAutoEnum
>>>
>>> class Colors(NoAutoEnum):
... RED = 1
... GREEN = 2
>>> Colors.RED
<Colors.RED: 1>
>>> class ColorsWithAuto(NoAutoEnum):
... RED = 1
... GREEN = auto()
Traceback (most recent call last):
...
ValueError: Creating values with `auto()` is not allowed. Please provide a value manually instead.
Raises:
ValueError: If `auto()` is attempted to be used for any enum member value.
"""
@staticmethod
def _generate_next_value_(
name: str, start: int, count: int, last_values: list[Any]
) -> Any:
msg = "Creating values with `auto()` is not allowed. Please provide a value manually instead."
raise ValueError(msg)
__all__ = ["NoAutoEnum"]

View File

@ -0,0 +1,60 @@
from __future__ import annotations
from warnings import warn
def find_stacklevel() -> int:
"""Find the first place in the stack that is not inside narwhals.
Returns:
Stacklevel.
Taken from:
https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51
"""
import inspect
from pathlib import Path
import narwhals as nw
pkg_dir = str(Path(nw.__file__).parent)
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
frame = inspect.currentframe()
n = 0
try:
while frame:
fname = inspect.getfile(frame)
if fname.startswith(pkg_dir) or (
(qualname := getattr(frame.f_code, "co_qualname", None))
# ignore @singledispatch wrappers
and qualname.startswith("singledispatch.")
):
frame = frame.f_back
n += 1
else: # pragma: no cover
break
else: # pragma: no cover
pass
finally:
# https://docs.python.org/3/library/inspect.html
# > Though the cycle detector will catch these, destruction of the frames
# > (and local variables) can be made deterministic by removing the cycle
# > in a finally clause.
del frame
return n
def issue_deprecation_warning(message: str, _version: str) -> None: # pragma: no cover
"""Issue a deprecation warning.
Arguments:
message: The message associated with the warning.
_version: Narwhals version when the warning was introduced. Just used for internal
bookkeeping.
"""
warn(message=message, category=DeprecationWarning, stacklevel=find_stacklevel())
def issue_warning(message: str, category: type[Warning]) -> None:
warn(message=message, category=category, stacklevel=find_stacklevel())

View File

@ -0,0 +1,615 @@
# Utilities for expression parsing
# Useful for backends which don't have any concept of expressions, such
# and pandas or PyArrow.
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI
from __future__ import annotations
from enum import Enum, auto
from itertools import chain
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
from narwhals._utils import is_compliant_expr, zip_strict
from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_numpy_array_1d
from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Never, TypeIs
from narwhals._compliant import CompliantExpr, CompliantFrameT
from narwhals._compliant.typing import (
AliasNames,
CompliantExprAny,
CompliantFrameAny,
CompliantNamespaceAny,
EvalNames,
)
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray
T = TypeVar("T")
def is_expr(obj: Any) -> TypeIs[Expr]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.expr import Expr
return isinstance(obj, Expr)
def is_series(obj: Any) -> TypeIs[Series[Any]]:
"""Check whether `obj` is a Narwhals Expr."""
from narwhals.series import Series
return isinstance(obj, Series)
def is_into_expr_eager(obj: Any) -> TypeIs[Expr | Series[Any] | str | _1DArray]:
from narwhals.expr import Expr
from narwhals.series import Series
return isinstance(obj, (Series, Expr, str)) or is_numpy_array_1d(obj)
def combine_evaluate_output_names(
*exprs: CompliantExpr[CompliantFrameT, Any],
) -> EvalNames[CompliantFrameT]:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
# first name of `expr1`.
if not is_compliant_expr(exprs[0]): # pragma: no cover
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
raise AssertionError(msg)
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
return exprs[0]._evaluate_output_names(df)[:1]
return evaluate_output_names
def combine_alias_output_names(*exprs: CompliantExprAny) -> AliasNames | None:
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the
# aliasing function of `expr1` and apply it to the first output name of `expr1`.
if exprs[0]._alias_output_names is None:
return None
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
return exprs[0]._alias_output_names(names)[:1] # type: ignore[misc]
return alias_output_names
def evaluate_output_names_and_aliases(
expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str]
) -> tuple[Sequence[str], Sequence[str]]:
output_names = expr._evaluate_output_names(df)
aliases = (
output_names
if expr._alias_output_names is None
else expr._alias_output_names(output_names)
)
if exclude:
assert expr._metadata is not None # noqa: S101
if expr._metadata.expansion_kind.is_multi_unnamed():
output_names, aliases = zip_strict(
*[
(x, alias)
for x, alias in zip_strict(output_names, aliases)
if x not in exclude
]
)
return output_names, aliases
class ExprKind(Enum):
"""Describe which kind of expression we are dealing with."""
LITERAL = auto()
"""e.g. `nw.lit(1)`"""
AGGREGATION = auto()
"""Reduces to a single value, not affected by row order, e.g. `nw.col('a').mean()`"""
ORDERABLE_AGGREGATION = auto()
"""Reduces to a single value, affected by row order, e.g. `nw.col('a').arg_max()`"""
ELEMENTWISE = auto()
"""Preserves length, can operate without context for surrounding rows, e.g. `nw.col('a').abs()`."""
ORDERABLE_WINDOW = auto()
"""Depends on the rows around it and on their order, e.g. `diff`."""
WINDOW = auto()
"""Depends on the rows around it and possibly their order, e.g. `rank`."""
FILTRATION = auto()
"""Changes length, not affected by row order, e.g. `drop_nulls`."""
ORDERABLE_FILTRATION = auto()
"""Changes length, affected by row order, e.g. `tail`."""
OVER = auto()
"""Results from calling `.over` on expression."""
UNKNOWN = auto()
"""Based on the information we have, we can't determine the ExprKind."""
@property
def is_scalar_like(self) -> bool:
return self in {ExprKind.LITERAL, ExprKind.AGGREGATION}
@property
def is_orderable_window(self) -> bool:
return self in {ExprKind.ORDERABLE_WINDOW, ExprKind.ORDERABLE_AGGREGATION}
@classmethod
def from_expr(cls, obj: Expr) -> ExprKind:
meta = obj._metadata
if meta.is_literal:
return ExprKind.LITERAL
if meta.is_scalar_like:
return ExprKind.AGGREGATION
if meta.is_elementwise:
return ExprKind.ELEMENTWISE
return ExprKind.UNKNOWN
@classmethod
def from_into_expr(
cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool
) -> ExprKind:
if is_expr(obj):
return cls.from_expr(obj)
if (
is_narwhals_series(obj)
or is_numpy_array(obj)
or (isinstance(obj, str) and not str_as_lit)
):
return ExprKind.ELEMENTWISE
return ExprKind.LITERAL
def is_scalar_like(
obj: ExprKind,
) -> TypeIs[Literal[ExprKind.LITERAL, ExprKind.AGGREGATION]]:
return obj.is_scalar_like
class ExpansionKind(Enum):
"""Describe what kind of expansion the expression performs."""
SINGLE = auto()
"""e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`"""
MULTI_NAMED = auto()
"""e.g. `nw.col('a', 'b')`"""
MULTI_UNNAMED = auto()
"""e.g. `nw.all()`, nw.nth(0, 1)"""
def is_multi_unnamed(self) -> bool:
return self is ExpansionKind.MULTI_UNNAMED
def is_multi_output(self) -> bool:
return self in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED}
def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]:
if self is ExpansionKind.MULTI_UNNAMED and other is ExpansionKind.MULTI_UNNAMED:
# e.g. nw.selectors.all() - nw.selectors.numeric().
return ExpansionKind.MULTI_UNNAMED
# Don't attempt anything more complex, keep it simple and raise in the face of ambiguity.
msg = f"Unsupported ExpansionKind combination, got {self} and {other}, please report a bug." # pragma: no cover
raise AssertionError(msg) # pragma: no cover
class ExprMetadata:
"""Expression metadata.
Parameters:
expansion_kind: What kind of expansion the expression performs.
has_windows: Whether it already contains window functions.
is_elementwise: Whether it can operate row-by-row without context
of the other rows around it.
is_literal: Whether it is just a literal wrapped in an expression.
is_scalar_like: Whether it is a literal or an aggregation.
last_node: The ExprKind of the last node.
n_orderable_ops: The number of order-dependent operations. In the
lazy case, this number must be `0` by the time the expression
is evaluated.
preserves_length: Whether the expression preserves the input length.
"""
__slots__ = (
"expansion_kind",
"has_windows",
"is_elementwise",
"is_literal",
"is_scalar_like",
"last_node",
"n_orderable_ops",
"preserves_length",
)
def __init__(
self,
expansion_kind: ExpansionKind,
last_node: ExprKind,
*,
has_windows: bool = False,
n_orderable_ops: int = 0,
preserves_length: bool = True,
is_elementwise: bool = True,
is_scalar_like: bool = False,
is_literal: bool = False,
) -> None:
if is_literal:
assert is_scalar_like # noqa: S101 # debug assertion
if is_elementwise:
assert preserves_length # noqa: S101 # debug assertion
self.expansion_kind: ExpansionKind = expansion_kind
self.last_node: ExprKind = last_node
self.has_windows: bool = has_windows
self.n_orderable_ops: int = n_orderable_ops
self.is_elementwise: bool = is_elementwise
self.preserves_length: bool = preserves_length
self.is_scalar_like: bool = is_scalar_like
self.is_literal: bool = is_literal
def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover
msg = f"Cannot subclass {cls.__name__!r}"
raise TypeError(msg)
def __repr__(self) -> str: # pragma: no cover
return (
f"ExprMetadata(\n"
f" expansion_kind: {self.expansion_kind},\n"
f" last_node: {self.last_node},\n"
f" has_windows: {self.has_windows},\n"
f" n_orderable_ops: {self.n_orderable_ops},\n"
f" is_elementwise: {self.is_elementwise},\n"
f" preserves_length: {self.preserves_length},\n"
f" is_scalar_like: {self.is_scalar_like},\n"
f" is_literal: {self.is_literal},\n"
")"
)
@property
def is_filtration(self) -> bool:
return not self.preserves_length and not self.is_scalar_like
def with_aggregation(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.AGGREGATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
)
def with_orderable_aggregation(self) -> ExprMetadata:
# Deprecated, used only in stable.v1.
if self.is_scalar_like: # pragma: no cover
msg = "Can't apply aggregations to scalar-like expressions."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_AGGREGATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=True,
is_literal=False,
)
def with_elementwise_op(self) -> ExprMetadata:
return ExprMetadata(
self.expansion_kind,
ExprKind.ELEMENTWISE,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=self.preserves_length,
is_elementwise=self.is_elementwise,
is_scalar_like=self.is_scalar_like,
is_literal=self.is_literal,
)
def with_window(self) -> ExprMetadata:
# Window function which may (but doesn't have to) be used with `over(order_by=...)`.
if self.is_scalar_like:
msg = "Can't apply window (e.g. `rank`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.WINDOW,
has_windows=self.has_windows,
# The function isn't order-dependent (but, users can still use `order_by` if they wish!),
# so we don't increment `n_orderable_ops`.
n_orderable_ops=self.n_orderable_ops,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_orderable_window(self) -> ExprMetadata:
# Window function which must be used with `over(order_by=...)`.
if self.is_scalar_like:
msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_WINDOW,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=self.preserves_length,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_ordered_over(self) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
n_orderable_ops = self.n_orderable_ops
if not n_orderable_ops and self.last_node is not ExprKind.WINDOW:
msg = (
"Cannot use `order_by` in `over` on expression which isn't orderable.\n"
"If your expression is orderable, then make sure that `over(order_by=...)`\n"
"comes immediately after the order-dependent expression.\n\n"
"Hint: instead of\n"
" - `(nw.col('price').diff() + 1).over(order_by='date')`\n"
"write:\n"
" + `nw.col('price').diff().over(order_by='date') + 1`\n"
)
raise InvalidOperationError(msg)
if self.last_node.is_orderable_window:
n_orderable_ops -= 1
return ExprMetadata(
self.expansion_kind,
ExprKind.OVER,
has_windows=True,
n_orderable_ops=n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_partitioned_over(self) -> ExprMetadata:
if self.has_windows:
msg = "Cannot nest `over` statements."
raise InvalidOperationError(msg)
if self.is_elementwise or self.is_filtration:
msg = (
"Cannot use `over` on expressions which are elementwise\n"
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
)
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.OVER,
has_windows=True,
n_orderable_ops=self.n_orderable_ops,
preserves_length=True,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_filtration(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.FILTRATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
def with_orderable_filtration(self) -> ExprMetadata:
if self.is_scalar_like:
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
raise InvalidOperationError(msg)
return ExprMetadata(
self.expansion_kind,
ExprKind.ORDERABLE_FILTRATION,
has_windows=self.has_windows,
n_orderable_ops=self.n_orderable_ops + 1,
preserves_length=False,
is_elementwise=False,
is_scalar_like=False,
is_literal=False,
)
@staticmethod
def aggregation() -> ExprMetadata:
return ExprMetadata(
ExpansionKind.SINGLE,
ExprKind.AGGREGATION,
is_elementwise=False,
preserves_length=False,
is_scalar_like=True,
)
@staticmethod
def literal() -> ExprMetadata:
return ExprMetadata(
ExpansionKind.SINGLE,
ExprKind.LITERAL,
is_elementwise=False,
preserves_length=False,
is_literal=True,
is_scalar_like=True,
)
@staticmethod
def selector_single() -> ExprMetadata:
# e.g. `nw.col('a')`, `nw.nth(0)`
return ExprMetadata(ExpansionKind.SINGLE, ExprKind.ELEMENTWISE)
@staticmethod
def selector_multi_named() -> ExprMetadata:
# e.g. `nw.col('a', 'b')`
return ExprMetadata(ExpansionKind.MULTI_NAMED, ExprKind.ELEMENTWISE)
@staticmethod
def selector_multi_unnamed() -> ExprMetadata:
# e.g. `nw.all()`
return ExprMetadata(ExpansionKind.MULTI_UNNAMED, ExprKind.ELEMENTWISE)
@classmethod
def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata:
# We may be able to allow multi-output rhs in the future:
# https://github.com/narwhals-dev/narwhals/issues/2244.
return combine_metadata(
lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False
)
@classmethod
def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata:
return combine_metadata(
*exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True
)
def combine_metadata(
*args: IntoExpr | object | None,
str_as_lit: bool,
allow_multi_output: bool,
to_single_output: bool,
) -> ExprMetadata:
"""Combine metadata from `args`.
Arguments:
args: Arguments, maybe expressions, literals, or Series.
str_as_lit: Whether to interpret strings as literals or as column names.
allow_multi_output: Whether to allow multi-output inputs.
to_single_output: Whether the result is always single-output, regardless
of the inputs (e.g. `nw.sum_horizontal`).
"""
n_filtrations = 0
result_expansion_kind = ExpansionKind.SINGLE
result_has_windows = False
result_n_orderable_ops = 0
# result preserves length if at least one input does
result_preserves_length = False
# result is elementwise if all inputs are elementwise
result_is_elementwise = True
# result is scalar-like if all inputs are scalar-like
result_is_scalar_like = True
# result is literal if all inputs are literal
result_is_literal = True
for i, arg in enumerate(args):
if (isinstance(arg, str) and not str_as_lit) or is_series(arg):
result_preserves_length = True
result_is_scalar_like = False
result_is_literal = False
elif is_expr(arg):
metadata = arg._metadata
if metadata.expansion_kind.is_multi_output():
expansion_kind = metadata.expansion_kind
if i > 0 and not allow_multi_output:
# Left-most argument is always allowed to be multi-output.
msg = (
"Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) "
"are not supported in this context."
)
raise MultiOutputExpressionError(msg)
if not to_single_output:
result_expansion_kind = (
result_expansion_kind & expansion_kind
if i > 0
else expansion_kind
)
result_has_windows |= metadata.has_windows
result_n_orderable_ops += metadata.n_orderable_ops
result_preserves_length |= metadata.preserves_length
result_is_elementwise &= metadata.is_elementwise
result_is_scalar_like &= metadata.is_scalar_like
result_is_literal &= metadata.is_literal
n_filtrations += int(metadata.is_filtration)
if n_filtrations > 1:
msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
raise InvalidOperationError(msg)
if result_preserves_length and n_filtrations:
msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
raise InvalidOperationError(msg)
return ExprMetadata(
result_expansion_kind,
# n-ary operations align positionally, and so the last node is elementwise.
ExprKind.ELEMENTWISE,
has_windows=result_has_windows,
n_orderable_ops=result_n_orderable_ops,
preserves_length=result_preserves_length,
is_elementwise=result_is_elementwise,
is_scalar_like=result_is_scalar_like,
is_literal=result_is_literal,
)
def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None:
# Raise if any argument in `args` isn't length-preserving.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
from narwhals.series import Series
if not all(
(is_expr(x) and x._metadata.preserves_length) or isinstance(x, (str, Series))
for x in args
):
msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'."
raise InvalidOperationError(msg)
def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
# Raise if any argument in `args` isn't an aggregation or literal.
# For Series input, we don't raise (yet), we let such checks happen later,
# as this function works lazily and so can't evaluate lengths.
exprs = chain(args, kwargs.values())
return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs)
def apply_n_ary_operation(
plx: CompliantNamespaceAny,
n_ary_function: Callable[..., CompliantExprAny],
*comparands: IntoExpr | NonNestedLiteral | _1DArray,
str_as_lit: bool,
) -> CompliantExprAny:
parse = plx.parse_into_expr
compliant_exprs = (parse(into, str_as_lit=str_as_lit) for into in comparands)
kinds = [
ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit)
for comparand in comparands
]
broadcast = any(not kind.is_scalar_like for kind in kinds)
compliant_exprs = (
compliant_expr.broadcast(kind)
if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind)
else compliant_expr
for compliant_expr, kind in zip_strict(compliant_exprs, kinds)
)
return n_ary_function(*compliant_exprs)

View File

@ -0,0 +1,432 @@
from __future__ import annotations
import operator
from io import BytesIO
from typing import TYPE_CHECKING, Any, Literal, cast
import ibis
import ibis.expr.types as ir
from narwhals._ibis.utils import evaluate_exprs, native_to_narwhals_dtype
from narwhals._sql.dataframe import SQLLazyFrame
from narwhals._utils import (
Implementation,
ValidateBackendVersion,
Version,
not_implemented,
parse_columns_to_drop,
zip_strict,
)
from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from pathlib import Path
from types import ModuleType
import pandas as pd
import pyarrow as pa
from ibis.expr.operations import Binary
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._ibis.expr import IbisExpr
from narwhals._ibis.group_by import IbisGroupBy
from narwhals._ibis.namespace import IbisNamespace
from narwhals._ibis.series import IbisInterchangeSeries
from narwhals._typing import _EagerAllowedImpl
from narwhals._utils import _LimitedContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.stable.v1 import DataFrame as DataFrameV1
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
JoinPredicates: TypeAlias = "Sequence[ir.BooleanColumn] | Sequence[str]"
class IbisLazyFrame(
SQLLazyFrame["IbisExpr", "ir.Table", "LazyFrame[ir.Table] | DataFrameV1[ir.Table]"],
ValidateBackendVersion,
):
_implementation = Implementation.IBIS
def __init__(
self, df: ir.Table, *, version: Version, validate_backend_version: bool = False
) -> None:
self._native_frame: ir.Table = df
self._version = version
self._cached_schema: dict[str, DType] | None = None
self._cached_columns: list[str] | None = None
if validate_backend_version:
self._validate_backend_version()
@staticmethod
def _is_native(obj: ir.Table | Any) -> TypeIs[ir.Table]:
return isinstance(obj, ir.Table)
@classmethod
def from_native(cls, data: ir.Table, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
def to_narwhals(self) -> LazyFrame[ir.Table] | DataFrameV1[ir.Table]:
if self._version is Version.V1:
from narwhals.stable.v1 import DataFrame
return DataFrame(self, level="interchange")
return self._version.lazyframe(self, level="lazy")
def __narwhals_dataframe__(self) -> Self: # pragma: no cover
# Keep around for backcompat.
if self._version is not Version.V1:
msg = "__narwhals_dataframe__ is not implemented for IbisLazyFrame"
raise AttributeError(msg)
return self
def __narwhals_lazyframe__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType:
return ibis
def __narwhals_namespace__(self) -> IbisNamespace:
from narwhals._ibis.namespace import IbisNamespace
return IbisNamespace(version=self._version)
def get_column(self, name: str) -> IbisInterchangeSeries:
from narwhals._ibis.series import IbisInterchangeSeries
return IbisInterchangeSeries(self.native.select(name), version=self._version)
def _iter_columns(self) -> Iterator[ir.Expr]:
for name in self.columns:
yield self.native[name]
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is None or backend is Implementation.PYARROW:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self.native.to_pyarrow(),
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.to_pandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
self.native.to_polars(),
validate_backend_version=True,
version=self._version,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def head(self, n: int) -> Self:
return self._with_native(self.native.head(n))
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: IbisExpr) -> Self:
selection = [
cast("ir.Scalar", val.name(name))
for name, val in evaluate_exprs(self, *exprs)
]
return self._with_native(self.native.aggregate(selection))
def select(self, *exprs: IbisExpr) -> Self:
selection = [val.name(name) for name, val in evaluate_exprs(self, *exprs)]
if not selection:
msg = "At least one expression must be provided to `select` with the Ibis backend."
raise ValueError(msg)
t = self.native.select(*selection)
return self._with_native(t)
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
selection = (col for col in self.columns if col not in columns_to_drop)
return self._with_native(self.native.select(*selection))
def lazy(self, backend: None = None, **_: None) -> Self:
# The `backend`` argument has no effect but we keep it here for
# backwards compatibility because in `narwhals.stable.v1`
# function `.from_native()` will return a DataFrame for Ibis.
if backend is not None: # pragma: no cover
msg = "`backend` argument is not supported for Ibis"
raise ValueError(msg)
return self
def with_columns(self, *exprs: IbisExpr) -> Self:
new_columns_map = dict(evaluate_exprs(self, *exprs))
return self._with_native(self.native.mutate(**new_columns_map))
def filter(self, predicate: IbisExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
mask = cast("ir.BooleanValue", predicate(self)[0])
return self._with_native(self.native.filter(mask))
@property
def schema(self) -> dict[str, DType]:
if self._cached_schema is None:
# Note: prefer `self._cached_schema` over `functools.cached_property`
# due to Python3.13 failures.
self._cached_schema = {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in self.native.schema().fields.items()
}
return self._cached_schema
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_schema is not None
else list(self.native.columns)
)
return self._cached_columns
def to_pandas(self) -> pd.DataFrame:
# only if version is v1, keep around for backcompat
return self.native.to_pandas()
def to_arrow(self) -> pa.Table:
# only if version is v1, keep around for backcompat
return self.native.to_pyarrow()
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
def _with_native(self, df: ir.Table) -> Self:
return self.__class__(df, version=self._version)
def group_by(
self, keys: Sequence[str] | Sequence[IbisExpr], *, drop_null_keys: bool
) -> IbisGroupBy:
from narwhals._ibis.group_by import IbisGroupBy
return IbisGroupBy(self, keys, drop_null_keys=drop_null_keys)
def rename(self, mapping: Mapping[str, str]) -> Self:
def _rename(col: str) -> str:
return mapping.get(col, col)
return self._with_native(self.native.rename(_rename))
@staticmethod
def _join_drop_duplicate_columns(df: ir.Table, columns: Iterable[str], /) -> ir.Table:
"""Ibis adds a suffix to the right table col, even when it matches the left during a join."""
duplicates = set(df.columns).intersection(columns)
return df.drop(*duplicates) if duplicates else df
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_native = "outer" if how == "full" else how
rname = "{name}" + suffix
if other == self:
# Ibis does not support self-references unless created as a view
other = self._with_native(other.native.view())
if how_native == "cross":
joined = self.native.join(other.native, how=how_native, rname=rname)
return self._with_native(joined)
# help mypy
assert left_on is not None # noqa: S101
assert right_on is not None # noqa: S101
predicates = self._convert_predicates(other, left_on, right_on)
joined = self.native.join(other.native, predicates, how=how_native, rname=rname)
if how_native == "left":
right_names = (n + suffix for n in right_on)
joined = self._join_drop_duplicate_columns(joined, right_names)
it = (cast("Binary", p.op()) for p in predicates if not isinstance(p, str))
to_drop = []
for pred in it:
right = pred.right.name
# Mirrors how polars works.
if right not in self.columns and pred.left.name != right:
to_drop.append(right)
if to_drop:
joined = joined.drop(*to_drop)
return self._with_native(joined)
def join_asof(
self,
other: Self,
*,
left_on: str,
right_on: str,
by_left: Sequence[str] | None,
by_right: Sequence[str] | None,
strategy: AsofJoinStrategy,
suffix: str,
) -> Self:
rname = "{name}" + suffix
strategy_op = {"backward": operator.ge, "forward": operator.le}
predicates: JoinPredicates = []
if op := strategy_op.get(strategy):
on: ir.BooleanColumn = op(self.native[left_on], other.native[right_on])
else:
msg = "Only `backward` and `forward` strategies are currently supported for Ibis"
raise NotImplementedError(msg)
if by_left is not None and by_right is not None:
predicates = self._convert_predicates(other, by_left, by_right)
joined = self.native.asof_join(other.native, on, predicates, rname=rname)
joined = self._join_drop_duplicate_columns(joined, [right_on + suffix])
if by_right is not None:
right_names = (n + suffix for n in by_right)
joined = self._join_drop_duplicate_columns(joined, right_names)
return self._with_native(joined)
def _convert_predicates(
self, other: Self, left_on: Sequence[str], right_on: Sequence[str]
) -> JoinPredicates:
if left_on == right_on:
return left_on
return [
cast("ir.BooleanColumn", (self.native[left] == other.native[right]))
for left, right in zip_strict(left_on, right_on)
]
def collect_schema(self) -> dict[str, DType]:
return {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in self.native.schema().fields.items()
}
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset_ := subset if keep == "any" else (subset or self.columns):
# Sanitise input
if any(x not in self.columns for x in subset_):
msg = f"Columns {set(subset_).difference(self.columns)} not found in {self.columns}."
raise ColumnNotFoundError(msg)
mapped_keep: dict[str, Literal["first"] | None] = {
"any": "first",
"none": None,
}
to_keep = mapped_keep[keep]
return self._with_native(self.native.distinct(on=subset_, keep=to_keep))
return self._with_native(self.native.distinct(on=subset))
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
descending = [descending for _ in range(len(by))]
sort_cols: list[Any] = []
for i in range(len(by)):
direction_fn = ibis.desc if descending[i] else ibis.asc
col = direction_fn(by[i], nulls_first=not nulls_last)
sort_cols.append(col)
return self._with_native(self.native.order_by(*sort_cols))
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
if isinstance(reverse, bool):
reverse = [reverse] * len(list(by))
sort_cols = []
for is_reverse, by_col in zip_strict(reverse, by):
direction_fn = ibis.asc if is_reverse else ibis.desc
col = direction_fn(by_col, nulls_first=False)
sort_cols.append(cast("ir.Column", col))
return self._with_native(self.native.order_by(*sort_cols).head(k))
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset_ = subset if subset is not None else self.columns
return self._with_native(self.native.drop_null(subset_))
def explode(self, columns: Sequence[str]) -> Self:
dtypes = self._version.dtypes
schema = self.collect_schema()
for col in columns:
dtype = schema[col]
if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)
if len(columns) != 1:
msg = (
"Exploding on multiple columns is not supported with Ibis backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)
return self._with_native(self.native.unnest(columns[0], keep_empty=True))
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
import ibis.selectors as s
index_: Sequence[str] = [] if index is None else index
on_: Sequence[str] = (
[c for c in self.columns if c not in index_] if on is None else on
)
# Discard columns not in the index
final_columns = list(dict.fromkeys([*index_, variable_name, value_name]))
unpivoted = self.native.pivot_longer(
s.cols(*on_), names_to=variable_name, values_to=value_name
)
return self._with_native(unpivoted.select(*final_columns))
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
to_select = [
ibis.row_number().over(ibis.window(order_by=order_by)).name(name),
ibis.selectors.all(),
]
return self._with_native(self.native.select(*to_select))
def sink_parquet(self, file: str | Path | BytesIO) -> None:
if isinstance(file, BytesIO): # pragma: no cover
msg = "Writing to BytesIO is not supported for Ibis backend."
raise NotImplementedError(msg)
self.native.to_parquet(file)
gather_every = not_implemented.deprecated(
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
tail = not_implemented.deprecated(
"`LazyFrame.tail` is deprecated and will be removed in a future version."
)
# Intentionally not implemented, as Ibis does its own expression rewriting.
_evaluate_window_expr = not_implemented()

View File

@ -0,0 +1,347 @@
from __future__ import annotations
import operator
from functools import partial
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast
import ibis
from narwhals._ibis.expr_dt import IbisExprDateTimeNamespace
from narwhals._ibis.expr_list import IbisExprListNamespace
from narwhals._ibis.expr_str import IbisExprStringNamespace
from narwhals._ibis.expr_struct import IbisExprStructNamespace
from narwhals._ibis.utils import is_floating, lit, narwhals_to_native_dtype
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, not_implemented, zip_strict
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
import ibis.expr.types as ir
from typing_extensions import Self
from narwhals._compliant import WindowInputs
from narwhals._compliant.typing import (
AliasNames,
EvalNames,
EvalSeries,
WindowFunction,
)
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._ibis.dataframe import IbisLazyFrame
from narwhals._ibis.namespace import IbisNamespace
from narwhals._utils import _LimitedContext
from narwhals.typing import IntoDType, RankMethod, RollingInterpolationMethod
ExprT = TypeVar("ExprT", bound=ir.Value)
IbisWindowFunction = WindowFunction[IbisLazyFrame, ir.Value]
IbisWindowInputs = WindowInputs[ir.Value]
class IbisExpr(SQLExpr["IbisLazyFrame", "ir.Value"]):
_implementation = Implementation.IBIS
def __init__(
self,
call: EvalSeries[IbisLazyFrame, ir.Value],
window_function: IbisWindowFunction | None = None,
*,
evaluate_output_names: EvalNames[IbisLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
implementation: Implementation = Implementation.IBIS,
) -> None:
self._call = call
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._metadata: ExprMetadata | None = None
self._window_function: IbisWindowFunction | None = window_function
@property
def window_function(self) -> IbisWindowFunction:
def default_window_func(
df: IbisLazyFrame, window_inputs: IbisWindowInputs
) -> Sequence[ir.Value]:
return [
expr.over(
ibis.window(
group_by=window_inputs.partition_by,
order_by=self._sort(*window_inputs.order_by),
)
)
for expr in self(df)
]
return self._window_function or default_window_func
def _window_expression(
self,
expr: ir.Value,
partition_by: Sequence[str | ir.Value] = (),
order_by: Sequence[str | ir.Column] = (),
rows_start: int | None = None,
rows_end: int | None = None,
*,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> ir.Value:
if rows_start is not None and rows_end is not None:
rows_between = {"preceding": -rows_start, "following": rows_end}
elif rows_end is not None:
rows_between = {"following": rows_end}
elif rows_start is not None: # pragma: no cover
rows_between = {"preceding": -rows_start}
else:
rows_between = {}
window = ibis.window(
group_by=partition_by,
order_by=self._sort(*order_by, descending=descending, nulls_last=nulls_last),
**rows_between,
)
return expr.over(window)
def __narwhals_namespace__(self) -> IbisNamespace: # pragma: no cover
from narwhals._ibis.namespace import IbisNamespace
return IbisNamespace(version=self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
# Ibis does its own broadcasting.
return self
def _sort(
self,
*cols: ir.Column | str,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> Iterator[ir.Column]:
descending = descending or [False] * len(cols)
nulls_last = nulls_last or [False] * len(cols)
mapping = {
(False, False): partial(ibis.asc, nulls_first=True),
(False, True): partial(ibis.asc, nulls_first=False),
(True, False): partial(ibis.desc, nulls_first=True),
(True, True): partial(ibis.desc, nulls_first=False),
}
yield from (
cast("ir.Column", mapping[(_desc, _nulls_last)](col))
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[IbisLazyFrame],
/,
*,
context: _LimitedContext,
) -> Self:
def func(df: IbisLazyFrame) -> Sequence[ir.Column]:
return [df.native[name] for name in evaluate_column_names(df)]
return cls(
func,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: IbisLazyFrame) -> Sequence[ir.Column]:
return [df.native[i] for i in column_indices]
return cls(
func,
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
)
def _with_binary(self, op: Callable[..., ir.Value], other: Self | Any) -> Self:
return self._with_callable(op, other=other)
def _with_elementwise(
self, op: Callable[..., ir.Value], /, **expressifiable_args: Self | Any
) -> Self:
return self._with_callable(op, **expressifiable_args)
@classmethod
def _alias_native(cls, expr: ExprT, name: str, /) -> ExprT:
return cast("ExprT", expr.name(name))
def __invert__(self) -> Self:
invert = cast("Callable[..., ir.Value]", operator.invert)
return self._with_callable(invert)
def all(self) -> Self:
return self._with_callable(lambda expr: expr.all().fill_null(lit(True)))
def any(self) -> Self:
return self._with_callable(lambda expr: expr.any().fill_null(lit(False)))
def quantile(
self, quantile: float, interpolation: RollingInterpolationMethod
) -> Self:
if interpolation != "linear":
msg = "Only linear interpolation methods are supported for Ibis quantile."
raise NotImplementedError(msg)
return self._with_callable(lambda expr: expr.quantile(quantile))
def clip(self, lower_bound: Any, upper_bound: Any) -> Self:
def _clip(
expr: ir.NumericValue, lower: Any | None = None, upper: Any | None = None
) -> ir.NumericValue:
return expr.clip(lower=lower, upper=upper)
if lower_bound is None:
return self._with_callable(_clip, upper=upper_bound)
if upper_bound is None:
return self._with_callable(_clip, lower=lower_bound)
return self._with_callable(_clip, lower=lower_bound, upper=upper_bound)
def n_unique(self) -> Self:
return self._with_callable(
lambda expr: expr.nunique() + expr.isnull().any().cast("int8")
)
def len(self) -> Self:
def func(df: IbisLazyFrame) -> Sequence[ir.IntegerScalar]:
return [df.native.count() for _ in self._evaluate_output_names(df)]
return self.__class__(
func,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)
def std(self, ddof: int) -> Self:
def _std(expr: ir.NumericColumn, ddof: int) -> ir.Value:
if ddof == 0:
return expr.std(how="pop")
if ddof == 1:
return expr.std(how="sample")
n_samples = expr.count()
std_pop = expr.std(how="pop")
ddof_lit = lit(ddof)
return std_pop * n_samples.sqrt() / (n_samples - ddof_lit).sqrt()
return self._with_callable(lambda expr: _std(expr, ddof))
def var(self, ddof: int) -> Self:
def _var(expr: ir.NumericColumn, ddof: int) -> ir.Value:
if ddof == 0:
return expr.var(how="pop")
if ddof == 1:
return expr.var(how="sample")
n_samples = expr.count()
var_pop = expr.var(how="pop")
ddof_lit = lit(ddof)
return var_pop * n_samples / (n_samples - ddof_lit)
return self._with_callable(lambda expr: _var(expr, ddof))
def null_count(self) -> Self:
return self._with_callable(lambda expr: expr.isnull().sum())
def is_nan(self) -> Self:
def func(expr: ir.FloatingValue | Any) -> ir.Value:
otherwise = expr.isnan() if is_floating(expr.type()) else False
return ibis.ifelse(expr.isnull(), None, otherwise)
return self._with_callable(func)
def is_finite(self) -> Self:
return self._with_callable(lambda expr: ~(expr.isinf() | expr.isnan()))
def is_in(self, other: Sequence[Any]) -> Self:
return self._with_callable(lambda expr: expr.isin(other))
def fill_null(self, value: Self | Any, strategy: Any, limit: int | None) -> Self:
# Ibis doesn't yet allow ignoring nulls in first/last with window functions, which makes forward/backward
# strategies inconsistent when there are nulls present: https://github.com/ibis-project/ibis/issues/9539
if strategy is not None:
msg = "`strategy` is not supported for the Ibis backend"
raise NotImplementedError(msg)
if limit is not None:
msg = "`limit` is not supported for the Ibis backend" # pragma: no cover
raise NotImplementedError(msg)
def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value:
return expr.fill_null(value)
return self._with_callable(_fill_null, value=value)
def cast(self, dtype: IntoDType) -> Self:
def _func(expr: ir.Column) -> ir.Value:
native_dtype = narwhals_to_native_dtype(dtype, self._version)
# ibis `cast` overloads do not include DataType, only literals
return expr.cast(native_dtype) # type: ignore[unused-ignore]
return self._with_callable(_func)
def is_unique(self) -> Self:
return self._with_callable(
lambda expr: expr.isnull().count().over(ibis.window(group_by=(expr))) == 1
)
def rank(self, method: RankMethod, *, descending: bool) -> Self:
def _rank(expr: ir.Column) -> ir.Value:
order_by = next(self._sort(expr, descending=[descending], nulls_last=[True]))
window = ibis.window(order_by=order_by)
if method == "dense":
rank_ = order_by.dense_rank()
elif method == "ordinal":
rank_ = ibis.row_number().over(window)
else:
rank_ = order_by.rank()
# Ibis uses 0-based ranking. Add 1 to match polars 1-based rank.
rank_ = rank_ + lit(1)
# For "max" and "average", adjust using the count of rows in the partition.
if method == "max":
# Define a window partitioned by expr (i.e. each distinct value)
partition = ibis.window(group_by=[expr])
cnt = expr.count().over(partition)
rank_ = rank_ + cnt - lit(1)
elif method == "average":
partition = ibis.window(group_by=[expr])
cnt = expr.count().over(partition)
avg = cast("ir.NumericValue", (cnt - lit(1)) / lit(2.0))
rank_ = rank_ + avg
return ibis.cases((expr.notnull(), rank_))
return self._with_callable(_rank)
@property
def str(self) -> IbisExprStringNamespace:
return IbisExprStringNamespace(self)
@property
def dt(self) -> IbisExprDateTimeNamespace:
return IbisExprDateTimeNamespace(self)
@property
def list(self) -> IbisExprListNamespace:
return IbisExprListNamespace(self)
@property
def struct(self) -> IbisExprStructNamespace:
return IbisExprStructNamespace(self)
# NOTE: https://github.com/ibis-project/ibis/issues/10542
cum_prod = not_implemented()
# NOTE: https://github.com/ibis-project/ibis/issues/11176
skew = not_implemented()
kurtosis = not_implemented()
_count_star = not_implemented()
# Intentionally not implemented, as Ibis does its own expression rewriting.
_push_down_window_function = not_implemented()

View File

@ -0,0 +1,83 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from narwhals._duration import Interval
from narwhals._ibis.utils import (
UNITS_DICT_BUCKET,
UNITS_DICT_TRUNCATE,
timedelta_to_ibis_interval,
)
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
import ibis.expr.types as ir
from narwhals._ibis.expr import IbisExpr
from narwhals._ibis.utils import BucketUnit, TruncateUnit
class IbisExprDateTimeNamespace(SQLExprDateTimeNamesSpace["IbisExpr"]):
def millisecond(self) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.millisecond())
def microsecond(self) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.microsecond())
def to_string(self, format: str) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.strftime(format))
def weekday(self) -> IbisExpr:
# Ibis uses 0-6 for Monday-Sunday. Add 1 to match polars.
return self.compliant._with_callable(lambda expr: expr.day_of_week.index() + 1)
def _bucket(self, kwds: dict[BucketUnit, Any], /) -> Callable[..., ir.TimestampValue]:
def fn(expr: ir.TimestampValue) -> ir.TimestampValue:
return expr.bucket(**kwds)
return fn
def _truncate(self, unit: TruncateUnit, /) -> Callable[..., ir.TimestampValue]:
def fn(expr: ir.TimestampValue) -> ir.TimestampValue:
return expr.truncate(unit)
return fn
def truncate(self, every: str) -> IbisExpr:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if unit == "q":
multiple, unit = 3 * multiple, "mo"
if multiple != 1:
if self.compliant._backend_version < (7, 1): # pragma: no cover
msg = "Truncating datetimes with multiples of the unit is only supported in Ibis >= 7.1."
raise NotImplementedError(msg)
fn = self._bucket({UNITS_DICT_BUCKET[unit]: multiple})
else:
fn = self._truncate(UNITS_DICT_TRUNCATE[unit])
return self.compliant._with_callable(fn)
def offset_by(self, every: str) -> IbisExpr:
interval = Interval.parse_no_constraints(every)
unit = interval.unit
if unit in {"y", "q", "mo", "d", "ns"}:
msg = f"Offsetting by {unit} is not yet supported for ibis."
raise NotImplementedError(msg)
offset = timedelta_to_ibis_interval(interval.to_timedelta())
return self.compliant._with_callable(lambda expr: expr.add(offset))
def replace_time_zone(self, time_zone: str | None) -> IbisExpr:
if time_zone is None:
return self.compliant._with_callable(lambda expr: expr.cast("timestamp"))
msg = "`replace_time_zone` with non-null `time_zone` not yet implemented for Ibis" # pragma: no cover
raise NotImplementedError(msg)
nanosecond = not_implemented()
total_minutes = not_implemented()
total_seconds = not_implemented()
total_milliseconds = not_implemented()
total_microseconds = not_implemented()
total_nanoseconds = not_implemented()
convert_time_zone = not_implemented()
timestamp = not_implemented()

View File

@ -0,0 +1,29 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import ListNamespace
if TYPE_CHECKING:
import ibis.expr.types as ir
from narwhals._ibis.expr import IbisExpr
from narwhals.typing import NonNestedLiteral
class IbisExprListNamespace(LazyExprNamespace["IbisExpr"], ListNamespace["IbisExpr"]):
def len(self) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.length())
def unique(self) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.unique())
def contains(self, item: NonNestedLiteral) -> IbisExpr:
return self.compliant._with_callable(lambda expr: expr.contains(item))
def get(self, index: int) -> IbisExpr:
def _get(expr: ir.ArrayColumn) -> ir.Column:
return expr[index]
return self.compliant._with_callable(_get)

View File

@ -0,0 +1,83 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
from ibis.expr.datatypes import Timestamp
from narwhals._sql.expr_str import SQLExprStringNamespace
from narwhals._utils import _is_naive_format, not_implemented
if TYPE_CHECKING:
import ibis.expr.types as ir
from typing_extensions import TypeAlias
from narwhals._ibis.expr import IbisExpr
IntoStringValue: TypeAlias = "str | ir.StringValue"
class IbisExprStringNamespace(SQLExprStringNamespace["IbisExpr"]):
def strip_chars(self, characters: str | None) -> IbisExpr:
if characters is not None:
msg = "Ibis does not support `characters` argument in `str.strip_chars`"
raise NotImplementedError(msg)
return self.compliant._with_callable(lambda expr: expr.strip())
def _replace_all(
self, pattern: IntoStringValue, value: IntoStringValue
) -> Callable[..., ir.StringValue]:
def fn(expr: ir.StringColumn) -> ir.StringValue:
return expr.re_replace(pattern, value)
return fn
def _replace_all_literal(
self, pattern: IntoStringValue, value: IntoStringValue
) -> Callable[..., ir.StringValue]:
def fn(expr: ir.StringColumn) -> ir.StringValue:
return expr.replace(pattern, value) # pyright: ignore[reportArgumentType]
return fn
def replace_all(
self, pattern: str, value: str | IbisExpr, *, literal: bool
) -> IbisExpr:
fn = self._replace_all_literal if literal else self._replace_all
if isinstance(value, str):
return self.compliant._with_callable(fn(pattern, value))
return self.compliant._with_elementwise(
lambda expr, value: fn(pattern, value)(expr), value=value
)
def _to_datetime(self, format: str) -> Callable[..., ir.TimestampValue]:
def fn(expr: ir.StringColumn) -> ir.TimestampValue:
return expr.as_timestamp(format)
return fn
def _to_datetime_naive(self, format: str) -> Callable[..., ir.TimestampValue]:
def fn(expr: ir.StringColumn) -> ir.TimestampValue:
dtype: Any = Timestamp(timezone=None)
return expr.as_timestamp(format).cast(dtype)
return fn
def to_datetime(self, format: str | None) -> IbisExpr:
if format is None:
msg = "Cannot infer format with Ibis backend"
raise NotImplementedError(msg)
fn = self._to_datetime_naive if _is_naive_format(format) else self._to_datetime
return self.compliant._with_callable(fn(format))
def to_date(self, format: str | None) -> IbisExpr:
if format is None:
msg = "Cannot infer format with Ibis backend"
raise NotImplementedError(msg)
def fn(expr: ir.StringColumn) -> ir.DateValue:
return expr.as_date(format)
return self.compliant._with_callable(fn)
replace = not_implemented()

View File

@ -0,0 +1,19 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import StructNamespace
if TYPE_CHECKING:
import ibis.expr.types as ir
from narwhals._ibis.expr import IbisExpr
class IbisExprStructNamespace(LazyExprNamespace["IbisExpr"], StructNamespace["IbisExpr"]):
def field(self, name: str) -> IbisExpr:
def func(expr: ir.StructColumn) -> ir.Column:
return expr[name]
return self.compliant._with_callable(func).alias(name)

View File

@ -0,0 +1,32 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._sql.group_by import SQLGroupBy
if TYPE_CHECKING:
from collections.abc import Sequence
import ibis.expr.types as ir # noqa: F401
from narwhals._ibis.dataframe import IbisLazyFrame
from narwhals._ibis.expr import IbisExpr
class IbisGroupBy(SQLGroupBy["IbisLazyFrame", "IbisExpr", "ir.Value"]):
def __init__(
self,
df: IbisLazyFrame,
keys: Sequence[str] | Sequence[IbisExpr],
/,
*,
drop_null_keys: bool,
) -> None:
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
def agg(self, *exprs: IbisExpr) -> IbisLazyFrame:
native = self.compliant.native
return self.compliant._with_native(
native.group_by(self._keys).aggregate(*self._evaluate_exprs(exprs))
).rename(dict(zip(self._keys, self._output_key_names)))

View File

@ -0,0 +1,160 @@
from __future__ import annotations
import operator
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any
import ibis
import ibis.expr.types as ir
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._ibis.dataframe import IbisLazyFrame
from narwhals._ibis.expr import IbisExpr
from narwhals._ibis.selectors import IbisSelectorNamespace
from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype
from narwhals._sql.namespace import SQLNamespace
from narwhals._sql.when_then import SQLThen, SQLWhen
from narwhals._utils import Implementation, requires
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from narwhals._utils import Version
from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral
class IbisNamespace(SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"]):
_implementation: Implementation = Implementation.IBIS
def __init__(self, *, version: Version) -> None:
self._version = version
@property
def selectors(self) -> IbisSelectorNamespace:
return IbisSelectorNamespace.from_namespace(self)
@property
def _expr(self) -> type[IbisExpr]:
return IbisExpr
@property
def _lazyframe(self) -> type[IbisLazyFrame]:
return IbisLazyFrame
def _function(self, name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
return function(name, *args)
def _lit(self, value: Any) -> ir.Value:
return lit(value)
def _when(
self, condition: ir.Value, value: ir.Value, otherwise: ir.Expr | None = None
) -> ir.Value:
if otherwise is None:
return ibis.cases((condition, value))
return ibis.cases((condition, value), else_=otherwise) # pragma: no cover
def _coalesce(self, *exprs: ir.Value) -> ir.Value:
return ibis.coalesce(*exprs)
def concat(
self, items: Iterable[IbisLazyFrame], *, how: ConcatMethod
) -> IbisLazyFrame:
if how == "diagonal":
msg = "diagonal concat not supported for Ibis. Please join instead."
raise NotImplementedError(msg)
items = list(items)
native_items = [item.native for item in items]
schema = items[0].schema
if not all(x.schema == schema for x in items[1:]):
msg = "inputs should all have the same schema"
raise TypeError(msg)
return self._lazyframe.from_native(ibis.union(*native_items), context=self)
def concat_str(
self, *exprs: IbisExpr, separator: str, ignore_nulls: bool
) -> IbisExpr:
def func(df: IbisLazyFrame) -> list[ir.Value]:
cols = list(chain.from_iterable(expr(df) for expr in exprs))
cols_casted = [s.cast("string") for s in cols]
if not ignore_nulls:
result = cols_casted[0]
for col in cols_casted[1:]:
result = result + separator + col
else:
result = lit(separator).join(cols_casted)
return [result]
return self._expr(
call=func,
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
version=self._version,
)
def mean_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
def func(cols: Iterable[ir.Value]) -> ir.Value:
cols = list(cols)
return reduce(operator.add, (col.fill_null(lit(0)) for col in cols)) / reduce(
operator.add, (col.isnull().ifelse(lit(0), lit(1)) for col in cols)
)
return self._expr._from_elementwise_horizontal_op(func, *exprs)
@requires.backend_version((10, 0))
def when(self, predicate: IbisExpr) -> IbisWhen:
return IbisWhen.from_expr(predicate, context=self)
def lit(self, value: Any, dtype: IntoDType | None) -> IbisExpr:
def func(_df: IbisLazyFrame) -> Sequence[ir.Value]:
ibis_dtype = narwhals_to_native_dtype(dtype, self._version) if dtype else None
return [lit(value, ibis_dtype)]
return self._expr(
func,
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
version=self._version,
)
def len(self) -> IbisExpr:
def func(_df: IbisLazyFrame) -> list[ir.Value]:
return [_df.native.count()]
return self._expr(
call=func,
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
version=self._version,
)
class IbisWhen(SQLWhen["IbisLazyFrame", "ir.Value", IbisExpr]):
lit = lit
@property
def _then(self) -> type[IbisThen]:
return IbisThen
def __call__(self, df: IbisLazyFrame) -> Sequence[ir.Value]:
is_expr = self._condition._is_expr
condition = df._evaluate_expr(self._condition)
then_ = self._then_value
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
other_ = self._otherwise_value
if other_ is None:
result = ibis.cases((condition, then))
else:
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
result = ibis.cases((condition, then), else_=otherwise)
return [result]
class IbisThen(SQLThen["IbisLazyFrame", "ir.Value", IbisExpr], IbisExpr): ...

View File

@ -0,0 +1,32 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
from narwhals._ibis.expr import IbisExpr
if TYPE_CHECKING:
import ibis.expr.types as ir # noqa: F401
from narwhals._ibis.dataframe import IbisLazyFrame # noqa: F401
from narwhals._ibis.expr import IbisWindowFunction
class IbisSelectorNamespace(LazySelectorNamespace["IbisLazyFrame", "ir.Value"]):
@property
def _selector(self) -> type[IbisSelector]:
return IbisSelector
class IbisSelector( # type: ignore[misc]
CompliantSelector["IbisLazyFrame", "ir.Value"], IbisExpr
):
_window_function: IbisWindowFunction | None = None
def _to_expr(self) -> IbisExpr:
return IbisExpr(
self._call,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
)

View File

@ -0,0 +1,41 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NoReturn
from narwhals._ibis.utils import native_to_narwhals_dtype
from narwhals.dependencies import get_ibis
if TYPE_CHECKING:
from types import ModuleType
from typing_extensions import Self
from narwhals._utils import Version
from narwhals.dtypes import DType
class IbisInterchangeSeries:
def __init__(self, df: Any, version: Version) -> None:
self._native_series = df
self._version = version
def __narwhals_series__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType:
return get_ibis()
@property
def dtype(self) -> DType:
return native_to_narwhals_dtype(
self._native_series.schema().types[0], self._version
)
def __getattr__(self, attr: str) -> NoReturn:
msg = (
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)

View File

@ -0,0 +1,270 @@
from __future__ import annotations
from functools import lru_cache
from typing import TYPE_CHECKING, Any, Literal, cast, overload
import ibis
import ibis.expr.datatypes as ibis_dtypes
from narwhals._utils import Version, isinstance_or_issubclass
if TYPE_CHECKING:
from collections.abc import Mapping
from datetime import timedelta
import ibis.expr.types as ir
from ibis.common.temporal import TimestampUnit
from ibis.expr.datatypes import DataType as IbisDataType
from typing_extensions import TypeAlias, TypeIs
from narwhals._duration import IntervalUnit
from narwhals._ibis.dataframe import IbisLazyFrame
from narwhals._ibis.expr import IbisExpr
from narwhals.dtypes import DType
from narwhals.typing import IntoDType, PythonLiteral
Incomplete: TypeAlias = Any
"""Marker for upstream issues."""
@overload
def lit(value: bool, dtype: None = ...) -> ir.BooleanScalar: ... # noqa: FBT001
@overload
def lit(value: int, dtype: None = ...) -> ir.IntegerScalar: ...
@overload
def lit(value: float, dtype: None = ...) -> ir.FloatingScalar: ...
@overload
def lit(value: str, dtype: None = ...) -> ir.StringScalar: ...
@overload
def lit(value: PythonLiteral | ir.Value, dtype: None = ...) -> ir.Scalar: ...
@overload
def lit(value: Any, dtype: Any) -> Incomplete: ...
def lit(value: Any, dtype: Any | None = None) -> Incomplete:
"""Alias for `ibis.literal`."""
literal: Incomplete = ibis.literal
return literal(value, dtype)
BucketUnit: TypeAlias = Literal[
"years",
"quarters",
"months",
"days",
"hours",
"minutes",
"seconds",
"milliseconds",
"microseconds",
"nanoseconds",
]
TruncateUnit: TypeAlias = Literal[
"Y", "Q", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"
]
UNITS_DICT_BUCKET: Mapping[IntervalUnit, BucketUnit] = {
"y": "years",
"q": "quarters",
"mo": "months",
"d": "days",
"h": "hours",
"m": "minutes",
"s": "seconds",
"ms": "milliseconds",
"us": "microseconds",
"ns": "nanoseconds",
}
UNITS_DICT_TRUNCATE: Mapping[IntervalUnit, TruncateUnit] = {
"y": "Y",
"q": "Q",
"mo": "M",
"d": "D",
"h": "h",
"m": "m",
"s": "s",
"ms": "ms",
"us": "us",
"ns": "ns",
}
FUNCTION_REMAPPING = {
"starts_with": "startswith",
"ends_with": "endswith",
"regexp_matches": "re_search",
"str_split": "split",
"dayofyear": "day_of_year",
"to_date": "date",
}
def evaluate_exprs(df: IbisLazyFrame, /, *exprs: IbisExpr) -> list[tuple[str, ir.Value]]:
native_results: list[tuple[str, ir.Value]] = []
for expr in exprs:
native_series_list = expr(df)
output_names = expr._evaluate_output_names(df)
if expr._alias_output_names is not None:
output_names = expr._alias_output_names(output_names)
if len(output_names) != len(native_series_list): # pragma: no cover
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
raise AssertionError(msg)
native_results.extend(zip(output_names, native_series_list))
return native_results
@lru_cache(maxsize=16)
def native_to_narwhals_dtype(ibis_dtype: IbisDataType, version: Version) -> DType: # noqa: C901, PLR0912
dtypes = version.dtypes
if ibis_dtype.is_int64():
return dtypes.Int64()
if ibis_dtype.is_int32():
return dtypes.Int32()
if ibis_dtype.is_int16():
return dtypes.Int16()
if ibis_dtype.is_int8():
return dtypes.Int8()
if ibis_dtype.is_uint64():
return dtypes.UInt64()
if ibis_dtype.is_uint32():
return dtypes.UInt32()
if ibis_dtype.is_uint16():
return dtypes.UInt16()
if ibis_dtype.is_uint8():
return dtypes.UInt8()
if ibis_dtype.is_boolean():
return dtypes.Boolean()
if ibis_dtype.is_float64():
return dtypes.Float64()
if ibis_dtype.is_float32():
return dtypes.Float32()
if ibis_dtype.is_string():
return dtypes.String()
if ibis_dtype.is_date():
return dtypes.Date()
if is_timestamp(ibis_dtype):
_unit = cast("TimestampUnit", ibis_dtype.unit)
return dtypes.Datetime(time_unit=_unit.value, time_zone=ibis_dtype.timezone)
if is_interval(ibis_dtype):
_time_unit = ibis_dtype.unit.value
if _time_unit not in {"ns", "us", "ms", "s"}: # pragma: no cover
msg = f"Unsupported interval unit: {_time_unit}"
raise NotImplementedError(msg)
return dtypes.Duration(_time_unit)
if is_array(ibis_dtype):
if ibis_dtype.length:
return dtypes.Array(
native_to_narwhals_dtype(ibis_dtype.value_type, version),
ibis_dtype.length,
)
return dtypes.List(native_to_narwhals_dtype(ibis_dtype.value_type, version))
if is_struct(ibis_dtype):
return dtypes.Struct(
[
dtypes.Field(name, native_to_narwhals_dtype(dtype, version))
for name, dtype in ibis_dtype.items()
]
)
if ibis_dtype.is_decimal(): # pragma: no cover
return dtypes.Decimal()
if ibis_dtype.is_time():
return dtypes.Time()
if ibis_dtype.is_binary():
return dtypes.Binary()
return dtypes.Unknown() # pragma: no cover
def is_timestamp(obj: IbisDataType) -> TypeIs[ibis_dtypes.Timestamp]:
return obj.is_timestamp()
def is_interval(obj: IbisDataType) -> TypeIs[ibis_dtypes.Interval]:
return obj.is_interval()
def is_array(obj: IbisDataType) -> TypeIs[ibis_dtypes.Array[Any]]:
return obj.is_array()
def is_struct(obj: IbisDataType) -> TypeIs[ibis_dtypes.Struct]:
return obj.is_struct()
def is_floating(obj: IbisDataType) -> TypeIs[ibis_dtypes.Floating]:
return obj.is_floating()
dtypes = Version.MAIN.dtypes
NW_TO_IBIS_DTYPES: Mapping[type[DType], IbisDataType] = {
dtypes.Float64: ibis_dtypes.Float64(),
dtypes.Float32: ibis_dtypes.Float32(),
dtypes.Binary: ibis_dtypes.Binary(),
dtypes.String: ibis_dtypes.String(),
dtypes.Boolean: ibis_dtypes.Boolean(),
dtypes.Date: ibis_dtypes.Date(),
dtypes.Time: ibis_dtypes.Time(),
dtypes.Int8: ibis_dtypes.Int8(),
dtypes.Int16: ibis_dtypes.Int16(),
dtypes.Int32: ibis_dtypes.Int32(),
dtypes.Int64: ibis_dtypes.Int64(),
dtypes.UInt8: ibis_dtypes.UInt8(),
dtypes.UInt16: ibis_dtypes.UInt16(),
dtypes.UInt32: ibis_dtypes.UInt32(),
dtypes.UInt64: ibis_dtypes.UInt64(),
dtypes.Decimal: ibis_dtypes.Decimal(),
}
# Enum support: https://github.com/ibis-project/ibis/issues/10991
UNSUPPORTED_DTYPES = (dtypes.Int128, dtypes.UInt128, dtypes.Categorical, dtypes.Enum)
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> IbisDataType:
dtypes = version.dtypes
base_type = dtype.base_type()
if ibis_type := NW_TO_IBIS_DTYPES.get(base_type):
return ibis_type
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return ibis_dtypes.Timestamp.from_unit(dtype.time_unit, timezone=dtype.time_zone)
if isinstance_or_issubclass(dtype, dtypes.Duration):
return ibis_dtypes.Interval(unit=dtype.time_unit) # pyright: ignore[reportArgumentType]
if isinstance_or_issubclass(dtype, dtypes.List):
inner = narwhals_to_native_dtype(dtype.inner, version)
return ibis_dtypes.Array(value_type=inner)
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = [
(field.name, narwhals_to_native_dtype(field.dtype, version))
for field in dtype.fields
]
return ibis_dtypes.Struct.from_tuples(fields)
if isinstance_or_issubclass(dtype, dtypes.Array):
inner = narwhals_to_native_dtype(dtype.inner, version)
return ibis_dtypes.Array(value_type=inner, length=dtype.size)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for Ibis."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
def timedelta_to_ibis_interval(td: timedelta) -> ibis.expr.types.temporal.IntervalScalar:
return ibis.interval(days=td.days, seconds=td.seconds, microseconds=td.microseconds)
def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
# Workaround SQL vs Ibis differences.
if name == "row_number":
return ibis.row_number() + lit(1)
if name == "least":
return ibis.least(*args)
if name == "greatest":
return ibis.greatest(*args)
expr = args[0]
if name == "var_pop":
return cast("ir.NumericColumn", expr).var(how="pop")
if name == "var_samp":
return cast("ir.NumericColumn", expr).var(how="sample")
if name == "stddev_pop":
return cast("ir.NumericColumn", expr).std(how="pop")
if name == "stddev_samp":
return cast("ir.NumericColumn", expr).std(how="sample")
if name == "substr":
# Ibis is 0-indexed here, SQL is 1-indexed
return cast("ir.StringColumn", expr).substr(args[1] - 1, *args[2:]) # type: ignore[operator] # pyright: ignore[reportArgumentType]
return getattr(expr, FUNCTION_REMAPPING.get(name, name))(*args[1:])

View File

@ -0,0 +1,159 @@
from __future__ import annotations
import enum
from typing import TYPE_CHECKING, Any, NoReturn
from narwhals._utils import Version, parse_version
if TYPE_CHECKING:
import pandas as pd
import pyarrow as pa
from typing_extensions import Self, TypeIs
from narwhals._interchange.series import InterchangeSeries
from narwhals.dtypes import DType
from narwhals.stable.v1.typing import DataFrameLike
class DtypeKind(enum.IntEnum):
# https://data-apis.org/dataframe-protocol/latest/API.html
INT = 0
UINT = 1
FLOAT = 2
BOOL = 20
STRING = 21 # UTF-8
DATETIME = 22
CATEGORICAL = 23
def map_interchange_dtype_to_narwhals_dtype( # noqa: C901, PLR0911, PLR0912
interchange_dtype: tuple[DtypeKind, int, Any, Any],
) -> DType:
dtypes = Version.V1.dtypes
if interchange_dtype[0] == DtypeKind.INT:
if interchange_dtype[1] == 64:
return dtypes.Int64()
if interchange_dtype[1] == 32:
return dtypes.Int32()
if interchange_dtype[1] == 16:
return dtypes.Int16()
if interchange_dtype[1] == 8:
return dtypes.Int8()
msg = "Invalid bit width for INT" # pragma: no cover
raise AssertionError(msg)
if interchange_dtype[0] == DtypeKind.UINT:
if interchange_dtype[1] == 64:
return dtypes.UInt64()
if interchange_dtype[1] == 32:
return dtypes.UInt32()
if interchange_dtype[1] == 16:
return dtypes.UInt16()
if interchange_dtype[1] == 8:
return dtypes.UInt8()
msg = "Invalid bit width for UINT" # pragma: no cover
raise AssertionError(msg)
if interchange_dtype[0] == DtypeKind.FLOAT:
if interchange_dtype[1] == 64:
return dtypes.Float64()
if interchange_dtype[1] == 32:
return dtypes.Float32()
msg = "Invalid bit width for FLOAT" # pragma: no cover
raise AssertionError(msg)
if interchange_dtype[0] == DtypeKind.BOOL:
return dtypes.Boolean()
if interchange_dtype[0] == DtypeKind.STRING:
return dtypes.String()
if interchange_dtype[0] == DtypeKind.DATETIME:
return dtypes.Datetime()
if interchange_dtype[0] == DtypeKind.CATEGORICAL: # pragma: no cover
# upstream issue: https://github.com/ibis-project/ibis/issues/9570
return dtypes.Categorical()
msg = f"Invalid dtype, got: {interchange_dtype}" # pragma: no cover
raise AssertionError(msg)
class InterchangeFrame:
_version = Version.V1
def __init__(self, df: DataFrameLike) -> None:
self._interchange_frame = df.__dataframe__()
def __narwhals_dataframe__(self) -> Self:
return self
def __native_namespace__(self) -> NoReturn:
msg = (
"Cannot access native namespace for interchange-level dataframes with unknown backend."
"If you would like to see this kind of object supported in Narwhals, please "
"open a feature request at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
def get_column(self, name: str) -> InterchangeSeries:
from narwhals._interchange.series import InterchangeSeries
return InterchangeSeries(self._interchange_frame.get_column_by_name(name))
def to_pandas(self) -> pd.DataFrame:
import pandas as pd # ignore-banned-import()
if parse_version(pd) < (1, 5, 0): # pragma: no cover
msg = (
"Conversion to pandas is achieved via interchange protocol which requires"
f" 'pandas>=1.5.0' to be installed, found {pd.__version__}"
)
raise NotImplementedError(msg)
return pd.api.interchange.from_dataframe(self._interchange_frame)
def to_arrow(self) -> pa.Table:
from pyarrow.interchange.from_dataframe import ( # ignore-banned-import()
from_dataframe,
)
return from_dataframe(self._interchange_frame)
@property
def schema(self) -> dict[str, DType]:
return {
column_name: map_interchange_dtype_to_narwhals_dtype(
self._interchange_frame.get_column_by_name(column_name).dtype
)
for column_name in self._interchange_frame.column_names()
}
@property
def columns(self) -> list[str]:
return list(self._interchange_frame.column_names())
def __getattr__(self, attr: str) -> NoReturn:
msg = (
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
"Hint: you probably called `nw.from_native` on an object which isn't fully "
"supported by Narwhals, yet implements `__dataframe__`. If you would like to "
"see this kind of object supported in Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
def simple_select(self, *column_names: str) -> Self:
frame = self._interchange_frame.select_columns_by_name(list(column_names))
if not hasattr(frame, "_df"): # pragma: no cover
msg = (
"Expected interchange object to implement `_df` property to allow for recovering original object.\n"
"See https://github.com/data-apis/dataframe-api/issues/360."
)
raise NotImplementedError(msg)
return self.__class__(frame._df)
def select(self, *exprs: str) -> Self: # pragma: no cover
msg = (
"`select`-ing not by name is not supported for interchange-only level.\n\n"
"If you would like to see this kind of object better supported in "
"Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
def supports_dataframe_interchange(obj: Any) -> TypeIs[DataFrameLike]:
return hasattr(obj, "__dataframe__")

View File

@ -0,0 +1,47 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, NoReturn
from narwhals._interchange.dataframe import map_interchange_dtype_to_narwhals_dtype
from narwhals._utils import Version
if TYPE_CHECKING:
from typing_extensions import Self
from narwhals.dtypes import DType
class InterchangeSeries:
_version = Version.V1
def __init__(self, df: Any) -> None:
self._native_series = df
def __narwhals_series__(self) -> Self:
return self
def __native_namespace__(self) -> NoReturn:
msg = (
"Cannot access native namespace for interchange-level series with unknown backend. "
"If you would like to see this kind of object supported in Narwhals, please "
"open a feature request at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)
@property
def dtype(self) -> DType:
return map_interchange_dtype_to_narwhals_dtype(self._native_series.dtype)
@property
def native(self) -> Any:
return self._native_series
def __getattr__(self, attr: str) -> NoReturn:
msg = ( # pragma: no cover
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
"Hint: you probably called `nw.from_native` on an object which isn't fully "
"supported by Narwhals, yet implements `__dataframe__`. If you would like to "
"see this kind of object supported in Narwhals, please open a feature request "
"at https://github.com/narwhals-dev/narwhals/issues."
)
raise NotImplementedError(msg)

View File

@ -0,0 +1,409 @@
"""Narwhals-level equivalent of `CompliantNamespace`."""
from __future__ import annotations
from typing import (
TYPE_CHECKING,
Any,
Callable,
Generic,
Literal,
Protocol,
TypeVar,
overload,
)
from narwhals._compliant.typing import CompliantNamespaceAny, CompliantNamespaceT_co
from narwhals._utils import Implementation, Version
from narwhals.dependencies import (
get_cudf,
get_modin,
get_pandas,
get_polars,
get_pyarrow,
is_dask_dataframe,
is_duckdb_relation,
is_ibis_table,
is_pyspark_connect_dataframe,
is_pyspark_dataframe,
is_sqlframe_dataframe,
)
if TYPE_CHECKING:
from collections.abc import Collection, Sized
from typing import ClassVar
import duckdb
import pandas as pd
import polars as pl
import pyarrow as pa
import pyspark.sql as pyspark_sql
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._arrow.namespace import ArrowNamespace
from narwhals._dask.namespace import DaskNamespace
from narwhals._duckdb.namespace import DuckDBNamespace
from narwhals._ibis.namespace import IbisNamespace
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._polars.namespace import PolarsNamespace
from narwhals._spark_like.dataframe import SQLFrameDataFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._typing import (
Arrow,
Backend,
Dask,
DuckDB,
EagerAllowed,
Ibis,
IntoBackend,
PandasLike,
Polars,
SparkLike,
)
from narwhals.typing import NativeDataFrame, NativeLazyFrame, NativeSeries
T = TypeVar("T")
_Guard: TypeAlias = "Callable[[Any], TypeIs[T]]"
EagerAllowedNamespace: TypeAlias = "Namespace[PandasLikeNamespace] | Namespace[ArrowNamespace] | Namespace[PolarsNamespace]"
class _BasePandasLike(Sized, Protocol):
index: Any
"""`mypy` doesn't like the asymmetric `property` setter in `pandas`."""
def __getitem__(self, key: Any, /) -> Any: ...
def __mul__(self, other: float | Collection[float] | Self) -> Self: ...
def __floordiv__(self, other: float | Collection[float] | Self) -> Self: ...
@property
def loc(self) -> Any: ...
@property
def shape(self) -> tuple[int, ...]: ...
def set_axis(self, labels: Any, *, axis: Any = ..., copy: bool = ...) -> Self: ...
def copy(self, deep: bool = ...) -> Self: ... # noqa: FBT001
def rename(self, *args: Any, inplace: Literal[False], **kwds: Any) -> Self:
"""`inplace=False` is required to avoid (incorrect?) default overloads."""
...
class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): ...
class _BasePandasLikeSeries(NativeSeries, _BasePandasLike, Protocol):
def where(self, cond: Any, other: Any = ..., **kwds: Any) -> Any: ...
class _NativeDask(Protocol):
_partition_type: type[pd.DataFrame]
class _CuDFDataFrame(_BasePandasLikeFrame, Protocol):
def to_pylibcudf(self, *args: Any, **kwds: Any) -> Any: ...
class _CuDFSeries(_BasePandasLikeSeries, Protocol):
def to_pylibcudf(self, *args: Any, **kwds: Any) -> Any: ...
class _NativeIbis(Protocol):
def sql(self, *args: Any, **kwds: Any) -> Any: ...
def __pyarrow_result__(self, *args: Any, **kwds: Any) -> Any: ...
def __pandas_result__(self, *args: Any, **kwds: Any) -> Any: ...
def __polars_result__(self, *args: Any, **kwds: Any) -> Any: ...
class _ModinDataFrame(_BasePandasLikeFrame, Protocol):
_pandas_class: type[pd.DataFrame]
class _ModinSeries(_BasePandasLikeSeries, Protocol):
_pandas_class: type[pd.Series[Any]]
_NativePolars: TypeAlias = "pl.DataFrame | pl.LazyFrame | pl.Series"
_NativeArrow: TypeAlias = "pa.Table | pa.ChunkedArray[Any]"
_NativeDuckDB: TypeAlias = "duckdb.DuckDBPyRelation"
_NativePandas: TypeAlias = "pd.DataFrame | pd.Series[Any]"
_NativeModin: TypeAlias = "_ModinDataFrame | _ModinSeries"
_NativeCuDF: TypeAlias = "_CuDFDataFrame | _CuDFSeries"
_NativePandasLikeSeries: TypeAlias = "pd.Series[Any] | _CuDFSeries | _ModinSeries"
_NativePandasLikeDataFrame: TypeAlias = (
"pd.DataFrame | _CuDFDataFrame | _ModinDataFrame"
)
_NativePandasLike: TypeAlias = "_NativePandasLikeDataFrame |_NativePandasLikeSeries"
_NativeSQLFrame: TypeAlias = "SQLFrameDataFrame"
_NativePySpark: TypeAlias = "pyspark_sql.DataFrame"
_NativePySparkConnect: TypeAlias = "PySparkConnectDataFrame"
_NativeSparkLike: TypeAlias = (
"_NativeSQLFrame | _NativePySpark | _NativePySparkConnect"
)
NativeKnown: TypeAlias = "_NativePolars | _NativeArrow | _NativePandasLike | _NativeSparkLike | _NativeDuckDB | _NativeDask | _NativeIbis"
NativeUnknown: TypeAlias = "NativeDataFrame | NativeSeries | NativeLazyFrame"
NativeAny: TypeAlias = "NativeKnown | NativeUnknown"
__all__ = ["Namespace"]
class Namespace(Generic[CompliantNamespaceT_co]):
_compliant_namespace: CompliantNamespaceT_co
_version: ClassVar[Version] = Version.MAIN
def __init__(self, namespace: CompliantNamespaceT_co, /) -> None:
self._compliant_namespace = namespace
def __init_subclass__(cls, *args: Any, version: Version, **kwds: Any) -> None:
super().__init_subclass__(*args, **kwds)
if isinstance(version, Version):
cls._version = version
else:
msg = f"Expected {Version} but got {type(version).__name__!r}"
raise TypeError(msg)
def __repr__(self) -> str:
return f"Namespace[{type(self.compliant).__name__}]"
@property
def compliant(self) -> CompliantNamespaceT_co:
return self._compliant_namespace
@property
def implementation(self) -> Implementation:
return self.compliant._implementation
@property
def version(self) -> Version:
return self._version
@overload
@classmethod
def from_backend(cls, backend: PandasLike, /) -> Namespace[PandasLikeNamespace]: ...
@overload
@classmethod
def from_backend(cls, backend: Polars, /) -> Namespace[PolarsNamespace]: ...
@overload
@classmethod
def from_backend(cls, backend: Arrow, /) -> Namespace[ArrowNamespace]: ...
@overload
@classmethod
def from_backend(cls, backend: SparkLike, /) -> Namespace[SparkLikeNamespace]: ...
@overload
@classmethod
def from_backend(cls, backend: DuckDB, /) -> Namespace[DuckDBNamespace]: ...
@overload
@classmethod
def from_backend(cls, backend: Dask, /) -> Namespace[DaskNamespace]: ...
@overload
@classmethod
def from_backend(cls, backend: Ibis, /) -> Namespace[IbisNamespace]: ...
@overload
@classmethod
def from_backend(cls, backend: EagerAllowed, /) -> EagerAllowedNamespace: ...
@overload
@classmethod
def from_backend(
cls, backend: IntoBackend[Backend], /
) -> Namespace[CompliantNamespaceAny]: ...
@classmethod
def from_backend(
cls: type[Namespace[Any]], backend: IntoBackend[Backend], /
) -> Namespace[Any]:
"""Instantiate from native namespace module, string, or Implementation.
Arguments:
backend: native namespace module, string, or Implementation.
Examples:
>>> from narwhals._namespace import Namespace
>>> Namespace.from_backend("polars")
Namespace[PolarsNamespace]
"""
impl = Implementation.from_backend(backend)
backend_version = impl._backend_version() # noqa: F841
version = cls._version
ns: CompliantNamespaceAny
if impl.is_pandas_like():
from narwhals._pandas_like.namespace import PandasLikeNamespace
ns = PandasLikeNamespace(implementation=impl, version=version)
elif impl.is_polars():
from narwhals._polars.namespace import PolarsNamespace
ns = PolarsNamespace(version=version)
elif impl.is_pyarrow():
from narwhals._arrow.namespace import ArrowNamespace
ns = ArrowNamespace(version=version)
elif impl.is_spark_like():
from narwhals._spark_like.namespace import SparkLikeNamespace
ns = SparkLikeNamespace(implementation=impl, version=version)
elif impl.is_duckdb():
from narwhals._duckdb.namespace import DuckDBNamespace
ns = DuckDBNamespace(version=version)
elif impl.is_dask():
from narwhals._dask.namespace import DaskNamespace
ns = DaskNamespace(version=version)
elif impl.is_ibis():
from narwhals._ibis.namespace import IbisNamespace
ns = IbisNamespace(version=version)
else:
msg = "Not supported Implementation" # pragma: no cover
raise AssertionError(msg)
return cls(ns)
@overload
@classmethod
def from_native_object(
cls, native: _NativePolars, /
) -> Namespace[PolarsNamespace]: ...
@overload
@classmethod
def from_native_object(
cls, native: _NativePandas, /
) -> Namespace[PandasLikeNamespace[pd.DataFrame, pd.Series[Any]]]: ...
@overload
@classmethod
def from_native_object(cls, native: _NativeArrow, /) -> Namespace[ArrowNamespace]: ...
@overload
@classmethod
def from_native_object(
cls, native: _NativeSparkLike, /
) -> Namespace[SparkLikeNamespace]: ...
@overload
@classmethod
def from_native_object(
cls, native: _NativeDuckDB, /
) -> Namespace[DuckDBNamespace]: ...
@overload
@classmethod
def from_native_object(cls, native: _NativeDask, /) -> Namespace[DaskNamespace]: ...
@overload
@classmethod
def from_native_object(cls, native: _NativeIbis, /) -> Namespace[IbisNamespace]: ...
@overload
@classmethod
def from_native_object(
cls, native: _NativeModin, /
) -> Namespace[PandasLikeNamespace[_ModinDataFrame, _ModinSeries]]: ...
@overload
@classmethod
def from_native_object(
cls, native: _NativeCuDF, /
) -> Namespace[PandasLikeNamespace[_CuDFDataFrame, _CuDFSeries]]: ...
@overload
@classmethod
def from_native_object(
cls, native: _NativePandasLike, /
) -> Namespace[PandasLikeNamespace[Any, Any]]: ...
@overload
@classmethod
def from_native_object(
cls, native: NativeUnknown, /
) -> Namespace[CompliantNamespaceAny]: ...
@classmethod
def from_native_object(
cls: type[Namespace[Any]], native: NativeAny, /
) -> Namespace[Any]:
impl: Backend
if is_native_polars(native):
impl = Implementation.POLARS
elif is_native_pandas(native):
impl = Implementation.PANDAS
elif is_native_arrow(native):
impl = Implementation.PYARROW
elif is_native_spark_like(native):
impl = (
Implementation.SQLFRAME
if is_native_sqlframe(native)
else Implementation.PYSPARK_CONNECT
if is_native_pyspark_connect(native)
else Implementation.PYSPARK
)
elif is_native_dask(native): # pragma: no cover
impl = Implementation.DASK
elif is_native_duckdb(native):
impl = Implementation.DUCKDB
elif is_native_cudf(native): # pragma: no cover
impl = Implementation.CUDF
elif is_native_modin(native): # pragma: no cover
impl = Implementation.MODIN
elif is_native_ibis(native):
impl = Implementation.IBIS
else:
msg = f"Unsupported type: {type(native).__qualname__!r}"
raise TypeError(msg)
return cls.from_backend(impl)
def is_native_polars(obj: Any) -> TypeIs[_NativePolars]:
return (pl := get_polars()) is not None and isinstance(
obj, (pl.DataFrame, pl.Series, pl.LazyFrame)
)
def is_native_arrow(obj: Any) -> TypeIs[_NativeArrow]:
return (pa := get_pyarrow()) is not None and isinstance(
obj, (pa.Table, pa.ChunkedArray)
)
def is_native_dask(obj: Any) -> TypeIs[_NativeDask]:
return is_dask_dataframe(obj)
is_native_duckdb: _Guard[_NativeDuckDB] = is_duckdb_relation
is_native_sqlframe: _Guard[_NativeSQLFrame] = is_sqlframe_dataframe
is_native_pyspark: _Guard[_NativePySpark] = is_pyspark_dataframe
is_native_pyspark_connect: _Guard[_NativePySparkConnect] = is_pyspark_connect_dataframe
def is_native_pandas(obj: Any) -> TypeIs[_NativePandas]:
return (pd := get_pandas()) is not None and isinstance(obj, (pd.DataFrame, pd.Series))
def is_native_modin(obj: Any) -> TypeIs[_NativeModin]:
return (mpd := get_modin()) is not None and isinstance(
obj, (mpd.DataFrame, mpd.Series)
) # pragma: no cover
def is_native_cudf(obj: Any) -> TypeIs[_NativeCuDF]:
return (cudf := get_cudf()) is not None and isinstance(
obj, (cudf.DataFrame, cudf.Series)
) # pragma: no cover
def is_native_pandas_like(obj: Any) -> TypeIs[_NativePandasLike]:
return (
is_native_pandas(obj) or is_native_cudf(obj) or is_native_modin(obj)
) # pragma: no cover
def is_native_spark_like(obj: Any) -> TypeIs[_NativeSparkLike]:
return (
is_native_sqlframe(obj)
or is_native_pyspark(obj)
or is_native_pyspark_connect(obj)
)
def is_native_ibis(obj: Any) -> TypeIs[_NativeIbis]:
return is_ibis_table(obj)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,345 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import EagerExpr
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._pandas_like.group_by import PandasLikeGroupBy
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._utils import generate_temporary_column_name
if TYPE_CHECKING:
from collections.abc import Sequence
from typing_extensions import Self
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
from narwhals._expression_parsing import ExprMetadata
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.namespace import PandasLikeNamespace
from narwhals._utils import Implementation, Version, _LimitedContext
from narwhals.typing import PythonLiteral
WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT = {
"cum_sum": "cumsum",
"cum_min": "cummin",
"cum_max": "cummax",
"cum_prod": "cumprod",
# Pandas cumcount starts counting from 0 while Polars starts from 1
# Pandas cumcount counts nulls while Polars does not
# So, instead of using "cumcount" we use "cumsum" on notna() to get the same result
"cum_count": "cumsum",
"rolling_sum": "sum",
"rolling_mean": "mean",
"rolling_std": "std",
"rolling_var": "var",
"shift": "shift",
"rank": "rank",
"diff": "diff",
"fill_null": "fillna",
"quantile": "quantile",
"ewm_mean": "mean",
}
def window_kwargs_to_pandas_equivalent(
function_name: str, kwargs: ScalarKwargs
) -> dict[str, PythonLiteral]:
if function_name == "shift":
assert "n" in kwargs # noqa: S101
pandas_kwargs: dict[str, PythonLiteral] = {"periods": kwargs["n"]}
elif function_name == "rank":
assert "method" in kwargs # noqa: S101
assert "descending" in kwargs # noqa: S101
_method = kwargs["method"]
pandas_kwargs = {
"method": "first" if _method == "ordinal" else _method,
"ascending": not kwargs["descending"],
"na_option": "keep",
"pct": False,
}
elif function_name.startswith("cum_"): # Cumulative operation
pandas_kwargs = {"skipna": True}
elif function_name.startswith("rolling_"): # Rolling operation
assert "min_samples" in kwargs # noqa: S101
assert "window_size" in kwargs # noqa: S101
assert "center" in kwargs # noqa: S101
pandas_kwargs = {
"min_periods": kwargs["min_samples"],
"window": kwargs["window_size"],
"center": kwargs["center"],
}
elif function_name in {"std", "var"}:
assert "ddof" in kwargs # noqa: S101
pandas_kwargs = {"ddof": kwargs["ddof"]}
elif function_name == "fill_null":
assert "strategy" in kwargs # noqa: S101
assert "limit" in kwargs # noqa: S101
pandas_kwargs = {"strategy": kwargs["strategy"], "limit": kwargs["limit"]}
elif function_name == "quantile":
assert "quantile" in kwargs # noqa: S101
assert "interpolation" in kwargs # noqa: S101
pandas_kwargs = {
"q": kwargs["quantile"],
"interpolation": kwargs["interpolation"],
}
elif function_name.startswith("ewm_"):
assert "com" in kwargs # noqa: S101
assert "span" in kwargs # noqa: S101
assert "half_life" in kwargs # noqa: S101
assert "alpha" in kwargs # noqa: S101
assert "adjust" in kwargs # noqa: S101
assert "min_samples" in kwargs # noqa: S101
assert "ignore_nulls" in kwargs # noqa: S101
pandas_kwargs = {
"com": kwargs["com"],
"span": kwargs["span"],
"halflife": kwargs["half_life"],
"alpha": kwargs["alpha"],
"adjust": kwargs["adjust"],
"min_periods": kwargs["min_samples"],
"ignore_na": kwargs["ignore_nulls"],
}
else: # sum, len, ...
pandas_kwargs = {}
return pandas_kwargs
class PandasLikeExpr(EagerExpr["PandasLikeDataFrame", PandasLikeSeries]):
def __init__(
self,
call: EvalSeries[PandasLikeDataFrame, PandasLikeSeries],
*,
depth: int,
function_name: str,
evaluate_output_names: EvalNames[PandasLikeDataFrame],
alias_output_names: AliasNames | None,
implementation: Implementation,
version: Version,
scalar_kwargs: ScalarKwargs | None = None,
) -> None:
self._call = call
self._depth = depth
self._function_name = function_name
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._implementation = implementation
self._version = version
self._scalar_kwargs = scalar_kwargs or {}
self._metadata: ExprMetadata | None = None
def __narwhals_namespace__(self) -> PandasLikeNamespace:
from narwhals._pandas_like.namespace import PandasLikeNamespace
return PandasLikeNamespace(self._implementation, version=self._version)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[PandasLikeDataFrame],
/,
*,
context: _LimitedContext,
function_name: str = "",
) -> Self:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
try:
return [
PandasLikeSeries(
df._native_frame[column_name],
implementation=df._implementation,
version=df._version,
)
for column_name in evaluate_column_names(df)
]
except KeyError as e:
if error := df._check_columns_exist(evaluate_column_names(df)):
raise error from e
raise
return cls(
func,
depth=0,
function_name=function_name,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
implementation=context._implementation,
version=context._version,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
native = df.native
return [
PandasLikeSeries.from_native(native.iloc[:, i], context=df)
for i in column_indices
]
return cls(
func,
depth=0,
function_name="nth",
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
implementation=context._implementation,
version=context._version,
)
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self:
return self._reuse_series(
"ewm_mean",
scalar_kwargs={
"com": com,
"span": span,
"half_life": half_life,
"alpha": alpha,
"adjust": adjust,
"min_samples": min_samples,
"ignore_nulls": ignore_nulls,
},
)
def over( # noqa: C901, PLR0915
self, partition_by: Sequence[str], order_by: Sequence[str]
) -> Self:
if not partition_by:
# e.g. `nw.col('a').cum_sum().order_by(key)`
# We can always easily support this as it doesn't require grouping.
assert order_by # noqa: S101
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
token = generate_temporary_column_name(8, df.columns)
df = df.with_row_index(token, order_by=None).sort(
*order_by, descending=False, nulls_last=False
)
results = self(df.drop([token], strict=True))
sorting_indices = df.get_column(token)
for s in results:
s._scatter_in_place(sorting_indices, s)
return results
elif not self._is_elementary():
msg = (
"Only elementary expressions are supported for `.over` in pandas-like backends.\n\n"
"Please see: "
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
)
raise NotImplementedError(msg)
else:
function_name = PandasLikeGroupBy._leaf_name(self)
pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get(
function_name, PandasLikeGroupBy._REMAP_AGGS.get(function_name)
)
if pandas_function_name is None:
msg = (
f"Unsupported function: {function_name} in `over` context.\n\n"
f"Supported functions are {', '.join(WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT)}\n"
f"and {', '.join(PandasLikeGroupBy._REMAP_AGGS)}."
)
raise NotImplementedError(msg)
pandas_kwargs = window_kwargs_to_pandas_equivalent(
function_name, self._scalar_kwargs
)
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, PLR0912, PLR0914, PLR0915
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
if function_name == "cum_count":
plx = self.__narwhals_namespace__()
df = df.with_columns(~plx.col(*output_names).is_null())
if function_name.startswith("cum_"):
assert "reverse" in self._scalar_kwargs # noqa: S101
reverse = self._scalar_kwargs["reverse"]
else:
assert "reverse" not in self._scalar_kwargs # noqa: S101
reverse = False
if order_by:
columns = list(set(partition_by).union(output_names).union(order_by))
token = generate_temporary_column_name(8, columns)
df = (
df.simple_select(*columns)
.with_row_index(token, order_by=None)
.sort(*order_by, descending=reverse, nulls_last=reverse)
)
sorting_indices = df.get_column(token)
elif reverse:
columns = list(set(partition_by).union(output_names))
df = df.simple_select(*columns)._gather_slice(slice(None, None, -1))
grouped = df._native_frame.groupby(partition_by)
if function_name.startswith("rolling"):
rolling = grouped[list(output_names)].rolling(**pandas_kwargs)
assert pandas_function_name is not None # help mypy # noqa: S101
if pandas_function_name in {"std", "var"}:
assert "ddof" in self._scalar_kwargs # noqa: S101
res_native = getattr(rolling, pandas_function_name)(
ddof=self._scalar_kwargs["ddof"]
)
else:
res_native = getattr(rolling, pandas_function_name)()
elif function_name.startswith("ewm"):
if self._implementation.is_pandas() and (
self._implementation._backend_version()
) < (1, 2): # pragma: no cover
msg = (
"Exponentially weighted calculation is not available in over "
f"context for pandas versions older than 1.2.0, found {self._implementation._backend_version()}."
)
raise NotImplementedError(msg)
ewm = grouped[list(output_names)].ewm(**pandas_kwargs)
assert pandas_function_name is not None # help mypy # noqa: S101
res_native = getattr(ewm, pandas_function_name)()
elif function_name == "fill_null":
assert "strategy" in self._scalar_kwargs # noqa: S101
assert "limit" in self._scalar_kwargs # noqa: S101
df_grouped = grouped[list(output_names)]
if self._scalar_kwargs["strategy"] == "forward":
res_native = df_grouped.ffill(limit=self._scalar_kwargs["limit"])
elif self._scalar_kwargs["strategy"] == "backward":
res_native = df_grouped.bfill(limit=self._scalar_kwargs["limit"])
else: # pragma: no cover
# This is deprecated in pandas. Indeed, `nw.col('a').fill_null(3).over('b')`
# does not seem very useful, and DuckDB doesn't support it either.
msg = "`fill_null` with `over` without `strategy` specified is not supported."
raise NotImplementedError(msg)
elif function_name == "len":
if len(output_names) != 1: # pragma: no cover
msg = "Safety check failed, please report a bug."
raise AssertionError(msg)
res_native = grouped.transform("size").to_frame(aliases[0])
else:
res_native = grouped[list(output_names)].transform(
pandas_function_name, **pandas_kwargs
)
result_frame = df._with_native(res_native).rename(
dict(zip(output_names, aliases))
)
results = [result_frame.get_column(name) for name in aliases]
if order_by:
for s in results:
s._scatter_in_place(sorting_indices, s)
return results
if reverse:
return [s._gather_slice(slice(None, None, -1)) for s in results]
return results
return self.__class__(
func,
depth=self._depth + 1,
function_name=self._function_name + "->over",
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
implementation=self._implementation,
version=self._version,
)

View File

@ -0,0 +1,365 @@
from __future__ import annotations
import warnings
from functools import lru_cache
from itertools import chain
from operator import methodcaller
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from narwhals._compliant import EagerGroupBy
from narwhals._exceptions import issue_warning
from narwhals._expression_parsing import evaluate_output_names_and_aliases
from narwhals._utils import zip_strict
from narwhals.dependencies import is_pandas_like_dataframe
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
import pandas as pd
from pandas.api.typing import DataFrameGroupBy as _NativeGroupBy
from typing_extensions import TypeAlias, Unpack
from narwhals._compliant.typing import NarwhalsAggregation, ScalarKwargs
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.expr import PandasLikeExpr
NativeGroupBy: TypeAlias = "_NativeGroupBy[tuple[str, ...], Literal[True]]"
NativeApply: TypeAlias = "Callable[[pd.DataFrame], pd.Series[Any]]"
InefficientNativeAggregation: TypeAlias = Literal["cov", "skew"]
NativeAggregation: TypeAlias = Literal[
"any",
"all",
"count",
"first",
"idxmax",
"idxmin",
"last",
"max",
"mean",
"median",
"min",
"mode",
"nunique",
"prod",
"quantile",
"sem",
"size",
"std",
"sum",
"var",
InefficientNativeAggregation,
]
"""https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#built-in-aggregation-methods"""
_NativeAgg: TypeAlias = "Callable[[Any], pd.DataFrame | pd.Series[Any]]"
"""Equivalent to a partial method call on `DataFrameGroupBy`."""
NonStrHashable: TypeAlias = Any
"""Because `pandas` allows *"names"* like that 😭"""
@lru_cache(maxsize=32)
def _native_agg(name: NativeAggregation, /, **kwds: Unpack[ScalarKwargs]) -> _NativeAgg:
if name == "nunique":
return methodcaller(name, dropna=False)
if not kwds or kwds.get("ddof") == 1:
return methodcaller(name)
return methodcaller(name, **kwds)
class AggExpr:
"""Wrapper storing the intermediate state per-`PandasLikeExpr`.
There's a lot of edge cases to handle, so aim to evaluate as little
as possible - and store anything that's needed twice.
Warning:
While a `PandasLikeExpr` can be reused - this wrapper is valid **only**
in a single `.agg(...)` operation.
"""
expr: PandasLikeExpr
output_names: Sequence[str]
aliases: Sequence[str]
def __init__(self, expr: PandasLikeExpr) -> None:
self.expr = expr
self.output_names = ()
self.aliases = ()
self._leaf_name: NarwhalsAggregation | Any = ""
def with_expand_names(self, group_by: PandasLikeGroupBy, /) -> AggExpr:
"""**Mutating operation**.
Stores the results of `evaluate_output_names_and_aliases`.
"""
df = group_by.compliant
exclude = group_by.exclude
self.output_names, self.aliases = evaluate_output_names_and_aliases(
self.expr, df, exclude
)
return self
def _getitem_aggs(
self, group_by: PandasLikeGroupBy, /
) -> pd.DataFrame | pd.Series[Any]:
"""Evaluate the wrapped expression as a group_by operation."""
result: pd.DataFrame | pd.Series[Any]
names = self.output_names
if self.is_len() and self.is_top_level_function():
result = group_by._grouped.size()
elif self.is_len():
result_single = group_by._grouped.size()
ns = group_by.compliant.__narwhals_namespace__()
result = ns._concat_horizontal(
[ns.from_native(result_single).alias(name).native for name in names]
)
elif self.is_mode():
compliant = group_by.compliant
if (keep := self.kwargs.get("keep")) != "any": # pragma: no cover
msg = (
f"`Expr.mode(keep='{keep}')` is not implemented in group by context for "
f"backend {compliant._implementation}\n\n"
"Hint: Use `nw.col(...).mode(keep='any')` instead."
)
raise NotImplementedError(msg)
cols = list(names)
native = compliant.native
keys, kwargs = group_by._keys, group_by._kwargs
# Implementation based on the following suggestion:
# https://github.com/pandas-dev/pandas/issues/19254#issuecomment-778661578
ns = compliant.__narwhals_namespace__()
result = ns._concat_horizontal(
[
native.groupby([*keys, col], **kwargs)
.size()
.sort_values(ascending=False)
.reset_index(col)
.groupby(keys, **kwargs)[col]
.head(1)
.sort_index()
for col in cols
]
)
else:
select = names[0] if len(names) == 1 else list(names)
result = self.native_agg()(group_by._grouped[select])
if is_pandas_like_dataframe(result):
result.columns = list(self.aliases)
else:
result.name = self.aliases[0]
return result
def is_len(self) -> bool:
return self.leaf_name == "len"
def is_mode(self) -> bool:
return self.leaf_name == "mode"
def is_top_level_function(self) -> bool:
# e.g. `nw.len()`.
return self.expr._depth == 0
@property
def kwargs(self) -> ScalarKwargs:
return self.expr._scalar_kwargs
@property
def leaf_name(self) -> NarwhalsAggregation | Any:
if name := self._leaf_name:
return name
self._leaf_name = PandasLikeGroupBy._leaf_name(self.expr)
return self._leaf_name
def native_agg(self) -> _NativeAgg:
"""Return a partial `DataFrameGroupBy` method, missing only `self`."""
return _native_agg(
PandasLikeGroupBy._remap_expr_name(self.leaf_name), **self.kwargs
)
class PandasLikeGroupBy(
EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", NativeAggregation]
):
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, NativeAggregation]] = {
"sum": "sum",
"mean": "mean",
"median": "median",
"max": "max",
"min": "min",
"mode": "mode",
"std": "std",
"var": "var",
"len": "size",
"n_unique": "nunique",
"count": "count",
"quantile": "quantile",
"all": "all",
"any": "any",
}
_original_columns: tuple[str, ...]
"""Column names *prior* to any aliasing in `ParseKeysGroupBy`."""
_keys: list[str]
"""Stores the **aliased** version of group keys from `ParseKeysGroupBy`."""
_output_key_names: list[str]
"""Stores the **original** version of group keys."""
_kwargs: Mapping[str, bool]
"""Stores keyword arguments for `DataFrame.groupby` other than `by`."""
@property
def exclude(self) -> tuple[str, ...]:
"""Group keys to ignore when expanding multi-output aggregations."""
return self._exclude
def __init__(
self,
df: PandasLikeDataFrame,
keys: Sequence[PandasLikeExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._original_columns = tuple(df.columns)
self._drop_null_keys = drop_null_keys
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
df, keys
)
self._exclude: tuple[str, ...] = (*self._keys, *self._output_key_names)
# Drop index to avoid potential collisions:
# https://github.com/narwhals-dev/narwhals/issues/1907.
native = self.compliant.native
if set(native.index.names).intersection(self.compliant.columns):
native = native.reset_index(drop=True)
self._kwargs = {
"sort": False,
"as_index": True,
"dropna": drop_null_keys,
"observed": True,
}
self._grouped: NativeGroupBy = native.groupby(self._keys.copy(), **self._kwargs)
def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame:
all_aggs_are_simple = True
agg_exprs: list[AggExpr] = []
for expr in exprs:
agg_exprs.append(AggExpr(expr).with_expand_names(self))
if not self._is_simple(expr):
all_aggs_are_simple = False
if all_aggs_are_simple:
result: pd.DataFrame
if agg_exprs:
ns = self.compliant.__narwhals_namespace__()
result = ns._concat_horizontal(self._getitem_aggs(agg_exprs))
else:
result = self.compliant.__native_namespace__().DataFrame(
list(self._grouped.groups), columns=self._keys
)
elif self.compliant.native.empty:
raise empty_results_error()
else:
result = self._apply_aggs(exprs)
# NOTE: Keep `inplace=True` to avoid making a redundant copy.
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
result.reset_index(inplace=True) # noqa: PD002
return self._select_results(result, agg_exprs)
def _select_results(
self, df: pd.DataFrame, /, agg_exprs: Sequence[AggExpr]
) -> PandasLikeDataFrame:
"""Responsible for remapping temp column names back to original.
See `ParseKeysGroupBy`.
"""
new_names = chain.from_iterable(e.aliases for e in agg_exprs)
return (
self.compliant._with_native(df, validate_column_names=False)
.simple_select(*self._keys, *new_names)
.rename(dict(zip(self._keys, self._output_key_names)))
)
def _getitem_aggs(
self, exprs: Iterable[AggExpr], /
) -> list[pd.DataFrame | pd.Series[Any]]:
return [e._getitem_aggs(self) for e in exprs]
def _apply_aggs(self, exprs: Iterable[PandasLikeExpr]) -> pd.DataFrame:
"""Stub issue for `include_groups` [pandas-dev/pandas-stubs#1270].
- [User guide] mentions `include_groups` 4 times without deprecation.
- [`DataFrameGroupBy.apply`] doc says the default value of `True` is deprecated since `2.2.0`.
- `False` is explicitly the only *non-deprecated* option, but entirely omitted since [pandas-dev/pandas-stubs#1268].
[pandas-dev/pandas-stubs#1270]: https://github.com/pandas-dev/pandas-stubs/issues/1270
[User guide]: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html
[`DataFrameGroupBy.apply`]: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.core.groupby.DataFrameGroupBy.apply.html
[pandas-dev/pandas-stubs#1268]: https://github.com/pandas-dev/pandas-stubs/pull/1268
"""
warn_complex_group_by()
impl = self.compliant._implementation
func = self._apply_exprs_function(exprs)
apply = self._grouped.apply
if impl.is_pandas() and impl._backend_version() >= (2, 2):
return apply(func, include_groups=False) # type: ignore[call-overload]
return apply(func) # pragma: no cover
def _apply_exprs_function(self, exprs: Iterable[PandasLikeExpr]) -> NativeApply:
ns = self.compliant.__narwhals_namespace__()
into_series = ns._series.from_iterable
def fn(df: pd.DataFrame) -> pd.Series[Any]:
compliant = self.compliant._with_native(df)
results = (
(keys.native.iloc[0], keys.name)
for expr in exprs
for keys in expr(compliant)
)
out_group, out_names = zip_strict(*results) if results else ([], [])
return into_series(out_group, index=out_names, context=ns).native
return fn
def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]:
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message=".*a length 1 tuple will be returned",
category=FutureWarning,
)
with_native = self.compliant._with_native
for key, group in self._grouped:
yield (key, with_native(group).simple_select(*self._original_columns))
def empty_results_error() -> ValueError:
"""Don't even attempt this, it's way too inconsistent across pandas versions."""
msg = (
"No results for group-by aggregation.\n\n"
"Hint: you were probably trying to apply a non-elementary aggregation with a "
"pandas-like API.\n"
"Please rewrite your query such that group-by aggregations "
"are elementary. For example, instead of:\n\n"
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
"use:\n\n"
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
)
return ValueError(msg)
def warn_complex_group_by() -> None:
issue_warning(
"Found complex group-by expression, which can't be expressed efficiently with the "
"pandas API. If you can, please rewrite your query such that group-by aggregations "
"are simple (e.g. mean, std, min, max, ...). \n\n"
"Please see: "
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/",
UserWarning,
)

View File

@ -0,0 +1,441 @@
from __future__ import annotations
import operator
import warnings
from functools import reduce
from itertools import chain
from typing import TYPE_CHECKING, Any, Literal, Protocol, overload
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
from narwhals._expression_parsing import (
combine_alias_output_names,
combine_evaluate_output_names,
)
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.selectors import PandasSelectorNamespace
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT
from narwhals._pandas_like.utils import is_non_nullable_boolean
from narwhals._utils import zip_strict
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from typing_extensions import TypeAlias
from narwhals._compliant.typing import ScalarKwargs
from narwhals._utils import Implementation, Version
from narwhals.typing import IntoDType, NonNestedLiteral
Incomplete: TypeAlias = Any
"""Escape hatch, but leaving a trace that this isn't ideal."""
_Vertical: TypeAlias = Literal[0]
_Horizontal: TypeAlias = Literal[1]
Axis: TypeAlias = Literal[_Vertical, _Horizontal]
VERTICAL: _Vertical = 0
HORIZONTAL: _Horizontal = 1
class PandasLikeNamespace(
EagerNamespace[
PandasLikeDataFrame,
PandasLikeSeries,
PandasLikeExpr,
NativeDataFrameT,
NativeSeriesT,
]
):
@property
def _dataframe(self) -> type[PandasLikeDataFrame]:
return PandasLikeDataFrame
@property
def _expr(self) -> type[PandasLikeExpr]:
return PandasLikeExpr
@property
def _series(self) -> type[PandasLikeSeries]:
return PandasLikeSeries
@property
def selectors(self) -> PandasSelectorNamespace:
return PandasSelectorNamespace.from_namespace(self)
def __init__(self, implementation: Implementation, version: Version) -> None:
self._implementation = implementation
self._version = version
def coalesce(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
align = self._series._align_full_broadcast
series = align(*(s for _expr in exprs for s in _expr(df)))
return [
reduce(lambda x, y: x.fill_null(y, strategy=None, limit=None), series)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="coalesce",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> PandasLikeExpr:
def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries:
pandas_series = self._series.from_iterable(
data=[value],
name="literal",
index=df._native_frame.index[0:1],
context=self,
)
if dtype:
return pandas_series.cast(dtype)
return pandas_series
return PandasLikeExpr(
lambda df: [_lit_pandas_series(df)],
depth=0,
function_name="lit",
evaluate_output_names=lambda _df: ["literal"],
alias_output_names=None,
implementation=self._implementation,
version=self._version,
)
def len(self) -> PandasLikeExpr:
return PandasLikeExpr(
lambda df: [
self._series.from_iterable(
[len(df._native_frame)], name="len", index=[0], context=self
)
],
depth=0,
function_name="len",
evaluate_output_names=lambda _df: ["len"],
alias_output_names=None,
implementation=self._implementation,
version=self._version,
)
# --- horizontal ---
def sum_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
align = self._series._align_full_broadcast
it = chain.from_iterable(expr(df) for expr in exprs)
series = align(*it)
native_series = (s.fill_null(0, None, None) for s in series)
return [reduce(operator.add, native_series)]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="sum_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def all_horizontal(
self, *exprs: PandasLikeExpr, ignore_nulls: bool
) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
align = self._series._align_full_broadcast
series = [s for _expr in exprs for s in _expr(df)]
if not ignore_nulls and any(
s.native.dtype == "object" and s.is_null().any() for s in series
):
# classical NumPy boolean columns don't support missing values, so
# only do the full scan with `is_null` if we have `object` dtype.
msg = "Cannot use `ignore_nulls=False` in `all_horizontal` for non-nullable NumPy-backed pandas Series when nulls are present."
raise ValueError(msg)
it = (
(
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
s if is_non_nullable_boolean(s) else s.fill_null(True, None, None)
for s in series
)
if ignore_nulls
else iter(series)
)
return [reduce(operator.and_, align(*it))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="all_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def any_horizontal(
self, *exprs: PandasLikeExpr, ignore_nulls: bool
) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
align = self._series._align_full_broadcast
series = [s for _expr in exprs for s in _expr(df)]
if not ignore_nulls and any(
s.native.dtype == "object" and s.is_null().any() for s in series
):
# classical NumPy boolean columns don't support missing values, so
# only do the full scan with `is_null` if we have `object` dtype.
msg = "Cannot use `ignore_nulls=False` in `any_horizontal` for non-nullable NumPy-backed pandas Series when nulls are present."
raise ValueError(msg)
it = (
(
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
s if is_non_nullable_boolean(s) else s.fill_null(False, None, None)
for s in series
)
if ignore_nulls
else iter(series)
)
return [reduce(operator.or_, align(*it))]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="any_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def mean_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
expr_results = [s for _expr in exprs for s in _expr(df)]
align = self._series._align_full_broadcast
series = align(
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
)
non_na = align(*(1 - s.is_null() for s in expr_results))
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="mean_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def min_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
it = chain.from_iterable(expr(df) for expr in exprs)
align = self._series._align_full_broadcast
series = align(*it)
return [
PandasLikeSeries(
self.concat(
(s.to_frame() for s in series), how="horizontal"
)._native_frame.min(axis=1),
implementation=self._implementation,
version=self._version,
).alias(series[0].name)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="min_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
def max_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
it = chain.from_iterable(expr(df) for expr in exprs)
align = self._series._align_full_broadcast
series = align(*it)
return [
PandasLikeSeries(
self.concat(
(s.to_frame() for s in series), how="horizontal"
).native.max(axis=1),
implementation=self._implementation,
version=self._version,
).alias(series[0].name)
]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="max_horizontal",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
@property
def _concat(self) -> _NativeConcat[NativeDataFrameT, NativeSeriesT]:
"""Concatenate pandas objects along a particular axis.
Return the **native** equivalent of `pd.concat`.
"""
return self._implementation.to_native_namespace().concat
def _concat_diagonal(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFrameT:
if self._implementation.is_pandas() and self._backend_version < (3,):
return self._concat(dfs, axis=VERTICAL, copy=False)
return self._concat(dfs, axis=VERTICAL)
def _concat_horizontal(
self, dfs: Sequence[NativeDataFrameT | NativeSeriesT], /
) -> NativeDataFrameT:
if self._implementation.is_cudf():
with warnings.catch_warnings():
warnings.filterwarnings(
"ignore",
message="The behavior of array concatenation with empty entries is deprecated",
category=FutureWarning,
)
return self._concat(dfs, axis=HORIZONTAL)
elif self._implementation.is_pandas() and self._backend_version < (3,):
return self._concat(dfs, axis=HORIZONTAL, copy=False)
return self._concat(dfs, axis=HORIZONTAL)
def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFrameT:
cols_0 = dfs[0].columns
for i, df in enumerate(dfs[1:], start=1):
cols_current = df.columns
if not (
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
):
msg = (
"unable to vstack, column names don't match:\n"
f" - dataframe 0: {cols_0.to_list()}\n"
f" - dataframe {i}: {cols_current.to_list()}\n"
)
raise TypeError(msg)
if self._implementation.is_pandas() and self._backend_version < (3,):
return self._concat(dfs, axis=VERTICAL, copy=False)
return self._concat(dfs, axis=VERTICAL)
def when(self, predicate: PandasLikeExpr) -> PandasWhen[NativeSeriesT]:
return PandasWhen[NativeSeriesT].from_expr(predicate, context=self)
def concat_str(
self, *exprs: PandasLikeExpr, separator: str, ignore_nulls: bool
) -> PandasLikeExpr:
string = self._version.dtypes.String()
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
expr_results = [s for _expr in exprs for s in _expr(df)]
align = self._series._align_full_broadcast
series = align(*(s.cast(string) for s in expr_results))
null_mask = align(*(s.is_null() for s in expr_results))
if not ignore_nulls:
null_mask_result = reduce(operator.or_, null_mask)
result = reduce(lambda x, y: x + separator + y, series).zip_with(
~null_mask_result, None
)
else:
# NOTE: Trying to help `mypy` later
# error: Cannot determine type of "values" [has-type]
values: list[PandasLikeSeries]
init_value, *values = [
s.zip_with(~nm, "") for s, nm in zip_strict(series, null_mask)
]
sep_array = init_value.from_iterable(
data=[separator] * len(init_value),
name="sep",
index=init_value.native.index,
context=self,
)
separators = (sep_array.zip_with(~nm, "") for nm in null_mask[:-1])
result = reduce(
operator.add,
(s + v for s, v in zip_strict(separators, values)),
init_value,
)
return [result]
return self._expr._from_callable(
func=func,
depth=max(x._depth for x in exprs) + 1,
function_name="concat_str",
evaluate_output_names=combine_evaluate_output_names(*exprs),
alias_output_names=combine_alias_output_names(*exprs),
context=self,
)
class _NativeConcat(Protocol[NativeDataFrameT, NativeSeriesT]):
@overload
def __call__(
self,
objs: Iterable[NativeDataFrameT],
*,
axis: _Vertical,
copy: bool | None = ...,
) -> NativeDataFrameT: ...
@overload
def __call__(
self, objs: Iterable[NativeSeriesT], *, axis: _Vertical, copy: bool | None = ...
) -> NativeSeriesT: ...
@overload
def __call__(
self,
objs: Iterable[NativeDataFrameT | NativeSeriesT],
*,
axis: _Horizontal,
copy: bool | None = ...,
) -> NativeDataFrameT: ...
@overload
def __call__(
self,
objs: Iterable[NativeDataFrameT | NativeSeriesT],
*,
axis: Axis,
copy: bool | None = ...,
) -> NativeDataFrameT | NativeSeriesT: ...
def __call__(
self,
objs: Iterable[NativeDataFrameT | NativeSeriesT],
*,
axis: Axis,
copy: bool | None = None,
) -> NativeDataFrameT | NativeSeriesT: ...
class PandasWhen(
EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, NativeSeriesT]
):
@property
# Signature of "_then" incompatible with supertype "CompliantWhen"
# ArrowWhen seems to follow the same pattern, but no mypy complaint there?
def _then(self) -> type[PandasThen]: # type: ignore[override]
return PandasThen
def _if_then_else(
self,
when: NativeSeriesT,
then: NativeSeriesT,
otherwise: NativeSeriesT | NonNestedLiteral,
) -> NativeSeriesT:
where: Incomplete = then.where
return where(when) if otherwise is None else where(when, otherwise)
class PandasThen(
CompliantThen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, PandasWhen],
PandasLikeExpr,
):
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "whenthen"

View File

@ -0,0 +1,38 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import CompliantSelector, EagerSelectorNamespace
from narwhals._pandas_like.expr import PandasLikeExpr
if TYPE_CHECKING:
from narwhals._compliant.typing import ScalarKwargs
from narwhals._pandas_like.dataframe import PandasLikeDataFrame # noqa: F401
from narwhals._pandas_like.series import PandasLikeSeries # noqa: F401
class PandasSelectorNamespace(
EagerSelectorNamespace["PandasLikeDataFrame", "PandasLikeSeries"]
):
@property
def _selector(self) -> type[PandasSelector]:
return PandasSelector
class PandasSelector( # type: ignore[misc]
CompliantSelector["PandasLikeDataFrame", "PandasLikeSeries"], PandasLikeExpr
):
_depth: int = 0
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
_function_name: str = "selector"
def _to_expr(self) -> PandasLikeExpr:
return PandasLikeExpr(
self._call,
depth=self._depth,
function_name=self._function_name,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
implementation=self._implementation,
version=self._version,
)

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,17 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant.any_namespace import CatNamespace
from narwhals._pandas_like.utils import PandasLikeSeriesNamespace
if TYPE_CHECKING:
from narwhals._pandas_like.series import PandasLikeSeries
class PandasLikeSeriesCatNamespace(
PandasLikeSeriesNamespace, CatNamespace["PandasLikeSeries"]
):
def get_categories(self) -> PandasLikeSeries:
s = self.native
return self.with_native(type(s)(s.cat.categories, name=s.name))

View File

@ -0,0 +1,290 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from narwhals._compliant.any_namespace import DateTimeNamespace
from narwhals._constants import (
EPOCH_YEAR,
MS_PER_SECOND,
NS_PER_SECOND,
SECONDS_PER_DAY,
US_PER_SECOND,
)
from narwhals._duration import Interval
from narwhals._pandas_like.utils import (
ALIAS_DICT,
UNITS_DICT,
PandasLikeSeriesNamespace,
calculate_timestamp_date,
calculate_timestamp_datetime,
get_dtype_backend,
int_dtype_mapper,
is_dtype_pyarrow,
)
if TYPE_CHECKING:
from datetime import timedelta
import pandas as pd
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals.typing import TimeUnit
class PandasLikeSeriesDateTimeNamespace(
PandasLikeSeriesNamespace, DateTimeNamespace["PandasLikeSeries"]
):
def date(self) -> PandasLikeSeries:
result = self.with_native(self.native.dt.date)
if str(result.dtype).lower() == "object":
msg = (
"Accessing `date` on the default pandas backend "
"will return a Series of type `object`."
"\nThis differs from polars API and will prevent `.dt` chaining. "
"Please switch to the `pyarrow` backend:"
'\ndf.convert_dtypes(dtype_backend="pyarrow")'
)
raise NotImplementedError(msg)
return result
def year(self) -> PandasLikeSeries:
return self.with_native(self.native.dt.year)
def month(self) -> PandasLikeSeries:
return self.with_native(self.native.dt.month)
def day(self) -> PandasLikeSeries:
return self.with_native(self.native.dt.day)
def hour(self) -> PandasLikeSeries:
return self.with_native(self.native.dt.hour)
def minute(self) -> PandasLikeSeries:
return self.with_native(self.native.dt.minute)
def second(self) -> PandasLikeSeries:
return self.with_native(self.native.dt.second)
def millisecond(self) -> PandasLikeSeries:
return self.microsecond() // 1000
def microsecond(self) -> PandasLikeSeries:
if self.backend_version < (3, 0, 0) and self._is_pyarrow():
# crazy workaround for https://github.com/pandas-dev/pandas/issues/59154
import pyarrow.compute as pc # ignore-banned-import()
from narwhals._arrow.utils import lit
arr_ns = self.native.array
arr = arr_ns.__arrow_array__()
result_arr = pc.add(
pc.multiply(pc.millisecond(arr), lit(1_000)), pc.microsecond(arr)
)
result = type(self.native)(type(arr_ns)(result_arr), name=self.native.name)
return self.with_native(result)
return self.with_native(self.native.dt.microsecond)
def nanosecond(self) -> PandasLikeSeries:
return self.microsecond() * 1_000 + self.native.dt.nanosecond
def ordinal_day(self) -> PandasLikeSeries:
year_start = self.native.dt.year
result = (
self.native.to_numpy().astype("datetime64[D]")
- (year_start.to_numpy() - EPOCH_YEAR).astype("datetime64[Y]")
).astype("int32") + 1
dtype = "Int64[pyarrow]" if self._is_pyarrow() else "int32"
return self.with_native(
type(self.native)(result, dtype=dtype, name=year_start.name)
)
def weekday(self) -> PandasLikeSeries:
# Pandas is 0-6 while Polars is 1-7
return self.with_native(self.native.dt.weekday) + 1
def _is_pyarrow(self) -> bool:
return is_dtype_pyarrow(self.native.dtype)
def _get_total_seconds(self) -> Any:
if hasattr(self.native.dt, "total_seconds"):
return self.native.dt.total_seconds()
return ( # pragma: no cover
self.native.dt.days * SECONDS_PER_DAY
+ self.native.dt.seconds
+ (self.native.dt.microseconds / US_PER_SECOND)
+ (self.native.dt.nanoseconds / NS_PER_SECOND)
)
def total_minutes(self) -> PandasLikeSeries:
s = self._get_total_seconds()
# this calculates the sign of each series element
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
s_abs = s.abs() // 60
if ~s.isna().any():
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self.with_native(s_abs * s_sign)
def total_seconds(self) -> PandasLikeSeries:
s = self._get_total_seconds()
# this calculates the sign of each series element
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self.with_native(s_abs * s_sign)
def total_milliseconds(self) -> PandasLikeSeries:
s = self._get_total_seconds() * MS_PER_SECOND
# this calculates the sign of each series element
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self.with_native(s_abs * s_sign)
def total_microseconds(self) -> PandasLikeSeries:
s = self._get_total_seconds() * US_PER_SECOND
# this calculates the sign of each series element
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self.with_native(s_abs * s_sign)
def total_nanoseconds(self) -> PandasLikeSeries:
s = self._get_total_seconds() * NS_PER_SECOND
# this calculates the sign of each series element
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
s_abs = s.abs() // 1
if ~s.isna().any():
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
return self.with_native(s_abs * s_sign)
def to_string(self, format: str) -> PandasLikeSeries:
# Polars' parser treats `'%.f'` as pandas does `'.%f'`
# PyArrow interprets `'%S'` as "seconds, plus fractional seconds"
# and doesn't support `%f`
if not self._is_pyarrow():
format = format.replace("%S%.f", "%S.%f")
else:
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
return self.with_native(self.native.dt.strftime(format))
def replace_time_zone(self, time_zone: str | None) -> PandasLikeSeries:
de_zone = self.native.dt.tz_localize(None)
result = de_zone.dt.tz_localize(time_zone) if time_zone is not None else de_zone
return self.with_native(result)
def convert_time_zone(self, time_zone: str) -> PandasLikeSeries:
if self.compliant.dtype.time_zone is None: # type: ignore[attr-defined]
result = self.native.dt.tz_localize("UTC").dt.tz_convert(time_zone)
else:
result = self.native.dt.tz_convert(time_zone)
return self.with_native(result)
def timestamp(self, time_unit: TimeUnit) -> PandasLikeSeries:
s = self.native
dtype = self.compliant.dtype
mask_na = s.isna()
dtypes = self.version.dtypes
if dtype == dtypes.Date:
# Date is only supported in pandas dtypes if pyarrow-backed
s_cast = s.astype("Int32[pyarrow]")
result = calculate_timestamp_date(s_cast, time_unit)
elif isinstance(dtype, dtypes.Datetime):
fn = (
s.view
if (self.implementation.is_pandas() and self.backend_version < (2,))
else s.astype
)
s_cast = fn("Int64[pyarrow]") if self._is_pyarrow() else fn("int64")
result = calculate_timestamp_datetime(s_cast, dtype.time_unit, time_unit)
else:
msg = "Input should be either of Date or Datetime type"
raise TypeError(msg)
result[mask_na] = None
return self.with_native(result)
def truncate(self, every: str) -> PandasLikeSeries:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
native = self.native
if self.implementation.is_cudf():
if multiple != 1:
msg = f"Only multiple `1` is supported for cuDF, got: {multiple}."
raise NotImplementedError(msg)
return self.with_native(self.native.dt.floor(ALIAS_DICT.get(unit, unit)))
dtype_backend = get_dtype_backend(native.dtype, self.compliant._implementation)
if unit in {"mo", "q", "y"}:
if self.implementation.is_cudf():
msg = f"Truncating to {unit} is not supported yet for cuDF."
raise NotImplementedError(msg)
if dtype_backend == "pyarrow":
import pyarrow.compute as pc # ignore-banned-import
ca = native.array._pa_array
result_arr = pc.floor_temporal(ca, multiple, UNITS_DICT[unit])
else:
if unit == "q":
multiple *= 3
np_unit = "M"
elif unit == "mo":
np_unit = "M"
else:
np_unit = "Y"
arr = native.values # noqa: PD011
arr_dtype = arr.dtype
result_arr = arr.astype(f"datetime64[{multiple}{np_unit}]").astype(
arr_dtype
)
result_native = type(native)(
result_arr, dtype=native.dtype, index=native.index, name=native.name
)
return self.with_native(result_native)
return self.with_native(
self.native.dt.floor(f"{multiple}{ALIAS_DICT.get(unit, unit)}")
)
def offset_by(self, by: str) -> PandasLikeSeries:
native = self.native
pdx = self.compliant.__native_namespace__()
if self._is_pyarrow():
import pyarrow as pa # ignore-banned-import
compliant = self.compliant
ca = pa.chunked_array([compliant.to_arrow()]) # type: ignore[arg-type]
result = (
compliant._version.namespace.from_backend("pyarrow")
.compliant.from_native(ca)
.dt.offset_by(by)
.native
)
result_pd = native.__class__(
result, dtype=native.dtype, index=native.index, name=native.name
)
else:
interval = Interval.parse_no_constraints(by)
multiple, unit = interval.multiple, interval.unit
if unit == "q":
multiple *= 3
unit = "mo"
offset: pd.DateOffset | timedelta
if unit == "y":
offset = pdx.DateOffset(years=multiple)
elif unit == "mo":
offset = pdx.DateOffset(months=multiple)
elif unit == "ns":
offset = pdx.Timedelta(multiple, unit=UNITS_DICT[unit])
else:
offset = interval.to_timedelta()
dtype = self.compliant.dtype
datetime_dtype = self.version.dtypes.Datetime
if unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
native_without_timezone = native.dt.tz_localize(None)
result_pd = native_without_timezone + offset
result_pd = result_pd.dt.tz_localize(dtype.time_zone)
else:
result_pd = native + offset
return self.with_native(result_pd)

View File

@ -0,0 +1,42 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant.any_namespace import ListNamespace
from narwhals._pandas_like.utils import (
PandasLikeSeriesNamespace,
get_dtype_backend,
narwhals_to_native_dtype,
)
from narwhals._utils import not_implemented
if TYPE_CHECKING:
from narwhals._pandas_like.series import PandasLikeSeries
class PandasLikeSeriesListNamespace(
PandasLikeSeriesNamespace, ListNamespace["PandasLikeSeries"]
):
def len(self) -> PandasLikeSeries:
result = self.native.list.len()
implementation = self.implementation
backend_version = self.backend_version
if implementation.is_pandas() and backend_version < (3, 0): # pragma: no cover
# `result` is a new object so it's safe to do this inplace.
result.index = self.native.index
dtype = narwhals_to_native_dtype(
self.version.dtypes.UInt32(),
get_dtype_backend(result.dtype, implementation),
implementation,
self.version,
)
return self.with_native(result.astype(dtype)).alias(self.native.name)
unique = not_implemented()
contains = not_implemented()
def get(self, index: int) -> PandasLikeSeries:
result = self.native.list[index]
result.name = self.native.name
return self.with_native(result)

View File

@ -0,0 +1,92 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from narwhals._compliant.any_namespace import StringNamespace
from narwhals._pandas_like.utils import PandasLikeSeriesNamespace, is_dtype_pyarrow
if TYPE_CHECKING:
from narwhals._pandas_like.series import PandasLikeSeries
class PandasLikeSeriesStringNamespace(
PandasLikeSeriesNamespace, StringNamespace["PandasLikeSeries"]
):
def len_chars(self) -> PandasLikeSeries:
return self.with_native(self.native.str.len())
def replace(
self, pattern: str, value: str, *, literal: bool, n: int
) -> PandasLikeSeries:
try:
series = self.native.str.replace(
pat=pattern, repl=value, n=n, regex=not literal
)
except TypeError as e:
if not isinstance(value, str):
msg = f"{self.compliant._implementation} backed `.str.replace` only supports str replacement values"
raise TypeError(msg) from e
raise
return self.with_native(series)
def replace_all(self, pattern: str, value: str, *, literal: bool) -> PandasLikeSeries:
return self.replace(pattern, value, literal=literal, n=-1)
def strip_chars(self, characters: str | None) -> PandasLikeSeries:
return self.with_native(self.native.str.strip(characters))
def starts_with(self, prefix: str) -> PandasLikeSeries:
return self.with_native(self.native.str.startswith(prefix))
def ends_with(self, suffix: str) -> PandasLikeSeries:
return self.with_native(self.native.str.endswith(suffix))
def contains(self, pattern: str, *, literal: bool) -> PandasLikeSeries:
return self.with_native(self.native.str.contains(pat=pattern, regex=not literal))
def slice(self, offset: int, length: int | None) -> PandasLikeSeries:
stop = offset + length if length else None
return self.with_native(self.native.str.slice(start=offset, stop=stop))
def split(self, by: str) -> PandasLikeSeries:
implementation = self.implementation
if not implementation.is_cudf() and not is_dtype_pyarrow(self.native.dtype):
msg = (
"This operation requires a pyarrow-backed series. "
"Please refer to https://narwhals-dev.github.io/narwhals/api-reference/narwhals/#narwhals.maybe_convert_dtypes "
"and ensure you are using dtype_backend='pyarrow'. "
"Additionally, make sure you have pandas version 1.5+ and pyarrow installed. "
)
raise TypeError(msg)
return self.with_native(self.native.str.split(pat=by))
def to_datetime(self, format: str | None) -> PandasLikeSeries:
# If we know inputs are timezone-aware, we can pass `utc=True` for better performance.
if format and any(x in format for x in ("%z", "Z")):
return self.with_native(self._to_datetime(format, utc=True))
result = self.with_native(self._to_datetime(format, utc=False))
if (tz := getattr(result.dtype, "time_zone", None)) and tz != "UTC":
return result.dt.convert_time_zone("UTC")
return result
def _to_datetime(self, format: str | None, *, utc: bool) -> Any:
result = self.implementation.to_native_namespace().to_datetime(
self.native, format=format, utc=utc
)
return (
result.convert_dtypes(dtype_backend="pyarrow")
if is_dtype_pyarrow(self.native.dtype)
else result
)
def to_date(self, format: str | None) -> PandasLikeSeries:
return self.to_datetime(format=format).dt.date()
def to_uppercase(self) -> PandasLikeSeries:
return self.with_native(self.native.str.upper())
def to_lowercase(self) -> PandasLikeSeries:
return self.with_native(self.native.str.lower())
def zfill(self, width: int) -> PandasLikeSeries:
return self.with_native(self.native.str.zfill(width))

View File

@ -0,0 +1,16 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant.any_namespace import StructNamespace
from narwhals._pandas_like.utils import PandasLikeSeriesNamespace
if TYPE_CHECKING:
from narwhals._pandas_like.series import PandasLikeSeries
class PandasLikeSeriesStructNamespace(
PandasLikeSeriesNamespace, StructNamespace["PandasLikeSeries"]
):
def field(self, name: str) -> PandasLikeSeries:
return self.with_native(self.native.struct.field(name)).alias(name)

View File

@ -0,0 +1,43 @@
from __future__ import annotations # pragma: no cover
from typing import TYPE_CHECKING # pragma: no cover
from narwhals._typing_compat import TypeVar
if TYPE_CHECKING:
from typing import Any
import pandas as pd
from typing_extensions import TypeAlias
from narwhals._namespace import (
_CuDFDataFrame,
_CuDFSeries,
_ModinDataFrame,
_ModinSeries,
_NativePandasLikeDataFrame,
)
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.series import PandasLikeSeries
IntoPandasLikeExpr: TypeAlias = "PandasLikeExpr | PandasLikeSeries"
NativeSeriesT = TypeVar(
"NativeSeriesT",
"pd.Series[Any]",
"_CuDFSeries",
"_ModinSeries",
default="pd.Series[Any]",
)
NativeDataFrameT = TypeVar(
"NativeDataFrameT", bound="_NativePandasLikeDataFrame", default="pd.DataFrame"
)
NativeNDFrameT = TypeVar(
"NativeNDFrameT",
"pd.DataFrame",
"pd.Series[Any]",
"_CuDFDataFrame",
"_CuDFSeries",
"_ModinDataFrame",
"_ModinSeries",
)

View File

@ -0,0 +1,668 @@
from __future__ import annotations
import functools
import operator
import re
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
import pandas as pd
from narwhals._compliant import EagerSeriesNamespace
from narwhals._constants import (
MS_PER_SECOND,
NS_PER_MICROSECOND,
NS_PER_MILLISECOND,
NS_PER_SECOND,
SECONDS_PER_DAY,
US_PER_SECOND,
)
from narwhals._utils import (
Implementation,
Version,
_DeferredIterable,
check_columns_exist,
isinstance_or_issubclass,
)
from narwhals.exceptions import ShapeError
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping
from types import ModuleType
from pandas._typing import Dtype as PandasDtype
from pandas.core.dtypes.dtypes import BaseMaskedDtype
from typing_extensions import TypeAlias, TypeIs
from narwhals._duration import IntervalUnit
from narwhals._pandas_like.expr import PandasLikeExpr
from narwhals._pandas_like.series import PandasLikeSeries
from narwhals._pandas_like.typing import (
NativeDataFrameT,
NativeNDFrameT,
NativeSeriesT,
)
from narwhals.dtypes import DType
from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray
ExprT = TypeVar("ExprT", bound=PandasLikeExpr)
UnitCurrent: TypeAlias = TimeUnit
UnitTarget: TypeAlias = TimeUnit
BinOpBroadcast: TypeAlias = Callable[[Any, int], Any]
IntoRhs: TypeAlias = int
PANDAS_LIKE_IMPLEMENTATION = {
Implementation.PANDAS,
Implementation.CUDF,
Implementation.MODIN,
}
PD_DATETIME_RGX = r"""^
datetime64\[
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
(?:, # Begin non-capturing group for optional timezone
\s* # Optional whitespace after comma
(?P<time_zone> # Start named group for timezone
[a-zA-Z\/]+ # Match timezone name, e.g., UTC, America/New_York
(?:[+-]\d{2}:\d{2})? # Optional offset in format +HH:MM or -HH:MM
| # OR
pytz\.FixedOffset\(\d+\) # Match pytz.FixedOffset with integer offset in parentheses
) # End time_zone group
)? # End optional timezone group
\] # Closing bracket for datetime64
$"""
PATTERN_PD_DATETIME = re.compile(PD_DATETIME_RGX, re.VERBOSE)
PA_DATETIME_RGX = r"""^
timestamp\[
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
(?:, # Begin non-capturing group for optional timezone
\s?tz= # Match "tz=" prefix
(?P<time_zone> # Start named group for timezone
[a-zA-Z\/]* # Match timezone name (e.g., UTC, America/New_York)
(?: # Begin optional non-capturing group for offset
[+-]\d{2}:\d{2} # Match offset in format +HH:MM or -HH:MM
)? # End optional offset group
) # End time_zone group
)? # End optional timezone group
\] # Closing bracket for timestamp
\[pyarrow\] # Literal string "[pyarrow]"
$"""
PATTERN_PA_DATETIME = re.compile(PA_DATETIME_RGX, re.VERBOSE)
PD_DURATION_RGX = r"""^
timedelta64\[
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
\] # Closing bracket for timedelta64
$"""
PATTERN_PD_DURATION = re.compile(PD_DURATION_RGX, re.VERBOSE)
PA_DURATION_RGX = r"""^
duration\[
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
\] # Closing bracket for duration
\[pyarrow\] # Literal string "[pyarrow]"
$"""
PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE)
NativeIntervalUnit: TypeAlias = Literal[
"year",
"quarter",
"month",
"week",
"day",
"hour",
"minute",
"second",
"millisecond",
"microsecond",
"nanosecond",
]
ALIAS_DICT = {"d": "D", "m": "min"}
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
"y": "year",
"q": "quarter",
"mo": "month",
"d": "day",
"h": "hour",
"m": "minute",
"s": "second",
"ms": "millisecond",
"us": "microsecond",
"ns": "nanosecond",
}
PANDAS_VERSION = Implementation.PANDAS._backend_version()
"""Static backend version for `pandas`.
Always available if we reached here, due to a module-level import.
"""
def is_pandas_or_modin(implementation: Implementation) -> bool:
return implementation in {Implementation.PANDAS, Implementation.MODIN}
def align_and_extract_native(
lhs: PandasLikeSeries, rhs: PandasLikeSeries | object
) -> tuple[pd.Series[Any] | object, pd.Series[Any] | object]:
"""Validate RHS of binary operation.
If the comparison isn't supported, return `NotImplemented` so that the
"right-hand-side" operation (e.g. `__radd__`) can be tried.
"""
from narwhals._pandas_like.series import PandasLikeSeries
lhs_index = lhs.native.index
if lhs._broadcast and isinstance(rhs, PandasLikeSeries) and not rhs._broadcast:
return lhs.native.iloc[0], rhs.native
if isinstance(rhs, PandasLikeSeries):
if rhs._broadcast:
return (lhs.native, rhs.native.iloc[0])
if rhs.native.index is not lhs_index:
return (
lhs.native,
set_index(rhs.native, lhs_index, implementation=rhs._implementation),
)
return (lhs.native, rhs.native)
if isinstance(rhs, list):
msg = "Expected Series or scalar, got list."
raise TypeError(msg)
# `rhs` must be scalar, so just leave it as-is
return lhs.native, rhs
def set_index(
obj: NativeNDFrameT, index: Any, *, implementation: Implementation
) -> NativeNDFrameT:
"""Wrapper around pandas' set_axis to set object index.
We can set `copy` / `inplace` based on implementation/version.
"""
if isinstance(index, implementation.to_native_namespace().Index) and (
expected_len := len(index)
) != (actual_len := len(obj)):
msg = f"Expected object of length {expected_len}, got length: {actual_len}"
raise ShapeError(msg)
if implementation is Implementation.CUDF:
obj = obj.copy(deep=False)
obj.index = index
return obj
if implementation is Implementation.PANDAS and (
(1, 5) <= implementation._backend_version() < (3,)
): # pragma: no cover
return obj.set_axis(index, axis=0, copy=False)
return obj.set_axis(index, axis=0) # pragma: no cover
def rename(
obj: NativeNDFrameT, *args: Any, implementation: Implementation, **kwargs: Any
) -> NativeNDFrameT:
"""Wrapper around pandas' rename so that we can set `copy` based on implementation/version."""
if implementation is Implementation.PANDAS and (
implementation._backend_version() >= (3,)
): # pragma: no cover
return obj.rename(*args, **kwargs, inplace=False)
return obj.rename(*args, **kwargs, copy=False, inplace=False)
@functools.lru_cache(maxsize=16)
def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: # noqa: C901, PLR0912
dtype = str(native_dtype)
dtypes = version.dtypes
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
return dtypes.Int64()
if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}:
return dtypes.Int32()
if dtype in {"int16", "Int16", "Int16[pyarrow]", "int16[pyarrow]"}:
return dtypes.Int16()
if dtype in {"int8", "Int8", "Int8[pyarrow]", "int8[pyarrow]"}:
return dtypes.Int8()
if dtype in {"uint64", "UInt64", "UInt64[pyarrow]", "uint64[pyarrow]"}:
return dtypes.UInt64()
if dtype in {"uint32", "UInt32", "UInt32[pyarrow]", "uint32[pyarrow]"}:
return dtypes.UInt32()
if dtype in {"uint16", "UInt16", "UInt16[pyarrow]", "uint16[pyarrow]"}:
return dtypes.UInt16()
if dtype in {"uint8", "UInt8", "UInt8[pyarrow]", "uint8[pyarrow]"}:
return dtypes.UInt8()
if dtype in {
"float64",
"Float64",
"Float64[pyarrow]",
"float64[pyarrow]",
"double[pyarrow]",
}:
return dtypes.Float64()
if dtype in {
"float32",
"Float32",
"Float32[pyarrow]",
"float32[pyarrow]",
"float[pyarrow]",
}:
return dtypes.Float32()
if dtype in {
# "there is no problem which can't be solved by adding an extra string type" pandas
"string",
"string[python]",
"string[pyarrow]",
"string[pyarrow_numpy]",
"large_string[pyarrow]",
"str",
}:
return dtypes.String()
if dtype in {"bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"}:
return dtypes.Boolean()
if dtype.startswith("dictionary<"):
return dtypes.Categorical()
if dtype == "category":
return native_categorical_to_narwhals_dtype(native_dtype, version)
if (match_ := PATTERN_PD_DATETIME.match(dtype)) or (
match_ := PATTERN_PA_DATETIME.match(dtype)
):
dt_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
dt_time_zone: str | None = match_.group("time_zone")
return dtypes.Datetime(dt_time_unit, dt_time_zone)
if (match_ := PATTERN_PD_DURATION.match(dtype)) or (
match_ := PATTERN_PA_DURATION.match(dtype)
):
du_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
return dtypes.Duration(du_time_unit)
if dtype == "date32[day][pyarrow]":
return dtypes.Date()
if dtype.startswith("decimal") and dtype.endswith("[pyarrow]"):
return dtypes.Decimal()
if dtype.startswith("time") and dtype.endswith("[pyarrow]"):
return dtypes.Time()
if dtype.startswith("binary") and dtype.endswith("[pyarrow]"):
return dtypes.Binary()
return dtypes.Unknown() # pragma: no cover
def object_native_to_narwhals_dtype(
series: PandasLikeSeries | None, version: Version, implementation: Implementation
) -> DType:
dtypes = version.dtypes
if implementation is Implementation.CUDF:
# Per conversations with their maintainers, they don't support arbitrary
# objects, so we can just return String.
return dtypes.String()
infer = pd.api.types.infer_dtype
# Arbitrary limit of 100 elements to use to sniff dtype.
inferred_dtype = "empty" if series is None else infer(series.head(100), skipna=True)
if inferred_dtype == "string":
return dtypes.String()
if inferred_dtype == "empty" and version is not Version.V1:
# Default to String for empty Series.
return dtypes.String()
if inferred_dtype == "empty":
# But preserve returning Object in V1.
return dtypes.Object()
return dtypes.Object()
def native_categorical_to_narwhals_dtype(
native_dtype: pd.CategoricalDtype,
version: Version,
implementation: Literal[Implementation.CUDF] | None = None,
) -> DType:
dtypes = version.dtypes
if version is Version.V1:
return dtypes.Categorical()
if native_dtype.ordered:
into_iter = (
_cudf_categorical_to_list(native_dtype)
if implementation is Implementation.CUDF
else native_dtype.categories.to_list
)
return dtypes.Enum(_DeferredIterable(into_iter))
return dtypes.Categorical()
def _cudf_categorical_to_list(
native_dtype: Any,
) -> Callable[[], list[Any]]: # pragma: no cover
# NOTE: https://docs.rapids.ai/api/cudf/stable/user_guide/api_docs/api/cudf.core.dtypes.categoricaldtype/#cudf.core.dtypes.CategoricalDtype
def fn() -> list[Any]:
return native_dtype.categories.to_arrow().to_pylist()
return fn
def native_to_narwhals_dtype(
native_dtype: Any,
version: Version,
implementation: Implementation,
*,
allow_object: bool = False,
) -> DType:
str_dtype = str(native_dtype)
if str_dtype.startswith(("large_list", "list", "struct", "fixed_size_list")):
from narwhals._arrow.utils import (
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
)
if hasattr(native_dtype, "to_arrow"): # pragma: no cover
# cudf, cudf.pandas
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version)
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
if str_dtype == "category" and implementation.is_cudf():
# https://github.com/rapidsai/cudf/issues/18536
# https://github.com/rapidsai/cudf/issues/14027
return native_categorical_to_narwhals_dtype(
native_dtype, version, Implementation.CUDF
)
if str_dtype != "object":
return non_object_native_to_narwhals_dtype(native_dtype, version)
if implementation is Implementation.DASK:
# Per conversations with their maintainers, they don't support arbitrary
# objects, so we can just return String.
return version.dtypes.String()
if allow_object:
return object_native_to_narwhals_dtype(None, version, implementation)
msg = (
"Unreachable code, object dtype should be handled separately" # pragma: no cover
)
raise AssertionError(msg)
if Implementation.PANDAS._backend_version() >= (1, 2):
def is_dtype_numpy_nullable(dtype: Any) -> TypeIs[BaseMaskedDtype]:
"""Return `True` if `dtype` is `"numpy_nullable"`."""
# NOTE: We need a sentinel as the positive case is `BaseMaskedDtype.base = None`
# See https://github.com/narwhals-dev/narwhals/pull/2740#discussion_r2171667055
sentinel = object()
return (
isinstance(dtype, pd.api.extensions.ExtensionDtype)
and getattr(dtype, "base", sentinel) is None
)
else: # pragma: no cover
def is_dtype_numpy_nullable(dtype: Any) -> TypeIs[BaseMaskedDtype]:
# NOTE: `base` attribute was added between 1.1-1.2
# Checking by isinstance requires using an import path that is no longer valid
# `1.1`: https://github.com/pandas-dev/pandas/blob/b5958ee1999e9aead1938c0bba2b674378807b3d/pandas/core/arrays/masked.py#L37
# `1.2`: https://github.com/pandas-dev/pandas/blob/7c48ff4409c622c582c56a5702373f726de08e96/pandas/core/arrays/masked.py#L41
# `1.5`: https://github.com/pandas-dev/pandas/blob/35b0d1dcadf9d60722c055ee37442dc76a29e64c/pandas/core/dtypes/dtypes.py#L1609
if isinstance(dtype, pd.api.extensions.ExtensionDtype):
from pandas.core.arrays.masked import ( # type: ignore[attr-defined]
BaseMaskedDtype as OldBaseMaskedDtype, # pyright: ignore[reportAttributeAccessIssue]
)
return isinstance(dtype, OldBaseMaskedDtype)
return False
def get_dtype_backend(dtype: Any, implementation: Implementation) -> DTypeBackend:
"""Get dtype backend for pandas type.
Matches pandas' `dtype_backend` argument in `convert_dtypes`.
"""
if implementation is Implementation.CUDF:
return None
if is_dtype_pyarrow(dtype):
return "pyarrow"
return "numpy_nullable" if is_dtype_numpy_nullable(dtype) else None
# NOTE: Use this to avoid annotating inline
def iter_dtype_backends(
dtypes: Iterable[Any], implementation: Implementation
) -> Iterator[DTypeBackend]:
"""Yield a `DTypeBackend` per-dtype.
Matches pandas' `dtype_backend` argument in `convert_dtypes`.
"""
return (get_dtype_backend(dtype, implementation) for dtype in dtypes)
@functools.lru_cache(maxsize=16)
def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]:
return hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype)
dtypes = Version.MAIN.dtypes
NW_TO_PD_DTYPES_INVARIANT: Mapping[type[DType], str] = {
# TODO(Unassigned): is there no pyarrow-backed categorical?
# or at least, convert_dtypes(dtype_backend='pyarrow') doesn't
# convert to it?
dtypes.Categorical: "category",
dtypes.Object: "object",
}
NW_TO_PD_DTYPES_BACKEND: Mapping[type[DType], Mapping[DTypeBackend, str | type[Any]]] = {
dtypes.Float64: {
"pyarrow": "Float64[pyarrow]",
"numpy_nullable": "Float64",
None: "float64",
},
dtypes.Float32: {
"pyarrow": "Float32[pyarrow]",
"numpy_nullable": "Float32",
None: "float32",
},
dtypes.Int64: {"pyarrow": "Int64[pyarrow]", "numpy_nullable": "Int64", None: "int64"},
dtypes.Int32: {"pyarrow": "Int32[pyarrow]", "numpy_nullable": "Int32", None: "int32"},
dtypes.Int16: {"pyarrow": "Int16[pyarrow]", "numpy_nullable": "Int16", None: "int16"},
dtypes.Int8: {"pyarrow": "Int8[pyarrow]", "numpy_nullable": "Int8", None: "int8"},
dtypes.UInt64: {
"pyarrow": "UInt64[pyarrow]",
"numpy_nullable": "UInt64",
None: "uint64",
},
dtypes.UInt32: {
"pyarrow": "UInt32[pyarrow]",
"numpy_nullable": "UInt32",
None: "uint32",
},
dtypes.UInt16: {
"pyarrow": "UInt16[pyarrow]",
"numpy_nullable": "UInt16",
None: "uint16",
},
dtypes.UInt8: {"pyarrow": "UInt8[pyarrow]", "numpy_nullable": "UInt8", None: "uint8"},
dtypes.String: {"pyarrow": "string[pyarrow]", "numpy_nullable": "string", None: str},
dtypes.Boolean: {
"pyarrow": "boolean[pyarrow]",
"numpy_nullable": "boolean",
None: "bool",
},
}
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
def narwhals_to_native_dtype( # noqa: C901, PLR0912
dtype: IntoDType,
dtype_backend: DTypeBackend,
implementation: Implementation,
version: Version,
) -> str | PandasDtype:
if dtype_backend not in {None, "pyarrow", "numpy_nullable"}:
msg = f"Expected one of {{None, 'pyarrow', 'numpy_nullable'}}, got: '{dtype_backend}'"
raise ValueError(msg)
dtypes = version.dtypes
base_type = dtype.base_type()
if pd_type := NW_TO_PD_DTYPES_INVARIANT.get(base_type):
return pd_type
if into_pd_type := NW_TO_PD_DTYPES_BACKEND.get(base_type):
return into_pd_type[dtype_backend]
if isinstance_or_issubclass(dtype, dtypes.Datetime):
# Pandas does not support "ms" or "us" time units before version 2.0
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
2,
): # pragma: no cover
dt_time_unit = "ns"
else:
dt_time_unit = dtype.time_unit
if dtype_backend == "pyarrow":
tz_part = f", tz={tz}" if (tz := dtype.time_zone) else ""
return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]"
tz_part = f", {tz}" if (tz := dtype.time_zone) else ""
return f"datetime64[{dt_time_unit}{tz_part}]"
if isinstance_or_issubclass(dtype, dtypes.Duration):
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
2,
): # pragma: no cover
du_time_unit = "ns"
else:
du_time_unit = dtype.time_unit
return (
f"duration[{du_time_unit}][pyarrow]"
if dtype_backend == "pyarrow"
else f"timedelta64[{du_time_unit}]"
)
if isinstance_or_issubclass(dtype, dtypes.Date):
try:
import pyarrow as pa # ignore-banned-import # noqa: F401
except ModuleNotFoundError as exc: # pragma: no cover
# BUG: Never re-raised?
msg = "'pyarrow>=13.0.0' is required for `Date` dtype."
raise ModuleNotFoundError(msg) from exc
return "date32[pyarrow]"
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
ns = implementation.to_native_namespace()
return ns.CategoricalDtype(dtype.categories, ordered=True)
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if issubclass(
base_type, (dtypes.Struct, dtypes.Array, dtypes.List, dtypes.Time, dtypes.Binary)
):
return narwhals_to_native_arrow_dtype(dtype, implementation, version)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for {implementation}."
raise NotImplementedError(msg)
msg = f"Unknown dtype: {dtype}" # pragma: no cover
raise AssertionError(msg)
def narwhals_to_native_arrow_dtype(
dtype: IntoDType, implementation: Implementation, version: Version
) -> pd.ArrowDtype:
if is_pandas_or_modin(implementation) and PANDAS_VERSION >= (2, 2):
try:
import pyarrow as pa # ignore-banned-import # noqa: F401
except ImportError as exc: # pragma: no cover
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
raise ImportError(msg) from exc
from narwhals._arrow.utils import narwhals_to_native_dtype as _to_arrow_dtype
return pd.ArrowDtype(_to_arrow_dtype(dtype, version))
msg = ( # pragma: no cover
f"Converting to {dtype} dtype is not supported for implementation "
f"{implementation} and version {version}."
)
raise NotImplementedError(msg)
def int_dtype_mapper(dtype: Any) -> str:
if "pyarrow" in str(dtype):
return "Int64[pyarrow]"
if str(dtype).lower() != str(dtype): # pragma: no cover
return "Int64"
return "int64"
_TIMESTAMP_DATETIME_OP_FACTOR: Mapping[
tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]
] = {
("ns", "us"): (operator.floordiv, 1_000),
("ns", "ms"): (operator.floordiv, 1_000_000),
("us", "ns"): (operator.mul, NS_PER_MICROSECOND),
("us", "ms"): (operator.floordiv, 1_000),
("ms", "ns"): (operator.mul, NS_PER_MILLISECOND),
("ms", "us"): (operator.mul, 1_000),
("s", "ns"): (operator.mul, NS_PER_SECOND),
("s", "us"): (operator.mul, US_PER_SECOND),
("s", "ms"): (operator.mul, MS_PER_SECOND),
}
def calculate_timestamp_datetime(
s: NativeSeriesT, current: TimeUnit, time_unit: TimeUnit
) -> NativeSeriesT:
if current == time_unit:
return s
if item := _TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
fn, factor = item
return fn(s, factor)
msg = ( # pragma: no cover
f"unexpected time unit {current}, please report an issue at "
"https://github.com/narwhals-dev/narwhals"
)
raise AssertionError(msg)
_TIMESTAMP_DATE_FACTOR: Mapping[TimeUnit, int] = {
"ns": NS_PER_SECOND,
"us": US_PER_SECOND,
"ms": MS_PER_SECOND,
"s": 1,
}
def calculate_timestamp_date(s: NativeSeriesT, time_unit: TimeUnit) -> NativeSeriesT:
return s * SECONDS_PER_DAY * _TIMESTAMP_DATE_FACTOR[time_unit]
def select_columns_by_name(
df: NativeDataFrameT,
column_names: list[str] | _1DArray, # NOTE: Cannot be a tuple!
implementation: Implementation,
) -> NativeDataFrameT | Any:
"""Select columns by name.
Prefer this over `df.loc[:, column_names]` as it's
generally more performant.
"""
if len(column_names) == df.shape[1] and (df.columns == column_names).all():
return df
if (df.columns.dtype.kind == "b") or (
implementation is Implementation.PANDAS
and implementation._backend_version() < (1, 5)
):
# See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122
# for why we need this
if error := check_columns_exist(column_names, available=df.columns.tolist()):
raise error
return df.loc[:, column_names]
try:
return df[column_names]
except KeyError as e:
if error := check_columns_exist(column_names, available=df.columns.tolist()):
raise error from e
raise
def is_non_nullable_boolean(s: PandasLikeSeries) -> bool:
# cuDF booleans are nullable but the native dtype is still 'bool'.
return (
s._implementation
in {Implementation.PANDAS, Implementation.MODIN, Implementation.DASK}
and s.native.dtype == "bool"
)
def import_array_module(implementation: Implementation, /) -> ModuleType:
"""Returns numpy or cupy module depending on the given implementation."""
if implementation in {Implementation.PANDAS, Implementation.MODIN}:
import numpy as np
return np
if implementation is Implementation.CUDF:
import cupy as cp # ignore-banned-import # cuDF dependency.
return cp
msg = f"Expected pandas/modin/cudf, got: {implementation}" # pragma: no cover
raise AssertionError(msg)
class PandasLikeSeriesNamespace(EagerSeriesNamespace["PandasLikeSeries", Any]): ...

View File

@ -0,0 +1,678 @@
from __future__ import annotations
from collections.abc import Iterator, Mapping, Sequence, Sized
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload
import polars as pl
from narwhals._polars.namespace import PolarsNamespace
from narwhals._polars.series import PolarsSeries
from narwhals._polars.utils import (
catch_polars_exception,
extract_args_kwargs,
native_to_narwhals_dtype,
)
from narwhals._utils import (
Implementation,
_into_arrow_table,
convert_str_slice_to_int_slice,
is_compliant_series,
is_index_selector,
is_range,
is_sequence_like,
is_slice_index,
is_slice_none,
parse_columns_to_drop,
requires,
)
from narwhals.dependencies import is_numpy_array_1d
from narwhals.exceptions import ColumnNotFoundError
if TYPE_CHECKING:
from collections.abc import Iterable
from types import ModuleType
from typing import Callable
import pandas as pd
import pyarrow as pa
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy
from narwhals._spark_like.utils import SparkSession
from narwhals._translate import IntoArrowTable
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import DataFrame, LazyFrame
from narwhals.dtypes import DType
from narwhals.typing import (
IntoSchema,
JoinStrategy,
MultiColSelector,
MultiIndexSelector,
PivotAgg,
SingleIndexSelector,
_2DArray,
)
T = TypeVar("T")
R = TypeVar("R")
Method: TypeAlias = "Callable[..., R]"
"""Generic alias representing all methods implemented via `__getattr__`.
Where `R` is the return type.
"""
# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly.
INHERITED_METHODS = frozenset(
[
"clone",
"drop_nulls",
"estimated_size",
"explode",
"filter",
"gather_every",
"head",
"is_unique",
"item",
"iter_rows",
"join_asof",
"rename",
"row",
"rows",
"sample",
"select",
"sink_parquet",
"sort",
"tail",
"to_arrow",
"to_pandas",
"unique",
"with_columns",
"write_csv",
"write_parquet",
]
)
NativePolarsFrame = TypeVar("NativePolarsFrame", pl.DataFrame, pl.LazyFrame)
class PolarsBaseFrame(Generic[NativePolarsFrame]):
drop_nulls: Method[Self]
explode: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
head: Method[Self]
join_asof: Method[Self]
rename: Method[Self]
select: Method[Self]
sort: Method[Self]
tail: Method[Self]
unique: Method[Self]
with_columns: Method[Self]
_native_frame: NativePolarsFrame
_implementation = Implementation.POLARS
_version: Version
def __init__(
self,
df: NativePolarsFrame,
*,
version: Version,
validate_backend_version: bool = False,
) -> None:
self._native_frame = df
self._version = version
if validate_backend_version:
self._validate_backend_version()
def _validate_backend_version(self) -> None:
"""Raise if installed version below `nw._utils.MIN_VERSIONS`.
**Only use this when moving between backends.**
Otherwise, the validation will have taken place already.
"""
_ = self._implementation._backend_version()
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def native(self) -> NativePolarsFrame:
return self._native_frame
@property
def columns(self) -> list[str]:
return self.native.columns
def __narwhals_namespace__(self) -> PolarsNamespace:
return PolarsNamespace(version=self._version)
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _with_native(self, df: NativePolarsFrame) -> Self:
return self.__class__(df, version=self._version)
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
@classmethod
def from_native(cls, data: NativePolarsFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: Any) -> Self:
return self.select(*exprs)
@property
def schema(self) -> dict[str, DType]:
return self.collect_schema()
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
how_native = (
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
)
return self._with_native(
self.native.join(
other=other.native,
how=how_native, # type: ignore[arg-type]
left_on=left_on,
right_on=right_on,
suffix=suffix,
)
)
def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.top_k(
k=k,
by=by,
descending=reverse, # type: ignore[call-arg]
)
)
return self._with_native(self.native.top_k(k=k, by=by, reverse=reverse))
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(
self.native.melt(
id_vars=index,
value_vars=on,
variable_name=variable_name,
value_name=value_name,
)
)
return self._with_native(
self.native.unpivot(
on=on, index=index, variable_name=variable_name, value_name=value_name
)
)
def collect_schema(self) -> dict[str, DType]:
df = self.native
schema = df.schema if self._backend_version < (1,) else df.collect_schema()
return {
name: native_to_narwhals_dtype(dtype, self._version)
for name, dtype in schema.items()
}
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
frame = self.native
if order_by is None:
result = frame.with_row_index(name)
else:
end = pl.count() if self._backend_version < (0, 20, 5) else pl.len()
result = frame.select(
pl.int_range(start=0, end=end).sort_by(order_by).alias(name), pl.all()
)
return self._with_native(result)
class PolarsDataFrame(PolarsBaseFrame[pl.DataFrame]):
clone: Method[Self]
collect: Method[CompliantDataFrameAny]
estimated_size: Method[int | float]
gather_every: Method[Self]
item: Method[Any]
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
is_unique: Method[PolarsSeries]
row: Method[tuple[Any, ...]]
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
sample: Method[Self]
to_arrow: Method[pa.Table]
to_pandas: Method[pd.DataFrame]
# NOTE: `write_csv` requires an `@overload` for `str | None`
# Can't do that here 😟
write_csv: Method[Any]
write_parquet: Method[None]
@classmethod
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
if context._implementation._backend_version() >= (1, 3):
native = pl.DataFrame(data)
else: # pragma: no cover
native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
return cls.from_native(native, context=context)
@classmethod
def from_dict(
cls,
data: Mapping[str, Any],
/,
*,
context: _LimitedContext,
schema: IntoSchema | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = Schema(schema).to_polars() if schema is not None else schema
return cls.from_native(pl.from_dict(data, pl_schema), context=context)
@staticmethod
def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]:
return isinstance(obj, pl.DataFrame)
@classmethod
def from_numpy(
cls,
data: _2DArray,
/,
*,
context: _LimitedContext, # NOTE: Maybe only `Implementation`?
schema: IntoSchema | Sequence[str] | None,
) -> Self:
from narwhals.schema import Schema
pl_schema = (
Schema(schema).to_polars()
if isinstance(schema, (Mapping, Schema))
else schema
)
return cls.from_native(pl.from_numpy(data, pl_schema), context=context)
def to_narwhals(self) -> DataFrame[pl.DataFrame]:
return self._version.dataframe(self, level="full")
def __repr__(self) -> str: # pragma: no cover
return "PolarsDataFrame"
def __narwhals_dataframe__(self) -> Self:
return self
@overload
def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ...
@overload
def _from_native_object(self, obj: pl.DataFrame) -> Self: ...
@overload
def _from_native_object(self, obj: T) -> T: ...
def _from_native_object(
self, obj: pl.Series | pl.DataFrame | T
) -> Self | PolarsSeries | T:
if isinstance(obj, pl.Series):
return PolarsSeries.from_native(obj, context=self)
if self._is_native(obj):
return self._with_native(obj)
# scalar
return obj
def __len__(self) -> int:
return len(self.native)
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS: # pragma: no cover
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
try:
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
raise ColumnNotFoundError(msg) from e
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
return func
def __array__(
self, dtype: Any | None = None, *, copy: bool | None = None
) -> _2DArray:
if self._backend_version < (0, 20, 28) and copy is not None:
msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28):
return self.native.__array__(dtype)
return self.native.__array__(dtype)
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
return self.native.to_numpy()
@property
def shape(self) -> tuple[int, int]:
return self.native.shape
def __getitem__( # noqa: C901, PLR0912
self,
item: tuple[
SingleIndexSelector | MultiIndexSelector[PolarsSeries],
MultiColSelector[PolarsSeries],
],
) -> Any:
rows, columns = item
if self._backend_version > (0, 20, 30):
rows_native = rows.native if is_compliant_series(rows) else rows
columns_native = columns.native if is_compliant_series(columns) else columns
selector = rows_native, columns_native
selected = self.native.__getitem__(selector) # type: ignore[index]
return self._from_native_object(selected)
else: # pragma: no cover # noqa: RET505
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
# This mostly mirrors the logic in `EagerDataFrame.__getitem__`.
rows = list(rows) if isinstance(rows, tuple) else rows
columns = list(columns) if isinstance(columns, tuple) else columns
if is_numpy_array_1d(columns):
columns = columns.tolist()
native = self.native
if not is_slice_none(columns):
if isinstance(columns, Sized) and len(columns) == 0:
return self.select()
if is_index_selector(columns):
if is_slice_index(columns) or is_range(columns):
native = native.select(
self.columns[slice(columns.start, columns.stop, columns.step)]
)
# NOTE: `mypy` loses track of `PolarsSeries` when `is_compliant_series` is used here
# `pyright` is fine
elif isinstance(columns, PolarsSeries):
native = native[:, columns.native.to_list()]
else:
native = native[:, columns]
elif isinstance(columns, slice):
native = native.select(
self.columns[
slice(*convert_str_slice_to_int_slice(columns, self.columns))
]
)
elif is_compliant_series(columns):
native = native.select(columns.native.to_list())
elif is_sequence_like(columns):
native = native.select(columns)
else:
msg = f"Unreachable code, got unexpected type: {type(columns)}"
raise AssertionError(msg)
if not is_slice_none(rows):
if isinstance(rows, int):
native = native[[rows], :]
elif isinstance(rows, (slice, range)):
native = native[rows, :]
elif is_compliant_series(rows):
native = native[rows.native, :]
elif is_sequence_like(rows):
native = native[rows, :]
else:
msg = f"Unreachable code, got unexpected type: {type(rows)}"
raise AssertionError(msg)
return self._with_native(native)
def get_column(self, name: str) -> PolarsSeries:
return PolarsSeries.from_native(self.native.get_column(name), context=self)
def iter_columns(self) -> Iterator[PolarsSeries]:
for series in self.native.iter_columns():
yield PolarsSeries.from_native(series, context=self)
def lazy(
self,
backend: _LazyAllowedImpl | None = None,
*,
session: SparkSession | None = None,
) -> CompliantLazyFrameAny:
if backend is None or backend is Implementation.POLARS:
return PolarsLazyFrame.from_native(self.native.lazy(), context=self)
if backend is Implementation.DUCKDB:
import duckdb # ignore-banned-import
from narwhals._duckdb.dataframe import DuckDBLazyFrame
_df = self.native
return DuckDBLazyFrame(
duckdb.table("_df"), validate_backend_version=True, version=self._version
)
if backend is Implementation.DASK:
import dask.dataframe as dd # ignore-banned-import
from narwhals._dask.dataframe import DaskLazyFrame
return DaskLazyFrame(
dd.from_pandas(self.native.to_pandas()),
validate_backend_version=True,
version=self._version,
)
if backend is Implementation.IBIS:
import ibis # ignore-banned-import
from narwhals._ibis.dataframe import IbisLazyFrame
return IbisLazyFrame(
ibis.memtable(self.native, columns=self.columns),
validate_backend_version=True,
version=self._version,
)
if backend.is_spark_like():
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
if session is None:
msg = "Spark like backends require `session` to be not None."
raise ValueError(msg)
return SparkLikeLazyFrame._from_compliant_dataframe(
self, # pyright: ignore[reportArgumentType]
session=session,
implementation=backend,
version=self._version,
)
raise AssertionError # pragma: no cover
@overload
def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...
@overload
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
def to_dict(
self, *, as_series: bool
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
if as_series:
return {
name: PolarsSeries.from_native(col, context=self)
for name, col in self.native.to_dict().items()
}
return self.native.to_dict(as_series=False)
def group_by(
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
) -> PolarsGroupBy:
from narwhals._polars.group_by import PolarsGroupBy
return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(to_drop))
@requires.backend_version((1,))
def pivot(
self,
on: Sequence[str],
*,
index: Sequence[str] | None,
values: Sequence[str] | None,
aggregate_function: PivotAgg | None,
sort_columns: bool,
separator: str,
) -> Self:
try:
result = self.native.pivot(
on,
index=index,
values=values,
aggregate_function=aggregate_function,
sort_columns=sort_columns,
separator=separator,
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
return self._from_native_object(result)
def to_polars(self) -> pl.DataFrame:
return self.native
def join(
self,
other: Self,
*,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
try:
return super().join(
other=other, how=how, left_on=left_on, right_on=right_on, suffix=suffix
)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
def top_k(
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
) -> Self:
try:
return super().top_k(k=k, by=by, reverse=reverse)
except Exception as e: # noqa: BLE001 # pragma: no cover
raise catch_polars_exception(e) from None
class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]):
sink_parquet: Method[None]
@staticmethod
def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]:
return isinstance(obj, pl.LazyFrame)
def to_narwhals(self) -> LazyFrame[pl.LazyFrame]:
return self._version.lazyframe(self, level="lazy")
def __repr__(self) -> str: # pragma: no cover
return "PolarsLazyFrame"
def __narwhals_lazyframe__(self) -> Self:
return self
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS: # pragma: no cover
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
try:
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
raise ColumnNotFoundError(str(e)) from e
return func
def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover
yield from self.collect(Implementation.POLARS).iter_columns()
def collect_schema(self) -> dict[str, DType]:
try:
return super().collect_schema()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
try:
result = self.native.collect(**kwargs)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
if backend is None or backend is Implementation.POLARS:
return PolarsDataFrame.from_native(result, context=self)
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
result.to_pandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
if backend is Implementation.PYARROW:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
result.to_arrow(),
validate_backend_version=True,
version=self._version,
validate_column_names=False,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def group_by(
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
) -> PolarsLazyGroupBy:
from narwhals._polars.group_by import PolarsLazyGroupBy
return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
if self._backend_version < (1, 0, 0):
return self._with_native(self.native.drop(columns))
return self._with_native(self.native.drop(columns, strict=strict))

View File

@ -0,0 +1,479 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, Literal
import polars as pl
from narwhals._polars.utils import (
PolarsAnyNamespace,
PolarsCatNamespace,
PolarsDateTimeNamespace,
PolarsListNamespace,
PolarsStringNamespace,
PolarsStructNamespace,
extract_args_kwargs,
extract_native,
narwhals_to_native_dtype,
)
from narwhals._utils import Implementation, requires
if TYPE_CHECKING:
from collections.abc import Mapping, Sequence
from typing_extensions import Self
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._polars.dataframe import Method
from narwhals._polars.namespace import PolarsNamespace
from narwhals._utils import Version
from narwhals.typing import IntoDType, ModeKeepStrategy, NumericLiteral
class PolarsExpr:
# CompliantExpr
_implementation: Implementation = Implementation.POLARS
_version: Version
_native_expr: pl.Expr
_metadata: ExprMetadata | None = None
_evaluate_output_names: Any
_alias_output_names: Any
__call__: Any
# CompliantExpr + builtin descriptor
# TODO @dangotbanned: Remove in #2713
@classmethod
def from_column_names(cls, *_: Any, **__: Any) -> Self:
raise NotImplementedError
@classmethod
def from_column_indices(cls, *_: Any, **__: Any) -> Self:
raise NotImplementedError
@staticmethod
def _eval_names_indices(*_: Any) -> Any:
raise NotImplementedError
def __narwhals_expr__(self) -> Self: # pragma: no cover
return self
def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
from narwhals._polars.namespace import PolarsNamespace
return PolarsNamespace(version=self._version)
def __init__(self, expr: pl.Expr, version: Version) -> None:
self._native_expr = expr
self._version = version
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
@property
def native(self) -> pl.Expr:
return self._native_expr
def __repr__(self) -> str: # pragma: no cover
return "PolarsExpr"
def _with_native(self, expr: pl.Expr) -> Self:
return self.__class__(expr, self._version)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
# Let Polars do its thing.
return self
def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
return func
def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
return {name: min_samples}
def cast(self, dtype: IntoDType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
return self._with_native(self.native.cast(dtype_pl))
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self:
native = self.native.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
ignore_nulls=ignore_nulls,
**self._renamed_min_periods(min_samples),
)
if self._backend_version < (1,): # pragma: no cover
native = pl.when(~self.native.is_null()).then(native).otherwise(None)
return self._with_native(native)
def is_nan(self) -> Self:
if self._backend_version >= (1, 18):
native = self.native.is_nan()
else: # pragma: no cover
native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
return self._with_native(native)
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
if self._backend_version < (1, 9):
if order_by:
msg = "`order_by` in Polars requires version 1.10 or greater"
raise NotImplementedError(msg)
native = self.native.over(partition_by or pl.lit(1))
else:
native = self.native.over(
partition_by or pl.lit(1), order_by=order_by or None
)
return self._with_native(native)
@requires.backend_version((1,))
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_var(
window_size=window_size, center=center, ddof=ddof, **kwds
)
return self._with_native(native)
@requires.backend_version((1,))
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_std(
window_size=window_size, center=center, ddof=ddof, **kwds
)
return self._with_native(native)
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_sum(window_size=window_size, center=center, **kwds)
return self._with_native(native)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
kwds = self._renamed_min_periods(min_samples)
native = self.native.rolling_mean(window_size=window_size, center=center, **kwds)
return self._with_native(native)
def map_batches(
self,
function: Callable[[Any], Any],
return_dtype: IntoDType | None,
*,
returns_scalar: bool,
) -> Self:
pl_version = self._backend_version
return_dtype_pl = (
narwhals_to_native_dtype(return_dtype, self._version)
if return_dtype is not None
else None
if pl_version < (1, 32)
else pl.self_dtype()
)
kwargs = {} if pl_version < (0, 20, 31) else {"returns_scalar": returns_scalar}
native = self.native.map_batches(function, return_dtype_pl, **kwargs)
return self._with_native(native)
@requires.backend_version((1,))
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self:
return_dtype_pl = (
narwhals_to_native_dtype(return_dtype, self._version)
if return_dtype
else None
)
native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl)
return self._with_native(native)
def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator]
def __ne__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator]
def __ge__(self, other: Any) -> Self:
return self._with_native(self.native.__ge__(extract_native(other)))
def __gt__(self, other: Any) -> Self:
return self._with_native(self.native.__gt__(extract_native(other)))
def __le__(self, other: Any) -> Self:
return self._with_native(self.native.__le__(extract_native(other)))
def __lt__(self, other: Any) -> Self:
return self._with_native(self.native.__lt__(extract_native(other)))
def __and__(self, other: PolarsExpr | bool | Any) -> Self:
return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator]
def __or__(self, other: PolarsExpr | bool | Any) -> Self:
return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator]
def __add__(self, other: Any) -> Self:
return self._with_native(self.native.__add__(extract_native(other)))
def __sub__(self, other: Any) -> Self:
return self._with_native(self.native.__sub__(extract_native(other)))
def __mul__(self, other: Any) -> Self:
return self._with_native(self.native.__mul__(extract_native(other)))
def __pow__(self, other: Any) -> Self:
return self._with_native(self.native.__pow__(extract_native(other)))
def __truediv__(self, other: Any) -> Self:
return self._with_native(self.native.__truediv__(extract_native(other)))
def __floordiv__(self, other: Any) -> Self:
return self._with_native(self.native.__floordiv__(extract_native(other)))
def __mod__(self, other: Any) -> Self:
return self._with_native(self.native.__mod__(extract_native(other)))
def __invert__(self) -> Self:
return self._with_native(self.native.__invert__())
def cum_count(self, *, reverse: bool) -> Self:
return self._with_native(self.native.cum_count(reverse=reverse))
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> Self:
left = self.native
right = other.native if isinstance(other, PolarsExpr) else pl.lit(other)
if self._backend_version < (1, 32, 0):
lower_bound = right.abs()
tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol)
# Values are close if abs_diff <= tolerance, and both finite
abs_diff = (left - right).abs()
all_ = pl.all_horizontal
is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite())
# Handle infinity cases: infinities are "close" only if they have the same sign
is_same_inf = all_(
left.is_infinite(), right.is_infinite(), (left.sign() == right.sign())
)
# Handle nan cases:
# * nans_equals = True => if both values are NaN, then True
# * nans_equals = False => if any value is NaN, then False
left_is_nan, right_is_nan = left.is_nan(), right.is_nan()
either_nan = left_is_nan | right_is_nan
result = (is_close | is_same_inf) & either_nan.not_()
if nans_equal:
result = result | (left_is_nan & right_is_nan)
else:
result = left.is_close(
right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)
def mode(self, *, keep: ModeKeepStrategy) -> Self:
result = self.native.mode()
return self._with_native(result.first() if keep == "any" else result)
@property
def dt(self) -> PolarsExprDateTimeNamespace:
return PolarsExprDateTimeNamespace(self)
@property
def str(self) -> PolarsExprStringNamespace:
return PolarsExprStringNamespace(self)
@property
def cat(self) -> PolarsExprCatNamespace:
return PolarsExprCatNamespace(self)
@property
def name(self) -> PolarsExprNameNamespace:
return PolarsExprNameNamespace(self)
@property
def list(self) -> PolarsExprListNamespace:
return PolarsExprListNamespace(self)
@property
def struct(self) -> PolarsExprStructNamespace:
return PolarsExprStructNamespace(self)
# Polars
abs: Method[Self]
all: Method[Self]
any: Method[Self]
alias: Method[Self]
arg_max: Method[Self]
arg_min: Method[Self]
arg_true: Method[Self]
clip: Method[Self]
count: Method[Self]
cum_max: Method[Self]
cum_min: Method[Self]
cum_prod: Method[Self]
cum_sum: Method[Self]
diff: Method[Self]
drop_nulls: Method[Self]
exp: Method[Self]
fill_null: Method[Self]
fill_nan: Method[Self]
gather_every: Method[Self]
head: Method[Self]
is_between: Method[Self]
is_duplicated: Method[Self]
is_finite: Method[Self]
is_first_distinct: Method[Self]
is_in: Method[Self]
is_last_distinct: Method[Self]
is_null: Method[Self]
is_unique: Method[Self]
kurtosis: Method[Self]
len: Method[Self]
log: Method[Self]
max: Method[Self]
mean: Method[Self]
median: Method[Self]
min: Method[Self]
n_unique: Method[Self]
null_count: Method[Self]
quantile: Method[Self]
rank: Method[Self]
round: Method[Self]
sample: Method[Self]
shift: Method[Self]
skew: Method[Self]
sqrt: Method[Self]
std: Method[Self]
sum: Method[Self]
sort: Method[Self]
tail: Method[Self]
unique: Method[Self]
var: Method[Self]
__rfloordiv__: Method[Self]
__rsub__: Method[Self]
__rmod__: Method[Self]
__rpow__: Method[Self]
__rtruediv__: Method[Self]
class PolarsExprNamespace(PolarsAnyNamespace[PolarsExpr, pl.Expr]):
def __init__(self, expr: PolarsExpr) -> None:
self._expr = expr
@property
def compliant(self) -> PolarsExpr:
return self._expr
@property
def native(self) -> pl.Expr:
return self._expr.native
class PolarsExprDateTimeNamespace(
PolarsExprNamespace, PolarsDateTimeNamespace[PolarsExpr, pl.Expr]
): ...
class PolarsExprStringNamespace(
PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr]
):
def zfill(self, width: int) -> PolarsExpr:
backend_version = self.compliant._backend_version
native_result = self.native.str.zfill(width)
if backend_version < (0, 20, 5): # pragma: no cover
# Reason:
# `TypeError: argument 'length': 'Expr' object cannot be interpreted as an integer`
# in `native_expr.str.slice(1, length)`
msg = "`zfill` is only available in 'polars>=0.20.5', found version '0.20.4'."
raise NotImplementedError(msg)
if backend_version <= (1, 30, 0):
length = self.native.str.len_chars()
less_than_width = length < width
plus = "+"
starts_with_plus = self.native.str.starts_with(plus)
native_result = (
pl.when(starts_with_plus & less_than_width)
.then(
self.native.str.slice(1, length)
.str.zfill(width - 1)
.str.pad_start(width, plus)
)
.otherwise(native_result)
)
return self.compliant._with_native(native_result)
class PolarsExprCatNamespace(
PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr]
): ...
class PolarsExprNameNamespace(PolarsExprNamespace):
_accessor = "name"
keep: Method[PolarsExpr]
map: Method[PolarsExpr]
prefix: Method[PolarsExpr]
suffix: Method[PolarsExpr]
to_lowercase: Method[PolarsExpr]
to_uppercase: Method[PolarsExpr]
class PolarsExprListNamespace(
PolarsExprNamespace, PolarsListNamespace[PolarsExpr, pl.Expr]
):
def len(self) -> PolarsExpr:
native_expr = self.native
native_result = native_expr.list.len()
if self.compliant._backend_version < (1, 16): # pragma: no cover
native_result = (
pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32())
)
elif self.compliant._backend_version < (1, 17): # pragma: no cover
native_result = native_result.cast(pl.UInt32())
return self.compliant._with_native(native_result)
def contains(self, item: Any) -> PolarsExpr:
if self.compliant._backend_version < (1, 28):
result: pl.Expr = pl.when(self.native.is_not_null()).then(
self.native.list.contains(item)
)
else:
result = self.native.list.contains(item)
return self.compliant._with_native(result)
class PolarsExprStructNamespace(
PolarsExprNamespace, PolarsStructNamespace[PolarsExpr, pl.Expr]
): ...

View File

@ -0,0 +1,76 @@
from __future__ import annotations
from typing import TYPE_CHECKING, cast
from narwhals._utils import is_sequence_of
if TYPE_CHECKING:
from collections.abc import Iterator, Sequence
from polars.dataframe.group_by import GroupBy as NativeGroupBy
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
class PolarsGroupBy:
_compliant_frame: PolarsDataFrame
_grouped: NativeGroupBy
@property
def compliant(self) -> PolarsDataFrame:
return self._compliant_frame
def __init__(
self,
df: PolarsDataFrame,
keys: Sequence[PolarsExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._keys = list(keys)
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
self._grouped = (
self.compliant.native.group_by(keys)
if is_sequence_of(keys, str)
else self.compliant.native.group_by(arg.native for arg in keys)
)
def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame:
agg_result = self._grouped.agg(arg.native for arg in aggs)
return self.compliant._with_native(agg_result)
def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
for key, df in self._grouped:
yield tuple(cast("str", key)), self.compliant._with_native(df)
class PolarsLazyGroupBy:
_compliant_frame: PolarsLazyFrame
_grouped: NativeLazyGroupBy
@property
def compliant(self) -> PolarsLazyFrame:
return self._compliant_frame
def __init__(
self,
df: PolarsLazyFrame,
keys: Sequence[PolarsExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
self._keys = list(keys)
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
self._grouped = (
self.compliant.native.group_by(keys)
if is_sequence_of(keys, str)
else self.compliant.native.group_by(arg.native for arg in keys)
)
def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame:
agg_result = self._grouped.agg(arg.native for arg in aggs)
return self.compliant._with_native(agg_result)

View File

@ -0,0 +1,281 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Any, Literal, cast, overload
import polars as pl
from narwhals._expression_parsing import is_expr, is_series
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype
from narwhals._utils import Implementation, requires, zip_strict
from narwhals.dependencies import is_numpy_array_2d
from narwhals.dtypes import DType
if TYPE_CHECKING:
from collections.abc import Iterable, Sequence
from datetime import timezone
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.typing import FrameT
from narwhals._utils import Version, _LimitedContext
from narwhals.expr import Expr
from narwhals.series import Series
from narwhals.typing import (
Into1DArray,
IntoDType,
IntoSchema,
NonNestedLiteral,
TimeUnit,
_1DArray,
_2DArray,
)
class PolarsNamespace:
all: Method[PolarsExpr]
coalesce: Method[PolarsExpr]
col: Method[PolarsExpr]
exclude: Method[PolarsExpr]
sum_horizontal: Method[PolarsExpr]
min_horizontal: Method[PolarsExpr]
max_horizontal: Method[PolarsExpr]
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]]
_implementation: Implementation = Implementation.POLARS
_version: Version
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __init__(self, *, version: Version) -> None:
self._version = version
def __getattr__(self, attr: str) -> Any:
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
return self._expr(getattr(pl, attr)(*pos, **kwds), version=self._version)
return func
@property
def _dataframe(self) -> type[PolarsDataFrame]:
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame
@property
def _lazyframe(self) -> type[PolarsLazyFrame]:
from narwhals._polars.dataframe import PolarsLazyFrame
return PolarsLazyFrame
@property
def _expr(self) -> type[PolarsExpr]:
return PolarsExpr
@property
def _series(self) -> type[PolarsSeries]:
return PolarsSeries
def parse_into_expr(
self,
data: Expr | NonNestedLiteral | Series[pl.Series] | _1DArray,
/,
*,
str_as_lit: bool,
) -> PolarsExpr | None:
if data is None:
# NOTE: To avoid `pl.lit(None)` failing this `None` check
# https://github.com/pola-rs/polars/blob/58dd8e5770f16a9bef9009a1c05f00e15a5263c7/py-polars/polars/expr/expr.py#L2870-L2872
return data
if is_expr(data):
expr = data._to_compliant_expr(self)
assert isinstance(expr, self._expr) # noqa: S101
return expr
if isinstance(data, str) and not str_as_lit:
return self.col(data)
return self.lit(data.to_native() if is_series(data) else data, None)
@overload
def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ...
@overload
def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ...
@overload
def from_native(self, data: pl.Series, /) -> PolarsSeries: ...
def from_native(
self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, /
) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries:
if self._dataframe._is_native(data):
return self._dataframe.from_native(data, context=self)
if self._series._is_native(data):
return self._series.from_native(data, context=self)
if self._lazyframe._is_native(data):
return self._lazyframe.from_native(data, context=self)
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
@overload
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ...
@overload
def from_numpy(
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
) -> PolarsDataFrame: ...
def from_numpy(
self,
data: Into1DArray | _2DArray,
/,
schema: IntoSchema | Sequence[str] | None = None,
) -> PolarsDataFrame | PolarsSeries:
if is_numpy_array_2d(data):
return self._dataframe.from_numpy(data, schema=schema, context=self)
return self._series.from_numpy(data, context=self) # pragma: no cover
@requires.backend_version(
(1, 0, 0), "Please use `col` for columns selection instead."
)
def nth(self, *indices: int) -> PolarsExpr:
return self._expr(pl.nth(*indices), version=self._version)
def len(self) -> PolarsExpr:
if self._backend_version < (0, 20, 5):
return self._expr(pl.count().alias("len"), self._version)
return self._expr(pl.len(), self._version)
def all_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
it = (expr.fill_null(True) for expr in exprs) if ignore_nulls else iter(exprs)
return self._expr(pl.all_horizontal(*(expr.native for expr in it)), self._version)
def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
it = (expr.fill_null(False) for expr in exprs) if ignore_nulls else iter(exprs)
return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version)
def concat(
self,
items: Iterable[FrameT],
*,
how: Literal["vertical", "horizontal", "diagonal"],
) -> PolarsDataFrame | PolarsLazyFrame:
result = pl.concat((item.native for item in items), how=how)
if isinstance(result, pl.DataFrame):
return self._dataframe(result, version=self._version)
return self._lazyframe.from_native(result, context=self)
def lit(self, value: Any, dtype: IntoDType | None) -> PolarsExpr:
if dtype is not None:
return self._expr(
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)),
version=self._version,
)
return self._expr(pl.lit(value), version=self._version)
def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr:
if self._backend_version < (0, 20, 8):
return self._expr(
pl.sum_horizontal(e._native_expr for e in exprs)
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs),
version=self._version,
)
return self._expr(
pl.mean_horizontal(e._native_expr for e in exprs), version=self._version
)
def concat_str(
self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool
) -> PolarsExpr:
pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs]
if self._backend_version < (0, 20, 6):
null_mask = [expr.is_null() for expr in pl_exprs]
sep = pl.lit(separator)
if not ignore_nulls:
null_mask_result = pl.any_horizontal(*null_mask)
output_expr = pl.reduce(
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
pl_exprs,
)
result = pl.when(~null_mask_result).then(output_expr)
else:
init_value, *values = [
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
for expr, nm in zip_strict(pl_exprs, null_mask)
]
separators = [
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
]
result = pl.fold( # type: ignore[assignment]
acc=init_value,
function=operator.add,
exprs=[s + v for s, v in zip_strict(separators, values)],
)
return self._expr(result, version=self._version)
return self._expr(
pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls),
version=self._version,
)
# NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`)
# 1. Others have lots of private stuff for code reuse
# i. None of that is useful here
# 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr`
@property
def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]:
return cast(
"CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]",
PolarsSelectorNamespace(self),
)
class PolarsSelectorNamespace:
_implementation = Implementation.POLARS
def __init__(self, context: _LimitedContext, /) -> None:
self._version = context._version
def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
native_dtypes = [
narwhals_to_native_dtype(dtype, self._version).__class__
if isinstance(dtype, type) and issubclass(dtype, DType)
else narwhals_to_native_dtype(dtype, self._version)
for dtype in dtypes
]
return PolarsExpr(pl.selectors.by_dtype(native_dtypes), version=self._version)
def matches(self, pattern: str) -> PolarsExpr:
return PolarsExpr(pl.selectors.matches(pattern=pattern), version=self._version)
def numeric(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.numeric(), version=self._version)
def boolean(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.boolean(), version=self._version)
def string(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.string(), version=self._version)
def categorical(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.categorical(), version=self._version)
def all(self) -> PolarsExpr:
return PolarsExpr(pl.selectors.all(), version=self._version)
def datetime(
self,
time_unit: TimeUnit | Iterable[TimeUnit] | None,
time_zone: str | timezone | Iterable[str | timezone | None] | None,
) -> PolarsExpr:
return PolarsExpr(
pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type]
version=self._version,
)

View File

@ -0,0 +1,795 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
import polars as pl
from narwhals._polars.utils import (
BACKEND_VERSION,
SERIES_ACCEPTS_PD_INDEX,
SERIES_RESPECTS_DTYPE,
PolarsAnyNamespace,
PolarsCatNamespace,
PolarsDateTimeNamespace,
PolarsListNamespace,
PolarsStringNamespace,
PolarsStructNamespace,
catch_polars_exception,
extract_args_kwargs,
extract_native,
narwhals_to_native_dtype,
native_to_narwhals_dtype,
)
from narwhals._utils import Implementation, requires
from narwhals.dependencies import is_numpy_array_1d, is_pandas_index
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from types import ModuleType
from typing import Literal, TypeVar
import pandas as pd
import pyarrow as pa
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._polars.dataframe import Method, PolarsDataFrame
from narwhals._polars.namespace import PolarsNamespace
from narwhals._utils import Version, _LimitedContext
from narwhals.dtypes import DType
from narwhals.series import Series
from narwhals.typing import (
Into1DArray,
IntoDType,
ModeKeepStrategy,
MultiIndexSelector,
NonNestedLiteral,
NumericLiteral,
_1DArray,
)
T = TypeVar("T")
IncludeBreakpoint: TypeAlias = Literal[False, True]
Incomplete: TypeAlias = Any
# Series methods where PolarsSeries just defers to Polars.Series directly.
INHERITED_METHODS = frozenset(
[
"__add__",
"__and__",
"__floordiv__",
"__invert__",
"__iter__",
"__mod__",
"__mul__",
"__or__",
"__pow__",
"__radd__",
"__rand__",
"__rfloordiv__",
"__rmod__",
"__rmul__",
"__ror__",
"__rsub__",
"__rtruediv__",
"__sub__",
"__truediv__",
"abs",
"all",
"any",
"arg_max",
"arg_min",
"arg_true",
"clip",
"count",
"cum_max",
"cum_min",
"cum_prod",
"cum_sum",
"diff",
"drop_nulls",
"exp",
"fill_null",
"fill_nan",
"filter",
"gather_every",
"head",
"is_between",
"is_close",
"is_duplicated",
"is_empty",
"is_finite",
"is_first_distinct",
"is_in",
"is_last_distinct",
"is_null",
"is_sorted",
"is_unique",
"item",
"kurtosis",
"len",
"log",
"max",
"mean",
"min",
"mode",
"n_unique",
"null_count",
"quantile",
"rank",
"round",
"sample",
"shift",
"skew",
"sqrt",
"std",
"sum",
"tail",
"to_arrow",
"to_frame",
"to_list",
"to_pandas",
"unique",
"var",
"zip_with",
]
)
class PolarsSeries:
_implementation: Implementation = Implementation.POLARS
_native_series: pl.Series
_version: Version
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
True: ["breakpoint", "count"],
False: ["count"],
}
def __init__(self, series: pl.Series, *, version: Version) -> None:
self._native_series = series
self._version = version
@property
def _backend_version(self) -> tuple[int, ...]:
return self._implementation._backend_version()
def __repr__(self) -> str: # pragma: no cover
return "PolarsSeries"
def __narwhals_namespace__(self) -> PolarsNamespace:
from narwhals._polars.namespace import PolarsNamespace
return PolarsNamespace(version=self._version)
def __narwhals_series__(self) -> Self:
return self
def __native_namespace__(self) -> ModuleType:
if self._implementation is Implementation.POLARS:
return self._implementation.to_native_namespace()
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
raise AssertionError(msg)
def _with_version(self, version: Version) -> Self:
return self.__class__(self.native, version=version)
@classmethod
def from_iterable(
cls,
data: Iterable[Any],
*,
context: _LimitedContext,
name: str = "",
dtype: IntoDType | None = None,
) -> Self:
version = context._version
dtype_pl = narwhals_to_native_dtype(dtype, version) if dtype else None
values: Incomplete = data
if SERIES_RESPECTS_DTYPE:
native = pl.Series(name, values, dtype=dtype_pl)
else: # pragma: no cover
if (not SERIES_ACCEPTS_PD_INDEX) and is_pandas_index(values):
values = values.to_series()
native = pl.Series(name, values)
if dtype_pl:
native = native.cast(dtype_pl)
return cls.from_native(native, context=context)
@staticmethod
def _is_native(obj: pl.Series | Any) -> TypeIs[pl.Series]:
return isinstance(obj, pl.Series)
@classmethod
def from_native(cls, data: pl.Series, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version)
@classmethod
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self:
native = pl.Series(data if is_numpy_array_1d(data) else [data])
return cls.from_native(native, context=context)
def to_narwhals(self) -> Series[pl.Series]:
return self._version.series(self, level="full")
def _with_native(self, series: pl.Series) -> Self:
return self.__class__(series, version=self._version)
@overload
def _from_native_object(self, series: pl.Series) -> Self: ...
@overload
def _from_native_object(self, series: pl.DataFrame) -> PolarsDataFrame: ...
@overload
def _from_native_object(self, series: T) -> T: ...
def _from_native_object(
self, series: pl.Series | pl.DataFrame | T
) -> Self | PolarsDataFrame | T:
if self._is_native(series):
return self._with_native(series)
if isinstance(series, pl.DataFrame):
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame.from_native(series, context=self)
# scalar
return series
def __getattr__(self, attr: str) -> Any:
if attr not in INHERITED_METHODS:
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
raise AttributeError(msg)
def func(*args: Any, **kwargs: Any) -> Any:
pos, kwds = extract_args_kwargs(args, kwargs)
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
return func
def __len__(self) -> int:
return len(self.native)
@property
def name(self) -> str:
return self.native.name
@property
def dtype(self) -> DType:
return native_to_narwhals_dtype(self.native.dtype, self._version)
@property
def native(self) -> pl.Series:
return self._native_series
def alias(self, name: str) -> Self:
return self._from_native_object(self.native.alias(name))
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
if isinstance(item, PolarsSeries):
return self._from_native_object(self.native.__getitem__(item.native))
return self._from_native_object(self.native.__getitem__(item))
def cast(self, dtype: IntoDType) -> Self:
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
return self._with_native(self.native.cast(dtype_pl))
@requires.backend_version((1,))
def replace_strict(
self,
old: Sequence[Any] | Mapping[Any, Any],
new: Sequence[Any],
*,
return_dtype: IntoDType | None,
) -> Self:
ser = self.native
dtype = (
narwhals_to_native_dtype(return_dtype, self._version)
if return_dtype
else None
)
return self._with_native(ser.replace_strict(old, new, return_dtype=dtype))
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
return self.__array__(dtype, copy=copy)
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray:
if self._backend_version < (0, 20, 29):
return self.native.__array__(dtype=dtype)
return self.native.__array__(dtype=dtype, copy=copy)
def __eq__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__eq__(extract_native(other)))
def __ne__(self, other: object) -> Self: # type: ignore[override]
return self._with_native(self.native.__ne__(extract_native(other)))
# NOTE: These need to be anything that can't match `PolarsExpr`, due to overload order
def __ge__(self, other: Self) -> Self:
return self._with_native(self.native.__ge__(extract_native(other)))
def __gt__(self, other: Self) -> Self:
return self._with_native(self.native.__gt__(extract_native(other)))
def __le__(self, other: Self) -> Self:
return self._with_native(self.native.__le__(extract_native(other)))
def __lt__(self, other: Self) -> Self:
return self._with_native(self.native.__lt__(extract_native(other)))
def __rpow__(self, other: PolarsSeries | Any) -> Self:
result = self.native.__rpow__(extract_native(other))
if self._backend_version < (1, 16, 1):
# Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071
result = result.alias(self.name)
return self._with_native(result)
def is_nan(self) -> Self:
try:
native_is_nan = self.native.is_nan()
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
if self._backend_version < (1, 18): # pragma: no cover
select = pl.when(self.native.is_not_null()).then(native_is_nan)
return self._with_native(pl.select(select)[self.name])
return self._with_native(native_is_nan)
def median(self) -> Any:
from narwhals.exceptions import InvalidOperationError
if not self.dtype.is_numeric():
msg = "`median` operation not supported for non-numeric input type."
raise InvalidOperationError(msg)
return self.native.median()
def to_dummies(self, *, separator: str, drop_first: bool) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame
if self._backend_version < (0, 20, 15):
has_nulls = self.native.is_null().any()
result = self.native.to_dummies(separator=separator)
output_columns = result.columns
if drop_first:
_ = output_columns.pop(int(has_nulls))
result = result.select(output_columns)
else:
result = self.native.to_dummies(separator=separator, drop_first=drop_first)
result = result.with_columns(pl.all().cast(pl.Int8))
return PolarsDataFrame.from_native(result, context=self)
def ewm_mean(
self,
*,
com: float | None,
span: float | None,
half_life: float | None,
alpha: float | None,
adjust: bool,
min_samples: int,
ignore_nulls: bool,
) -> Self:
extra_kwargs = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
native_result = self.native.ewm_mean(
com=com,
span=span,
half_life=half_life,
alpha=alpha,
adjust=adjust,
ignore_nulls=ignore_nulls,
**extra_kwargs,
)
if self._backend_version < (1,): # pragma: no cover
return self._with_native(
pl.select(
pl.when(~self.native.is_null()).then(native_result).otherwise(None)
)[self.native.name]
)
return self._with_native(native_result)
@requires.backend_version((1,))
def rolling_var(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_var(
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
)
)
@requires.backend_version((1,))
def rolling_std(
self, window_size: int, *, min_samples: int, center: bool, ddof: int
) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_std(
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
)
)
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_sum(
window_size=window_size, center=center, **extra_kwargs
)
)
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
extra_kwargs: dict[str, Any] = (
{"min_periods": min_samples}
if self._backend_version < (1, 21, 0)
else {"min_samples": min_samples}
)
return self._with_native(
self.native.rolling_mean(
window_size=window_size, center=center, **extra_kwargs
)
)
def sort(self, *, descending: bool, nulls_last: bool) -> Self:
if self._backend_version < (0, 20, 6):
result = self.native.sort(descending=descending)
if nulls_last:
is_null = result.is_null()
result = pl.concat([result.filter(~is_null), result.filter(is_null)])
else:
result = self.native.sort(descending=descending, nulls_last=nulls_last)
return self._with_native(result)
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
s = self.native.clone().scatter(indices, extract_native(values))
return self._with_native(s)
def value_counts(
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
) -> PolarsDataFrame:
from narwhals._polars.dataframe import PolarsDataFrame
if self._backend_version < (1, 0, 0):
value_name_ = name or ("proportion" if normalize else "count")
result = self.native.value_counts(sort=sort, parallel=parallel).select(
**{
(self.native.name): pl.col(self.native.name),
value_name_: pl.col("count") / pl.sum("count")
if normalize
else pl.col("count"),
}
)
else:
result = self.native.value_counts(
sort=sort, parallel=parallel, name=name, normalize=normalize
)
return PolarsDataFrame.from_native(result, context=self)
def cum_count(self, *, reverse: bool) -> Self:
return self._with_native(self.native.cum_count(reverse=reverse))
def __contains__(self, other: Any) -> bool:
try:
return self.native.__contains__(other)
except Exception as e: # noqa: BLE001
raise catch_polars_exception(e) from None
def is_close(
self,
other: Self | NumericLiteral,
*,
abs_tol: float,
rel_tol: float,
nans_equal: bool,
) -> PolarsSeries:
if self._backend_version < (1, 32, 0):
name = self.name
ns = self.__narwhals_namespace__()
other_expr = (
ns.lit(other.native, None) if isinstance(other, PolarsSeries) else other
)
expr = ns.col(name).is_close(
other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self.to_frame().select(expr).get_column(name)
other_series = other.native if isinstance(other, PolarsSeries) else other
result = self.native.is_close(
other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
)
return self._with_native(result)
def mode(self, *, keep: ModeKeepStrategy) -> Self:
result = self.native.mode()
return self._with_native(result.head(1) if keep == "any" else result)
def hist_from_bins(
self, bins: list[float], *, include_breakpoint: bool
) -> PolarsDataFrame:
if len(bins) <= 1:
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
elif self.native.is_empty():
if include_breakpoint:
native = (
pl.Series(bins[1:])
.to_frame("breakpoint")
.with_columns(count=pl.lit(0, pl.Int64))
)
else:
native = pl.select(count=pl.zeros(len(bins) - 1, pl.Int64))
else:
return self._hist_from_data(
bins=bins, bin_count=None, include_breakpoint=include_breakpoint
)
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
def hist_from_bin_count(
self, bin_count: int, *, include_breakpoint: bool
) -> PolarsDataFrame:
if bin_count == 0:
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
elif self.native.is_empty():
if include_breakpoint:
native = pl.select(
breakpoint=pl.int_range(1, bin_count + 1) / bin_count,
count=pl.lit(0, pl.Int64),
)
else:
native = pl.select(count=pl.zeros(bin_count, pl.Int64))
else:
count: int | None
if BACKEND_VERSION < (1, 15): # pragma: no cover
count = None
bins = self._bins_from_bin_count(bin_count=bin_count)
else:
count = bin_count
bins = None
return self._hist_from_data(
bins=bins, # type: ignore[arg-type]
bin_count=count,
include_breakpoint=include_breakpoint,
)
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
def _bins_from_bin_count(self, bin_count: int) -> pl.Series: # pragma: no cover
"""Prepare bins based on backend version compatibility.
polars <1.15 does not adjust the bins when they have equivalent min/max
polars <1.5 with bin_count=...
returns bins that range from -inf to +inf and has bin_count + 1 bins.
for compat: convert `bin_count=` call to `bins=`
"""
lower = cast("float", self.native.min())
upper = cast("float", self.native.max())
if lower == upper:
lower -= 0.5
upper += 0.5
width = (upper - lower) / bin_count
return pl.int_range(0, bin_count + 1, eager=True) * width + lower
def _hist_from_data(
self, bins: list[float] | None, bin_count: int | None, *, include_breakpoint: bool
) -> PolarsDataFrame:
"""Calculate histogram from non-empty data and post-process the results based on the backend version."""
from narwhals._polars.dataframe import PolarsDataFrame
series = self.native
# Polars inconsistently handles NaN values when computing histograms
# against predefined bins: https://github.com/pola-rs/polars/issues/21082
if BACKEND_VERSION < (1, 15) or bins is not None:
series = series.fill_nan(None)
df = series.hist(
bins,
bin_count=bin_count,
include_category=False,
include_breakpoint=include_breakpoint,
)
# Apply post-processing corrections
# Handle column naming
if not include_breakpoint:
col_name = df.columns[0]
df = df.select(pl.col(col_name).alias("count"))
elif BACKEND_VERSION < (1, 0): # pragma: no cover
df = df.rename({"break_point": "breakpoint"})
if bins is not None: # pragma: no cover
# polars<1.6 implicitly adds -inf and inf to either end of bins
if BACKEND_VERSION < (1, 6):
r = pl.int_range(0, len(df))
df = df.filter((r > 0) & (r < len(df) - 1))
# polars<1.27 makes the lowest bin a left/right closed interval
if BACKEND_VERSION < (1, 27):
df = (
df.slice(0, 1)
.with_columns(pl.col("count") + ((pl.lit(series) == bins[0]).sum()))
.vstack(df.slice(1))
)
return PolarsDataFrame.from_native(df, context=self)
def to_polars(self) -> pl.Series:
return self.native
@property
def dt(self) -> PolarsSeriesDateTimeNamespace:
return PolarsSeriesDateTimeNamespace(self)
@property
def str(self) -> PolarsSeriesStringNamespace:
return PolarsSeriesStringNamespace(self)
@property
def cat(self) -> PolarsSeriesCatNamespace:
return PolarsSeriesCatNamespace(self)
@property
def struct(self) -> PolarsSeriesStructNamespace:
return PolarsSeriesStructNamespace(self)
__add__: Method[Self]
__and__: Method[Self]
__floordiv__: Method[Self]
__invert__: Method[Self]
__iter__: Method[Iterator[Any]]
__mod__: Method[Self]
__mul__: Method[Self]
__or__: Method[Self]
__pow__: Method[Self]
__radd__: Method[Self]
__rand__: Method[Self]
__rfloordiv__: Method[Self]
__rmod__: Method[Self]
__rmul__: Method[Self]
__ror__: Method[Self]
__rsub__: Method[Self]
__rtruediv__: Method[Self]
__sub__: Method[Self]
__truediv__: Method[Self]
abs: Method[Self]
all: Method[bool]
any: Method[bool]
arg_max: Method[int]
arg_min: Method[int]
arg_true: Method[Self]
clip: Method[Self]
count: Method[int]
cum_max: Method[Self]
cum_min: Method[Self]
cum_prod: Method[Self]
cum_sum: Method[Self]
diff: Method[Self]
drop_nulls: Method[Self]
exp: Method[Self]
fill_null: Method[Self]
fill_nan: Method[Self]
filter: Method[Self]
gather_every: Method[Self]
head: Method[Self]
is_between: Method[Self]
is_duplicated: Method[Self]
is_empty: Method[bool]
is_finite: Method[Self]
is_first_distinct: Method[Self]
is_in: Method[Self]
is_last_distinct: Method[Self]
is_null: Method[Self]
is_sorted: Method[bool]
is_unique: Method[Self]
item: Method[Any]
kurtosis: Method[float | None]
len: Method[int]
log: Method[Self]
max: Method[Any]
mean: Method[float]
min: Method[Any]
n_unique: Method[int]
null_count: Method[int]
quantile: Method[float]
rank: Method[Self]
round: Method[Self]
sample: Method[Self]
shift: Method[Self]
skew: Method[float | None]
sqrt: Method[Self]
std: Method[float]
sum: Method[float]
tail: Method[Self]
to_arrow: Method[pa.Array[Any]]
to_frame: Method[PolarsDataFrame]
to_list: Method[list[Any]]
to_pandas: Method[pd.Series[Any]]
unique: Method[Self]
var: Method[float]
zip_with: Method[Self]
@property
def list(self) -> PolarsSeriesListNamespace:
return PolarsSeriesListNamespace(self)
class PolarsSeriesNamespace(PolarsAnyNamespace[PolarsSeries, pl.Series]):
def __init__(self, series: PolarsSeries) -> None:
self._series = series
@property
def compliant(self) -> PolarsSeries:
return self._series
@property
def native(self) -> pl.Series:
return self._series.native
@property
def name(self) -> str:
return self.compliant.name
def __narwhals_namespace__(self) -> PolarsNamespace:
return self.compliant.__narwhals_namespace__()
def to_frame(self) -> PolarsDataFrame:
return self.compliant.to_frame()
class PolarsSeriesDateTimeNamespace(
PolarsSeriesNamespace, PolarsDateTimeNamespace[PolarsSeries, pl.Series]
): ...
class PolarsSeriesStringNamespace(
PolarsSeriesNamespace, PolarsStringNamespace[PolarsSeries, pl.Series]
):
def zfill(self, width: int) -> PolarsSeries:
name = self.name
ns = self.__narwhals_namespace__()
return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name)
class PolarsSeriesCatNamespace(
PolarsSeriesNamespace, PolarsCatNamespace[PolarsSeries, pl.Series]
): ...
class PolarsSeriesListNamespace(
PolarsSeriesNamespace, PolarsListNamespace[PolarsSeries, pl.Series]
):
def len(self) -> PolarsSeries:
name = self.name
ns = self.__narwhals_namespace__()
return self.to_frame().select(ns.col(name).list.len()).get_column(name)
def contains(self, item: NonNestedLiteral) -> PolarsSeries:
name = self.name
ns = self.__narwhals_namespace__()
return self.to_frame().select(ns.col(name).list.contains(item)).get_column(name)
class PolarsSeriesStructNamespace(
PolarsSeriesNamespace, PolarsStructNamespace[PolarsSeries, pl.Series]
): ...

View File

@ -0,0 +1,25 @@
from __future__ import annotations # pragma: no cover
from typing import (
TYPE_CHECKING, # pragma: no cover
Union, # pragma: no cover
)
if TYPE_CHECKING:
import sys
from typing import Literal, TypeVar
if sys.version_info >= (3, 10):
from typing import TypeAlias
else:
from typing_extensions import TypeAlias
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
IntoPolarsExpr: TypeAlias = Union[PolarsExpr, PolarsSeries]
FrameT = TypeVar("FrameT", PolarsDataFrame, PolarsLazyFrame)
NativeAccessor: TypeAlias = Literal[
"arr", "cat", "dt", "list", "meta", "name", "str", "bin", "struct"
]

View File

@ -0,0 +1,351 @@
from __future__ import annotations
import abc
from functools import lru_cache
from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, TypeVar, overload
import polars as pl
from narwhals._duration import Interval
from narwhals._utils import (
Implementation,
Version,
_DeferredIterable,
_StoresCompliant,
_StoresNative,
deep_getattr,
isinstance_or_issubclass,
)
from narwhals.exceptions import (
ColumnNotFoundError,
ComputeError,
DuplicateError,
InvalidOperationError,
NarwhalsError,
ShapeError,
)
if TYPE_CHECKING:
from collections.abc import Callable, Iterable, Iterator, Mapping
from typing_extensions import TypeIs
from narwhals._polars.dataframe import Method
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
from narwhals._polars.typing import NativeAccessor
from narwhals.dtypes import DType
from narwhals.typing import IntoDType
T = TypeVar("T")
NativeT = TypeVar(
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
)
NativeT_co = TypeVar("NativeT_co", "pl.Series", "pl.Expr", covariant=True)
CompliantT_co = TypeVar("CompliantT_co", "PolarsSeries", "PolarsExpr", covariant=True)
CompliantT = TypeVar("CompliantT", "PolarsSeries", "PolarsExpr")
BACKEND_VERSION = Implementation.POLARS._backend_version()
"""Static backend version for `polars`."""
SERIES_RESPECTS_DTYPE: Final[bool] = BACKEND_VERSION >= (0, 20, 26)
"""`pl.Series(dtype=...)` fixed in https://github.com/pola-rs/polars/pull/15962
Includes `SERIES_ACCEPTS_PD_INDEX`.
"""
SERIES_ACCEPTS_PD_INDEX: Final[bool] = BACKEND_VERSION >= (0, 20, 7)
"""`pl.Series(values: pd.Index)` fixed in https://github.com/pola-rs/polars/pull/14087"""
@overload
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
@overload
def extract_native(obj: T) -> T: ...
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
return obj.native if _is_compliant_polars(obj) else obj
def _is_compliant_polars(
obj: _StoresNative[NativeT] | Any,
) -> TypeIs[_StoresNative[NativeT]]:
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
from narwhals._polars.expr import PolarsExpr
from narwhals._polars.series import PolarsSeries
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
def extract_args_kwargs(
args: Iterable[Any], kwds: Mapping[str, Any], /
) -> tuple[Iterator[Any], dict[str, Any]]:
it_args = (extract_native(arg) for arg in args)
return it_args, {k: extract_native(v) for k, v in kwds.items()}
@lru_cache(maxsize=16)
def native_to_narwhals_dtype( # noqa: C901, PLR0912
dtype: pl.DataType, version: Version
) -> DType:
dtypes = version.dtypes
if dtype == pl.Float64:
return dtypes.Float64()
if dtype == pl.Float32:
return dtypes.Float32()
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.Int128()
if dtype == pl.Int64:
return dtypes.Int64()
if dtype == pl.Int32:
return dtypes.Int32()
if dtype == pl.Int16:
return dtypes.Int16()
if dtype == pl.Int8:
return dtypes.Int8()
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
# Not available for Polars pre 1.8.0
return dtypes.UInt128()
if dtype == pl.UInt64:
return dtypes.UInt64()
if dtype == pl.UInt32:
return dtypes.UInt32()
if dtype == pl.UInt16:
return dtypes.UInt16()
if dtype == pl.UInt8:
return dtypes.UInt8()
if dtype == pl.String:
return dtypes.String()
if dtype == pl.Boolean:
return dtypes.Boolean()
if dtype == pl.Object:
return dtypes.Object()
if dtype == pl.Categorical:
return dtypes.Categorical()
if isinstance_or_issubclass(dtype, pl.Enum):
if version is Version.V1:
return dtypes.Enum() # type: ignore[call-arg]
categories = _DeferredIterable(dtype.categories.to_list)
return dtypes.Enum(categories)
if dtype == pl.Date:
return dtypes.Date()
if isinstance_or_issubclass(dtype, pl.Datetime):
return (
dtypes.Datetime()
if dtype is pl.Datetime
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
)
if isinstance_or_issubclass(dtype, pl.Duration):
return (
dtypes.Duration()
if dtype is pl.Duration
else dtypes.Duration(dtype.time_unit)
)
if isinstance_or_issubclass(dtype, pl.Struct):
fields = [
dtypes.Field(name, native_to_narwhals_dtype(tp, version))
for name, tp in dtype
]
return dtypes.Struct(fields)
if isinstance_or_issubclass(dtype, pl.List):
return dtypes.List(native_to_narwhals_dtype(dtype.inner, version))
if isinstance_or_issubclass(dtype, pl.Array):
outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape)
if dtype == pl.Decimal:
return dtypes.Decimal()
if dtype == pl.Time:
return dtypes.Time()
if dtype == pl.Binary:
return dtypes.Binary()
return dtypes.Unknown()
dtypes = Version.MAIN.dtypes
NW_TO_PL_DTYPES: Mapping[type[DType], pl.DataType] = {
dtypes.Float64: pl.Float64(),
dtypes.Float32: pl.Float32(),
dtypes.Binary: pl.Binary(),
dtypes.String: pl.String(),
dtypes.Boolean: pl.Boolean(),
dtypes.Categorical: pl.Categorical(),
dtypes.Date: pl.Date(),
dtypes.Time: pl.Time(),
dtypes.Int8: pl.Int8(),
dtypes.Int16: pl.Int16(),
dtypes.Int32: pl.Int32(),
dtypes.Int64: pl.Int64(),
dtypes.UInt8: pl.UInt8(),
dtypes.UInt16: pl.UInt16(),
dtypes.UInt32: pl.UInt32(),
dtypes.UInt64: pl.UInt64(),
dtypes.Object: pl.Object(),
dtypes.Unknown: pl.Unknown(),
}
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
def narwhals_to_native_dtype( # noqa: C901
dtype: IntoDType, version: Version
) -> pl.DataType:
dtypes = version.dtypes
base_type = dtype.base_type()
if pl_type := NW_TO_PL_DTYPES.get(base_type):
return pl_type
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
# Not available for Polars pre 1.8.0
return pl.Int128()
if isinstance_or_issubclass(dtype, dtypes.Enum):
if version is Version.V1:
msg = "Converting to Enum is not supported in narwhals.stable.v1"
raise NotImplementedError(msg)
if isinstance(dtype, dtypes.Enum):
return pl.Enum(dtype.categories)
msg = "Can not cast / initialize Enum without categories present"
raise ValueError(msg)
if isinstance_or_issubclass(dtype, dtypes.Datetime):
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.Duration):
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
if isinstance_or_issubclass(dtype, dtypes.List):
return pl.List(narwhals_to_native_dtype(dtype.inner, version))
if isinstance_or_issubclass(dtype, dtypes.Struct):
fields = [
pl.Field(field.name, narwhals_to_native_dtype(field.dtype, version))
for field in dtype.fields
]
return pl.Struct(fields)
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
size = dtype.size
kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size}
return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs)
if issubclass(base_type, UNSUPPORTED_DTYPES):
msg = f"Converting to {base_type.__name__} dtype is not supported for Polars."
raise NotImplementedError(msg)
return pl.Unknown() # pragma: no cover
def _is_polars_exception(exception: Exception) -> bool:
if BACKEND_VERSION >= (1,):
# Old versions of Polars didn't have PolarsError.
return isinstance(exception, pl.exceptions.PolarsError)
# Last attempt, for old Polars versions.
return "polars.exceptions" in str(type(exception)) # pragma: no cover
def _is_cudf_exception(exception: Exception) -> bool:
# These exceptions are raised when running polars on GPUs via cuDF
return str(exception).startswith("CUDF failure")
def catch_polars_exception(exception: Exception) -> NarwhalsError | Exception:
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
return ColumnNotFoundError(str(exception))
if isinstance(exception, pl.exceptions.ShapeError):
return ShapeError(str(exception))
if isinstance(exception, pl.exceptions.InvalidOperationError):
return InvalidOperationError(str(exception))
if isinstance(exception, pl.exceptions.DuplicateError):
return DuplicateError(str(exception))
if isinstance(exception, pl.exceptions.ComputeError):
return ComputeError(str(exception))
if _is_polars_exception(exception) or _is_cudf_exception(exception):
return NarwhalsError(str(exception)) # pragma: no cover
# Just return exception as-is.
return exception
class PolarsAnyNamespace(
_StoresCompliant[CompliantT_co],
_StoresNative[NativeT_co],
Protocol[CompliantT_co, NativeT_co],
):
_accessor: ClassVar[NativeAccessor]
def __getattr__(self, attr: str) -> Callable[..., CompliantT_co]:
def func(*args: Any, **kwargs: Any) -> CompliantT_co:
pos, kwds = extract_args_kwargs(args, kwargs)
method = deep_getattr(self.native, self._accessor, attr)
return self.compliant._with_native(method(*pos, **kwds))
return func
class PolarsDateTimeNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "dt"
def truncate(self, every: str) -> CompliantT:
# Ensure consistent error message is raised.
Interval.parse(every)
return self.__getattr__("truncate")(every)
def offset_by(self, by: str) -> CompliantT:
# Ensure consistent error message is raised.
Interval.parse_no_constraints(by)
return self.__getattr__("offset_by")(by)
to_string: Method[CompliantT]
replace_time_zone: Method[CompliantT]
convert_time_zone: Method[CompliantT]
timestamp: Method[CompliantT]
date: Method[CompliantT]
year: Method[CompliantT]
month: Method[CompliantT]
day: Method[CompliantT]
hour: Method[CompliantT]
minute: Method[CompliantT]
second: Method[CompliantT]
millisecond: Method[CompliantT]
microsecond: Method[CompliantT]
nanosecond: Method[CompliantT]
ordinal_day: Method[CompliantT]
weekday: Method[CompliantT]
total_minutes: Method[CompliantT]
total_seconds: Method[CompliantT]
total_milliseconds: Method[CompliantT]
total_microseconds: Method[CompliantT]
total_nanoseconds: Method[CompliantT]
class PolarsStringNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "str"
# NOTE: Use `abstractmethod` if we have defs to implement, but also `Method` usage
@abc.abstractmethod
def zfill(self, width: int) -> CompliantT: ...
len_chars: Method[CompliantT]
replace: Method[CompliantT]
replace_all: Method[CompliantT]
strip_chars: Method[CompliantT]
starts_with: Method[CompliantT]
ends_with: Method[CompliantT]
contains: Method[CompliantT]
slice: Method[CompliantT]
split: Method[CompliantT]
to_date: Method[CompliantT]
to_datetime: Method[CompliantT]
to_lowercase: Method[CompliantT]
to_uppercase: Method[CompliantT]
class PolarsCatNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "cat"
get_categories: Method[CompliantT]
class PolarsListNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "list"
@abc.abstractmethod
def len(self) -> CompliantT: ...
get: Method[CompliantT]
unique: Method[CompliantT]
class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
_accessor: ClassVar[NativeAccessor] = "struct"
field: Method[CompliantT]

View File

@ -0,0 +1 @@
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI

View File

@ -0,0 +1,601 @@
from __future__ import annotations
from functools import reduce
from operator import and_
from typing import TYPE_CHECKING, Any
from narwhals._exceptions import issue_warning
from narwhals._namespace import is_native_spark_like
from narwhals._spark_like.utils import (
catch_pyspark_connect_exception,
catch_pyspark_sql_exception,
evaluate_exprs,
import_functions,
import_native_dtypes,
import_window,
native_to_narwhals_dtype,
)
from narwhals._sql.dataframe import SQLLazyFrame
from narwhals._utils import (
Implementation,
ValidateBackendVersion,
generate_temporary_column_name,
not_implemented,
parse_columns_to_drop,
zip_strict,
)
from narwhals.exceptions import InvalidOperationError
if TYPE_CHECKING:
from collections.abc import Iterable, Iterator, Mapping, Sequence
from io import BytesIO
from pathlib import Path
from types import ModuleType
import pyarrow as pa
from sqlframe.base.column import Column
from sqlframe.base.dataframe import BaseDataFrame
from sqlframe.base.window import Window
from typing_extensions import Self, TypeAlias, TypeIs
from narwhals._compliant.typing import CompliantDataFrameAny
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._spark_like.utils import SparkSession
from narwhals._typing import _EagerAllowedImpl
from narwhals._utils import Version, _LimitedContext
from narwhals.dataframe import LazyFrame
from narwhals.dtypes import DType
from narwhals.typing import JoinStrategy, LazyUniqueKeepStrategy
SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any]
Incomplete: TypeAlias = Any # pragma: no cover
"""Marker for working code that fails type checking."""
class SparkLikeLazyFrame(
SQLLazyFrame["SparkLikeExpr", "SQLFrameDataFrame", "LazyFrame[SQLFrameDataFrame]"],
ValidateBackendVersion,
):
def __init__(
self,
native_dataframe: SQLFrameDataFrame,
*,
version: Version,
implementation: Implementation,
validate_backend_version: bool = False,
) -> None:
self._native_frame: SQLFrameDataFrame = native_dataframe
self._implementation = implementation
self._version = version
self._cached_schema: dict[str, DType] | None = None
self._cached_columns: list[str] | None = None
if validate_backend_version: # pragma: no cover
self._validate_backend_version()
@property
def _backend_version(self) -> tuple[int, ...]: # pragma: no cover
return self._implementation._backend_version()
@property
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import functions
return functions
return import_functions(self._implementation)
@property
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import types
return types
return import_native_dtypes(self._implementation)
@property
def _Window(self) -> type[Window]:
if TYPE_CHECKING:
from sqlframe.base.window import Window
return Window
return import_window(self._implementation)
@staticmethod
def _is_native(obj: SQLFrameDataFrame | Any) -> TypeIs[SQLFrameDataFrame]:
return is_native_spark_like(obj)
@classmethod
def from_native(cls, data: SQLFrameDataFrame, /, *, context: _LimitedContext) -> Self:
return cls(data, version=context._version, implementation=context._implementation)
def to_narwhals(self) -> LazyFrame[SQLFrameDataFrame]:
return self._version.lazyframe(self, level="lazy")
def __native_namespace__(self) -> ModuleType: # pragma: no cover
return self._implementation.to_native_namespace()
def __narwhals_namespace__(self) -> SparkLikeNamespace:
from narwhals._spark_like.namespace import SparkLikeNamespace
return SparkLikeNamespace(
version=self._version, implementation=self._implementation
)
def __narwhals_lazyframe__(self) -> Self:
return self
def _with_version(self, version: Version) -> Self:
return self.__class__(
self.native, version=version, implementation=self._implementation
)
def _with_native(self, df: SQLFrameDataFrame) -> Self:
return self.__class__(
df, version=self._version, implementation=self._implementation
)
def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover
import pyarrow as pa # ignore-banned-import
from narwhals._arrow.utils import narwhals_to_native_dtype
schema: list[tuple[str, pa.DataType]] = []
nw_schema = self.collect_schema()
native_schema = self.native.schema
for key, value in nw_schema.items():
try:
native_dtype = narwhals_to_native_dtype(value, self._version)
except Exception as exc: # noqa: BLE001,PERF203
native_spark_dtype = native_schema[key].dataType # type: ignore[index]
# If we can't convert the type, just set it to `pa.null`, and warn.
# Avoid the warning if we're starting from PySpark's void type.
# We can avoid the check when we introduce `nw.Null` dtype.
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
if not isinstance(native_spark_dtype, null_type):
issue_warning(
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
UserWarning,
)
schema.append((key, pa.null()))
else:
schema.append((key, native_dtype))
return pa.schema(schema)
def _collect_to_arrow(self) -> pa.Table:
if self._implementation.is_pyspark() and self._backend_version < (4,):
import pyarrow as pa # ignore-banned-import
try:
return pa.Table.from_batches(self.native._collect_as_arrow())
except ValueError as exc:
if "at least one RecordBatch" in str(exc):
# Empty dataframe
data: dict[str, list[Any]] = {k: [] for k in self.columns}
pa_schema = self._to_arrow_schema()
return pa.Table.from_pydict(data, schema=pa_schema)
raise # pragma: no cover
elif self._implementation.is_pyspark_connect() and self._backend_version < (4,):
import pyarrow as pa # ignore-banned-import
pa_schema = self._to_arrow_schema()
return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema)
else:
return self.native.toArrow()
def _iter_columns(self) -> Iterator[Column]:
for col in self.columns:
yield self._F.col(col)
@property
def columns(self) -> list[str]:
if self._cached_columns is None:
self._cached_columns = (
list(self.schema)
if self._cached_schema is not None
else self.native.columns
)
return self._cached_columns
def _collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
if backend is Implementation.PANDAS:
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
return PandasLikeDataFrame(
self.native.toPandas(),
implementation=Implementation.PANDAS,
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is None or backend is Implementation.PYARROW:
from narwhals._arrow.dataframe import ArrowDataFrame
return ArrowDataFrame(
self._collect_to_arrow(),
validate_backend_version=True,
version=self._version,
validate_column_names=True,
)
if backend is Implementation.POLARS:
import polars as pl # ignore-banned-import
from narwhals._polars.dataframe import PolarsDataFrame
return PolarsDataFrame(
pl.from_arrow(self._collect_to_arrow()), # type: ignore[arg-type]
validate_backend_version=True,
version=self._version,
)
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
raise ValueError(msg) # pragma: no cover
def collect(
self, backend: _EagerAllowedImpl | None, **kwargs: Any
) -> CompliantDataFrameAny:
if self._implementation.is_pyspark_connect():
try:
return self._collect(backend, **kwargs)
except Exception as e: # noqa: BLE001
raise catch_pyspark_connect_exception(e) from None
return self._collect(backend, **kwargs)
def simple_select(self, *column_names: str) -> Self:
return self._with_native(self.native.select(*column_names))
def aggregate(self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
new_columns_list = [col.alias(col_name) for col_name, col in new_columns]
if self._implementation.is_pyspark():
try:
return self._with_native(self.native.agg(*new_columns_list))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.agg(*new_columns_list))
def select(self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
if self._implementation.is_pyspark(): # pragma: no cover
try:
return self._with_native(self.native.select(*new_columns_list))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.select(*new_columns_list))
def with_columns(self, *exprs: SparkLikeExpr) -> Self:
new_columns = evaluate_exprs(self, *exprs)
if self._implementation.is_pyspark(): # pragma: no cover
try:
return self._with_native(self.native.withColumns(dict(new_columns)))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.withColumns(dict(new_columns)))
def filter(self, predicate: SparkLikeExpr) -> Self:
# `[0]` is safe as the predicate's expression only returns a single column
condition = predicate._call(self)[0]
if self._implementation.is_pyspark():
try:
return self._with_native(self.native.where(condition))
except Exception as e: # noqa: BLE001
raise catch_pyspark_sql_exception(e, self) from None
return self._with_native(self.native.where(condition))
@property
def schema(self) -> dict[str, DType]:
if self._cached_schema is None:
self._cached_schema = {
field.name: native_to_narwhals_dtype(
field.dataType,
self._version,
self._native_dtypes,
self.native.sparkSession,
)
for field in self.native.schema
}
return self._cached_schema
def collect_schema(self) -> dict[str, DType]:
return self.schema
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
return self._with_native(self.native.drop(*columns_to_drop))
def head(self, n: int) -> Self:
return self._with_native(self.native.limit(n))
def group_by(
self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool
) -> SparkLikeLazyGroupBy:
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
if isinstance(descending, bool):
descending = [descending] * len(by)
if nulls_last:
sort_funcs = (
self._F.desc_nulls_last if d else self._F.asc_nulls_last
for d in descending
)
else:
sort_funcs = (
self._F.desc_nulls_first if d else self._F.asc_nulls_first
for d in descending
)
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
return self._with_native(self.native.sort(*sort_cols))
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
by = list(by)
if isinstance(reverse, bool):
reverse = [reverse] * len(by)
sort_funcs = (
self._F.desc_nulls_last if not d else self._F.asc_nulls_last for d in reverse
)
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
return self._with_native(self.native.sort(*sort_cols).limit(k))
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
subset = list(subset) if subset else None
return self._with_native(self.native.dropna(subset=subset))
def rename(self, mapping: Mapping[str, str]) -> Self:
rename_mapping = {
colname: mapping.get(colname, colname) for colname in self.columns
}
return self._with_native(
self.native.select(
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
)
)
def unique(
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
) -> Self:
if subset and (error := self._check_columns_exist(subset)):
raise error
subset = list(subset) if subset else None
if keep == "none":
tmp = generate_temporary_column_name(8, self.columns)
window = self._Window.partitionBy(subset or self.columns)
df = (
self.native.withColumn(tmp, self._F.count("*").over(window))
.filter(self._F.col(tmp) == self._F.lit(1))
.drop(self._F.col(tmp))
)
return self._with_native(df)
return self._with_native(self.native.dropDuplicates(subset=subset))
def join(
self,
other: Self,
how: JoinStrategy,
left_on: Sequence[str] | None,
right_on: Sequence[str] | None,
suffix: str,
) -> Self:
left_columns = self.columns
right_columns = other.columns
right_on_: list[str] = list(right_on) if right_on is not None else []
left_on_: list[str] = list(left_on) if left_on is not None else []
# create a mapping for columns on other
# `right_on` columns will be renamed as `left_on`
# the remaining columns will be either added the suffix or left unchanged.
right_cols_to_rename = (
[c for c in right_columns if c not in right_on_]
if how != "full"
else right_columns
)
rename_mapping = {
**dict(zip(right_on_, left_on_)),
**{
colname: f"{colname}{suffix}" if colname in left_columns else colname
for colname in right_cols_to_rename
},
}
other_native = other.native.select(
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
)
# If how in {"semi", "anti"}, then resulting columns are same as left columns
# Otherwise, we add the right columns with the new mapping, while keeping the
# original order of right_columns.
col_order = left_columns.copy()
if how in {"inner", "left", "cross"}:
col_order.extend(
rename_mapping[colname]
for colname in right_columns
if colname not in right_on_
)
elif how == "full":
col_order.extend(rename_mapping.values())
right_on_remapped = [rename_mapping[c] for c in right_on_]
on_ = (
reduce(
and_,
(
getattr(self.native, left_key) == getattr(other_native, right_key)
for left_key, right_key in zip_strict(left_on_, right_on_remapped)
),
)
if how == "full"
else None
if how == "cross"
else left_on_
)
how_native = "full_outer" if how == "full" else how
return self._with_native(
self.native.join(other_native, on=on_, how=how_native).select(col_order)
)
def explode(self, columns: Sequence[str]) -> Self:
dtypes = self._version.dtypes
schema = self.collect_schema()
for col_to_explode in columns:
dtype = schema[col_to_explode]
if dtype != dtypes.List:
msg = (
f"`explode` operation not supported for dtype `{dtype}`, "
"expected List type"
)
raise InvalidOperationError(msg)
column_names = self.columns
if len(columns) != 1:
msg = (
"Exploding on multiple columns is not supported with SparkLike backend since "
"we cannot guarantee that the exploded columns have matching element counts."
)
raise NotImplementedError(msg)
if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect():
return self._with_native(
self.native.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.explode_outer(col_name).alias(col_name)
for col_name in column_names
]
)
)
if self._implementation.is_sqlframe():
# Not every sqlframe dialect supports `explode_outer` function
# (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289)
# therefore we simply explode the array column which will ignore nulls and
# zero sized arrays, and append these specific condition with nulls (to
# match polars behavior).
def null_condition(col_name: str) -> Column:
return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0)
return self._with_native(
self.native.select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.explode(col_name).alias(col_name)
for col_name in column_names
]
).union(
self.native.filter(null_condition(columns[0])).select(
*[
self._F.col(col_name).alias(col_name)
if col_name != columns[0]
else self._F.lit(None).alias(col_name)
for col_name in column_names
]
)
)
)
msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" # pragma: no cover
raise AssertionError(msg)
def unpivot(
self,
on: Sequence[str] | None,
index: Sequence[str] | None,
variable_name: str,
value_name: str,
) -> Self:
if self._implementation.is_sqlframe():
if variable_name == "":
msg = "`variable_name` cannot be empty string for sqlframe backend."
raise NotImplementedError(msg)
if value_name == "":
msg = "`value_name` cannot be empty string for sqlframe backend."
raise NotImplementedError(msg)
else: # pragma: no cover
pass
ids = tuple(index) if index else ()
values = (
tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on)
)
unpivoted_native_frame = self.native.unpivot(
ids=ids,
values=values,
variableColumnName=variable_name,
valueColumnName=value_name,
)
if index is None:
unpivoted_native_frame = unpivoted_native_frame.drop(*ids)
return self._with_native(unpivoted_native_frame)
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
if order_by is None:
msg = "Cannot pass `order_by` to `with_row_index` for PySpark-like"
raise TypeError(msg)
row_index_expr = (
self._F.row_number().over(
self._Window.partitionBy(self._F.lit(1)).orderBy(*order_by)
)
- 1
).alias(name)
return self._with_native(self.native.select(row_index_expr, *self.columns))
def sink_parquet(self, file: str | Path | BytesIO) -> None:
self.native.write.parquet(file)
@classmethod
def _from_compliant_dataframe(
cls,
frame: CompliantDataFrameAny,
/,
*,
session: SparkSession,
implementation: Implementation,
version: Version,
) -> SparkLikeLazyFrame:
from importlib.util import find_spec
impl = implementation
is_spark_v4 = (not impl.is_sqlframe()) and impl._backend_version() >= (4, 0, 0)
if is_spark_v4: # pragma: no cover
# pyspark.sql requires pyarrow to be installed from v4.0.0
# and since v4.0.0 the input to `createDataFrame` can be a PyArrow Table.
data: Any = frame.to_arrow()
elif find_spec("pandas"):
data = frame.to_pandas()
else: # pragma: no cover
data = tuple(frame.iter_rows(named=True, buffer_size=512))
return cls(
session.createDataFrame(data),
version=version,
implementation=implementation,
validate_backend_version=True,
)
gather_every = not_implemented.deprecated(
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
)
join_asof = not_implemented()
tail = not_implemented.deprecated(
"`LazyFrame.tail` is deprecated and will be removed in a future version."
)

View File

@ -0,0 +1,391 @@
from __future__ import annotations
import operator
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast
from narwhals._expression_parsing import ExprKind, ExprMetadata
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace
from narwhals._spark_like.utils import (
import_functions,
import_native_dtypes,
import_window,
narwhals_to_native_dtype,
true_divide,
)
from narwhals._sql.expr import SQLExpr
from narwhals._utils import Implementation, Version, not_implemented, zip_strict
if TYPE_CHECKING:
from collections.abc import Iterator, Mapping, Sequence
from sqlframe.base.column import Column
from sqlframe.base.window import Window, WindowSpec
from typing_extensions import Self, TypeAlias
from narwhals._compliant import WindowInputs
from narwhals._compliant.typing import (
AliasNames,
EvalNames,
EvalSeries,
WindowFunction,
)
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.namespace import SparkLikeNamespace
from narwhals._utils import _LimitedContext
from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod
NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"]
SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column]
SparkWindowInputs = WindowInputs[Column]
class SparkLikeExpr(SQLExpr["SparkLikeLazyFrame", "Column"]):
def __init__(
self,
call: EvalSeries[SparkLikeLazyFrame, Column],
window_function: SparkWindowFunction | None = None,
*,
evaluate_output_names: EvalNames[SparkLikeLazyFrame],
alias_output_names: AliasNames | None,
version: Version,
implementation: Implementation,
) -> None:
self._call = call
self._evaluate_output_names = evaluate_output_names
self._alias_output_names = alias_output_names
self._version = version
self._implementation = implementation
self._metadata: ExprMetadata | None = None
self._window_function: SparkWindowFunction | None = window_function
_REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = {
"min": "rank",
"max": "rank",
"average": "rank",
"dense": "dense_rank",
"ordinal": "row_number",
}
def _count_star(self) -> Column:
return self._F.count("*")
def _window_expression(
self,
expr: Column,
partition_by: Sequence[str | Column] = (),
order_by: Sequence[str | Column] = (),
rows_start: int | None = None,
rows_end: int | None = None,
*,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> Column:
window = self.partition_by(*partition_by)
if order_by:
window = window.orderBy(
*self._sort(*order_by, descending=descending, nulls_last=nulls_last)
)
if rows_start is not None and rows_end is not None:
window = window.rowsBetween(rows_start, rows_end)
elif rows_end is not None:
window = window.rowsBetween(self._Window.unboundedPreceding, rows_end)
elif rows_start is not None: # pragma: no cover
window = window.rowsBetween(rows_start, self._Window.unboundedFollowing)
return expr.over(window)
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
if kind is ExprKind.LITERAL:
return self
return self.over([self._F.lit(1)], [])
@property
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import functions
return functions
return import_functions(self._implementation)
@property
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
if TYPE_CHECKING:
from sqlframe.base import types
return types
return import_native_dtypes(self._implementation)
@property
def _Window(self) -> type[Window]:
if TYPE_CHECKING:
from sqlframe.base.window import Window
return Window
return import_window(self._implementation)
def _sort(
self,
*cols: Column | str,
descending: Sequence[bool] | None = None,
nulls_last: Sequence[bool] | None = None,
) -> Iterator[Column]:
F = self._F
descending = descending or [False] * len(cols)
nulls_last = nulls_last or [False] * len(cols)
mapping = {
(False, False): F.asc_nulls_first,
(False, True): F.asc_nulls_last,
(True, False): F.desc_nulls_first,
(True, True): F.desc_nulls_last,
}
yield from (
mapping[(_desc, _nulls_last)](col)
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
)
def partition_by(self, *cols: Column | str) -> WindowSpec:
"""Wraps `Window().partitionBy`, with default and `WindowInputs` handling."""
return self._Window.partitionBy(*cols or [self._F.lit(1)])
def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover
from narwhals._spark_like.namespace import SparkLikeNamespace
return SparkLikeNamespace(
version=self._version, implementation=self._implementation
)
@classmethod
def _alias_native(cls, expr: Column, name: str) -> Column:
return expr.alias(name)
@classmethod
def from_column_names(
cls: type[Self],
evaluate_column_names: EvalNames[SparkLikeLazyFrame],
/,
*,
context: _LimitedContext,
) -> Self:
def func(df: SparkLikeLazyFrame) -> list[Column]:
return [df._F.col(col_name) for col_name in evaluate_column_names(df)]
return cls(
func,
evaluate_output_names=evaluate_column_names,
alias_output_names=None,
version=context._version,
implementation=context._implementation,
)
@classmethod
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
def func(df: SparkLikeLazyFrame) -> list[Column]:
columns = df.columns
return [df._F.col(columns[i]) for i in column_indices]
return cls(
func,
evaluate_output_names=cls._eval_names_indices(column_indices),
alias_output_names=None,
version=context._version,
implementation=context._implementation,
)
def __truediv__(self, other: SparkLikeExpr) -> Self:
def _truediv(expr: Column, other: Column) -> Column:
return true_divide(self._F, expr, other)
return self._with_binary(_truediv, other)
def __rtruediv__(self, other: SparkLikeExpr) -> Self:
def _rtruediv(expr: Column, other: Column) -> Column:
return true_divide(self._F, other, expr)
return self._with_binary(_rtruediv, other).alias("literal")
def __floordiv__(self, other: SparkLikeExpr) -> Self:
def _floordiv(expr: Column, other: Column) -> Column:
return self._F.floor(true_divide(self._F, expr, other))
return self._with_binary(_floordiv, other)
def __rfloordiv__(self, other: SparkLikeExpr) -> Self:
def _rfloordiv(expr: Column, other: Column) -> Column:
return self._F.floor(true_divide(self._F, other, expr))
return self._with_binary(_rfloordiv, other).alias("literal")
def __invert__(self) -> Self:
invert = cast("Callable[..., Column]", operator.invert)
return self._with_elementwise(invert)
def cast(self, dtype: IntoDType) -> Self:
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self(df)]
def window_f(
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
) -> Sequence[Column]:
spark_dtype = narwhals_to_native_dtype(
dtype, self._version, self._native_dtypes, df.native.sparkSession
)
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]
return self.__class__(
func,
window_f,
evaluate_output_names=self._evaluate_output_names,
alias_output_names=self._alias_output_names,
version=self._version,
implementation=self._implementation,
)
def median(self) -> Self:
def _median(expr: Column) -> Column:
if self._implementation in {
Implementation.PYSPARK,
Implementation.PYSPARK_CONNECT,
} and Implementation.PYSPARK._backend_version() < (3, 4): # pragma: no cover
# Use percentile_approx with default accuracy parameter (10000)
return self._F.percentile_approx(expr.cast("double"), 0.5)
return self._F.median(expr)
return self._with_callable(_median)
def null_count(self) -> Self:
def _null_count(expr: Column) -> Column:
return self._F.count_if(self._F.isnull(expr))
return self._with_callable(_null_count)
def std(self, ddof: int) -> Self:
F = self._F
if ddof == 0:
return self._with_callable(F.stddev_pop)
if ddof == 1:
return self._with_callable(F.stddev_samp)
def func(expr: Column) -> Column:
n_rows = F.count(expr)
return F.stddev_samp(expr) * F.sqrt((n_rows - 1) / (n_rows - ddof))
return self._with_callable(func)
def var(self, ddof: int) -> Self:
F = self._F
if ddof == 0:
return self._with_callable(F.var_pop)
if ddof == 1:
return self._with_callable(F.var_samp)
def func(expr: Column) -> Column:
n_rows = F.count(expr)
return F.var_samp(expr) * (n_rows - 1) / (n_rows - ddof)
return self._with_callable(func)
def is_finite(self) -> Self:
def _is_finite(expr: Column) -> Column:
# A value is finite if it's not NaN, and not infinite, while NULLs should be
# preserved
is_finite_condition = (
~self._F.isnan(expr)
& (expr != self._F.lit(float("inf")))
& (expr != self._F.lit(float("-inf")))
)
return self._F.when(~self._F.isnull(expr), is_finite_condition).otherwise(
None
)
return self._with_elementwise(_is_finite)
def is_in(self, values: Sequence[Any]) -> Self:
def _is_in(expr: Column) -> Column:
return expr.isin(values) if values else self._F.lit(False)
return self._with_elementwise(_is_in)
def len(self) -> Self:
def _len(_expr: Column) -> Column:
# Use count(*) to count all rows including nulls
return self._F.count("*")
return self._with_callable(_len)
def skew(self) -> Self:
return self._with_callable(self._F.skewness)
def kurtosis(self) -> Self:
return self._with_callable(self._F.kurtosis)
def n_unique(self) -> Self:
def _n_unique(expr: Column) -> Column:
return self._F.count_distinct(expr) + self._F.max(
self._F.isnull(expr).cast(self._native_dtypes.IntegerType())
)
return self._with_callable(_n_unique)
def is_nan(self) -> Self:
def _is_nan(expr: Column) -> Column:
return self._F.when(self._F.isnull(expr), None).otherwise(self._F.isnan(expr))
return self._with_elementwise(_is_nan)
def fill_null(
self,
value: Self | NonNestedLiteral,
strategy: FillNullStrategy | None,
limit: int | None,
) -> Self:
if strategy is not None:
def _fill_with_strategy(
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
) -> Sequence[Column]:
fn = self._F.last_value if strategy == "forward" else self._F.first_value
if strategy == "forward":
start = self._Window.unboundedPreceding if limit is None else -limit
end = self._Window.currentRow
else:
start = self._Window.currentRow
end = self._Window.unboundedFollowing if limit is None else limit
return [
fn(expr, ignoreNulls=True).over(
self.partition_by(*inputs.partition_by)
.orderBy(*self._sort(*inputs.order_by))
.rowsBetween(start, end)
)
for expr in self(df)
]
return self._with_window_function(_fill_with_strategy)
def _fill_constant(expr: Column, value: Column) -> Column:
return self._F.ifnull(expr, value)
return self._with_elementwise(_fill_constant, value=value)
@property
def str(self) -> SparkLikeExprStringNamespace:
return SparkLikeExprStringNamespace(self)
@property
def dt(self) -> SparkLikeExprDateTimeNamespace:
return SparkLikeExprDateTimeNamespace(self)
@property
def list(self) -> SparkLikeExprListNamespace:
return SparkLikeExprListNamespace(self)
@property
def struct(self) -> SparkLikeExprStructNamespace:
return SparkLikeExprStructNamespace(self)
quantile = not_implemented()

View File

@ -0,0 +1,192 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._constants import US_PER_SECOND
from narwhals._duration import Interval
from narwhals._spark_like.utils import (
UNITS_DICT,
fetch_session_time_zone,
strptime_to_pyspark_format,
)
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
from narwhals._utils import not_implemented
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlframe.base.column import Column
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeExprDateTimeNamespace(SQLExprDateTimeNamesSpace["SparkLikeExpr"]):
def _weekday(self, expr: Column) -> Column:
# PySpark's dayofweek returns 1-7 for Sunday-Saturday
return (self.compliant._F.dayofweek(expr) + 6) % 7
def to_string(self, format: str) -> SparkLikeExpr:
F = self.compliant._F
def _to_string(expr: Column) -> Column:
# Handle special formats
if format == "%G-W%V":
return self._format_iso_week(expr)
if format == "%G-W%V-%u":
return self._format_iso_week_with_day(expr)
format_, suffix = self._format_microseconds(expr, format)
# Convert Python format to PySpark format
pyspark_fmt = strptime_to_pyspark_format(format_)
result = F.date_format(expr, pyspark_fmt)
if "T" in format_:
# `strptime_to_pyspark_format` replaces "T" with " " since pyspark
# does not support the literal "T" in `date_format`.
# If no other spaces are in the given format, then we can revert this
# operation, otherwise we raise an exception.
if " " not in format_:
result = F.replace(result, F.lit(" "), F.lit("T"))
else: # pragma: no cover
msg = (
"`dt.to_string` with a format that contains both spaces and "
" the literal 'T' is not supported for spark-like backends."
)
raise NotImplementedError(msg)
return F.concat(result, *suffix)
return self.compliant._with_elementwise(_to_string)
def millisecond(self) -> SparkLikeExpr:
def _millisecond(expr: Column) -> Column:
return self.compliant._F.floor(
(self.compliant._F.unix_micros(expr) % US_PER_SECOND) / 1000
)
return self.compliant._with_elementwise(_millisecond)
def microsecond(self) -> SparkLikeExpr:
def _microsecond(expr: Column) -> Column:
return self.compliant._F.unix_micros(expr) % US_PER_SECOND
return self.compliant._with_elementwise(_microsecond)
def nanosecond(self) -> SparkLikeExpr:
def _nanosecond(expr: Column) -> Column:
return (self.compliant._F.unix_micros(expr) % US_PER_SECOND) * 1000
return self.compliant._with_elementwise(_nanosecond)
def weekday(self) -> SparkLikeExpr:
return self.compliant._with_elementwise(self._weekday)
def truncate(self, every: str) -> SparkLikeExpr:
interval = Interval.parse(every)
multiple, unit = interval.multiple, interval.unit
if multiple != 1:
msg = f"Only multiple 1 is currently supported for Spark-like.\nGot {multiple!s}."
raise ValueError(msg)
if unit == "ns":
msg = "Truncating to nanoseconds is not yet supported for Spark-like."
raise NotImplementedError(msg)
format = UNITS_DICT[unit]
def _truncate(expr: Column) -> Column:
return self.compliant._F.date_trunc(format, expr)
return self.compliant._with_elementwise(_truncate)
def offset_by(self, by: str) -> SparkLikeExpr:
interval = Interval.parse_no_constraints(by)
multiple, unit = interval.multiple, interval.unit
if unit == "ns": # pragma: no cover
msg = "Offsetting by nanoseconds is not yet supported for Spark-like."
raise NotImplementedError(msg)
F = self.compliant._F
def _offset_by(expr: Column) -> Column:
# https://github.com/eakmanrq/sqlframe/issues/441
return F.timestamp_add( # pyright: ignore[reportAttributeAccessIssue]
UNITS_DICT[unit], F.lit(multiple), expr
)
return self.compliant._with_callable(_offset_by)
def _no_op_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
native_series_list = self.compliant(df)
conn_time_zone = fetch_session_time_zone(df.native.sparkSession)
if conn_time_zone != time_zone:
msg = (
"PySpark stores the time zone in the session, rather than in the "
f"data type, so changing the timezone to anything other than {conn_time_zone} "
" (the current session time zone) is not supported."
)
raise NotImplementedError(msg)
return native_series_list
return self.compliant.__class__(
func,
evaluate_output_names=self.compliant._evaluate_output_names,
alias_output_names=self.compliant._alias_output_names,
version=self.compliant._version,
implementation=self.compliant._implementation,
)
def convert_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
return self._no_op_time_zone(time_zone)
def replace_time_zone(
self, time_zone: str | None
) -> SparkLikeExpr: # pragma: no cover
if time_zone is None:
return self.compliant._with_elementwise(
lambda expr: expr.cast("timestamp_ntz")
)
return self._no_op_time_zone(time_zone)
def _format_iso_week_with_day(self, expr: Column) -> Column:
"""Format datetime as ISO week string with day."""
F = self.compliant._F
year = F.date_format(expr, "yyyy")
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
day = self._weekday(expr)
return F.concat(year, F.lit("-W"), week, F.lit("-"), day.cast("string"))
def _format_iso_week(self, expr: Column) -> Column:
"""Format datetime as ISO week string."""
F = self.compliant._F
year = F.date_format(expr, "yyyy")
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
return F.concat(year, F.lit("-W"), week)
def _format_microseconds(
self, expr: Column, format: str
) -> tuple[str, tuple[Column, ...]]:
"""Format microseconds if present in format, else it's a no-op."""
F = self.compliant._F
suffix: tuple[Column, ...]
if format.endswith((".%f", "%.f")):
import re
micros = F.unix_micros(expr) % US_PER_SECOND
micros_str = F.lpad(micros.cast("string"), 6, "0")
suffix = (F.lit("."), micros_str)
format_ = re.sub(r"(.%|%.)f$", "", format)
return format_, suffix
return format, ()
timestamp = not_implemented()
total_seconds = not_implemented()
total_minutes = not_implemented()
total_milliseconds = not_implemented()
total_microseconds = not_implemented()
total_nanoseconds = not_implemented()

View File

@ -0,0 +1,35 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import ListNamespace
if TYPE_CHECKING:
from sqlframe.base.column import Column
from narwhals._spark_like.expr import SparkLikeExpr
from narwhals.typing import NonNestedLiteral
class SparkLikeExprListNamespace(
LazyExprNamespace["SparkLikeExpr"], ListNamespace["SparkLikeExpr"]
):
def len(self) -> SparkLikeExpr:
return self.compliant._with_elementwise(self.compliant._F.array_size)
def unique(self) -> SparkLikeExpr:
return self.compliant._with_elementwise(self.compliant._F.array_distinct)
def contains(self, item: NonNestedLiteral) -> SparkLikeExpr:
def func(expr: Column) -> Column:
F = self.compliant._F
return F.array_contains(expr, F.lit(item))
return self.compliant._with_elementwise(func)
def get(self, index: int) -> SparkLikeExpr:
def _get(expr: Column) -> Column:
return expr.getItem(index)
return self.compliant._with_elementwise(_get)

View File

@ -0,0 +1,36 @@
from __future__ import annotations
from functools import partial
from typing import TYPE_CHECKING
from narwhals._spark_like.utils import strptime_to_pyspark_format
from narwhals._sql.expr_str import SQLExprStringNamespace
from narwhals._utils import _is_naive_format, not_implemented
if TYPE_CHECKING:
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeExprStringNamespace(SQLExprStringNamespace["SparkLikeExpr"]):
def to_datetime(self, format: str | None) -> SparkLikeExpr:
F = self.compliant._F
if not format:
function = F.to_timestamp
elif _is_naive_format(format):
function = partial(
F.to_timestamp_ntz, format=F.lit(strptime_to_pyspark_format(format))
)
else:
format = strptime_to_pyspark_format(format)
function = partial(F.to_timestamp, format=format)
return self.compliant._with_elementwise(
lambda expr: function(F.replace(expr, F.lit("T"), F.lit(" ")))
)
def to_date(self, format: str | None) -> SparkLikeExpr:
F = self.compliant._F
return self.compliant._with_elementwise(
lambda expr: F.to_date(expr, format=strptime_to_pyspark_format(format))
)
replace = not_implemented()

View File

@ -0,0 +1,21 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._compliant import LazyExprNamespace
from narwhals._compliant.any_namespace import StructNamespace
if TYPE_CHECKING:
from sqlframe.base.column import Column
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeExprStructNamespace(
LazyExprNamespace["SparkLikeExpr"], StructNamespace["SparkLikeExpr"]
):
def field(self, name: str) -> SparkLikeExpr:
def func(expr: Column) -> Column:
return expr.getField(name)
return self.compliant._with_elementwise(func).alias(name)

View File

@ -0,0 +1,37 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from narwhals._sql.group_by import SQLGroupBy
if TYPE_CHECKING:
from collections.abc import Sequence
from sqlframe.base.column import Column # noqa: F401
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
from narwhals._spark_like.expr import SparkLikeExpr
class SparkLikeLazyGroupBy(SQLGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]):
def __init__(
self,
df: SparkLikeLazyFrame,
keys: Sequence[SparkLikeExpr] | Sequence[str],
/,
*,
drop_null_keys: bool,
) -> None:
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
result = (
self.compliant.native.groupBy(*self._keys).agg(*agg_columns)
if (agg_columns := list(self._evaluate_exprs(exprs)))
else self.compliant.native.select(*self._keys).dropDuplicates()
)
return self.compliant._with_native(result).rename(
dict(zip(self._keys, self._output_key_names))
)

Some files were not shown because too many files have changed in this diff Show More