done
This commit is contained in:
185
lib/python3.11/site-packages/narwhals/__init__.py
Normal file
185
lib/python3.11/site-packages/narwhals/__init__.py
Normal file
@ -0,0 +1,185 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import typing as _t
|
||||
|
||||
from narwhals import dependencies, dtypes, exceptions, selectors
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
is_ordered_categorical,
|
||||
maybe_align_index,
|
||||
maybe_convert_dtypes,
|
||||
maybe_get_index,
|
||||
maybe_reset_index,
|
||||
maybe_set_index,
|
||||
)
|
||||
from narwhals.dataframe import DataFrame, LazyFrame
|
||||
from narwhals.dtypes import (
|
||||
Array,
|
||||
Binary,
|
||||
Boolean,
|
||||
Categorical,
|
||||
Date,
|
||||
Datetime,
|
||||
Decimal,
|
||||
Duration,
|
||||
Enum,
|
||||
Field,
|
||||
Float32,
|
||||
Float64,
|
||||
Int8,
|
||||
Int16,
|
||||
Int32,
|
||||
Int64,
|
||||
Int128,
|
||||
List,
|
||||
Object,
|
||||
String,
|
||||
Struct,
|
||||
Time,
|
||||
UInt8,
|
||||
UInt16,
|
||||
UInt32,
|
||||
UInt64,
|
||||
UInt128,
|
||||
Unknown,
|
||||
)
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.functions import (
|
||||
all_ as all,
|
||||
all_horizontal,
|
||||
any_horizontal,
|
||||
coalesce,
|
||||
col,
|
||||
concat,
|
||||
concat_str,
|
||||
exclude,
|
||||
from_arrow,
|
||||
from_dict,
|
||||
from_numpy,
|
||||
len_ as len,
|
||||
lit,
|
||||
max,
|
||||
max_horizontal,
|
||||
mean,
|
||||
mean_horizontal,
|
||||
median,
|
||||
min,
|
||||
min_horizontal,
|
||||
new_series,
|
||||
nth,
|
||||
read_csv,
|
||||
read_parquet,
|
||||
scan_csv,
|
||||
scan_parquet,
|
||||
show_versions,
|
||||
sum,
|
||||
sum_horizontal,
|
||||
when,
|
||||
)
|
||||
from narwhals.schema import Schema
|
||||
from narwhals.series import Series
|
||||
from narwhals.translate import (
|
||||
from_native,
|
||||
get_native_namespace,
|
||||
narwhalify,
|
||||
to_native,
|
||||
to_py_scalar,
|
||||
)
|
||||
|
||||
__version__: str
|
||||
|
||||
__all__ = [
|
||||
"Array",
|
||||
"Binary",
|
||||
"Boolean",
|
||||
"Categorical",
|
||||
"DataFrame",
|
||||
"Date",
|
||||
"Datetime",
|
||||
"Decimal",
|
||||
"Duration",
|
||||
"Enum",
|
||||
"Expr",
|
||||
"Field",
|
||||
"Float32",
|
||||
"Float64",
|
||||
"Implementation",
|
||||
"Int8",
|
||||
"Int16",
|
||||
"Int32",
|
||||
"Int64",
|
||||
"Int128",
|
||||
"LazyFrame",
|
||||
"List",
|
||||
"Object",
|
||||
"Schema",
|
||||
"Series",
|
||||
"String",
|
||||
"Struct",
|
||||
"Time",
|
||||
"UInt8",
|
||||
"UInt16",
|
||||
"UInt32",
|
||||
"UInt64",
|
||||
"UInt128",
|
||||
"Unknown",
|
||||
"all",
|
||||
"all_horizontal",
|
||||
"any_horizontal",
|
||||
"coalesce",
|
||||
"col",
|
||||
"concat",
|
||||
"concat_str",
|
||||
"dependencies",
|
||||
"dtypes",
|
||||
"exceptions",
|
||||
"exclude",
|
||||
"from_arrow",
|
||||
"from_dict",
|
||||
"from_native",
|
||||
"from_numpy",
|
||||
"generate_temporary_column_name",
|
||||
"get_native_namespace",
|
||||
"is_ordered_categorical",
|
||||
"len",
|
||||
"lit",
|
||||
"max",
|
||||
"max_horizontal",
|
||||
"maybe_align_index",
|
||||
"maybe_convert_dtypes",
|
||||
"maybe_get_index",
|
||||
"maybe_reset_index",
|
||||
"maybe_set_index",
|
||||
"mean",
|
||||
"mean_horizontal",
|
||||
"median",
|
||||
"min",
|
||||
"min_horizontal",
|
||||
"narwhalify",
|
||||
"new_series",
|
||||
"nth",
|
||||
"read_csv",
|
||||
"read_parquet",
|
||||
"scan_csv",
|
||||
"scan_parquet",
|
||||
"selectors",
|
||||
"show_versions",
|
||||
"sum",
|
||||
"sum_horizontal",
|
||||
"to_native",
|
||||
"to_py_scalar",
|
||||
"when",
|
||||
]
|
||||
|
||||
|
||||
def __getattr__(name: _t.Literal["__version__"]) -> str: # type: ignore[misc]
|
||||
if name == "__version__":
|
||||
global __version__ # noqa: PLW0603
|
||||
|
||||
from importlib import metadata
|
||||
|
||||
__version__ = metadata.version(__name__)
|
||||
return __version__
|
||||
msg = f"module {__name__!r} has no attribute {name!r}"
|
||||
raise AttributeError(msg)
|
||||
792
lib/python3.11/site-packages/narwhals/_arrow/dataframe.py
Normal file
792
lib/python3.11/site-packages/narwhals/_arrow/dataframe.py
Normal file
@ -0,0 +1,792 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Collection, Iterator, Mapping, Sequence
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.utils import native_to_narwhals_dtype
|
||||
from narwhals._compliant import EagerDataFrame
|
||||
from narwhals._expression_parsing import ExprKind
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
check_column_names_are_unique,
|
||||
convert_str_slice_to_int_slice,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
scale_bytes,
|
||||
supports_arrow_c_stream,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array_1d
|
||||
from narwhals.exceptions import ShapeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
|
||||
ChunkedArrayAny,
|
||||
Mask,
|
||||
Order,
|
||||
)
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import (
|
||||
IntoSchema,
|
||||
JoinStrategy,
|
||||
SizedMultiIndexSelector,
|
||||
SizedMultiNameSelector,
|
||||
SizeUnit,
|
||||
UniqueKeepStrategy,
|
||||
_1DArray,
|
||||
_2DArray,
|
||||
_SliceIndex,
|
||||
_SliceName,
|
||||
)
|
||||
|
||||
JoinType: TypeAlias = Literal[
|
||||
"left semi",
|
||||
"right semi",
|
||||
"left anti",
|
||||
"right anti",
|
||||
"inner",
|
||||
"left outer",
|
||||
"right outer",
|
||||
"full outer",
|
||||
]
|
||||
PromoteOptions: TypeAlias = Literal["none", "default", "permissive"]
|
||||
|
||||
|
||||
class ArrowDataFrame(
|
||||
EagerDataFrame["ArrowSeries", "ArrowExpr", "pa.Table", "ChunkedArrayAny"]
|
||||
):
|
||||
_implementation = Implementation.PYARROW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: pa.Table,
|
||||
*,
|
||||
version: Version,
|
||||
validate_column_names: bool,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
if validate_column_names:
|
||||
check_column_names_are_unique(native_dataframe.column_names)
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
self._native_frame = native_dataframe
|
||||
self._version = version
|
||||
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
|
||||
backend_version = context._implementation._backend_version()
|
||||
if cls._is_native(data):
|
||||
native = data
|
||||
elif backend_version >= (14,) or isinstance(data, Collection):
|
||||
native = pa.table(data)
|
||||
elif supports_arrow_c_stream(data): # pragma: no cover
|
||||
msg = f"'pyarrow>=14.0.0' is required for `from_arrow` for object of type {type(data).__name__!r}."
|
||||
raise ModuleNotFoundError(msg)
|
||||
else: # pragma: no cover
|
||||
msg = f"`from_arrow` is not supported for object of type {type(data).__name__!r}."
|
||||
raise TypeError(msg)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pa_schema = Schema(schema).to_arrow() if schema is not None else schema
|
||||
if pa_schema and not data:
|
||||
native = pa_schema.empty_table()
|
||||
else:
|
||||
native = pa.Table.from_pydict(data, schema=pa_schema)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pa.Table | Any) -> TypeIs[pa.Table]:
|
||||
return isinstance(obj, pa.Table)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: pa.Table, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version, validate_column_names=True)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | Sequence[str] | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
arrays = [pa.array(val) for val in data.T]
|
||||
if isinstance(schema, (Mapping, Schema)):
|
||||
native = pa.Table.from_arrays(arrays, schema=Schema(schema).to_arrow())
|
||||
else:
|
||||
native = pa.Table.from_arrays(arrays, cls._numpy_column_names(data, schema))
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
def __narwhals_namespace__(self) -> ArrowNamespace:
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
|
||||
return ArrowNamespace(version=self._version)
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.PYARROW:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected pyarrow, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version, validate_column_names=False)
|
||||
|
||||
def _with_native(self, df: pa.Table, *, validate_column_names: bool = True) -> Self:
|
||||
return self.__class__(
|
||||
df, version=self._version, validate_column_names=validate_column_names
|
||||
)
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]:
|
||||
return self.native.shape
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
def row(self, index: int) -> tuple[Any, ...]:
|
||||
return tuple(col[index] for col in self.native.itercolumns())
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: Literal[True]) -> list[dict[str, Any]]: ...
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: Literal[False]) -> list[tuple[Any, ...]]: ...
|
||||
|
||||
@overload
|
||||
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]: ...
|
||||
|
||||
def rows(self, *, named: bool) -> list[tuple[Any, ...]] | list[dict[str, Any]]:
|
||||
if not named:
|
||||
return list(self.iter_rows(named=False, buffer_size=512)) # type: ignore[return-value]
|
||||
return self.native.to_pylist()
|
||||
|
||||
def iter_columns(self) -> Iterator[ArrowSeries]:
|
||||
for name, series in zip_strict(self.columns, self.native.itercolumns()):
|
||||
yield ArrowSeries.from_native(series, context=self, name=name)
|
||||
|
||||
_iter_columns = iter_columns
|
||||
|
||||
def iter_rows(
|
||||
self, *, named: bool, buffer_size: int
|
||||
) -> Iterator[tuple[Any, ...]] | Iterator[dict[str, Any]]:
|
||||
df = self.native
|
||||
num_rows = df.num_rows
|
||||
|
||||
if not named:
|
||||
for i in range(0, num_rows, buffer_size):
|
||||
rows = df[i : i + buffer_size].to_pydict().values()
|
||||
yield from zip_strict(*rows)
|
||||
else:
|
||||
for i in range(0, num_rows, buffer_size):
|
||||
yield from df[i : i + buffer_size].to_pylist()
|
||||
|
||||
def get_column(self, name: str) -> ArrowSeries:
|
||||
if not isinstance(name, str):
|
||||
msg = f"Expected str, got: {type(name)}"
|
||||
raise TypeError(msg)
|
||||
return ArrowSeries.from_native(self.native[name], context=self, name=name)
|
||||
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray:
|
||||
return self.native.__array__(dtype, copy=copy)
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[ChunkedArrayAny]) -> Self:
|
||||
if len(rows) == 0:
|
||||
return self._with_native(self.native.slice(0, 0))
|
||||
if self._backend_version < (18,) and isinstance(rows, tuple):
|
||||
rows = list(rows)
|
||||
return self._with_native(self.native.take(rows))
|
||||
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self:
|
||||
start = rows.start or 0
|
||||
stop = rows.stop if rows.stop is not None else len(self.native)
|
||||
if start < 0:
|
||||
start = len(self.native) + start
|
||||
if stop < 0:
|
||||
stop = len(self.native) + stop
|
||||
if rows.step is not None and rows.step != 1:
|
||||
msg = "Slicing with step is not supported on PyArrow tables"
|
||||
raise NotImplementedError(msg)
|
||||
return self._with_native(self.native.slice(start, stop - start))
|
||||
|
||||
def _select_slice_name(self, columns: _SliceName) -> Self:
|
||||
start, stop, step = convert_str_slice_to_int_slice(columns, self.columns)
|
||||
return self._with_native(self.native.select(self.columns[start:stop:step]))
|
||||
|
||||
def _select_slice_index(self, columns: _SliceIndex | range) -> Self:
|
||||
return self._with_native(
|
||||
self.native.select(self.columns[columns.start : columns.stop : columns.step])
|
||||
)
|
||||
|
||||
def _select_multi_index(
|
||||
self, columns: SizedMultiIndexSelector[ChunkedArrayAny]
|
||||
) -> Self:
|
||||
selector: Sequence[int]
|
||||
if isinstance(columns, pa.ChunkedArray):
|
||||
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
|
||||
selector = cast("Sequence[int]", columns.to_pylist())
|
||||
# TODO @dangotbanned: Fix upstream, it is actually much narrower
|
||||
# **Doesn't accept `ndarray`**
|
||||
elif is_numpy_array_1d(columns):
|
||||
selector = columns.tolist()
|
||||
else:
|
||||
selector = columns
|
||||
return self._with_native(self.native.select(selector))
|
||||
|
||||
def _select_multi_name(
|
||||
self, columns: SizedMultiNameSelector[ChunkedArrayAny]
|
||||
) -> Self:
|
||||
selector: Sequence[str] | _1DArray
|
||||
if isinstance(columns, pa.ChunkedArray):
|
||||
# TODO @dangotbanned: Fix upstream with `pa.ChunkedArray.to_pylist(self) -> list[Any]:`
|
||||
selector = cast("Sequence[str]", columns.to_pylist())
|
||||
else:
|
||||
selector = columns
|
||||
# NOTE: Fixed in https://github.com/zen-xu/pyarrow-stubs/pull/221
|
||||
return self._with_native(self.native.select(selector)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
return {
|
||||
field.name: native_to_narwhals_dtype(field.type, self._version)
|
||||
for field in self.native.schema
|
||||
}
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def estimated_size(self, unit: SizeUnit) -> int | float:
|
||||
sz = self.native.nbytes
|
||||
return scale_bytes(sz, unit)
|
||||
|
||||
explode = not_implemented()
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
return self.native.column_names
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(
|
||||
self.native.select(list(column_names)), validate_column_names=False
|
||||
)
|
||||
|
||||
def select(self, *exprs: ArrowExpr) -> Self:
|
||||
new_series = self._evaluate_into_exprs(*exprs)
|
||||
if not new_series:
|
||||
# return empty dataframe, like Polars does
|
||||
return self._with_native(
|
||||
self.native.__class__.from_arrays([]), validate_column_names=False
|
||||
)
|
||||
names = [s.name for s in new_series]
|
||||
align = new_series[0]._align_full_broadcast
|
||||
reshaped = align(*new_series)
|
||||
df = pa.Table.from_arrays([s.native for s in reshaped], names=names)
|
||||
return self._with_native(df, validate_column_names=True)
|
||||
|
||||
def _extract_comparand(self, other: ArrowSeries) -> ChunkedArrayAny:
|
||||
length = len(self)
|
||||
if not other._broadcast:
|
||||
if (len_other := len(other)) != length:
|
||||
msg = f"Expected object of length {length}, got: {len_other}."
|
||||
raise ShapeError(msg)
|
||||
return other.native
|
||||
|
||||
value = other.native[0]
|
||||
return pa.chunked_array([pa.repeat(value, length)])
|
||||
|
||||
def with_columns(self, *exprs: ArrowExpr) -> Self:
|
||||
# NOTE: We use a faux-mutable variable and repeatedly "overwrite" (native_frame)
|
||||
# All `pyarrow` data is immutable, so this is fine
|
||||
native_frame = self.native
|
||||
new_columns = self._evaluate_into_exprs(*exprs)
|
||||
columns = self.columns
|
||||
|
||||
for col_value in new_columns:
|
||||
col_name = col_value.name
|
||||
column = self._extract_comparand(col_value)
|
||||
native_frame = (
|
||||
native_frame.set_column(columns.index(col_name), col_name, column=column)
|
||||
if col_name in columns
|
||||
else native_frame.append_column(col_name, column=column)
|
||||
)
|
||||
|
||||
return self._with_native(native_frame, validate_column_names=False)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[ArrowExpr], *, drop_null_keys: bool
|
||||
) -> ArrowGroupBy:
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
|
||||
return ArrowGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
how_to_join_map: dict[str, JoinType] = {
|
||||
"anti": "left anti",
|
||||
"semi": "left semi",
|
||||
"inner": "inner",
|
||||
"left": "left outer",
|
||||
"full": "full outer",
|
||||
}
|
||||
|
||||
if how == "cross":
|
||||
plx = self.__narwhals_namespace__()
|
||||
key_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=[*self.columns, *other.columns]
|
||||
)
|
||||
|
||||
return self._with_native(
|
||||
self.with_columns(
|
||||
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
|
||||
)
|
||||
.native.join(
|
||||
other.with_columns(
|
||||
plx.lit(0, None).alias(key_token).broadcast(ExprKind.LITERAL)
|
||||
).native,
|
||||
keys=key_token,
|
||||
right_keys=key_token,
|
||||
join_type="inner",
|
||||
right_suffix=suffix,
|
||||
)
|
||||
.drop([key_token])
|
||||
)
|
||||
|
||||
coalesce_keys = how != "full" # polars full join does not coalesce keys
|
||||
return self._with_native(
|
||||
self.native.join(
|
||||
other.native,
|
||||
keys=left_on or [], # type: ignore[arg-type]
|
||||
right_keys=right_on, # type: ignore[arg-type]
|
||||
join_type=how_to_join_map[how],
|
||||
right_suffix=suffix,
|
||||
coalesce_keys=coalesce_keys,
|
||||
)
|
||||
)
|
||||
|
||||
join_asof = not_implemented()
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(to_drop), validate_column_names=False)
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
if subset is None:
|
||||
return self._with_native(self.native.drop_null(), validate_column_names=False)
|
||||
plx = self.__narwhals_namespace__()
|
||||
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
||||
return self.filter(mask)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
order: Order = "descending" if descending else "ascending"
|
||||
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
|
||||
else:
|
||||
sorting = [
|
||||
(key, "descending" if is_descending else "ascending")
|
||||
for key, is_descending in zip_strict(by, descending)
|
||||
]
|
||||
|
||||
null_placement = "at_end" if nulls_last else "at_start"
|
||||
|
||||
return self._with_native(
|
||||
self.native.sort_by(sorting, null_placement=null_placement),
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
if isinstance(reverse, bool):
|
||||
order: Order = "ascending" if reverse else "descending"
|
||||
sorting: list[tuple[str, Order]] = [(key, order) for key in by]
|
||||
else:
|
||||
sorting = [
|
||||
(key, "ascending" if is_ascending else "descending")
|
||||
for key, is_ascending in zip_strict(by, reverse)
|
||||
]
|
||||
return self._with_native(
|
||||
self.native.take(pc.select_k_unstable(self.native, k, sorting)), # type: ignore[call-overload]
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
return self.native.to_pandas()
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
return pl.from_arrow(self.native) # type: ignore[return-value]
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
arr: Any = np.column_stack([col.to_numpy() for col in self.native.columns])
|
||||
return arr
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, ArrowSeries]: ...
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, ArrowSeries] | dict[str, list[Any]]:
|
||||
it = self.iter_columns()
|
||||
if as_series:
|
||||
return {ser.name: ser for ser in it}
|
||||
return {ser.name: ser.to_list() for ser in it}
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
plx = self.__narwhals_namespace__()
|
||||
if order_by is None:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
data = pa.array(np.arange(len(self), dtype=np.int64))
|
||||
row_index = plx._expr._from_series(
|
||||
plx._series.from_iterable(data, context=self, name=name)
|
||||
)
|
||||
else:
|
||||
rank = plx.col(order_by[0]).rank("ordinal", descending=False)
|
||||
row_index = (rank.over(partition_by=[], order_by=order_by) - 1).alias(name)
|
||||
return self.select(row_index, plx.all())
|
||||
|
||||
def filter(self, predicate: ArrowExpr | list[bool | None]) -> Self:
|
||||
if isinstance(predicate, list):
|
||||
mask_native: Mask | ChunkedArrayAny = predicate
|
||||
else:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask_native = self._evaluate_into_exprs(predicate)[0].native
|
||||
return self._with_native(
|
||||
self.native.filter(mask_native), validate_column_names=False
|
||||
)
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
df = self.native
|
||||
if n >= 0:
|
||||
return self._with_native(df.slice(0, n), validate_column_names=False)
|
||||
num_rows = df.num_rows
|
||||
return self._with_native(
|
||||
df.slice(0, max(0, num_rows + n)), validate_column_names=False
|
||||
)
|
||||
|
||||
def tail(self, n: int) -> Self:
|
||||
df = self.native
|
||||
if n >= 0:
|
||||
num_rows = df.num_rows
|
||||
return self._with_native(
|
||||
df.slice(max(0, num_rows - n)), validate_column_names=False
|
||||
)
|
||||
return self._with_native(df.slice(abs(n)), validate_column_names=False)
|
||||
|
||||
def lazy(
|
||||
self,
|
||||
backend: _LazyAllowedImpl | None = None,
|
||||
*,
|
||||
session: SparkSession | None = None,
|
||||
) -> CompliantLazyFrameAny:
|
||||
if backend is None:
|
||||
return self
|
||||
if backend is Implementation.DUCKDB:
|
||||
import duckdb # ignore-banned-import
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
|
||||
_df = self.native
|
||||
return DuckDBLazyFrame(
|
||||
duckdb.table("_df"), validate_backend_version=True, version=self._version
|
||||
)
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsLazyFrame
|
||||
|
||||
return PolarsLazyFrame(
|
||||
cast("pl.DataFrame", pl.from_arrow(self.native)).lazy(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
if backend is Implementation.DASK:
|
||||
import dask.dataframe as dd # ignore-banned-import
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
return DaskLazyFrame(
|
||||
dd.from_pandas(self.native.to_pandas()),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
if backend is Implementation.IBIS:
|
||||
import ibis # ignore-banned-import
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
|
||||
return IbisLazyFrame(
|
||||
ibis.memtable(self.native, columns=self.columns),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend.is_spark_like():
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
|
||||
if session is None:
|
||||
msg = "Spark like backends require `session` to be not None."
|
||||
raise ValueError(msg)
|
||||
|
||||
return SparkLikeLazyFrame._from_compliant_dataframe(
|
||||
self, session=session, implementation=backend, version=self._version
|
||||
)
|
||||
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is Implementation.PYARROW or backend is None:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self.native, version=self._version, validate_column_names=False
|
||||
)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.to_pandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
cast("pl.DataFrame", pl.from_arrow(self.native)),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise AssertionError(msg) # pragma: no cover
|
||||
|
||||
def clone(self) -> Self:
|
||||
return self._with_native(self.native, validate_column_names=False)
|
||||
|
||||
def item(self, row: int | None, column: int | str | None) -> Any:
|
||||
from narwhals._arrow.series import maybe_extract_py_scalar
|
||||
|
||||
if row is None and column is None:
|
||||
if self.shape != (1, 1):
|
||||
msg = (
|
||||
"can only call `.item()` if the dataframe is of shape (1, 1),"
|
||||
" or if explicit row/col values are provided;"
|
||||
f" frame has shape {self.shape!r}"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return maybe_extract_py_scalar(self.native[0][0], return_py_scalar=True)
|
||||
|
||||
if row is None or column is None:
|
||||
msg = "cannot call `.item()` with only one of `row` or `column`"
|
||||
raise ValueError(msg)
|
||||
|
||||
_col = self.columns.index(column) if isinstance(column, str) else column
|
||||
return maybe_extract_py_scalar(self.native[_col][row], return_py_scalar=True)
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
names: dict[str, str] | list[str]
|
||||
if self._backend_version >= (17,):
|
||||
names = cast("dict[str, str]", mapping)
|
||||
else: # pragma: no cover
|
||||
names = [mapping.get(c, c) for c in self.columns]
|
||||
return self._with_native(self.native.rename_columns(names))
|
||||
|
||||
def write_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
import pyarrow.parquet as pp
|
||||
|
||||
pp.write_table(self.native, file)
|
||||
|
||||
@overload
|
||||
def write_csv(self, file: None) -> str: ...
|
||||
|
||||
@overload
|
||||
def write_csv(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
def write_csv(self, file: str | Path | BytesIO | None) -> str | None:
|
||||
import pyarrow.csv as pa_csv
|
||||
|
||||
if file is None:
|
||||
csv_buffer = pa.BufferOutputStream()
|
||||
pa_csv.write_csv(self.native, csv_buffer)
|
||||
return csv_buffer.getvalue().to_pybytes().decode()
|
||||
pa_csv.write_csv(self.native, file)
|
||||
return None
|
||||
|
||||
def is_unique(self) -> ArrowSeries:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
row_index = pa.array(np.arange(len(self)))
|
||||
keep_idx = (
|
||||
self.native.append_column(col_token, row_index)
|
||||
.group_by(self.columns)
|
||||
.aggregate([(col_token, "min"), (col_token, "max")])
|
||||
)
|
||||
native = pa.chunked_array(
|
||||
pc.and_(
|
||||
pc.is_in(row_index, keep_idx[f"{col_token}_min"]),
|
||||
pc.is_in(row_index, keep_idx[f"{col_token}_max"]),
|
||||
)
|
||||
)
|
||||
return ArrowSeries.from_native(native, context=self)
|
||||
|
||||
def unique(
|
||||
self,
|
||||
subset: Sequence[str] | None,
|
||||
*,
|
||||
keep: UniqueKeepStrategy,
|
||||
maintain_order: bool | None = None,
|
||||
) -> Self:
|
||||
# The param `maintain_order` is only here for compatibility with the Polars API
|
||||
# and has no effect on the output.
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
subset = list(subset or self.columns)
|
||||
|
||||
if keep in {"any", "first", "last"}:
|
||||
from narwhals._arrow.group_by import ArrowGroupBy
|
||||
|
||||
agg_func = ArrowGroupBy._REMAP_UNIQUE[keep]
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
keep_idx_native = (
|
||||
self.native.append_column(col_token, pa.array(np.arange(len(self))))
|
||||
.group_by(subset)
|
||||
.aggregate([(col_token, agg_func)])
|
||||
.column(f"{col_token}_{agg_func}")
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.take(keep_idx_native), validate_column_names=False
|
||||
)
|
||||
|
||||
keep_idx = self.simple_select(*subset).is_unique()
|
||||
plx = self.__narwhals_namespace__()
|
||||
return self.filter(plx._expr._from_series(keep_idx))
|
||||
|
||||
def gather_every(self, n: int, offset: int) -> Self:
|
||||
return self._with_native(self.native[offset::n], validate_column_names=False)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
return self.native
|
||||
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self:
|
||||
import numpy as np # ignore-banned-import
|
||||
|
||||
num_rows = len(self)
|
||||
if n is None and fraction is not None:
|
||||
n = int(num_rows * fraction)
|
||||
rng = np.random.default_rng(seed=seed)
|
||||
idx = np.arange(num_rows)
|
||||
mask = rng.choice(idx, size=n, replace=with_replacement)
|
||||
return self._with_native(self.native.take(mask), validate_column_names=False)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
n_rows = len(self)
|
||||
index_ = [] if index is None else index
|
||||
on_ = [c for c in self.columns if c not in index_] if on is None else on
|
||||
concat = (
|
||||
partial(pa.concat_tables, promote_options="permissive")
|
||||
if self._backend_version >= (14, 0, 0)
|
||||
else pa.concat_tables
|
||||
)
|
||||
names = [*index_, variable_name, value_name]
|
||||
return self._with_native(
|
||||
concat(
|
||||
[
|
||||
pa.Table.from_arrays(
|
||||
[
|
||||
*(self.native.column(idx_col) for idx_col in index_),
|
||||
cast(
|
||||
"ChunkedArrayAny",
|
||||
pa.array([on_col] * n_rows, pa.string()),
|
||||
),
|
||||
self.native.column(on_col),
|
||||
],
|
||||
names=names,
|
||||
)
|
||||
for on_col in on_
|
||||
]
|
||||
)
|
||||
)
|
||||
# TODO(Unassigned): Even with promote_options="permissive", pyarrow does not
|
||||
# upcast numeric to non-numeric (e.g. string) datatypes
|
||||
|
||||
pivot = not_implemented()
|
||||
170
lib/python3.11/site-packages/narwhals/_arrow/expr.py
Normal file
170
lib/python3.11/site-packages/narwhals/_arrow/expr.py
Normal file
@ -0,0 +1,170 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._compliant import EagerExpr
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._expression_parsing import ExprMetadata
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
|
||||
|
||||
class ArrowExpr(EagerExpr["ArrowDataFrame", ArrowSeries]):
|
||||
_implementation: Implementation = Implementation.PYARROW
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[ArrowDataFrame, ArrowSeries],
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[ArrowDataFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
implementation: Implementation | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._depth = depth
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[ArrowDataFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
try:
|
||||
return [
|
||||
ArrowSeries(
|
||||
df.native[column_name], name=column_name, version=df._version
|
||||
)
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
tbl = df.native
|
||||
cols = df.columns
|
||||
return [
|
||||
ArrowSeries.from_native(tbl[i], name=cols[i], context=df)
|
||||
for i in column_indices
|
||||
]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def __narwhals_namespace__(self) -> ArrowNamespace:
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
|
||||
return ArrowNamespace(version=self._version)
|
||||
|
||||
def _reuse_series_extra_kwargs(
|
||||
self, *, returns_scalar: bool = False
|
||||
) -> dict[str, Any]:
|
||||
return {"_return_py_scalar": False} if returns_scalar else {}
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
if (
|
||||
partition_by
|
||||
and self._metadata is not None
|
||||
and not self._metadata.is_scalar_like
|
||||
):
|
||||
msg = "Only aggregation or literal operations are supported in grouped `over` context for PyArrow."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if not partition_by:
|
||||
# e.g. `nw.col('a').cum_sum().order_by(key)`
|
||||
# which we can always easily support, as it doesn't require grouping.
|
||||
assert order_by # noqa: S101
|
||||
|
||||
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
||||
token = generate_temporary_column_name(8, df.columns)
|
||||
df = df.with_row_index(token, order_by=None).sort(
|
||||
*order_by, descending=False, nulls_last=False
|
||||
)
|
||||
result = self(df.drop([token], strict=True))
|
||||
# TODO(marco): is there a way to do this efficiently without
|
||||
# doing 2 sorts? Here we're sorting the dataframe and then
|
||||
# again calling `sort_indices`. `ArrowSeries.scatter` would also sort.
|
||||
sorting_indices = pc.sort_indices(df.get_column(token).native)
|
||||
return [s._with_native(s.native.take(sorting_indices)) for s in result]
|
||||
else:
|
||||
|
||||
def func(df: ArrowDataFrame) -> Sequence[ArrowSeries]:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
if overlap := set(output_names).intersection(partition_by):
|
||||
# E.g. `df.select(nw.all().sum().over('a'))`. This is well-defined,
|
||||
# we just don't support it yet.
|
||||
msg = (
|
||||
f"Column names {overlap} appear in both expression output names and in `over` keys.\n"
|
||||
"This is not yet supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
tmp = df.group_by(partition_by, drop_null_keys=False).agg(self)
|
||||
tmp = df.simple_select(*partition_by).join(
|
||||
tmp,
|
||||
how="left",
|
||||
left_on=partition_by,
|
||||
right_on=partition_by,
|
||||
suffix="_right",
|
||||
)
|
||||
return [tmp.get_column(alias) for alias in aliases]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
ewm_mean = not_implemented()
|
||||
159
lib/python3.11/site-packages/narwhals/_arrow/group_by.py
Normal file
159
lib/python3.11/site-packages/narwhals/_arrow/group_by.py
Normal file
@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import collections
|
||||
from typing import TYPE_CHECKING, Any, ClassVar
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import cast_to_comparable_string_types, extract_py_scalar
|
||||
from narwhals._compliant import EagerGroupBy
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import generate_temporary_column_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.typing import ( # type: ignore[attr-defined]
|
||||
AggregateOptions,
|
||||
Aggregation,
|
||||
Incomplete,
|
||||
)
|
||||
from narwhals._compliant.typing import NarwhalsAggregation
|
||||
from narwhals.typing import UniqueKeepStrategy
|
||||
|
||||
|
||||
class ArrowGroupBy(EagerGroupBy["ArrowDataFrame", "ArrowExpr", "Aggregation"]):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "approximate_median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"std": "stddev",
|
||||
"var": "variance",
|
||||
"len": "count",
|
||||
"n_unique": "count_distinct",
|
||||
"count": "count",
|
||||
"all": "all",
|
||||
"any": "any",
|
||||
}
|
||||
_REMAP_UNIQUE: ClassVar[Mapping[UniqueKeepStrategy, Aggregation]] = {
|
||||
"any": "min",
|
||||
"first": "min",
|
||||
"last": "max",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: ArrowDataFrame,
|
||||
keys: Sequence[ArrowExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._df = df
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
self._grouped = pa.TableGroupBy(self.compliant.native, self._keys)
|
||||
self._drop_null_keys = drop_null_keys
|
||||
|
||||
def agg(self, *exprs: ArrowExpr) -> ArrowDataFrame:
|
||||
self._ensure_all_simple(exprs)
|
||||
aggs: list[tuple[str, Aggregation, AggregateOptions | None]] = []
|
||||
expected_pyarrow_column_names: list[str] = self._keys.copy()
|
||||
new_column_names: list[str] = self._keys.copy()
|
||||
exclude = (*self._keys, *self._output_key_names)
|
||||
|
||||
for expr in exprs:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(
|
||||
expr, self.compliant, exclude
|
||||
)
|
||||
|
||||
if expr._depth == 0:
|
||||
# e.g. `agg(nw.len())`
|
||||
if expr._function_name != "len": # pragma: no cover
|
||||
msg = "Safety assertion failed, please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
raise AssertionError(msg)
|
||||
|
||||
new_column_names.append(aliases[0])
|
||||
expected_pyarrow_column_names.append(f"{self._keys[0]}_count")
|
||||
aggs.append((self._keys[0], "count", pc.CountOptions(mode="all")))
|
||||
continue
|
||||
|
||||
function_name = self._leaf_name(expr)
|
||||
if function_name in {"std", "var"}:
|
||||
assert "ddof" in expr._scalar_kwargs # noqa: S101
|
||||
option: Any = pc.VarianceOptions(ddof=expr._scalar_kwargs["ddof"])
|
||||
elif function_name in {"len", "n_unique"}:
|
||||
option = pc.CountOptions(mode="all")
|
||||
elif function_name == "count":
|
||||
option = pc.CountOptions(mode="only_valid")
|
||||
elif function_name in {"all", "any"}:
|
||||
option = pc.ScalarAggregateOptions(min_count=0)
|
||||
else:
|
||||
option = None
|
||||
|
||||
function_name = self._remap_expr_name(function_name)
|
||||
new_column_names.extend(aliases)
|
||||
expected_pyarrow_column_names.extend(
|
||||
[f"{output_name}_{function_name}" for output_name in output_names]
|
||||
)
|
||||
aggs.extend(
|
||||
[(output_name, function_name, option) for output_name in output_names]
|
||||
)
|
||||
|
||||
result_simple = self._grouped.aggregate(aggs)
|
||||
|
||||
# Rename columns, being very careful
|
||||
expected_old_names_indices: dict[str, list[int]] = collections.defaultdict(list)
|
||||
for idx, item in enumerate(expected_pyarrow_column_names):
|
||||
expected_old_names_indices[item].append(idx)
|
||||
if not (
|
||||
set(result_simple.column_names) == set(expected_pyarrow_column_names)
|
||||
and len(result_simple.column_names) == len(expected_pyarrow_column_names)
|
||||
): # pragma: no cover
|
||||
msg = (
|
||||
f"Safety assertion failed, expected {expected_pyarrow_column_names} "
|
||||
f"got {result_simple.column_names}, "
|
||||
"please report a bug at https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
index_map: list[int] = [
|
||||
expected_old_names_indices[item].pop(0) for item in result_simple.column_names
|
||||
]
|
||||
new_column_names = [new_column_names[i] for i in index_map]
|
||||
result_simple = result_simple.rename_columns(new_column_names)
|
||||
return self.compliant._with_native(result_simple).rename(
|
||||
dict(zip(self._keys, self._output_key_names))
|
||||
)
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[Any, ArrowDataFrame]]:
|
||||
col_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=self.compliant.columns
|
||||
)
|
||||
null_token: str = "__null_token_value__" # noqa: S105
|
||||
|
||||
table = self.compliant.native
|
||||
it, separator_scalar = cast_to_comparable_string_types(
|
||||
*(table[key] for key in self._keys), separator=""
|
||||
)
|
||||
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
|
||||
# Reality: `str` is fine
|
||||
concat_str: Incomplete = pc.binary_join_element_wise
|
||||
key_values = concat_str(
|
||||
*it, separator_scalar, null_handling="replace", null_replacement=null_token
|
||||
)
|
||||
table = table.add_column(i=0, field_=col_token, column=key_values)
|
||||
|
||||
for v in pc.unique(key_values):
|
||||
t = self.compliant._with_native(
|
||||
table.filter(pc.equal(table[col_token], v)).drop([col_token])
|
||||
)
|
||||
row = t.simple_select(*self._keys).row(0)
|
||||
yield (
|
||||
tuple(extract_py_scalar(el) for el in row),
|
||||
t.simple_select(*self._df.columns),
|
||||
)
|
||||
303
lib/python3.11/site-packages/narwhals/_arrow/namespace.py
Normal file
303
lib/python3.11/site-packages/narwhals/_arrow/namespace.py
Normal file
@ -0,0 +1,303 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.selectors import ArrowSelectorNamespace
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.utils import cast_to_comparable_string_types
|
||||
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
from narwhals._arrow.typing import ArrayOrScalar, ChunkedArrayAny, Incomplete
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class ArrowNamespace(
|
||||
EagerNamespace[ArrowDataFrame, ArrowSeries, ArrowExpr, pa.Table, "ChunkedArrayAny"]
|
||||
):
|
||||
_implementation = Implementation.PYARROW
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[ArrowDataFrame]:
|
||||
return ArrowDataFrame
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[ArrowExpr]:
|
||||
return ArrowExpr
|
||||
|
||||
@property
|
||||
def _series(self) -> type[ArrowSeries]:
|
||||
return ArrowSeries
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def len(self) -> ArrowExpr:
|
||||
# coverage bug? this is definitely hit
|
||||
return self._expr( # pragma: no cover
|
||||
lambda df: [
|
||||
ArrowSeries.from_iterable([len(df.native)], name="len", context=self)
|
||||
],
|
||||
depth=0,
|
||||
function_name="len",
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> ArrowExpr:
|
||||
def _lit_arrow_series(_: ArrowDataFrame) -> ArrowSeries:
|
||||
arrow_series = ArrowSeries.from_iterable(
|
||||
data=[value], name="literal", context=self
|
||||
)
|
||||
if dtype:
|
||||
return arrow_series.cast(dtype)
|
||||
return arrow_series
|
||||
|
||||
return self._expr(
|
||||
lambda df: [_lit_arrow_series(df)],
|
||||
depth=0,
|
||||
function_name="lit",
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def all_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
if ignore_nulls:
|
||||
series = (s.fill_null(True, None, None) for s in series)
|
||||
return [reduce(operator.and_, align(*series))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="all_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def any_horizontal(self, *exprs: ArrowExpr, ignore_nulls: bool) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
series: Iterator[ArrowSeries] = chain.from_iterable(e(df) for e in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
if ignore_nulls:
|
||||
series = (s.fill_null(False, None, None) for s in series)
|
||||
return [reduce(operator.or_, align(*series))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="any_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def sum_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
it = chain.from_iterable(expr(df) for expr in exprs)
|
||||
series = (s.fill_null(0, strategy=None, limit=None) for s in it)
|
||||
align = self._series._align_full_broadcast
|
||||
return [reduce(operator.add, align(*series))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="sum_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
int_64 = self._version.dtypes.Int64()
|
||||
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
expr_results = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(
|
||||
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
|
||||
)
|
||||
non_na = align(*(1 - s.is_null().cast(int_64) for s in expr_results))
|
||||
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="mean_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def min_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
init_series, *series = align(init_series, *series)
|
||||
native_series = reduce(
|
||||
pc.min_element_wise, [s.native for s in series], init_series.native
|
||||
)
|
||||
return [
|
||||
ArrowSeries(native_series, name=init_series.name, version=self._version)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="min_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def max_horizontal(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
init_series, *series = align(init_series, *series)
|
||||
native_series = reduce(
|
||||
pc.max_element_wise, [s.native for s in series], init_series.native
|
||||
)
|
||||
return [
|
||||
ArrowSeries(native_series, name=init_series.name, version=self._version)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="max_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def _concat_diagonal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
if self._backend_version >= (14,):
|
||||
return pa.concat_tables(dfs, promote_options="default")
|
||||
return pa.concat_tables(dfs, promote=True) # pragma: no cover
|
||||
|
||||
def _concat_horizontal(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
names = list(chain.from_iterable(df.column_names for df in dfs))
|
||||
arrays = list(chain.from_iterable(df.itercolumns() for df in dfs))
|
||||
return pa.Table.from_arrays(arrays, names=names)
|
||||
|
||||
def _concat_vertical(self, dfs: Sequence[pa.Table], /) -> pa.Table:
|
||||
cols_0 = dfs[0].column_names
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.column_names
|
||||
if cols_current != cols_0:
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0}\n"
|
||||
f" - dataframe {i}: {cols_current}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return pa.concat_tables(dfs)
|
||||
|
||||
@property
|
||||
def selectors(self) -> ArrowSelectorNamespace:
|
||||
return ArrowSelectorNamespace.from_namespace(self)
|
||||
|
||||
def when(self, predicate: ArrowExpr) -> ArrowWhen:
|
||||
return ArrowWhen.from_expr(predicate, context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: ArrowExpr, separator: str, ignore_nulls: bool
|
||||
) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
compliant_series_list = align(
|
||||
*(chain.from_iterable(expr(df) for expr in exprs))
|
||||
)
|
||||
name = compliant_series_list[0].name
|
||||
null_handling: Literal["skip", "emit_null"] = (
|
||||
"skip" if ignore_nulls else "emit_null"
|
||||
)
|
||||
it, separator_scalar = cast_to_comparable_string_types(
|
||||
*(s.native for s in compliant_series_list), separator=separator
|
||||
)
|
||||
# NOTE: stubs indicate `separator` must also be a `ChunkedArray`
|
||||
# Reality: `str` is fine
|
||||
concat_str: Incomplete = pc.binary_join_element_wise
|
||||
compliant = self._series(
|
||||
concat_str(*it, separator_scalar, null_handling=null_handling),
|
||||
name=name,
|
||||
version=self._version,
|
||||
)
|
||||
return [compliant]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="concat_str",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def coalesce(self, *exprs: ArrowExpr) -> ArrowExpr:
|
||||
def func(df: ArrowDataFrame) -> list[ArrowSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
init_series, *series = align(*chain.from_iterable(expr(df) for expr in exprs))
|
||||
return [
|
||||
ArrowSeries(
|
||||
pc.coalesce(init_series.native, *(s.native for s in series)),
|
||||
name=init_series.name,
|
||||
version=self._version,
|
||||
)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="coalesce",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
|
||||
class ArrowWhen(EagerWhen[ArrowDataFrame, ArrowSeries, ArrowExpr, "ChunkedArrayAny"]):
|
||||
@property
|
||||
def _then(self) -> type[ArrowThen]:
|
||||
return ArrowThen
|
||||
|
||||
def _if_then_else(
|
||||
self,
|
||||
when: ChunkedArrayAny,
|
||||
then: ChunkedArrayAny,
|
||||
otherwise: ArrayOrScalar | NonNestedLiteral,
|
||||
/,
|
||||
) -> ChunkedArrayAny:
|
||||
otherwise = pa.nulls(len(when), then.type) if otherwise is None else otherwise
|
||||
return pc.if_else(when, then, otherwise)
|
||||
|
||||
|
||||
class ArrowThen(
|
||||
CompliantThen[ArrowDataFrame, ArrowSeries, ArrowExpr, ArrowWhen], ArrowExpr
|
||||
):
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "whenthen"
|
||||
33
lib/python3.11/site-packages/narwhals/_arrow/selectors.py
Normal file
33
lib/python3.11/site-packages/narwhals/_arrow/selectors.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._compliant import CompliantSelector, EagerSelectorNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame # noqa: F401
|
||||
from narwhals._arrow.series import ArrowSeries # noqa: F401
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
|
||||
|
||||
class ArrowSelectorNamespace(EagerSelectorNamespace["ArrowDataFrame", "ArrowSeries"]):
|
||||
@property
|
||||
def _selector(self) -> type[ArrowSelector]:
|
||||
return ArrowSelector
|
||||
|
||||
|
||||
class ArrowSelector(CompliantSelector["ArrowDataFrame", "ArrowSeries"], ArrowExpr): # type: ignore[misc]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "selector"
|
||||
|
||||
def _to_expr(self) -> ArrowExpr:
|
||||
return ArrowExpr(
|
||||
self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
1191
lib/python3.11/site-packages/narwhals/_arrow/series.py
Normal file
1191
lib/python3.11/site-packages/narwhals/_arrow/series.py
Normal file
File diff suppressed because it is too large
Load Diff
18
lib/python3.11/site-packages/narwhals/_arrow/series_cat.py
Normal file
18
lib/python3.11/site-packages/narwhals/_arrow/series_cat.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import Incomplete
|
||||
|
||||
|
||||
class ArrowSeriesCatNamespace(ArrowSeriesNamespace):
|
||||
def get_categories(self) -> ArrowSeries:
|
||||
# NOTE: Should be `list[pa.DictionaryArray]`, but `DictionaryArray` has no attributes
|
||||
chunks: Incomplete = self.native.chunks
|
||||
return self.with_native(pa.concat_arrays(x.dictionary for x in chunks).unique())
|
||||
226
lib/python3.11/site-packages/narwhals/_arrow/series_dt.py
Normal file
226
lib/python3.11/site-packages/narwhals/_arrow/series_dt.py
Normal file
@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, cast
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import UNITS_DICT, ArrowSeriesNamespace, floordiv_compat, lit
|
||||
from narwhals._constants import (
|
||||
MS_PER_MINUTE,
|
||||
MS_PER_SECOND,
|
||||
NS_PER_MICROSECOND,
|
||||
NS_PER_MILLISECOND,
|
||||
NS_PER_MINUTE,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_DAY,
|
||||
SECONDS_PER_MINUTE,
|
||||
US_PER_MINUTE,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._duration import Interval
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import ChunkedArrayAny, ScalarAny
|
||||
from narwhals.dtypes import Datetime
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
UnitCurrent: TypeAlias = TimeUnit
|
||||
UnitTarget: TypeAlias = TimeUnit
|
||||
BinOpBroadcast: TypeAlias = Callable[[ChunkedArrayAny, ScalarAny], ChunkedArrayAny]
|
||||
IntoRhs: TypeAlias = int
|
||||
|
||||
|
||||
class ArrowSeriesDateTimeNamespace(ArrowSeriesNamespace):
|
||||
_TIMESTAMP_DATE_FACTOR: ClassVar[Mapping[TimeUnit, int]] = {
|
||||
"ns": NS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ms": MS_PER_SECOND,
|
||||
"s": 1,
|
||||
}
|
||||
_TIMESTAMP_DATETIME_OP_FACTOR: ClassVar[
|
||||
Mapping[tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]]
|
||||
] = {
|
||||
("ns", "us"): (floordiv_compat, 1_000),
|
||||
("ns", "ms"): (floordiv_compat, 1_000_000),
|
||||
("us", "ns"): (pc.multiply, NS_PER_MICROSECOND),
|
||||
("us", "ms"): (floordiv_compat, 1_000),
|
||||
("ms", "ns"): (pc.multiply, NS_PER_MILLISECOND),
|
||||
("ms", "us"): (pc.multiply, 1_000),
|
||||
("s", "ns"): (pc.multiply, NS_PER_SECOND),
|
||||
("s", "us"): (pc.multiply, US_PER_SECOND),
|
||||
("s", "ms"): (pc.multiply, MS_PER_SECOND),
|
||||
}
|
||||
|
||||
@property
|
||||
def unit(self) -> TimeUnit: # NOTE: Unsafe (native).
|
||||
return cast("pa.TimestampType[TimeUnit, Any]", self.native.type).unit
|
||||
|
||||
@property
|
||||
def time_zone(self) -> str | None: # NOTE: Unsafe (narwhals).
|
||||
return cast("Datetime", self.compliant.dtype).time_zone
|
||||
|
||||
def to_string(self, format: str) -> ArrowSeries:
|
||||
# PyArrow differs from other libraries in that %S also prints out
|
||||
# the fractional part of the second...:'(
|
||||
# https://arrow.apache.org/docs/python/generated/pyarrow.compute.strftime.html
|
||||
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
|
||||
return self.with_native(pc.strftime(self.native, format))
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> ArrowSeries:
|
||||
if time_zone is not None:
|
||||
result = pc.assume_timezone(pc.local_timestamp(self.native), time_zone)
|
||||
else:
|
||||
result = pc.local_timestamp(self.native)
|
||||
return self.with_native(result)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> ArrowSeries:
|
||||
ser = self.replace_time_zone("UTC") if self.time_zone is None else self.compliant
|
||||
return self.with_native(ser.native.cast(pa.timestamp(self.unit, time_zone)))
|
||||
|
||||
def timestamp(self, time_unit: TimeUnit) -> ArrowSeries:
|
||||
ser = self.compliant
|
||||
dtypes = ser._version.dtypes
|
||||
if isinstance(ser.dtype, dtypes.Datetime):
|
||||
current = ser.dtype.time_unit
|
||||
s_cast = self.native.cast(pa.int64())
|
||||
if current == time_unit:
|
||||
result = s_cast
|
||||
elif item := self._TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
|
||||
fn, factor = item
|
||||
result = fn(s_cast, lit(factor))
|
||||
else: # pragma: no cover
|
||||
msg = f"unexpected time unit {current}, please report an issue at https://github.com/narwhals-dev/narwhals"
|
||||
raise AssertionError(msg)
|
||||
return self.with_native(result)
|
||||
if isinstance(ser.dtype, dtypes.Date):
|
||||
time_s = pc.multiply(self.native.cast(pa.int32()), lit(SECONDS_PER_DAY))
|
||||
factor = self._TIMESTAMP_DATE_FACTOR[time_unit]
|
||||
return self.with_native(pc.multiply(time_s, lit(factor)))
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
|
||||
def date(self) -> ArrowSeries:
|
||||
return self.with_native(self.native.cast(pa.date32()))
|
||||
|
||||
def year(self) -> ArrowSeries:
|
||||
return self.with_native(pc.year(self.native))
|
||||
|
||||
def month(self) -> ArrowSeries:
|
||||
return self.with_native(pc.month(self.native))
|
||||
|
||||
def day(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day(self.native))
|
||||
|
||||
def hour(self) -> ArrowSeries:
|
||||
return self.with_native(pc.hour(self.native))
|
||||
|
||||
def minute(self) -> ArrowSeries:
|
||||
return self.with_native(pc.minute(self.native))
|
||||
|
||||
def second(self) -> ArrowSeries:
|
||||
return self.with_native(pc.second(self.native))
|
||||
|
||||
def millisecond(self) -> ArrowSeries:
|
||||
return self.with_native(pc.millisecond(self.native))
|
||||
|
||||
def microsecond(self) -> ArrowSeries:
|
||||
arr = self.native
|
||||
result = pc.add(pc.multiply(pc.millisecond(arr), lit(1000)), pc.microsecond(arr))
|
||||
return self.with_native(result)
|
||||
|
||||
def nanosecond(self) -> ArrowSeries:
|
||||
result = pc.add(
|
||||
pc.multiply(self.microsecond().native, lit(1000)), pc.nanosecond(self.native)
|
||||
)
|
||||
return self.with_native(result)
|
||||
|
||||
def ordinal_day(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day_of_year(self.native))
|
||||
|
||||
def weekday(self) -> ArrowSeries:
|
||||
return self.with_native(pc.day_of_week(self.native, count_from_zero=False))
|
||||
|
||||
def total_minutes(self) -> ArrowSeries:
|
||||
unit_to_minutes_factor = {
|
||||
"s": SECONDS_PER_MINUTE,
|
||||
"ms": MS_PER_MINUTE,
|
||||
"us": US_PER_MINUTE,
|
||||
"ns": NS_PER_MINUTE,
|
||||
}
|
||||
factor = lit(unit_to_minutes_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_seconds(self) -> ArrowSeries:
|
||||
unit_to_seconds_factor = {
|
||||
"s": 1,
|
||||
"ms": MS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ns": NS_PER_SECOND,
|
||||
}
|
||||
factor = lit(unit_to_seconds_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_milliseconds(self) -> ArrowSeries:
|
||||
unit_to_milli_factor = {
|
||||
"s": 1e3, # seconds
|
||||
"ms": 1, # milli
|
||||
"us": 1e3, # micro
|
||||
"ns": 1e6, # nano
|
||||
}
|
||||
factor = lit(unit_to_milli_factor[self.unit], type=pa.int64())
|
||||
if self.unit == "s":
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_microseconds(self) -> ArrowSeries:
|
||||
unit_to_micro_factor = {
|
||||
"s": 1e6, # seconds
|
||||
"ms": 1e3, # milli
|
||||
"us": 1, # micro
|
||||
"ns": 1e3, # nano
|
||||
}
|
||||
factor = lit(unit_to_micro_factor[self.unit], type=pa.int64())
|
||||
if self.unit in {"s", "ms"}:
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
return self.with_native(pc.divide(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def total_nanoseconds(self) -> ArrowSeries:
|
||||
unit_to_nano_factor = {
|
||||
"s": NS_PER_SECOND,
|
||||
"ms": NS_PER_MILLISECOND,
|
||||
"us": NS_PER_MICROSECOND,
|
||||
"ns": 1,
|
||||
}
|
||||
factor = lit(unit_to_nano_factor[self.unit], type=pa.int64())
|
||||
return self.with_native(pc.multiply(self.native, factor).cast(pa.int64()))
|
||||
|
||||
def truncate(self, every: str) -> ArrowSeries:
|
||||
interval = Interval.parse(every)
|
||||
return self.with_native(
|
||||
pc.floor_temporal(self.native, interval.multiple, UNITS_DICT[interval.unit])
|
||||
)
|
||||
|
||||
def offset_by(self, by: str) -> ArrowSeries:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
native = self.native
|
||||
if interval.unit in {"y", "q", "mo"}:
|
||||
msg = f"Offsetting by {interval.unit} is not yet supported for pyarrow."
|
||||
raise NotImplementedError(msg)
|
||||
dtype = self.compliant.dtype
|
||||
datetime_dtype = self.version.dtypes.Datetime
|
||||
if interval.unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
|
||||
offset: pa.DurationScalar[Any] = lit(interval.to_timedelta())
|
||||
native_naive = pc.local_timestamp(native)
|
||||
result = pc.assume_timezone(pc.add(native_naive, offset), dtype.time_zone)
|
||||
return self.with_native(result)
|
||||
if interval.unit == "ns": # pragma: no cover
|
||||
offset = lit(interval.multiple, pa.duration("ns")) # type: ignore[assignment]
|
||||
else:
|
||||
offset = lit(interval.to_timedelta())
|
||||
return self.with_native(pc.add(native, offset))
|
||||
24
lib/python3.11/site-packages/narwhals/_arrow/series_list.py
Normal file
24
lib/python3.11/site-packages/narwhals/_arrow/series_list.py
Normal file
@ -0,0 +1,24 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
|
||||
class ArrowSeriesListNamespace(ArrowSeriesNamespace):
|
||||
def len(self) -> ArrowSeries:
|
||||
return self.with_native(pc.list_value_length(self.native).cast(pa.uint32()))
|
||||
|
||||
unique = not_implemented()
|
||||
|
||||
contains = not_implemented()
|
||||
|
||||
def get(self, index: int) -> ArrowSeries:
|
||||
return self.with_native(pc.list_element(self.native, index))
|
||||
115
lib/python3.11/site-packages/narwhals/_arrow/series_str.py
Normal file
115
lib/python3.11/site-packages/narwhals/_arrow/series_str.py
Normal file
@ -0,0 +1,115 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import string
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace, lit, parse_datetime_format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import Incomplete
|
||||
|
||||
|
||||
class ArrowSeriesStringNamespace(ArrowSeriesNamespace):
|
||||
def len_chars(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_length(self.native))
|
||||
|
||||
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> ArrowSeries:
|
||||
fn = pc.replace_substring if literal else pc.replace_substring_regex
|
||||
try:
|
||||
arr = fn(self.native, pattern, replacement=value, max_replacements=n)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "PyArrow backed `.str.replace` only supports str replacement values"
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
return self.with_native(arr)
|
||||
|
||||
def replace_all(self, pattern: str, value: str, *, literal: bool) -> ArrowSeries:
|
||||
try:
|
||||
return self.replace(pattern, value, literal=literal, n=-1)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "PyArrow backed `.str.replace_all` only supports str replacement values."
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
|
||||
def strip_chars(self, characters: str | None) -> ArrowSeries:
|
||||
return self.with_native(
|
||||
pc.utf8_trim(self.native, characters or string.whitespace)
|
||||
)
|
||||
|
||||
def starts_with(self, prefix: str) -> ArrowSeries:
|
||||
return self.with_native(pc.equal(self.slice(0, len(prefix)).native, lit(prefix)))
|
||||
|
||||
def ends_with(self, suffix: str) -> ArrowSeries:
|
||||
return self.with_native(
|
||||
pc.equal(self.slice(-len(suffix), None).native, lit(suffix))
|
||||
)
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> ArrowSeries:
|
||||
check_func = pc.match_substring if literal else pc.match_substring_regex
|
||||
return self.with_native(check_func(self.native, pattern))
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> ArrowSeries:
|
||||
stop = offset + length if length is not None else None
|
||||
return self.with_native(
|
||||
pc.utf8_slice_codeunits(self.native, start=offset, stop=stop)
|
||||
)
|
||||
|
||||
def split(self, by: str) -> ArrowSeries:
|
||||
split_series = pc.split_pattern(self.native, by) # type: ignore[call-overload]
|
||||
return self.with_native(split_series)
|
||||
|
||||
def to_datetime(self, format: str | None) -> ArrowSeries:
|
||||
format = parse_datetime_format(self.native) if format is None else format
|
||||
timestamp_array = pc.strptime(self.native, format=format, unit="us")
|
||||
return self.with_native(timestamp_array)
|
||||
|
||||
def to_date(self, format: str | None) -> ArrowSeries:
|
||||
return self.to_datetime(format=format).dt.date()
|
||||
|
||||
def to_uppercase(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_upper(self.native))
|
||||
|
||||
def to_lowercase(self) -> ArrowSeries:
|
||||
return self.with_native(pc.utf8_lower(self.native))
|
||||
|
||||
def zfill(self, width: int) -> ArrowSeries:
|
||||
binary_join: Incomplete = pc.binary_join_element_wise
|
||||
native = self.native
|
||||
hyphen, plus = lit("-"), lit("+")
|
||||
first_char, remaining_chars = (
|
||||
self.slice(0, 1).native,
|
||||
self.slice(1, None).native,
|
||||
)
|
||||
|
||||
# Conditions
|
||||
less_than_width = pc.less(pc.utf8_length(native), lit(width))
|
||||
starts_with_hyphen = pc.equal(first_char, hyphen)
|
||||
starts_with_plus = pc.equal(first_char, plus)
|
||||
|
||||
conditions = pc.make_struct(
|
||||
pc.and_(starts_with_hyphen, less_than_width),
|
||||
pc.and_(starts_with_plus, less_than_width),
|
||||
less_than_width,
|
||||
)
|
||||
|
||||
# Cases
|
||||
padded_remaining_chars = pc.utf8_lpad(remaining_chars, width - 1, padding="0")
|
||||
|
||||
result = pc.case_when(
|
||||
conditions,
|
||||
binary_join(
|
||||
pa.repeat(hyphen, len(native)), padded_remaining_chars, ""
|
||||
), # starts with hyphen and less than width
|
||||
binary_join(
|
||||
pa.repeat(plus, len(native)), padded_remaining_chars, ""
|
||||
), # starts with plus and less than width
|
||||
pc.utf8_lpad(native, width=width, padding="0"), # less than width
|
||||
native,
|
||||
)
|
||||
return self.with_native(result)
|
||||
@ -0,0 +1,15 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._arrow.utils import ArrowSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
|
||||
class ArrowSeriesStructNamespace(ArrowSeriesNamespace):
|
||||
def field(self, name: str) -> ArrowSeries:
|
||||
return self.with_native(pc.struct_field(self.native, name)).alias(name)
|
||||
72
lib/python3.11/site-packages/narwhals/_arrow/typing.py
Normal file
72
lib/python3.11/site-packages/narwhals/_arrow/typing.py
Normal file
@ -0,0 +1,72 @@
|
||||
from __future__ import annotations # pragma: no cover
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING, # pragma: no cover
|
||||
Any, # pragma: no cover
|
||||
TypeVar, # pragma: no cover
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
from typing import Generic, Literal
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
import pyarrow as pa
|
||||
from pyarrow.__lib_pxi.table import (
|
||||
AggregateOptions, # noqa: F401
|
||||
Aggregation, # noqa: F401
|
||||
)
|
||||
from pyarrow._stubs_typing import ( # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource] # pyright: ignore[reportMissingModuleSource]
|
||||
Indices, # noqa: F401
|
||||
Mask, # noqa: F401
|
||||
Order, # noqa: F401
|
||||
)
|
||||
|
||||
from narwhals._arrow.expr import ArrowExpr
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
IntoArrowExpr: TypeAlias = "ArrowExpr | ArrowSeries"
|
||||
TieBreaker: TypeAlias = Literal["min", "max", "first", "dense"]
|
||||
NullPlacement: TypeAlias = Literal["at_start", "at_end"]
|
||||
NativeIntervalUnit: TypeAlias = Literal[
|
||||
"year",
|
||||
"quarter",
|
||||
"month",
|
||||
"week",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
"millisecond",
|
||||
"microsecond",
|
||||
"nanosecond",
|
||||
]
|
||||
|
||||
ChunkedArrayAny: TypeAlias = pa.ChunkedArray[Any]
|
||||
ArrayAny: TypeAlias = pa.Array[Any]
|
||||
ArrayOrChunkedArray: TypeAlias = "ArrayAny | ChunkedArrayAny"
|
||||
ScalarAny: TypeAlias = pa.Scalar[Any]
|
||||
ArrayOrScalar: TypeAlias = "ArrayOrChunkedArray | ScalarAny"
|
||||
ArrayOrScalarT1 = TypeVar("ArrayOrScalarT1", ArrayAny, ChunkedArrayAny, ScalarAny)
|
||||
ArrayOrScalarT2 = TypeVar("ArrayOrScalarT2", ArrayAny, ChunkedArrayAny, ScalarAny)
|
||||
_AsPyType = TypeVar("_AsPyType")
|
||||
|
||||
class _BasicDataType(pa.DataType, Generic[_AsPyType]): ...
|
||||
|
||||
|
||||
Incomplete: TypeAlias = Any # pragma: no cover
|
||||
"""
|
||||
Marker for working code that fails on the stubs.
|
||||
|
||||
Common issues:
|
||||
- Annotated for `Array`, but not `ChunkedArray`
|
||||
- Relies on typing information that the stubs don't provide statically
|
||||
- Missing attributes
|
||||
- Incorrect return types
|
||||
- Inconsistent use of generic/concrete types
|
||||
- `_clone_signature` used on signatures that are not identical
|
||||
"""
|
||||
438
lib/python3.11/site-packages/narwhals/_arrow/utils.py
Normal file
438
lib/python3.11/site-packages/narwhals/_arrow/utils.py
Normal file
@ -0,0 +1,438 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
import pyarrow as pa
|
||||
import pyarrow.compute as pc
|
||||
|
||||
from narwhals._compliant import EagerSeriesNamespace
|
||||
from narwhals._utils import Version, isinstance_or_issubclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
from narwhals._arrow.typing import (
|
||||
ArrayAny,
|
||||
ArrayOrScalar,
|
||||
ArrayOrScalarT1,
|
||||
ArrayOrScalarT2,
|
||||
ChunkedArrayAny,
|
||||
NativeIntervalUnit,
|
||||
ScalarAny,
|
||||
)
|
||||
from narwhals._duration import IntervalUnit
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType, PythonLiteral
|
||||
|
||||
# NOTE: stubs don't allow for `ChunkedArray[StructArray]`
|
||||
# Intended to represent the `.chunks` property storing `list[pa.StructArray]`
|
||||
ChunkedArrayStructArray: TypeAlias = ChunkedArrayAny
|
||||
|
||||
def is_timestamp(t: Any) -> TypeIs[pa.TimestampType[Any, Any]]: ...
|
||||
def is_duration(t: Any) -> TypeIs[pa.DurationType[Any]]: ...
|
||||
def is_list(t: Any) -> TypeIs[pa.ListType[Any]]: ...
|
||||
def is_large_list(t: Any) -> TypeIs[pa.LargeListType[Any]]: ...
|
||||
def is_fixed_size_list(t: Any) -> TypeIs[pa.FixedSizeListType[Any, Any]]: ...
|
||||
def is_dictionary(t: Any) -> TypeIs[pa.DictionaryType[Any, Any, Any]]: ...
|
||||
def extract_regex(
|
||||
strings: ChunkedArrayAny,
|
||||
/,
|
||||
pattern: str,
|
||||
*,
|
||||
options: Any = None,
|
||||
memory_pool: Any = None,
|
||||
) -> ChunkedArrayStructArray: ...
|
||||
else:
|
||||
from pyarrow.compute import extract_regex
|
||||
from pyarrow.types import (
|
||||
is_dictionary, # noqa: F401
|
||||
is_duration,
|
||||
is_fixed_size_list,
|
||||
is_large_list,
|
||||
is_list,
|
||||
is_timestamp,
|
||||
)
|
||||
|
||||
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
|
||||
lit = pa.scalar
|
||||
"""Alias for `pyarrow.scalar`."""
|
||||
|
||||
|
||||
def extract_py_scalar(value: Any, /) -> Any:
|
||||
from narwhals._arrow.series import maybe_extract_py_scalar
|
||||
|
||||
return maybe_extract_py_scalar(value, return_py_scalar=True)
|
||||
|
||||
|
||||
def is_array_or_scalar(obj: Any) -> TypeIs[ArrayOrScalar]:
|
||||
"""Return True for any base `pyarrow` container."""
|
||||
return isinstance(obj, (pa.ChunkedArray, pa.Array, pa.Scalar))
|
||||
|
||||
|
||||
def chunked_array(
|
||||
arr: ArrayOrScalar | list[Iterable[Any]], dtype: pa.DataType | None = None, /
|
||||
) -> ChunkedArrayAny:
|
||||
if isinstance(arr, pa.ChunkedArray):
|
||||
return arr
|
||||
if isinstance(arr, list):
|
||||
return pa.chunked_array(arr, dtype)
|
||||
return pa.chunked_array([arr], dtype)
|
||||
|
||||
|
||||
def nulls_like(n: int, series: ArrowSeries) -> ArrayAny:
|
||||
"""Create a strongly-typed Array instance with all elements null.
|
||||
|
||||
Uses the type of `series`, without upseting `mypy`.
|
||||
"""
|
||||
return pa.nulls(n, series.native.type)
|
||||
|
||||
|
||||
def zeros(n: int, /) -> pa.Int64Array:
|
||||
return pa.repeat(0, n)
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype(dtype: pa.DataType, version: Version) -> DType: # noqa: C901, PLR0912
|
||||
dtypes = version.dtypes
|
||||
if pa.types.is_int64(dtype):
|
||||
return dtypes.Int64()
|
||||
if pa.types.is_int32(dtype):
|
||||
return dtypes.Int32()
|
||||
if pa.types.is_int16(dtype):
|
||||
return dtypes.Int16()
|
||||
if pa.types.is_int8(dtype):
|
||||
return dtypes.Int8()
|
||||
if pa.types.is_uint64(dtype):
|
||||
return dtypes.UInt64()
|
||||
if pa.types.is_uint32(dtype):
|
||||
return dtypes.UInt32()
|
||||
if pa.types.is_uint16(dtype):
|
||||
return dtypes.UInt16()
|
||||
if pa.types.is_uint8(dtype):
|
||||
return dtypes.UInt8()
|
||||
if pa.types.is_boolean(dtype):
|
||||
return dtypes.Boolean()
|
||||
if pa.types.is_float64(dtype):
|
||||
return dtypes.Float64()
|
||||
if pa.types.is_float32(dtype):
|
||||
return dtypes.Float32()
|
||||
# bug in coverage? it shows `31->exit` (where `31` is currently the line number of
|
||||
# the next line), even though both when the if condition is true and false are covered
|
||||
if ( # pragma: no cover
|
||||
pa.types.is_string(dtype)
|
||||
or pa.types.is_large_string(dtype)
|
||||
or getattr(pa.types, "is_string_view", lambda _: False)(dtype)
|
||||
):
|
||||
return dtypes.String()
|
||||
if pa.types.is_date32(dtype):
|
||||
return dtypes.Date()
|
||||
if is_timestamp(dtype):
|
||||
return dtypes.Datetime(time_unit=dtype.unit, time_zone=dtype.tz)
|
||||
if is_duration(dtype):
|
||||
return dtypes.Duration(time_unit=dtype.unit)
|
||||
if pa.types.is_dictionary(dtype):
|
||||
return dtypes.Categorical()
|
||||
if pa.types.is_struct(dtype):
|
||||
return dtypes.Struct(
|
||||
[
|
||||
dtypes.Field(
|
||||
dtype.field(i).name,
|
||||
native_to_narwhals_dtype(dtype.field(i).type, version),
|
||||
)
|
||||
for i in range(dtype.num_fields)
|
||||
]
|
||||
)
|
||||
if is_list(dtype) or is_large_list(dtype):
|
||||
return dtypes.List(native_to_narwhals_dtype(dtype.value_type, version))
|
||||
if is_fixed_size_list(dtype):
|
||||
return dtypes.Array(
|
||||
native_to_narwhals_dtype(dtype.value_type, version), dtype.list_size
|
||||
)
|
||||
if pa.types.is_decimal(dtype):
|
||||
return dtypes.Decimal()
|
||||
if pa.types.is_time32(dtype) or pa.types.is_time64(dtype):
|
||||
return dtypes.Time()
|
||||
if pa.types.is_binary(dtype):
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_PA_DTYPES: Mapping[type[DType], pa.DataType] = {
|
||||
dtypes.Float64: pa.float64(),
|
||||
dtypes.Float32: pa.float32(),
|
||||
dtypes.Binary: pa.binary(),
|
||||
dtypes.String: pa.string(),
|
||||
dtypes.Boolean: pa.bool_(),
|
||||
dtypes.Categorical: pa.dictionary(pa.uint32(), pa.string()),
|
||||
dtypes.Date: pa.date32(),
|
||||
dtypes.Time: pa.time64("ns"),
|
||||
dtypes.Int8: pa.int8(),
|
||||
dtypes.Int16: pa.int16(),
|
||||
dtypes.Int32: pa.int32(),
|
||||
dtypes.Int64: pa.int64(),
|
||||
dtypes.UInt8: pa.uint8(),
|
||||
dtypes.UInt16: pa.uint16(),
|
||||
dtypes.UInt32: pa.uint32(),
|
||||
dtypes.UInt64: pa.uint64(),
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Object)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> pa.DataType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if pa_type := NW_TO_PA_DTYPES.get(base_type):
|
||||
return pa_type
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
unit = dtype.time_unit
|
||||
return pa.timestamp(unit, tz) if (tz := dtype.time_zone) else pa.timestamp(unit)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return pa.duration(dtype.time_unit)
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
return pa.list_(value_type=narwhals_to_native_dtype(dtype.inner, version=version))
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
return pa.struct(
|
||||
[
|
||||
(field.name, narwhals_to_native_dtype(field.dtype, version=version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
inner = narwhals_to_native_dtype(dtype.inner, version=version)
|
||||
list_size = dtype.size
|
||||
return pa.list_(inner, list_size=list_size)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for PyArrow."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def extract_native(
|
||||
lhs: ArrowSeries, rhs: ArrowSeries | PythonLiteral | ScalarAny
|
||||
) -> tuple[ChunkedArrayAny | ScalarAny, ChunkedArrayAny | ScalarAny]:
|
||||
"""Extract native objects in binary operation.
|
||||
|
||||
If the comparison isn't supported, return `NotImplemented` so that the
|
||||
"right-hand-side" operation (e.g. `__radd__`) can be tried.
|
||||
|
||||
If one of the two sides has a `_broadcast` flag, then extract the scalar
|
||||
underneath it so that PyArrow can do its own broadcasting.
|
||||
"""
|
||||
from narwhals._arrow.series import ArrowSeries
|
||||
|
||||
if rhs is None: # pragma: no cover
|
||||
return lhs.native, lit(None, type=lhs._type)
|
||||
|
||||
if isinstance(rhs, ArrowSeries):
|
||||
if lhs._broadcast and not rhs._broadcast:
|
||||
return lhs.native[0], rhs.native
|
||||
if rhs._broadcast:
|
||||
return lhs.native, rhs.native[0]
|
||||
return lhs.native, rhs.native
|
||||
|
||||
if isinstance(rhs, list):
|
||||
msg = "Expected Series or scalar, got list."
|
||||
raise TypeError(msg)
|
||||
|
||||
return lhs.native, rhs if isinstance(rhs, pa.Scalar) else lit(rhs)
|
||||
|
||||
|
||||
def floordiv_compat(left: ArrayOrScalar, right: ArrayOrScalar, /) -> Any:
|
||||
# The following lines are adapted from pandas' pyarrow implementation.
|
||||
# Ref: https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L124-L154
|
||||
|
||||
if pa.types.is_integer(left.type) and pa.types.is_integer(right.type):
|
||||
divided = pc.divide_checked(left, right)
|
||||
# TODO @dangotbanned: Use a `TypeVar` in guards
|
||||
# Narrowing to a `Union` isn't interacting well with the rest of the stubs
|
||||
# https://github.com/zen-xu/pyarrow-stubs/pull/215
|
||||
if pa.types.is_signed_integer(divided.type):
|
||||
div_type = cast("pa._lib.Int64Type", divided.type)
|
||||
has_remainder = pc.not_equal(pc.multiply(divided, right), left)
|
||||
has_one_negative_operand = pc.less(
|
||||
pc.bit_wise_xor(left, right), lit(0, div_type)
|
||||
)
|
||||
result = pc.if_else(
|
||||
pc.and_(has_remainder, has_one_negative_operand),
|
||||
pc.subtract(divided, lit(1, div_type)),
|
||||
divided,
|
||||
)
|
||||
else:
|
||||
result = divided # pragma: no cover
|
||||
result = result.cast(left.type)
|
||||
else:
|
||||
divided = pc.divide(left, right)
|
||||
result = pc.floor(divided)
|
||||
return result
|
||||
|
||||
|
||||
def cast_for_truediv(
|
||||
arrow_array: ArrayOrScalarT1, pa_object: ArrayOrScalarT2
|
||||
) -> tuple[ArrayOrScalarT1, ArrayOrScalarT2]:
|
||||
# Lifted from:
|
||||
# https://github.com/pandas-dev/pandas/blob/262fcfbffcee5c3116e86a951d8b693f90411e68/pandas/core/arrays/arrow/array.py#L108-L122
|
||||
# Ensure int / int -> float mirroring Python/Numpy behavior
|
||||
# as pc.divide_checked(int, int) -> int
|
||||
if pa.types.is_integer(arrow_array.type) and pa.types.is_integer(pa_object.type):
|
||||
# GH: 56645. # noqa: ERA001
|
||||
# https://github.com/apache/arrow/issues/35563
|
||||
return arrow_array.cast(pa.float64(), safe=False), pa_object.cast(
|
||||
pa.float64(), safe=False
|
||||
)
|
||||
|
||||
return arrow_array, pa_object
|
||||
|
||||
|
||||
# Regex for date, time, separator and timezone components
|
||||
DATE_RE = r"(?P<date>\d{1,4}[-/.]\d{1,2}[-/.]\d{1,4}|\d{8})"
|
||||
SEP_RE = r"(?P<sep>\s|T)"
|
||||
TIME_RE = r"(?P<time>\d{2}:\d{2}(?::\d{2})?|\d{6}?)" # \s*(?P<period>[AP]M)?)?
|
||||
HMS_RE = r"^(?P<hms>\d{2}:\d{2}:\d{2})$"
|
||||
HM_RE = r"^(?P<hm>\d{2}:\d{2})$"
|
||||
HMS_RE_NO_SEP = r"^(?P<hms_no_sep>\d{6})$"
|
||||
TZ_RE = r"(?P<tz>Z|[+-]\d{2}:?\d{2})" # Matches 'Z', '+02:00', '+0200', '+02', etc.
|
||||
FULL_RE = rf"{DATE_RE}{SEP_RE}?{TIME_RE}?{TZ_RE}?$"
|
||||
|
||||
# Separate regexes for different date formats
|
||||
YMD_RE = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])$"
|
||||
DMY_RE = r"^(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep1>[-/.])(?P<month>0[1-9]|1[0-2])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
|
||||
MDY_RE = r"^(?P<month>0[1-9]|1[0-2])(?P<sep1>[-/.])(?P<day>0[1-9]|[12][0-9]|3[01])(?P<sep2>[-/.])(?P<year>(?:[12][0-9])?[0-9]{2})$"
|
||||
YMD_RE_NO_SEP = r"^(?P<year>(?:[12][0-9])?[0-9]{2})(?P<month>0[1-9]|1[0-2])(?P<day>0[1-9]|[12][0-9]|3[01])$"
|
||||
|
||||
DATE_FORMATS = (
|
||||
(YMD_RE_NO_SEP, "%Y%m%d"),
|
||||
(YMD_RE, "%Y-%m-%d"),
|
||||
(DMY_RE, "%d-%m-%Y"),
|
||||
(MDY_RE, "%m-%d-%Y"),
|
||||
)
|
||||
TIME_FORMATS = ((HMS_RE, "%H:%M:%S"), (HM_RE, "%H:%M"), (HMS_RE_NO_SEP, "%H%M%S"))
|
||||
|
||||
|
||||
def _extract_regex_concat_arrays(
|
||||
strings: ChunkedArrayAny,
|
||||
/,
|
||||
pattern: str,
|
||||
*,
|
||||
options: Any = None,
|
||||
memory_pool: Any = None,
|
||||
) -> pa.StructArray:
|
||||
r = pa.concat_arrays(
|
||||
extract_regex(strings, pattern, options=options, memory_pool=memory_pool).chunks
|
||||
)
|
||||
return cast("pa.StructArray", r)
|
||||
|
||||
|
||||
def parse_datetime_format(arr: ChunkedArrayAny) -> str:
|
||||
"""Try to infer datetime format from StringArray."""
|
||||
matches = _extract_regex_concat_arrays(arr.drop_null().slice(0, 10), pattern=FULL_RE)
|
||||
if not pc.all(matches.is_valid()).as_py():
|
||||
msg = (
|
||||
"Unable to infer datetime format, provided format is not supported. "
|
||||
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
separators = matches.field("sep")
|
||||
tz = matches.field("tz")
|
||||
|
||||
# separators and time zones must be unique
|
||||
if pc.count(pc.unique(separators)).as_py() > 1:
|
||||
msg = "Found multiple separator values while inferring datetime format."
|
||||
raise ValueError(msg)
|
||||
|
||||
if pc.count(pc.unique(tz)).as_py() > 1:
|
||||
msg = "Found multiple timezone values while inferring datetime format."
|
||||
raise ValueError(msg)
|
||||
|
||||
date_value = _parse_date_format(cast("pc.StringArray", matches.field("date")))
|
||||
time_value = _parse_time_format(cast("pc.StringArray", matches.field("time")))
|
||||
|
||||
sep_value = separators[0].as_py()
|
||||
tz_value = "%z" if tz[0].as_py() else ""
|
||||
|
||||
return f"{date_value}{sep_value}{time_value}{tz_value}"
|
||||
|
||||
|
||||
def _parse_date_format(arr: pc.StringArray) -> str:
|
||||
for date_rgx, date_fmt in DATE_FORMATS:
|
||||
matches = pc.extract_regex(arr, pattern=date_rgx)
|
||||
if date_fmt == "%Y%m%d" and pc.all(matches.is_valid()).as_py():
|
||||
return date_fmt
|
||||
if (
|
||||
pc.all(matches.is_valid()).as_py()
|
||||
and pc.count(pc.unique(sep1 := matches.field("sep1"))).as_py() == 1
|
||||
and pc.count(pc.unique(sep2 := matches.field("sep2"))).as_py() == 1
|
||||
and (date_sep_value := sep1[0].as_py()) == sep2[0].as_py()
|
||||
):
|
||||
return date_fmt.replace("-", date_sep_value)
|
||||
|
||||
msg = (
|
||||
"Unable to infer datetime format. "
|
||||
"Please report a bug to https://github.com/narwhals-dev/narwhals/issues"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
def _parse_time_format(arr: pc.StringArray) -> str:
|
||||
for time_rgx, time_fmt in TIME_FORMATS:
|
||||
matches = pc.extract_regex(arr, pattern=time_rgx)
|
||||
if pc.all(matches.is_valid()).as_py():
|
||||
return time_fmt
|
||||
return ""
|
||||
|
||||
|
||||
def pad_series(
|
||||
series: ArrowSeries, *, window_size: int, center: bool
|
||||
) -> tuple[ArrowSeries, int]:
|
||||
"""Pad series with None values on the left and/or right side, depending on the specified parameters.
|
||||
|
||||
Arguments:
|
||||
series: The input ArrowSeries to be padded.
|
||||
window_size: The desired size of the window.
|
||||
center: Specifies whether to center the padding or not.
|
||||
|
||||
Returns:
|
||||
A tuple containing the padded ArrowSeries and the offset value.
|
||||
"""
|
||||
if not center:
|
||||
return series, 0
|
||||
offset_left = window_size // 2
|
||||
# subtract one if window_size is even
|
||||
offset_right = offset_left - (window_size % 2 == 0)
|
||||
pad_left = pa.array([None] * offset_left, type=series._type)
|
||||
pad_right = pa.array([None] * offset_right, type=series._type)
|
||||
concat = pa.concat_arrays([pad_left, *series.native.chunks, pad_right])
|
||||
return series._with_native(concat), offset_left + offset_right
|
||||
|
||||
|
||||
def cast_to_comparable_string_types(
|
||||
*chunked_arrays: ChunkedArrayAny, separator: str
|
||||
) -> tuple[Iterator[ChunkedArrayAny], ScalarAny]:
|
||||
# Ensure `chunked_arrays` are either all `string` or all `large_string`.
|
||||
dtype = (
|
||||
pa.string() # (PyArrow default)
|
||||
if not any(pa.types.is_large_string(ca.type) for ca in chunked_arrays)
|
||||
else pa.large_string()
|
||||
)
|
||||
return (ca.cast(dtype) for ca in chunked_arrays), lit(separator, dtype)
|
||||
|
||||
|
||||
class ArrowSeriesNamespace(EagerSeriesNamespace["ArrowSeries", "ChunkedArrayAny"]): ...
|
||||
103
lib/python3.11/site-packages/narwhals/_compliant/__init__.py
Normal file
103
lib/python3.11/site-packages/narwhals/_compliant/__init__.py
Normal file
@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from narwhals._compliant.dataframe import (
|
||||
CompliantDataFrame,
|
||||
CompliantFrame,
|
||||
CompliantLazyFrame,
|
||||
EagerDataFrame,
|
||||
)
|
||||
from narwhals._compliant.expr import (
|
||||
CompliantExpr,
|
||||
DepthTrackingExpr,
|
||||
EagerExpr,
|
||||
LazyExpr,
|
||||
LazyExprNamespace,
|
||||
)
|
||||
from narwhals._compliant.group_by import (
|
||||
CompliantGroupBy,
|
||||
DepthTrackingGroupBy,
|
||||
EagerGroupBy,
|
||||
)
|
||||
from narwhals._compliant.namespace import (
|
||||
CompliantNamespace,
|
||||
DepthTrackingNamespace,
|
||||
EagerNamespace,
|
||||
LazyNamespace,
|
||||
)
|
||||
from narwhals._compliant.selectors import (
|
||||
CompliantSelector,
|
||||
CompliantSelectorNamespace,
|
||||
EagerSelectorNamespace,
|
||||
LazySelectorNamespace,
|
||||
)
|
||||
from narwhals._compliant.series import (
|
||||
CompliantSeries,
|
||||
EagerSeries,
|
||||
EagerSeriesCatNamespace,
|
||||
EagerSeriesDateTimeNamespace,
|
||||
EagerSeriesHist,
|
||||
EagerSeriesListNamespace,
|
||||
EagerSeriesNamespace,
|
||||
EagerSeriesStringNamespace,
|
||||
EagerSeriesStructNamespace,
|
||||
)
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprT,
|
||||
CompliantFrameT,
|
||||
CompliantSeriesOrNativeExprT_co,
|
||||
CompliantSeriesT,
|
||||
EagerDataFrameT,
|
||||
EagerSeriesT,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
NativeFrameT_co,
|
||||
NativeSeriesT_co,
|
||||
)
|
||||
from narwhals._compliant.when_then import CompliantThen, CompliantWhen, EagerWhen
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
|
||||
__all__ = [
|
||||
"CompliantDataFrame",
|
||||
"CompliantExpr",
|
||||
"CompliantExprT",
|
||||
"CompliantFrame",
|
||||
"CompliantFrameT",
|
||||
"CompliantGroupBy",
|
||||
"CompliantLazyFrame",
|
||||
"CompliantNamespace",
|
||||
"CompliantSelector",
|
||||
"CompliantSelectorNamespace",
|
||||
"CompliantSeries",
|
||||
"CompliantSeriesOrNativeExprT_co",
|
||||
"CompliantSeriesT",
|
||||
"CompliantThen",
|
||||
"CompliantWhen",
|
||||
"DepthTrackingExpr",
|
||||
"DepthTrackingGroupBy",
|
||||
"DepthTrackingNamespace",
|
||||
"EagerDataFrame",
|
||||
"EagerDataFrameT",
|
||||
"EagerExpr",
|
||||
"EagerGroupBy",
|
||||
"EagerNamespace",
|
||||
"EagerSelectorNamespace",
|
||||
"EagerSeries",
|
||||
"EagerSeriesCatNamespace",
|
||||
"EagerSeriesDateTimeNamespace",
|
||||
"EagerSeriesHist",
|
||||
"EagerSeriesListNamespace",
|
||||
"EagerSeriesNamespace",
|
||||
"EagerSeriesStringNamespace",
|
||||
"EagerSeriesStructNamespace",
|
||||
"EagerSeriesT",
|
||||
"EagerWhen",
|
||||
"EvalNames",
|
||||
"EvalSeries",
|
||||
"LazyExpr",
|
||||
"LazyExprNamespace",
|
||||
"LazyNamespace",
|
||||
"LazySelectorNamespace",
|
||||
"NativeFrameT_co",
|
||||
"NativeSeriesT_co",
|
||||
"WindowInputs",
|
||||
]
|
||||
@ -0,0 +1,94 @@
|
||||
"""`Expr` and `Series` namespace accessor protocols."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Protocol
|
||||
|
||||
from narwhals._utils import CompliantT_co, _StoresCompliant
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Callable
|
||||
|
||||
from narwhals.typing import NonNestedLiteral, TimeUnit
|
||||
|
||||
__all__ = [
|
||||
"CatNamespace",
|
||||
"DateTimeNamespace",
|
||||
"ListNamespace",
|
||||
"NameNamespace",
|
||||
"StringNamespace",
|
||||
"StructNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CatNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def get_categories(self) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class DateTimeNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def to_string(self, format: str) -> CompliantT_co: ...
|
||||
def replace_time_zone(self, time_zone: str | None) -> CompliantT_co: ...
|
||||
def convert_time_zone(self, time_zone: str) -> CompliantT_co: ...
|
||||
def timestamp(self, time_unit: TimeUnit) -> CompliantT_co: ...
|
||||
def date(self) -> CompliantT_co: ...
|
||||
def year(self) -> CompliantT_co: ...
|
||||
def month(self) -> CompliantT_co: ...
|
||||
def day(self) -> CompliantT_co: ...
|
||||
def hour(self) -> CompliantT_co: ...
|
||||
def minute(self) -> CompliantT_co: ...
|
||||
def second(self) -> CompliantT_co: ...
|
||||
def millisecond(self) -> CompliantT_co: ...
|
||||
def microsecond(self) -> CompliantT_co: ...
|
||||
def nanosecond(self) -> CompliantT_co: ...
|
||||
def ordinal_day(self) -> CompliantT_co: ...
|
||||
def weekday(self) -> CompliantT_co: ...
|
||||
def total_minutes(self) -> CompliantT_co: ...
|
||||
def total_seconds(self) -> CompliantT_co: ...
|
||||
def total_milliseconds(self) -> CompliantT_co: ...
|
||||
def total_microseconds(self) -> CompliantT_co: ...
|
||||
def total_nanoseconds(self) -> CompliantT_co: ...
|
||||
def truncate(self, every: str) -> CompliantT_co: ...
|
||||
def offset_by(self, by: str) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class ListNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def get(self, index: int) -> CompliantT_co: ...
|
||||
|
||||
def len(self) -> CompliantT_co: ...
|
||||
|
||||
def unique(self) -> CompliantT_co: ...
|
||||
def contains(self, item: NonNestedLiteral) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class NameNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def keep(self) -> CompliantT_co: ...
|
||||
def map(self, function: Callable[[str], str]) -> CompliantT_co: ...
|
||||
def prefix(self, prefix: str) -> CompliantT_co: ...
|
||||
def suffix(self, suffix: str) -> CompliantT_co: ...
|
||||
def to_lowercase(self) -> CompliantT_co: ...
|
||||
def to_uppercase(self) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class StringNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def len_chars(self) -> CompliantT_co: ...
|
||||
def replace(
|
||||
self, pattern: str, value: str, *, literal: bool, n: int
|
||||
) -> CompliantT_co: ...
|
||||
def replace_all(
|
||||
self, pattern: str, value: str, *, literal: bool
|
||||
) -> CompliantT_co: ...
|
||||
def strip_chars(self, characters: str | None) -> CompliantT_co: ...
|
||||
def starts_with(self, prefix: str) -> CompliantT_co: ...
|
||||
def ends_with(self, suffix: str) -> CompliantT_co: ...
|
||||
def contains(self, pattern: str, *, literal: bool) -> CompliantT_co: ...
|
||||
def slice(self, offset: int, length: int | None) -> CompliantT_co: ...
|
||||
def split(self, by: str) -> CompliantT_co: ...
|
||||
def to_datetime(self, format: str | None) -> CompliantT_co: ...
|
||||
def to_date(self, format: str | None) -> CompliantT_co: ...
|
||||
def to_lowercase(self) -> CompliantT_co: ...
|
||||
def to_uppercase(self) -> CompliantT_co: ...
|
||||
def zfill(self, width: int) -> CompliantT_co: ...
|
||||
|
||||
|
||||
class StructNamespace(_StoresCompliant[CompliantT_co], Protocol[CompliantT_co]):
|
||||
def field(self, name: str) -> CompliantT_co: ...
|
||||
213
lib/python3.11/site-packages/narwhals/_compliant/column.py
Normal file
213
lib/python3.11/site-packages/narwhals/_compliant/column.py
Normal file
@ -0,0 +1,213 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.any_namespace import (
|
||||
CatNamespace,
|
||||
DateTimeNamespace,
|
||||
ListNamespace,
|
||||
StringNamespace,
|
||||
StructNamespace,
|
||||
)
|
||||
from narwhals._compliant.namespace import CompliantNamespace
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import (
|
||||
ClosedInterval,
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
ModeKeepStrategy,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
RankMethod,
|
||||
TemporalLiteral,
|
||||
)
|
||||
|
||||
__all__ = ["CompliantColumn"]
|
||||
|
||||
|
||||
class CompliantColumn(Protocol):
|
||||
"""Common parts of `Expr`, `Series`."""
|
||||
|
||||
_version: Version
|
||||
|
||||
def __add__(self, other: Any) -> Self: ...
|
||||
def __and__(self, other: Any) -> Self: ...
|
||||
def __eq__(self, other: object) -> Self: ... # type: ignore[override]
|
||||
def __floordiv__(self, other: Any) -> Self: ...
|
||||
def __ge__(self, other: Any) -> Self: ...
|
||||
def __gt__(self, other: Any) -> Self: ...
|
||||
def __invert__(self) -> Self: ...
|
||||
def __le__(self, other: Any) -> Self: ...
|
||||
def __lt__(self, other: Any) -> Self: ...
|
||||
def __mod__(self, other: Any) -> Self: ...
|
||||
def __mul__(self, other: Any) -> Self: ...
|
||||
def __ne__(self, other: object) -> Self: ... # type: ignore[override]
|
||||
def __or__(self, other: Any) -> Self: ...
|
||||
def __pow__(self, other: Any) -> Self: ...
|
||||
def __rfloordiv__(self, other: Any) -> Self: ...
|
||||
def __rmod__(self, other: Any) -> Self: ...
|
||||
def __rpow__(self, other: Any) -> Self: ...
|
||||
def __rsub__(self, other: Any) -> Self: ...
|
||||
def __rtruediv__(self, other: Any) -> Self: ...
|
||||
def __sub__(self, other: Any) -> Self: ...
|
||||
def __truediv__(self, other: Any) -> Self: ...
|
||||
|
||||
def __narwhals_namespace__(self) -> CompliantNamespace[Any, Any]: ...
|
||||
|
||||
def abs(self) -> Self: ...
|
||||
def alias(self, name: str) -> Self: ...
|
||||
def cast(self, dtype: IntoDType) -> Self: ...
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self: ...
|
||||
def cum_count(self, *, reverse: bool) -> Self: ...
|
||||
def cum_max(self, *, reverse: bool) -> Self: ...
|
||||
def cum_min(self, *, reverse: bool) -> Self: ...
|
||||
def cum_prod(self, *, reverse: bool) -> Self: ...
|
||||
def cum_sum(self, *, reverse: bool) -> Self: ...
|
||||
def diff(self) -> Self: ...
|
||||
def drop_nulls(self) -> Self: ...
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self: ...
|
||||
def exp(self) -> Self: ...
|
||||
def sqrt(self) -> Self: ...
|
||||
def fill_nan(self, value: float | None) -> Self: ...
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self: ...
|
||||
def is_between(
|
||||
self, lower_bound: Self, upper_bound: Self, closed: ClosedInterval
|
||||
) -> Self:
|
||||
if closed == "left":
|
||||
return (self >= lower_bound) & (self < upper_bound)
|
||||
if closed == "right":
|
||||
return (self > lower_bound) & (self <= upper_bound)
|
||||
if closed == "none":
|
||||
return (self > lower_bound) & (self < upper_bound)
|
||||
return (self >= lower_bound) & (self <= upper_bound)
|
||||
|
||||
def is_close(
|
||||
self,
|
||||
other: Self | NumericLiteral,
|
||||
*,
|
||||
abs_tol: float,
|
||||
rel_tol: float,
|
||||
nans_equal: bool,
|
||||
) -> Self:
|
||||
from decimal import Decimal
|
||||
|
||||
other_abs: Self | NumericLiteral
|
||||
other_is_nan: Self | bool
|
||||
other_is_inf: Self | bool
|
||||
other_is_not_inf: Self | bool
|
||||
|
||||
if isinstance(other, (float, int, Decimal)):
|
||||
from math import isinf, isnan
|
||||
|
||||
# NOTE: See https://discuss.python.org/t/inferred-type-of-function-that-calls-dunder-abs-abs/101447
|
||||
other_abs = other.__abs__()
|
||||
other_is_nan = isnan(other)
|
||||
other_is_inf = isinf(other)
|
||||
|
||||
# Define the other_is_not_inf variable to prevent triggering the following warning:
|
||||
# > DeprecationWarning: Bitwise inversion '~' on bool is deprecated and will be
|
||||
# > removed in Python 3.16.
|
||||
other_is_not_inf = not other_is_inf
|
||||
|
||||
else:
|
||||
other_abs, other_is_nan = other.abs(), other.is_nan()
|
||||
other_is_not_inf = other.is_finite() | other_is_nan
|
||||
other_is_inf = ~other_is_not_inf
|
||||
|
||||
rel_threshold = self.abs().clip(lower_bound=other_abs, upper_bound=None) * rel_tol
|
||||
tolerance = rel_threshold.clip(lower_bound=abs_tol, upper_bound=None)
|
||||
|
||||
self_is_nan = self.is_nan()
|
||||
self_is_not_inf = self.is_finite() | self_is_nan
|
||||
|
||||
# Values are close if abs_diff <= tolerance, and both finite
|
||||
is_close = (
|
||||
((self - other).abs() <= tolerance) & self_is_not_inf & other_is_not_inf
|
||||
)
|
||||
|
||||
# Handle infinity cases: infinities are close/equal if they have the same sign
|
||||
self_sign, other_sign = self > 0, other > 0
|
||||
is_same_inf = (~self_is_not_inf) & other_is_inf & (self_sign == other_sign)
|
||||
|
||||
# Handle nan cases:
|
||||
# * If any value is NaN, then False (via `& ~either_nan`)
|
||||
# * However, if `nans_equals = True` and if _both_ values are NaN, then True
|
||||
either_nan = self_is_nan | other_is_nan
|
||||
result = (is_close | is_same_inf) & ~either_nan
|
||||
|
||||
if nans_equal:
|
||||
both_nan = self_is_nan & other_is_nan
|
||||
result = result | both_nan
|
||||
|
||||
return result
|
||||
|
||||
def is_duplicated(self) -> Self:
|
||||
return ~self.is_unique()
|
||||
|
||||
def is_finite(self) -> Self: ...
|
||||
def is_first_distinct(self) -> Self: ...
|
||||
def is_in(self, other: Any) -> Self: ...
|
||||
def is_last_distinct(self) -> Self: ...
|
||||
def is_nan(self) -> Self: ...
|
||||
def is_null(self) -> Self: ...
|
||||
def is_unique(self) -> Self: ...
|
||||
def log(self, base: float) -> Self: ...
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self: ...
|
||||
def rank(self, method: RankMethod, *, descending: bool) -> Self: ...
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self: ...
|
||||
def rolling_mean(
|
||||
self, window_size: int, *, min_samples: int, center: bool
|
||||
) -> Self: ...
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self: ...
|
||||
def rolling_sum(
|
||||
self, window_size: int, *, min_samples: int, center: bool
|
||||
) -> Self: ...
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self: ...
|
||||
def round(self, decimals: int) -> Self: ...
|
||||
def shift(self, n: int) -> Self: ...
|
||||
def unique(self) -> Self: ...
|
||||
|
||||
@property
|
||||
def str(self) -> StringNamespace[Self]: ...
|
||||
@property
|
||||
def dt(self) -> DateTimeNamespace[Self]: ...
|
||||
@property
|
||||
def cat(self) -> CatNamespace[Self]: ...
|
||||
@property
|
||||
def list(self) -> ListNamespace[Self]: ...
|
||||
@property
|
||||
def struct(self) -> StructNamespace[Self]: ...
|
||||
426
lib/python3.11/site-packages/narwhals/_compliant/dataframe.py
Normal file
426
lib/python3.11/site-packages/narwhals/_compliant/dataframe.py
Normal file
@ -0,0 +1,426 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping, Sequence, Sized
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, TypeVar, overload
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameAny,
|
||||
CompliantExprT_contra,
|
||||
CompliantLazyFrameAny,
|
||||
CompliantSeriesT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
NativeDataFrameT,
|
||||
NativeLazyFrameT,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals._translate import (
|
||||
ArrowConvertible,
|
||||
DictConvertible,
|
||||
FromNative,
|
||||
NumpyConvertible,
|
||||
ToNarwhals,
|
||||
ToNarwhalsT_co,
|
||||
)
|
||||
from narwhals._typing_compat import assert_never
|
||||
from narwhals._utils import (
|
||||
ValidateBackendVersion,
|
||||
Version,
|
||||
_StoresNative,
|
||||
check_columns_exist,
|
||||
is_compliant_series,
|
||||
is_index_selector,
|
||||
is_range,
|
||||
is_sequence_like,
|
||||
is_sized_multi_index_selector,
|
||||
is_slice_index,
|
||||
is_slice_none,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant.group_by import CompliantGroupBy, DataFrameGroupBy
|
||||
from narwhals._compliant.namespace import EagerNamespace
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
|
||||
from narwhals._utils import Implementation, _LimitedContext
|
||||
from narwhals.dataframe import DataFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
from narwhals.typing import (
|
||||
AsofJoinStrategy,
|
||||
IntoSchema,
|
||||
JoinStrategy,
|
||||
LazyUniqueKeepStrategy,
|
||||
MultiColSelector,
|
||||
MultiIndexSelector,
|
||||
PivotAgg,
|
||||
SingleIndexSelector,
|
||||
SizedMultiIndexSelector,
|
||||
SizedMultiNameSelector,
|
||||
SizeUnit,
|
||||
UniqueKeepStrategy,
|
||||
_2DArray,
|
||||
_SliceIndex,
|
||||
_SliceName,
|
||||
)
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
__all__ = ["CompliantDataFrame", "CompliantFrame", "CompliantLazyFrame", "EagerDataFrame"]
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_ToDict: TypeAlias = "dict[str, CompliantSeriesT] | dict[str, list[Any]]" # noqa: PYI047
|
||||
|
||||
_NativeFrameT = TypeVar("_NativeFrameT")
|
||||
|
||||
|
||||
class CompliantFrame(
|
||||
_StoresNative[_NativeFrameT],
|
||||
FromNative[_NativeFrameT],
|
||||
ToNarwhals[ToNarwhalsT_co],
|
||||
Protocol[CompliantExprT_contra, _NativeFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
"""Common parts of `DataFrame`, `LazyFrame`."""
|
||||
|
||||
_native_frame: _NativeFrameT
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
def __native_namespace__(self) -> ModuleType: ...
|
||||
def __narwhals_namespace__(self) -> Any: ...
|
||||
def _with_version(self, version: Version) -> Self: ...
|
||||
@classmethod
|
||||
def from_native(cls, data: _NativeFrameT, /, *, context: _LimitedContext) -> Self: ...
|
||||
@property
|
||||
def columns(self) -> Sequence[str]: ...
|
||||
@property
|
||||
def native(self) -> _NativeFrameT:
|
||||
return self._native_frame
|
||||
|
||||
@property
|
||||
def schema(self) -> Mapping[str, DType]: ...
|
||||
|
||||
def collect_schema(self) -> Mapping[str, DType]: ...
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self: ...
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self: ...
|
||||
def explode(self, columns: Sequence[str]) -> Self: ...
|
||||
def filter(self, predicate: CompliantExprT_contra | Incomplete) -> Self: ...
|
||||
def group_by(
|
||||
self,
|
||||
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> CompliantGroupBy[Self, CompliantExprT_contra]: ...
|
||||
def head(self, n: int) -> Self: ...
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self: ...
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self: ...
|
||||
def select(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
"""`select` where all args are column names."""
|
||||
...
|
||||
|
||||
def sort(
|
||||
self, *by: str, descending: bool | Sequence[bool], nulls_last: bool
|
||||
) -> Self: ...
|
||||
def tail(self, n: int) -> Self: ...
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self: ...
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self: ...
|
||||
def with_columns(self, *exprs: CompliantExprT_contra) -> Self: ...
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self: ...
|
||||
|
||||
|
||||
class CompliantDataFrame(
|
||||
NumpyConvertible["_2DArray", "_2DArray"],
|
||||
DictConvertible["_ToDict[CompliantSeriesT]", Mapping[str, Any]],
|
||||
ArrowConvertible["pa.Table", "IntoArrowTable"],
|
||||
Sized,
|
||||
CompliantFrame[CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
|
||||
Protocol[CompliantSeriesT, CompliantExprT_contra, NativeDataFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
def __narwhals_dataframe__(self) -> Self: ...
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | None,
|
||||
) -> Self: ...
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | Sequence[str] | None,
|
||||
) -> Self: ...
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _2DArray: ...
|
||||
def __getitem__(
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[CompliantSeriesT],
|
||||
MultiColSelector[CompliantSeriesT],
|
||||
],
|
||||
) -> Self: ...
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]: ...
|
||||
def clone(self) -> Self: ...
|
||||
def estimated_size(self, unit: SizeUnit) -> int | float: ...
|
||||
def gather_every(self, n: int, offset: int) -> Self: ...
|
||||
def get_column(self, name: str) -> CompliantSeriesT: ...
|
||||
def group_by(
|
||||
self,
|
||||
keys: Sequence[str] | Sequence[CompliantExprT_contra],
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> DataFrameGroupBy[Self, Any]: ...
|
||||
def item(self, row: int | None, column: int | str | None) -> Any: ...
|
||||
def iter_columns(self) -> Iterator[CompliantSeriesT]: ...
|
||||
def iter_rows(
|
||||
self, *, named: bool, buffer_size: int
|
||||
) -> Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]: ...
|
||||
def is_unique(self) -> CompliantSeriesT: ...
|
||||
def lazy(
|
||||
self, backend: _LazyAllowedImpl | None, *, session: SparkSession | None
|
||||
) -> CompliantLazyFrameAny: ...
|
||||
def pivot(
|
||||
self,
|
||||
on: Sequence[str],
|
||||
*,
|
||||
index: Sequence[str] | None,
|
||||
values: Sequence[str] | None,
|
||||
aggregate_function: PivotAgg | None,
|
||||
sort_columns: bool,
|
||||
separator: str,
|
||||
) -> Self: ...
|
||||
def row(self, index: int) -> tuple[Any, ...]: ...
|
||||
def rows(
|
||||
self, *, named: bool
|
||||
) -> Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]: ...
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self: ...
|
||||
def to_arrow(self) -> pa.Table: ...
|
||||
def to_pandas(self) -> pd.DataFrame: ...
|
||||
def to_polars(self) -> pl.DataFrame: ...
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, CompliantSeriesT]: ...
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, CompliantSeriesT] | dict[str, list[Any]]: ...
|
||||
def unique(
|
||||
self,
|
||||
subset: Sequence[str] | None,
|
||||
*,
|
||||
keep: UniqueKeepStrategy,
|
||||
maintain_order: bool | None = None,
|
||||
) -> Self: ...
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self: ...
|
||||
@overload
|
||||
def write_csv(self, file: None) -> str: ...
|
||||
@overload
|
||||
def write_csv(self, file: str | Path | BytesIO) -> None: ...
|
||||
def write_csv(self, file: str | Path | BytesIO | None) -> str | None: ...
|
||||
def write_parquet(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
|
||||
class CompliantLazyFrame(
|
||||
CompliantFrame[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
|
||||
Protocol[CompliantExprT_contra, NativeLazyFrameT, ToNarwhalsT_co],
|
||||
):
|
||||
def __narwhals_lazyframe__(self) -> Self: ...
|
||||
# `LazySelectorNamespace._iter_columns` depends
|
||||
def _iter_columns(self) -> Iterator[Any]: ...
|
||||
def aggregate(self, *exprs: CompliantExprT_contra) -> Self:
|
||||
"""`select` where all args are aggregations or literals.
|
||||
|
||||
(so, no broadcasting is necessary).
|
||||
"""
|
||||
...
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny: ...
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None: ...
|
||||
|
||||
|
||||
class EagerDataFrame(
|
||||
CompliantDataFrame[
|
||||
EagerSeriesT, EagerExprT, NativeDataFrameT, "DataFrame[NativeDataFrameT]"
|
||||
],
|
||||
CompliantLazyFrame[EagerExprT, "Incomplete", "DataFrame[NativeDataFrameT]"],
|
||||
ValidateBackendVersion,
|
||||
Protocol[EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __narwhals_namespace__(
|
||||
self,
|
||||
) -> EagerNamespace[
|
||||
Self, EagerSeriesT, EagerExprT, NativeDataFrameT, NativeSeriesT
|
||||
]: ...
|
||||
|
||||
def to_narwhals(self) -> DataFrame[NativeDataFrameT]:
|
||||
return self._version.dataframe(self, level="full")
|
||||
|
||||
def aggregate(self, *exprs: EagerExprT) -> Self:
|
||||
# NOTE: Ignore intermittent [False Negative]
|
||||
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "exprs" of type "EagerExprT@EagerDataFrame" in function "select"
|
||||
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
|
||||
return self.select(*exprs) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def _with_native(
|
||||
self, df: NativeDataFrameT, *, validate_column_names: bool = True
|
||||
) -> Self: ...
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist(subset, available=self.columns)
|
||||
|
||||
def _evaluate_expr(self, expr: EagerExprT, /) -> EagerSeriesT:
|
||||
"""Evaluate `expr` and ensure it has a **single** output."""
|
||||
result: Sequence[EagerSeriesT] = expr(self)
|
||||
assert len(result) == 1 # debug assertion # noqa: S101
|
||||
return result[0]
|
||||
|
||||
def _evaluate_into_exprs(self, *exprs: EagerExprT) -> Sequence[EagerSeriesT]:
|
||||
# NOTE: Ignore intermittent [False Negative]
|
||||
# Argument of type "EagerExprT@EagerDataFrame" cannot be assigned to parameter "expr" of type "EagerExprT@EagerDataFrame" in function "_evaluate_into_expr"
|
||||
# Type "EagerExprT@EagerDataFrame" is not assignable to type "EagerExprT@EagerDataFrame"
|
||||
return list(chain.from_iterable(self._evaluate_into_expr(expr) for expr in exprs)) # pyright: ignore[reportArgumentType]
|
||||
|
||||
def _evaluate_into_expr(self, expr: EagerExprT, /) -> Sequence[EagerSeriesT]:
|
||||
"""Return list of raw columns.
|
||||
|
||||
For eager backends we alias operations at each step.
|
||||
|
||||
As a safety precaution, here we can check that the expected result names match those
|
||||
we were expecting from the various `evaluate_output_names` / `alias_output_names` calls.
|
||||
|
||||
Note that for PySpark / DuckDB, we are less free to liberally set aliases whenever we want.
|
||||
"""
|
||||
aliases = expr._evaluate_aliases(self)
|
||||
result = expr(self)
|
||||
if list(aliases) != (
|
||||
result_aliases := [s.name for s in result]
|
||||
): # pragma: no cover
|
||||
msg = f"Safety assertion failed, expected {aliases}, got {result_aliases}"
|
||||
raise AssertionError(msg)
|
||||
return result
|
||||
|
||||
def _extract_comparand(self, other: EagerSeriesT, /) -> Any:
|
||||
"""Extract native Series, broadcasting to `len(self)` if necessary."""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
def _numpy_column_names(
|
||||
data: _2DArray, columns: Sequence[str] | None, /
|
||||
) -> list[str]:
|
||||
return list(columns or (f"column_{x}" for x in range(data.shape[1])))
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
|
||||
def _select_multi_index(
|
||||
self, columns: SizedMultiIndexSelector[NativeSeriesT]
|
||||
) -> Self: ...
|
||||
def _select_multi_name(
|
||||
self, columns: SizedMultiNameSelector[NativeSeriesT]
|
||||
) -> Self: ...
|
||||
def _select_slice_index(self, columns: _SliceIndex | range) -> Self: ...
|
||||
def _select_slice_name(self, columns: _SliceName) -> Self: ...
|
||||
def __getitem__( # noqa: C901, PLR0912
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[EagerSeriesT],
|
||||
MultiColSelector[EagerSeriesT],
|
||||
],
|
||||
) -> Self:
|
||||
rows, columns = item
|
||||
compliant = self
|
||||
if not is_slice_none(columns):
|
||||
if isinstance(columns, Sized) and len(columns) == 0:
|
||||
return compliant.select()
|
||||
if is_index_selector(columns):
|
||||
if is_slice_index(columns) or is_range(columns):
|
||||
compliant = compliant._select_slice_index(columns)
|
||||
elif is_compliant_series(columns):
|
||||
compliant = self._select_multi_index(columns.native)
|
||||
else:
|
||||
compliant = compliant._select_multi_index(columns)
|
||||
elif isinstance(columns, slice):
|
||||
compliant = compliant._select_slice_name(columns)
|
||||
elif is_compliant_series(columns):
|
||||
compliant = self._select_multi_name(columns.native)
|
||||
elif is_sequence_like(columns):
|
||||
compliant = self._select_multi_name(columns)
|
||||
else:
|
||||
assert_never(columns)
|
||||
|
||||
if not is_slice_none(rows):
|
||||
if isinstance(rows, int):
|
||||
compliant = compliant._gather([rows])
|
||||
elif isinstance(rows, (slice, range)):
|
||||
compliant = compliant._gather_slice(rows)
|
||||
elif is_compliant_series(rows):
|
||||
compliant = compliant._gather(rows.native)
|
||||
elif is_sized_multi_index_selector(rows):
|
||||
compliant = compliant._gather(rows)
|
||||
else:
|
||||
assert_never(rows)
|
||||
|
||||
return compliant
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
return self.write_parquet(file)
|
||||
1140
lib/python3.11/site-packages/narwhals/_compliant/expr.py
Normal file
1140
lib/python3.11/site-packages/narwhals/_compliant/expr.py
Normal file
File diff suppressed because it is too large
Load Diff
180
lib/python3.11/site-packages/narwhals/_compliant/group_by.py
Normal file
180
lib/python3.11/site-packages/narwhals/_compliant/group_by.py
Normal file
@ -0,0 +1,180 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Protocol, TypeVar
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameT,
|
||||
CompliantDataFrameT_co,
|
||||
CompliantExprT_contra,
|
||||
CompliantFrameT,
|
||||
CompliantFrameT_co,
|
||||
DepthTrackingExprAny,
|
||||
DepthTrackingExprT_contra,
|
||||
EagerExprT_contra,
|
||||
ImplExprT_contra,
|
||||
NarwhalsAggregation,
|
||||
)
|
||||
from narwhals._utils import is_sequence_of, zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
|
||||
from narwhals._compliant.expr import ImplExpr
|
||||
|
||||
|
||||
__all__ = ["CompliantGroupBy", "DepthTrackingGroupBy", "EagerGroupBy"]
|
||||
|
||||
NativeAggregationT_co = TypeVar(
|
||||
"NativeAggregationT_co", bound="str | Callable[..., Any]", covariant=True
|
||||
)
|
||||
|
||||
|
||||
_RE_LEAF_NAME: re.Pattern[str] = re.compile(r"(\w+->)")
|
||||
|
||||
|
||||
def _evaluate_aliases(
|
||||
frame: CompliantFrameT, exprs: Iterable[ImplExpr[CompliantFrameT, Any]], /
|
||||
) -> list[str]:
|
||||
it = (expr._evaluate_aliases(frame) for expr in exprs)
|
||||
return list(chain.from_iterable(it))
|
||||
|
||||
|
||||
class CompliantGroupBy(Protocol[CompliantFrameT_co, CompliantExprT_contra]):
|
||||
_compliant_frame: Any
|
||||
|
||||
@property
|
||||
def compliant(self) -> CompliantFrameT_co:
|
||||
return self._compliant_frame # type: ignore[no-any-return]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
compliant_frame: CompliantFrameT_co,
|
||||
keys: Sequence[CompliantExprT_contra] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None: ...
|
||||
|
||||
def agg(self, *exprs: CompliantExprT_contra) -> CompliantFrameT_co: ...
|
||||
|
||||
|
||||
class DataFrameGroupBy(
|
||||
CompliantGroupBy[CompliantDataFrameT_co, CompliantExprT_contra],
|
||||
Protocol[CompliantDataFrameT_co, CompliantExprT_contra],
|
||||
):
|
||||
def __iter__(self) -> Iterator[tuple[Any, CompliantDataFrameT_co]]: ...
|
||||
|
||||
|
||||
class ParseKeysGroupBy(
|
||||
CompliantGroupBy[CompliantFrameT, ImplExprT_contra],
|
||||
Protocol[CompliantFrameT, ImplExprT_contra],
|
||||
):
|
||||
def _parse_keys(
|
||||
self,
|
||||
compliant_frame: CompliantFrameT,
|
||||
keys: Sequence[ImplExprT_contra] | Sequence[str],
|
||||
) -> tuple[CompliantFrameT, list[str], list[str]]:
|
||||
if is_sequence_of(keys, str):
|
||||
keys_str = list(keys)
|
||||
return compliant_frame, keys_str, keys_str.copy()
|
||||
return self._parse_expr_keys(compliant_frame, keys=keys)
|
||||
|
||||
@staticmethod
|
||||
def _parse_expr_keys(
|
||||
compliant_frame: CompliantFrameT, keys: Sequence[ImplExprT_contra]
|
||||
) -> tuple[CompliantFrameT, list[str], list[str]]:
|
||||
"""Parses key expressions to set up `.agg` operation with correct information.
|
||||
|
||||
Since keys are expressions, it's possible to alias any such key to match
|
||||
other dataframe column names.
|
||||
|
||||
In order to match polars behavior and not overwrite columns when evaluating keys:
|
||||
|
||||
- We evaluate what the output key names should be, in order to remap temporary column
|
||||
names to the expected ones, and to exclude those from unnamed expressions in
|
||||
`.agg(...)` context (see https://github.com/narwhals-dev/narwhals/pull/2325#issuecomment-2800004520)
|
||||
- Create temporary names for evaluated key expressions that are guaranteed to have
|
||||
no overlap with any existing column name.
|
||||
- Add these temporary columns to the compliant dataframe.
|
||||
"""
|
||||
tmp_name_length = max(len(str(c)) for c in compliant_frame.columns) + 1
|
||||
|
||||
def _temporary_name(key: str) -> str:
|
||||
# 5 is the length of `__tmp`
|
||||
key_str = str(key) # pandas allows non-string column names :sob:
|
||||
return f"_{key_str}_tmp{'_' * (tmp_name_length - len(key_str) - 5)}"
|
||||
|
||||
keys_aliases = [expr._evaluate_aliases(compliant_frame) for expr in keys]
|
||||
safe_keys = [
|
||||
# multi-output expression cannot have duplicate names, hence it's safe to suffix
|
||||
key.name.map(_temporary_name)
|
||||
if (metadata := key._metadata) and metadata.expansion_kind.is_multi_output()
|
||||
# otherwise it's single named and we can use Expr.alias
|
||||
else key.alias(_temporary_name(new_names[0]))
|
||||
for key, new_names in zip_strict(keys, keys_aliases)
|
||||
]
|
||||
return (
|
||||
compliant_frame.with_columns(*safe_keys),
|
||||
_evaluate_aliases(compliant_frame, safe_keys),
|
||||
list(chain.from_iterable(keys_aliases)),
|
||||
)
|
||||
|
||||
|
||||
class DepthTrackingGroupBy(
|
||||
ParseKeysGroupBy[CompliantFrameT, DepthTrackingExprT_contra],
|
||||
Protocol[CompliantFrameT, DepthTrackingExprT_contra, NativeAggregationT_co],
|
||||
):
|
||||
"""`CompliantGroupBy` variant, deals with `Eager` and other backends that utilize `CompliantExpr._depth`."""
|
||||
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Any]]
|
||||
"""Mapping from `narwhals` to native representation.
|
||||
|
||||
Note:
|
||||
- `Dask` *may* return a `Callable` instead of a `str` referring to one.
|
||||
"""
|
||||
|
||||
def _ensure_all_simple(self, exprs: Sequence[DepthTrackingExprT_contra]) -> None:
|
||||
for expr in exprs:
|
||||
if not self._is_simple(expr):
|
||||
name = self.compliant._implementation.name.lower()
|
||||
msg = (
|
||||
f"Non-trivial complex aggregation found.\n\n"
|
||||
f"Hint: you were probably trying to apply a non-elementary aggregation with a"
|
||||
f"{name!r} table.\n"
|
||||
"Please rewrite your query such that group-by aggregations "
|
||||
"are elementary. For example, instead of:\n\n"
|
||||
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
|
||||
"use:\n\n"
|
||||
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
|
||||
)
|
||||
raise ValueError(msg)
|
||||
|
||||
@classmethod
|
||||
def _is_simple(cls, expr: DepthTrackingExprAny, /) -> bool:
|
||||
"""Return `True` is we can efficiently use `expr` in a native `group_by` context."""
|
||||
return expr._is_elementary() and cls._leaf_name(expr) in cls._REMAP_AGGS
|
||||
|
||||
@classmethod
|
||||
def _remap_expr_name(
|
||||
cls, name: NarwhalsAggregation | Any, /
|
||||
) -> NativeAggregationT_co:
|
||||
"""Replace `name`, with some native representation.
|
||||
|
||||
Arguments:
|
||||
name: Name of a `nw.Expr` aggregation method.
|
||||
"""
|
||||
return cls._REMAP_AGGS.get(name, name)
|
||||
|
||||
@classmethod
|
||||
def _leaf_name(cls, expr: DepthTrackingExprAny, /) -> NarwhalsAggregation | Any:
|
||||
"""Return the last function name in the chain defined by `expr`."""
|
||||
return _RE_LEAF_NAME.sub("", expr._function_name)
|
||||
|
||||
|
||||
class EagerGroupBy(
|
||||
DepthTrackingGroupBy[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
|
||||
DataFrameGroupBy[CompliantDataFrameT, EagerExprT_contra],
|
||||
Protocol[CompliantDataFrameT, EagerExprT_contra, NativeAggregationT_co],
|
||||
): ...
|
||||
238
lib/python3.11/site-packages/narwhals/_compliant/namespace.py
Normal file
238
lib/python3.11/site-packages/narwhals/_compliant/namespace.py
Normal file
@ -0,0 +1,238 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Protocol, overload
|
||||
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprT,
|
||||
CompliantFrameT,
|
||||
CompliantLazyFrameT,
|
||||
DepthTrackingExprT,
|
||||
EagerDataFrameT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
LazyExprT,
|
||||
NativeFrameT,
|
||||
NativeFrameT_co,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals._expression_parsing import is_expr, is_series
|
||||
from narwhals._utils import (
|
||||
exclude_column_names,
|
||||
get_column_names,
|
||||
passthrough_column_names,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array, is_numpy_array_2d
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Iterable, Sequence
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.selectors import CompliantSelectorNamespace
|
||||
from narwhals._compliant.when_then import CompliantWhen, EagerWhen
|
||||
from narwhals._utils import Implementation, Version
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
ConcatMethod,
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
IntoSchema,
|
||||
NonNestedLiteral,
|
||||
_1DArray,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
__all__ = [
|
||||
"CompliantNamespace",
|
||||
"DepthTrackingNamespace",
|
||||
"EagerNamespace",
|
||||
"LazyNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CompliantNamespace(Protocol[CompliantFrameT, CompliantExprT]):
|
||||
# NOTE: `narwhals`
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[CompliantExprT]: ...
|
||||
def parse_into_expr(
|
||||
self, data: Expr | NonNestedLiteral | Any, /, *, str_as_lit: bool
|
||||
) -> CompliantExprT | NonNestedLiteral:
|
||||
if is_expr(data):
|
||||
expr = data._to_compliant_expr(self)
|
||||
assert isinstance(expr, self._expr) # noqa: S101
|
||||
return expr
|
||||
if isinstance(data, str) and not str_as_lit:
|
||||
return self.col(data)
|
||||
return data
|
||||
|
||||
# NOTE: `polars`
|
||||
def all(self) -> CompliantExprT:
|
||||
return self._expr.from_column_names(get_column_names, context=self)
|
||||
|
||||
def col(self, *column_names: str) -> CompliantExprT:
|
||||
return self._expr.from_column_names(
|
||||
passthrough_column_names(column_names), context=self
|
||||
)
|
||||
|
||||
def exclude(self, excluded_names: Container[str]) -> CompliantExprT:
|
||||
return self._expr.from_column_names(
|
||||
partial(exclude_column_names, names=excluded_names), context=self
|
||||
)
|
||||
|
||||
def nth(self, *column_indices: int) -> CompliantExprT:
|
||||
return self._expr.from_column_indices(*column_indices, context=self)
|
||||
|
||||
def len(self) -> CompliantExprT: ...
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> CompliantExprT: ...
|
||||
def all_horizontal(
|
||||
self, *exprs: CompliantExprT, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
def any_horizontal(
|
||||
self, *exprs: CompliantExprT, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
def sum_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def mean_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def min_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def max_horizontal(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
def concat(
|
||||
self, items: Iterable[CompliantFrameT], *, how: ConcatMethod
|
||||
) -> CompliantFrameT: ...
|
||||
def when(
|
||||
self, predicate: CompliantExprT
|
||||
) -> CompliantWhen[CompliantFrameT, Incomplete, CompliantExprT]: ...
|
||||
def concat_str(
|
||||
self, *exprs: CompliantExprT, separator: str, ignore_nulls: bool
|
||||
) -> CompliantExprT: ...
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[Any, Any]: ...
|
||||
def coalesce(self, *exprs: CompliantExprT) -> CompliantExprT: ...
|
||||
|
||||
|
||||
class DepthTrackingNamespace(
|
||||
CompliantNamespace[CompliantFrameT, DepthTrackingExprT],
|
||||
Protocol[CompliantFrameT, DepthTrackingExprT],
|
||||
):
|
||||
def all(self) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
get_column_names, function_name="all", context=self
|
||||
)
|
||||
|
||||
def col(self, *column_names: str) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
passthrough_column_names(column_names), function_name="col", context=self
|
||||
)
|
||||
|
||||
def exclude(self, excluded_names: Container[str]) -> DepthTrackingExprT:
|
||||
return self._expr.from_column_names(
|
||||
partial(exclude_column_names, names=excluded_names),
|
||||
function_name="exclude",
|
||||
context=self,
|
||||
)
|
||||
|
||||
|
||||
class LazyNamespace(
|
||||
CompliantNamespace[CompliantLazyFrameT, LazyExprT],
|
||||
Protocol[CompliantLazyFrameT, LazyExprT, NativeFrameT_co],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[CompliantLazyFrameT]: ...
|
||||
|
||||
def from_native(self, data: NativeFrameT_co | Any, /) -> CompliantLazyFrameT:
|
||||
if self._lazyframe._is_native(data):
|
||||
return self._lazyframe.from_native(data, context=self)
|
||||
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
|
||||
raise TypeError(msg)
|
||||
|
||||
|
||||
class EagerNamespace(
|
||||
DepthTrackingNamespace[EagerDataFrameT, EagerExprT],
|
||||
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeFrameT, NativeSeriesT],
|
||||
):
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[EagerDataFrameT]: ...
|
||||
@property
|
||||
def _series(self) -> type[EagerSeriesT]: ...
|
||||
def when(
|
||||
self, predicate: EagerExprT
|
||||
) -> EagerWhen[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT]: ...
|
||||
|
||||
@overload
|
||||
def from_native(self, data: NativeFrameT, /) -> EagerDataFrameT: ...
|
||||
@overload
|
||||
def from_native(self, data: NativeSeriesT, /) -> EagerSeriesT: ...
|
||||
def from_native(
|
||||
self, data: NativeFrameT | NativeSeriesT | Any, /
|
||||
) -> EagerDataFrameT | EagerSeriesT:
|
||||
if self._dataframe._is_native(data):
|
||||
return self._dataframe.from_native(data, context=self)
|
||||
if self._series._is_native(data):
|
||||
return self._series.from_native(data, context=self)
|
||||
msg = f"Unsupported type: {type(data).__name__!r}"
|
||||
raise TypeError(msg)
|
||||
|
||||
def parse_into_expr(
|
||||
self,
|
||||
data: Expr | Series[NativeSeriesT] | _1DArray | NonNestedLiteral,
|
||||
/,
|
||||
*,
|
||||
str_as_lit: bool,
|
||||
) -> EagerExprT | NonNestedLiteral:
|
||||
if not (is_series(data) or is_numpy_array(data)):
|
||||
return super().parse_into_expr(data, str_as_lit=str_as_lit)
|
||||
return self._expr._from_series(
|
||||
data._compliant_series
|
||||
if is_series(data)
|
||||
else self._series.from_numpy(data, context=self)
|
||||
)
|
||||
|
||||
@overload
|
||||
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> EagerSeriesT: ...
|
||||
|
||||
@overload
|
||||
def from_numpy(
|
||||
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
|
||||
) -> EagerDataFrameT: ...
|
||||
|
||||
def from_numpy(
|
||||
self,
|
||||
data: Into1DArray | _2DArray,
|
||||
/,
|
||||
schema: IntoSchema | Sequence[str] | None = None,
|
||||
) -> EagerDataFrameT | EagerSeriesT:
|
||||
if is_numpy_array_2d(data):
|
||||
return self._dataframe.from_numpy(data, schema=schema, context=self)
|
||||
return self._series.from_numpy(data, context=self)
|
||||
|
||||
def _concat_diagonal(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
||||
def _concat_horizontal(
|
||||
self, dfs: Sequence[NativeFrameT | Any], /
|
||||
) -> NativeFrameT: ...
|
||||
def _concat_vertical(self, dfs: Sequence[NativeFrameT], /) -> NativeFrameT: ...
|
||||
def concat(
|
||||
self, items: Iterable[EagerDataFrameT], *, how: ConcatMethod
|
||||
) -> EagerDataFrameT:
|
||||
dfs = [item.native for item in items]
|
||||
if how == "horizontal":
|
||||
native = self._concat_horizontal(dfs)
|
||||
elif how == "vertical":
|
||||
native = self._concat_vertical(dfs)
|
||||
elif how == "diagonal":
|
||||
native = self._concat_diagonal(dfs)
|
||||
else: # pragma: no cover
|
||||
raise NotImplementedError
|
||||
return self._dataframe.from_native(native, context=self)
|
||||
318
lib/python3.11/site-packages/narwhals/_compliant/selectors.py
Normal file
318
lib/python3.11/site-packages/narwhals/_compliant/selectors.py
Normal file
@ -0,0 +1,318 @@
|
||||
"""Almost entirely complete, generic `selectors` implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Protocol, TypeVar, overload
|
||||
|
||||
from narwhals._compliant.expr import CompliantExpr
|
||||
from narwhals._utils import (
|
||||
_parse_time_unit_and_time_zone,
|
||||
dtype_matches_time_unit_and_time_zone,
|
||||
get_column_names,
|
||||
is_compliant_dataframe,
|
||||
zip_strict,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Iterable, Iterator, Sequence
|
||||
from datetime import timezone
|
||||
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.expr import NativeExpr
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantDataFrameAny,
|
||||
CompliantExprAny,
|
||||
CompliantFrameAny,
|
||||
CompliantLazyFrameAny,
|
||||
CompliantSeriesAny,
|
||||
CompliantSeriesOrNativeExprAny,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
)
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
__all__ = [
|
||||
"CompliantSelector",
|
||||
"CompliantSelectorNamespace",
|
||||
"EagerSelectorNamespace",
|
||||
"LazySelectorNamespace",
|
||||
]
|
||||
|
||||
|
||||
SeriesOrExprT = TypeVar("SeriesOrExprT", bound="CompliantSeriesOrNativeExprAny")
|
||||
SeriesT = TypeVar("SeriesT", bound="CompliantSeriesAny")
|
||||
ExprT = TypeVar("ExprT", bound="NativeExpr")
|
||||
FrameT = TypeVar("FrameT", bound="CompliantFrameAny")
|
||||
DataFrameT = TypeVar("DataFrameT", bound="CompliantDataFrameAny")
|
||||
LazyFrameT = TypeVar("LazyFrameT", bound="CompliantLazyFrameAny")
|
||||
SelectorOrExpr: TypeAlias = (
|
||||
"CompliantSelector[FrameT, SeriesOrExprT] | CompliantExpr[FrameT, SeriesOrExprT]"
|
||||
)
|
||||
|
||||
|
||||
class CompliantSelectorNamespace(Protocol[FrameT, SeriesOrExprT]):
|
||||
# NOTE: `narwhals`
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _selector(self) -> type[CompliantSelector[FrameT, SeriesOrExprT]]: ...
|
||||
@classmethod
|
||||
def from_namespace(cls, context: _LimitedContext, /) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
def _iter_columns(self, df: FrameT, /) -> Iterator[SeriesOrExprT]: ...
|
||||
def _iter_schema(self, df: FrameT, /) -> Iterator[tuple[str, DType]]: ...
|
||||
def _iter_columns_dtypes(
|
||||
self, df: FrameT, /
|
||||
) -> Iterator[tuple[SeriesOrExprT, DType]]: ...
|
||||
def _iter_columns_names(self, df: FrameT, /) -> Iterator[tuple[SeriesOrExprT, str]]:
|
||||
yield from zip_strict(self._iter_columns(df), df.columns)
|
||||
|
||||
def _is_dtype(
|
||||
self: CompliantSelectorNamespace[FrameT, SeriesOrExprT], dtype: type[DType], /
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [
|
||||
ser for ser, tp in self._iter_columns_dtypes(df) if isinstance(tp, dtype)
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if isinstance(tp, dtype)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
# NOTE: `polars`
|
||||
def by_dtype(
|
||||
self, dtypes: Collection[DType | type[DType]]
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp in dtypes]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if tp in dtypes]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def matches(self, pattern: str) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
p = re.compile(pattern)
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
if (
|
||||
is_compliant_dataframe(df)
|
||||
and not self._implementation.is_duckdb()
|
||||
and not self._implementation.is_ibis()
|
||||
):
|
||||
return [df.get_column(col) for col in df.columns if p.search(col)]
|
||||
|
||||
return [ser for ser, name in self._iter_columns_names(df) if p.search(name)]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [col for col in df.columns if p.search(col)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def numeric(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if tp.is_numeric()]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if tp.is_numeric()]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
def categorical(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.Categorical)
|
||||
|
||||
def string(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.String)
|
||||
|
||||
def boolean(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self._is_dtype(self._version.dtypes.Boolean)
|
||||
|
||||
def all(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return list(self._iter_columns(df))
|
||||
|
||||
return self._selector.from_callables(series, get_column_names, context=self)
|
||||
|
||||
def datetime(
|
||||
self,
|
||||
time_unit: TimeUnit | Iterable[TimeUnit] | None,
|
||||
time_zone: str | timezone | Iterable[str | timezone | None] | None,
|
||||
) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
time_units, time_zones = _parse_time_unit_and_time_zone(time_unit, time_zone)
|
||||
matches = partial(
|
||||
dtype_matches_time_unit_and_time_zone,
|
||||
dtypes=self._version.dtypes,
|
||||
time_units=time_units,
|
||||
time_zones=time_zones,
|
||||
)
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
return [ser for ser, tp in self._iter_columns_dtypes(df) if matches(tp)]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
return [name for name, tp in self._iter_schema(df) if matches(tp)]
|
||||
|
||||
return self._selector.from_callables(series, names, context=self)
|
||||
|
||||
|
||||
class EagerSelectorNamespace(
|
||||
CompliantSelectorNamespace[DataFrameT, SeriesT], Protocol[DataFrameT, SeriesT]
|
||||
):
|
||||
def _iter_schema(self, df: DataFrameT, /) -> Iterator[tuple[str, DType]]:
|
||||
for ser in self._iter_columns(df):
|
||||
yield ser.name, ser.dtype
|
||||
|
||||
def _iter_columns(self, df: DataFrameT, /) -> Iterator[SeriesT]:
|
||||
yield from df.iter_columns()
|
||||
|
||||
def _iter_columns_dtypes(self, df: DataFrameT, /) -> Iterator[tuple[SeriesT, DType]]:
|
||||
for ser in self._iter_columns(df):
|
||||
yield ser, ser.dtype
|
||||
|
||||
|
||||
class LazySelectorNamespace(
|
||||
CompliantSelectorNamespace[LazyFrameT, ExprT], Protocol[LazyFrameT, ExprT]
|
||||
):
|
||||
def _iter_schema(self, df: LazyFrameT) -> Iterator[tuple[str, DType]]:
|
||||
yield from df.schema.items()
|
||||
|
||||
def _iter_columns(self, df: LazyFrameT) -> Iterator[ExprT]:
|
||||
yield from df._iter_columns()
|
||||
|
||||
def _iter_columns_dtypes(self, df: LazyFrameT, /) -> Iterator[tuple[ExprT, DType]]:
|
||||
yield from zip_strict(self._iter_columns(df), df.schema.values())
|
||||
|
||||
|
||||
class CompliantSelector(
|
||||
CompliantExpr[FrameT, SeriesOrExprT], Protocol[FrameT, SeriesOrExprT]
|
||||
):
|
||||
_call: EvalSeries[FrameT, SeriesOrExprT]
|
||||
_function_name: str
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@classmethod
|
||||
def from_callables(
|
||||
cls,
|
||||
call: EvalSeries[FrameT, SeriesOrExprT],
|
||||
evaluate_output_names: EvalNames[FrameT],
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = call
|
||||
obj._evaluate_output_names = evaluate_output_names
|
||||
obj._alias_output_names = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[FrameT, SeriesOrExprT]:
|
||||
return self.__narwhals_namespace__().selectors
|
||||
|
||||
def _to_expr(self) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
|
||||
def _is_selector(
|
||||
self, other: Self | CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> TypeIs[CompliantSelector[FrameT, SeriesOrExprT]]:
|
||||
return isinstance(other, type(self))
|
||||
|
||||
@overload
|
||||
def __sub__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __sub__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __sub__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
x
|
||||
for x, name in zip_strict(self(df), lhs_names)
|
||||
if name not in rhs_names
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [x for x in lhs_names if x not in rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() - other
|
||||
|
||||
@overload
|
||||
def __or__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __or__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __or__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
*(
|
||||
x
|
||||
for x, name in zip_strict(self(df), lhs_names)
|
||||
if name not in rhs_names
|
||||
),
|
||||
*other(df),
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [*(x for x in lhs_names if x not in rhs_names), *rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() | other
|
||||
|
||||
@overload
|
||||
def __and__(self, other: Self) -> Self: ...
|
||||
@overload
|
||||
def __and__(
|
||||
self, other: CompliantExpr[FrameT, SeriesOrExprT]
|
||||
) -> CompliantExpr[FrameT, SeriesOrExprT]: ...
|
||||
def __and__(
|
||||
self, other: SelectorOrExpr[FrameT, SeriesOrExprT]
|
||||
) -> SelectorOrExpr[FrameT, SeriesOrExprT]:
|
||||
if self._is_selector(other):
|
||||
|
||||
def series(df: FrameT) -> Sequence[SeriesOrExprT]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [
|
||||
x for x, name in zip_strict(self(df), lhs_names) if name in rhs_names
|
||||
]
|
||||
|
||||
def names(df: FrameT) -> Sequence[str]:
|
||||
lhs_names, rhs_names = _eval_lhs_rhs(df, self, other)
|
||||
return [x for x in lhs_names if x in rhs_names]
|
||||
|
||||
return self.selectors._selector.from_callables(series, names, context=self)
|
||||
return self._to_expr() & other
|
||||
|
||||
def __invert__(self) -> CompliantSelector[FrameT, SeriesOrExprT]:
|
||||
return self.selectors.all() - self
|
||||
|
||||
|
||||
def _eval_lhs_rhs(
|
||||
df: CompliantFrameAny, lhs: CompliantExprAny, rhs: CompliantExprAny
|
||||
) -> tuple[Sequence[str], Sequence[str]]:
|
||||
return lhs._evaluate_output_names(df), rhs._evaluate_output_names(df)
|
||||
411
lib/python3.11/site-packages/narwhals/_compliant/series.py
Normal file
411
lib/python3.11/site-packages/narwhals/_compliant/series.py
Normal file
@ -0,0 +1,411 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, Protocol
|
||||
|
||||
from narwhals._compliant.any_namespace import (
|
||||
CatNamespace,
|
||||
DateTimeNamespace,
|
||||
ListNamespace,
|
||||
StringNamespace,
|
||||
StructNamespace,
|
||||
)
|
||||
from narwhals._compliant.column import CompliantColumn
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantSeriesT_co,
|
||||
EagerDataFrameAny,
|
||||
EagerSeriesT_co,
|
||||
NativeSeriesT,
|
||||
NativeSeriesT_co,
|
||||
)
|
||||
from narwhals._translate import FromIterable, FromNative, NumpyConvertible, ToNarwhals
|
||||
from narwhals._typing_compat import TypeVar, assert_never
|
||||
from narwhals._utils import (
|
||||
_StoresCompliant,
|
||||
_StoresNative,
|
||||
is_compliant_series,
|
||||
is_sized_multi_index_selector,
|
||||
unstable,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
from typing_extensions import NotRequired, Self, TypedDict
|
||||
|
||||
from narwhals._compliant.dataframe import CompliantDataFrame
|
||||
from narwhals._compliant.namespace import EagerNamespace
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
MultiIndexSelector,
|
||||
RollingInterpolationMethod,
|
||||
SizedMultiIndexSelector,
|
||||
_1DArray,
|
||||
_SliceIndex,
|
||||
)
|
||||
|
||||
class HistData(TypedDict, Generic[NativeSeriesT, "_CountsT_co"]):
|
||||
breakpoint: NotRequired[list[float] | _1DArray | list[Any]]
|
||||
count: NativeSeriesT | _1DArray | _CountsT_co | list[Any]
|
||||
|
||||
|
||||
_CountsT_co = TypeVar("_CountsT_co", bound="Iterable[Any]", covariant=True)
|
||||
|
||||
__all__ = [
|
||||
"CompliantSeries",
|
||||
"EagerSeries",
|
||||
"EagerSeriesCatNamespace",
|
||||
"EagerSeriesDateTimeNamespace",
|
||||
"EagerSeriesHist",
|
||||
"EagerSeriesListNamespace",
|
||||
"EagerSeriesNamespace",
|
||||
"EagerSeriesStringNamespace",
|
||||
"EagerSeriesStructNamespace",
|
||||
]
|
||||
|
||||
|
||||
class CompliantSeries(
|
||||
NumpyConvertible["_1DArray", "Into1DArray"],
|
||||
FromIterable,
|
||||
FromNative[NativeSeriesT],
|
||||
ToNarwhals["Series[NativeSeriesT]"],
|
||||
CompliantColumn,
|
||||
Protocol[NativeSeriesT],
|
||||
):
|
||||
# NOTE: `narwhals`
|
||||
_implementation: Implementation
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT: ...
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType: ...
|
||||
@classmethod
|
||||
def from_native(cls, data: NativeSeriesT, /, *, context: _LimitedContext) -> Self: ...
|
||||
def to_narwhals(self) -> Series[NativeSeriesT]:
|
||||
return self._version.series(self, level="full")
|
||||
|
||||
def _with_native(self, series: Any) -> Self: ...
|
||||
def _with_version(self, version: Version) -> Self: ...
|
||||
|
||||
# NOTE: `polars`
|
||||
@property
|
||||
def dtype(self) -> DType: ...
|
||||
@property
|
||||
def name(self) -> str: ...
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray: ...
|
||||
def __contains__(self, other: Any) -> bool: ...
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any: ...
|
||||
def __iter__(self) -> Iterator[Any]: ...
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self: ...
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls,
|
||||
data: Iterable[Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
name: str = "",
|
||||
dtype: IntoDType | None = None,
|
||||
) -> Self: ...
|
||||
def __radd__(self, other: Any) -> Self: ...
|
||||
def __rand__(self, other: Any) -> Self: ...
|
||||
def __rmul__(self, other: Any) -> Self: ...
|
||||
def __ror__(self, other: Any) -> Self: ...
|
||||
def all(self) -> bool: ...
|
||||
def any(self) -> bool: ...
|
||||
def arg_max(self) -> int: ...
|
||||
def arg_min(self) -> int: ...
|
||||
def arg_true(self) -> Self: ...
|
||||
def count(self) -> int: ...
|
||||
def filter(self, predicate: Any) -> Self: ...
|
||||
def gather_every(self, n: int, offset: int) -> Self: ...
|
||||
def head(self, n: int) -> Self: ...
|
||||
def is_empty(self) -> bool:
|
||||
return self.len() == 0
|
||||
|
||||
def is_sorted(self, *, descending: bool) -> bool: ...
|
||||
def item(self, index: int | None) -> Any: ...
|
||||
def kurtosis(self) -> float | None: ...
|
||||
def len(self) -> int: ...
|
||||
def max(self) -> Any: ...
|
||||
def mean(self) -> float: ...
|
||||
def median(self) -> float: ...
|
||||
def min(self) -> Any: ...
|
||||
def n_unique(self) -> int: ...
|
||||
def null_count(self) -> int: ...
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> float: ...
|
||||
def sample(
|
||||
self,
|
||||
n: int | None,
|
||||
*,
|
||||
fraction: float | None,
|
||||
with_replacement: bool,
|
||||
seed: int | None,
|
||||
) -> Self: ...
|
||||
def scatter(self, indices: int | Sequence[int], values: Any) -> Self: ...
|
||||
def shift(self, n: int) -> Self: ...
|
||||
def skew(self) -> float | None: ...
|
||||
def sort(self, *, descending: bool, nulls_last: bool) -> Self: ...
|
||||
def std(self, *, ddof: int) -> float: ...
|
||||
def sum(self) -> float: ...
|
||||
def tail(self, n: int) -> Self: ...
|
||||
def to_arrow(self) -> pa.Array[Any]: ...
|
||||
def to_dummies(
|
||||
self, *, separator: str, drop_first: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def to_frame(self) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def to_list(self) -> list[Any]: ...
|
||||
def to_pandas(self) -> pd.Series[Any]: ...
|
||||
def to_polars(self) -> pl.Series: ...
|
||||
def unique(self, *, maintain_order: bool = False) -> Self: ...
|
||||
def value_counts(
|
||||
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]: ...
|
||||
def var(self, *, ddof: int) -> float: ...
|
||||
def zip_with(self, mask: Any, other: Any) -> Self: ...
|
||||
|
||||
# NOTE: *Technically* `polars`
|
||||
@unstable
|
||||
def hist_from_bins(
|
||||
self, bins: list[float], *, include_breakpoint: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]:
|
||||
"""`Series.hist(bins=..., bin_count=None)`."""
|
||||
...
|
||||
|
||||
@unstable
|
||||
def hist_from_bin_count(
|
||||
self, bin_count: int, *, include_breakpoint: bool
|
||||
) -> CompliantDataFrame[Self, Any, Any, Any]:
|
||||
"""`Series.hist(bins=None, bin_count=...)`."""
|
||||
...
|
||||
|
||||
|
||||
class EagerSeries(CompliantSeries[NativeSeriesT], Protocol[NativeSeriesT]):
|
||||
_native_series: Any
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
_broadcast: bool
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@classmethod
|
||||
def _align_full_broadcast(cls, *series: Self) -> Sequence[Self]:
|
||||
"""Ensure all of `series` have the same length (and index if `pandas`).
|
||||
|
||||
Scalars get broadcasted to the full length of the longest Series.
|
||||
|
||||
This is useful when you need to construct a full Series anyway, such as:
|
||||
|
||||
DataFrame.select(...)
|
||||
|
||||
It should not be used in binary operations, such as:
|
||||
|
||||
nw.col("a") - nw.col("a").mean()
|
||||
|
||||
because then it's more efficient to extract the right-hand-side's single element as a scalar.
|
||||
"""
|
||||
...
|
||||
|
||||
def _from_scalar(self, value: Any) -> Self:
|
||||
return self.from_iterable([value], name=self.name, context=self)
|
||||
|
||||
def _with_native(
|
||||
self, series: NativeSeriesT, *, preserve_broadcast: bool = False
|
||||
) -> Self:
|
||||
"""Return a new `CompliantSeries`, wrapping the native `series`.
|
||||
|
||||
In cases when operations are known to not affect whether a result should
|
||||
be broadcast, we can pass `preserve_broadcast=True`.
|
||||
Set this with care - it should only be set for unary expressions which don't
|
||||
change length or order, such as `.alias` or `.fill_null`. If in doubt, don't
|
||||
set it, you probably don't need it.
|
||||
"""
|
||||
...
|
||||
|
||||
def __narwhals_namespace__(
|
||||
self,
|
||||
) -> EagerNamespace[Any, Self, Any, Any, NativeSeriesT]: ...
|
||||
|
||||
def _gather(self, rows: SizedMultiIndexSelector[NativeSeriesT]) -> Self: ...
|
||||
def _gather_slice(self, rows: _SliceIndex | range) -> Self: ...
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Self:
|
||||
if isinstance(item, (slice, range)):
|
||||
return self._gather_slice(item)
|
||||
if is_compliant_series(item):
|
||||
return self._gather(item.native)
|
||||
elif is_sized_multi_index_selector(item): # noqa: RET505
|
||||
return self._gather(item)
|
||||
assert_never(item)
|
||||
|
||||
@property
|
||||
def str(self) -> EagerSeriesStringNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def dt(self) -> EagerSeriesDateTimeNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def cat(self) -> EagerSeriesCatNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def list(self) -> EagerSeriesListNamespace[Self, NativeSeriesT]: ...
|
||||
@property
|
||||
def struct(self) -> EagerSeriesStructNamespace[Self, NativeSeriesT]: ...
|
||||
|
||||
|
||||
class _SeriesNamespace( # type: ignore[misc]
|
||||
_StoresCompliant[CompliantSeriesT_co],
|
||||
_StoresNative[NativeSeriesT_co],
|
||||
Protocol[CompliantSeriesT_co, NativeSeriesT_co],
|
||||
):
|
||||
_compliant_series: CompliantSeriesT_co
|
||||
|
||||
@property
|
||||
def compliant(self) -> CompliantSeriesT_co:
|
||||
return self._compliant_series
|
||||
|
||||
@property
|
||||
def implementation(self) -> Implementation:
|
||||
return self.compliant._implementation
|
||||
|
||||
@property
|
||||
def backend_version(self) -> tuple[int, ...]:
|
||||
return self.implementation._backend_version()
|
||||
|
||||
@property
|
||||
def version(self) -> Version:
|
||||
return self.compliant._version
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT_co:
|
||||
return self._compliant_series.native # type: ignore[no-any-return]
|
||||
|
||||
def with_native(self, series: Any, /) -> CompliantSeriesT_co:
|
||||
return self.compliant._with_native(series)
|
||||
|
||||
|
||||
class EagerSeriesNamespace(
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
Generic[EagerSeriesT_co, NativeSeriesT_co],
|
||||
):
|
||||
_compliant_series: EagerSeriesT_co
|
||||
|
||||
def __init__(self, series: EagerSeriesT_co, /) -> None:
|
||||
self._compliant_series = series
|
||||
|
||||
|
||||
class EagerSeriesCatNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
CatNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesDateTimeNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
DateTimeNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesListNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
ListNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesStringNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
StringNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesStructNamespace( # type: ignore[misc]
|
||||
_SeriesNamespace[EagerSeriesT_co, NativeSeriesT_co],
|
||||
StructNamespace[EagerSeriesT_co],
|
||||
Protocol[EagerSeriesT_co, NativeSeriesT_co],
|
||||
): ...
|
||||
|
||||
|
||||
class EagerSeriesHist(Protocol[NativeSeriesT, _CountsT_co]):
|
||||
_series: EagerSeries[NativeSeriesT]
|
||||
_breakpoint: bool
|
||||
_data: HistData[NativeSeriesT, _CountsT_co]
|
||||
|
||||
@property
|
||||
def native(self) -> NativeSeriesT:
|
||||
return self._series.native
|
||||
|
||||
@classmethod
|
||||
def from_series(
|
||||
cls, series: EagerSeries[NativeSeriesT], *, include_breakpoint: bool
|
||||
) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._series = series
|
||||
obj._breakpoint = include_breakpoint
|
||||
return obj
|
||||
|
||||
def to_frame(self) -> EagerDataFrameAny: ...
|
||||
def _linear_space( # NOTE: Roughly `pl.linear_space`
|
||||
self,
|
||||
start: float,
|
||||
end: float,
|
||||
num_samples: int,
|
||||
*,
|
||||
closed: Literal["both", "none"] = "both",
|
||||
) -> _1DArray: ...
|
||||
|
||||
# NOTE: *Could* be handled at narwhals-level
|
||||
def is_empty_series(self) -> bool: ...
|
||||
|
||||
# NOTE: **Should** be handled at narwhals-level
|
||||
def data_empty(self) -> HistData[NativeSeriesT, _CountsT_co]:
|
||||
return {"breakpoint": [], "count": []} if self._breakpoint else {"count": []}
|
||||
|
||||
# NOTE: *Could* be handled at narwhals-level, **iff** we add `nw.repeat`, `nw.linear_space`
|
||||
# See https://github.com/narwhals-dev/narwhals/pull/2839#discussion_r2215630696
|
||||
def series_empty(
|
||||
self, arg: int | list[float], /
|
||||
) -> HistData[NativeSeriesT, _CountsT_co]: ...
|
||||
|
||||
def with_bins(self, bins: list[float], /) -> Self:
|
||||
if len(bins) <= 1:
|
||||
self._data = self.data_empty()
|
||||
elif self.is_empty_series():
|
||||
self._data = self.series_empty(bins)
|
||||
else:
|
||||
self._data = self._calculate_hist(bins)
|
||||
return self
|
||||
|
||||
def with_bin_count(self, bin_count: int, /) -> Self:
|
||||
if bin_count == 0:
|
||||
self._data = self.data_empty()
|
||||
elif self.is_empty_series():
|
||||
self._data = self.series_empty(bin_count)
|
||||
else:
|
||||
self._data = self._calculate_hist(self._calculate_bins(bin_count))
|
||||
return self
|
||||
|
||||
def _calculate_breakpoint(self, arg: int | list[float], /) -> list[float] | _1DArray:
|
||||
bins = self._linear_space(0, 1, arg + 1) if isinstance(arg, int) else arg
|
||||
return bins[1:]
|
||||
|
||||
def _calculate_bins(self, bin_count: int) -> _1DArray: ...
|
||||
def _calculate_hist(
|
||||
self, bins: list[float] | _1DArray
|
||||
) -> HistData[NativeSeriesT, _CountsT_co]: ...
|
||||
206
lib/python3.11/site-packages/narwhals/_compliant/typing.py
Normal file
206
lib/python3.11/site-packages/narwhals/_compliant/typing.py
Normal file
@ -0,0 +1,206 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypedDict, TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.dataframe import (
|
||||
CompliantDataFrame,
|
||||
CompliantFrame,
|
||||
CompliantLazyFrame,
|
||||
EagerDataFrame,
|
||||
)
|
||||
from narwhals._compliant.expr import (
|
||||
CompliantExpr,
|
||||
DepthTrackingExpr,
|
||||
EagerExpr,
|
||||
ImplExpr,
|
||||
LazyExpr,
|
||||
NativeExpr,
|
||||
)
|
||||
from narwhals._compliant.namespace import CompliantNamespace, EagerNamespace
|
||||
from narwhals._compliant.series import CompliantSeries, EagerSeries
|
||||
from narwhals._compliant.window import WindowInputs
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoLazyFrame,
|
||||
ModeKeepStrategy,
|
||||
NativeDataFrame,
|
||||
NativeFrame,
|
||||
NativeSeries,
|
||||
RankMethod,
|
||||
RollingInterpolationMethod,
|
||||
)
|
||||
|
||||
class ScalarKwargs(TypedDict, total=False):
|
||||
"""Non-expressifiable args which we may need to reuse in `agg` or `over`."""
|
||||
|
||||
adjust: bool
|
||||
alpha: float | None
|
||||
center: int
|
||||
com: float | None
|
||||
ddof: int
|
||||
descending: bool
|
||||
half_life: float | None
|
||||
ignore_nulls: bool
|
||||
interpolation: RollingInterpolationMethod
|
||||
keep: ModeKeepStrategy
|
||||
limit: int | None
|
||||
method: RankMethod
|
||||
min_samples: int
|
||||
n: int
|
||||
quantile: float
|
||||
reverse: bool
|
||||
span: float | None
|
||||
strategy: FillNullStrategy | None
|
||||
window_size: int
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AliasName",
|
||||
"AliasNames",
|
||||
"CompliantDataFrameT",
|
||||
"CompliantFrameT",
|
||||
"CompliantLazyFrameT",
|
||||
"CompliantSeriesT",
|
||||
"EvalNames",
|
||||
"EvalSeries",
|
||||
"NarwhalsAggregation",
|
||||
"NativeFrameT_co",
|
||||
"NativeSeriesT_co",
|
||||
]
|
||||
CompliantExprAny: TypeAlias = "CompliantExpr[Any, Any]"
|
||||
CompliantSeriesAny: TypeAlias = "CompliantSeries[Any]"
|
||||
CompliantSeriesOrNativeExprAny: TypeAlias = "CompliantSeriesAny | NativeExpr"
|
||||
CompliantDataFrameAny: TypeAlias = "CompliantDataFrame[Any, Any, Any, Any]"
|
||||
CompliantLazyFrameAny: TypeAlias = "CompliantLazyFrame[Any, Any, Any]"
|
||||
CompliantFrameAny: TypeAlias = "CompliantFrame[Any, Any, Any]"
|
||||
CompliantNamespaceAny: TypeAlias = "CompliantNamespace[Any, Any]"
|
||||
|
||||
ImplExprAny: TypeAlias = "ImplExpr[Any, Any]"
|
||||
|
||||
DepthTrackingExprAny: TypeAlias = "DepthTrackingExpr[Any, Any]"
|
||||
|
||||
EagerDataFrameAny: TypeAlias = "EagerDataFrame[Any, Any, Any, Any]"
|
||||
EagerSeriesAny: TypeAlias = "EagerSeries[Any]"
|
||||
EagerExprAny: TypeAlias = "EagerExpr[Any, Any]"
|
||||
EagerNamespaceAny: TypeAlias = "EagerNamespace[EagerDataFrameAny, EagerSeriesAny, EagerExprAny, NativeFrame, NativeSeries]"
|
||||
|
||||
LazyExprAny: TypeAlias = "LazyExpr[Any, Any]"
|
||||
|
||||
NativeExprT = TypeVar("NativeExprT", bound="NativeExpr")
|
||||
NativeExprT_co = TypeVar("NativeExprT_co", bound="NativeExpr", covariant=True)
|
||||
NativeSeriesT = TypeVar("NativeSeriesT", bound="NativeSeries")
|
||||
NativeSeriesT_co = TypeVar("NativeSeriesT_co", bound="NativeSeries", covariant=True)
|
||||
NativeSeriesT_contra = TypeVar(
|
||||
"NativeSeriesT_contra", bound="NativeSeries", contravariant=True
|
||||
)
|
||||
NativeDataFrameT = TypeVar("NativeDataFrameT", bound="NativeDataFrame")
|
||||
NativeLazyFrameT = TypeVar("NativeLazyFrameT", bound="IntoLazyFrame")
|
||||
NativeFrameT = TypeVar("NativeFrameT", bound="NativeFrame")
|
||||
NativeFrameT_co = TypeVar("NativeFrameT_co", bound="NativeFrame", covariant=True)
|
||||
NativeFrameT_contra = TypeVar(
|
||||
"NativeFrameT_contra", bound="NativeFrame", contravariant=True
|
||||
)
|
||||
|
||||
CompliantExprT = TypeVar("CompliantExprT", bound=CompliantExprAny)
|
||||
CompliantExprT_co = TypeVar("CompliantExprT_co", bound=CompliantExprAny, covariant=True)
|
||||
CompliantExprT_contra = TypeVar(
|
||||
"CompliantExprT_contra", bound=CompliantExprAny, contravariant=True
|
||||
)
|
||||
CompliantSeriesT = TypeVar("CompliantSeriesT", bound=CompliantSeriesAny)
|
||||
CompliantSeriesT_co = TypeVar(
|
||||
"CompliantSeriesT_co", bound=CompliantSeriesAny, covariant=True
|
||||
)
|
||||
CompliantSeriesOrNativeExprT = TypeVar(
|
||||
"CompliantSeriesOrNativeExprT", bound=CompliantSeriesOrNativeExprAny
|
||||
)
|
||||
CompliantSeriesOrNativeExprT_co = TypeVar(
|
||||
"CompliantSeriesOrNativeExprT_co",
|
||||
bound=CompliantSeriesOrNativeExprAny,
|
||||
covariant=True,
|
||||
)
|
||||
CompliantFrameT = TypeVar("CompliantFrameT", bound=CompliantFrameAny)
|
||||
CompliantFrameT_co = TypeVar(
|
||||
"CompliantFrameT_co", bound=CompliantFrameAny, covariant=True
|
||||
)
|
||||
CompliantDataFrameT = TypeVar("CompliantDataFrameT", bound=CompliantDataFrameAny)
|
||||
CompliantDataFrameT_co = TypeVar(
|
||||
"CompliantDataFrameT_co", bound=CompliantDataFrameAny, covariant=True
|
||||
)
|
||||
CompliantLazyFrameT = TypeVar("CompliantLazyFrameT", bound=CompliantLazyFrameAny)
|
||||
CompliantLazyFrameT_co = TypeVar(
|
||||
"CompliantLazyFrameT_co", bound=CompliantLazyFrameAny, covariant=True
|
||||
)
|
||||
CompliantNamespaceT = TypeVar("CompliantNamespaceT", bound=CompliantNamespaceAny)
|
||||
CompliantNamespaceT_co = TypeVar(
|
||||
"CompliantNamespaceT_co", bound=CompliantNamespaceAny, covariant=True
|
||||
)
|
||||
|
||||
ImplExprT_contra = TypeVar("ImplExprT_contra", bound=ImplExprAny, contravariant=True)
|
||||
|
||||
DepthTrackingExprT = TypeVar("DepthTrackingExprT", bound=DepthTrackingExprAny)
|
||||
DepthTrackingExprT_contra = TypeVar(
|
||||
"DepthTrackingExprT_contra", bound=DepthTrackingExprAny, contravariant=True
|
||||
)
|
||||
|
||||
EagerExprT = TypeVar("EagerExprT", bound=EagerExprAny)
|
||||
EagerExprT_contra = TypeVar("EagerExprT_contra", bound=EagerExprAny, contravariant=True)
|
||||
EagerSeriesT = TypeVar("EagerSeriesT", bound=EagerSeriesAny)
|
||||
EagerSeriesT_co = TypeVar("EagerSeriesT_co", bound=EagerSeriesAny, covariant=True)
|
||||
|
||||
# NOTE: `pyright` gives false (8) positives if this uses `EagerDataFrameAny`?
|
||||
EagerDataFrameT = TypeVar("EagerDataFrameT", bound="EagerDataFrame[Any, Any, Any, Any]")
|
||||
|
||||
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
|
||||
LazyExprT_contra = TypeVar("LazyExprT_contra", bound=LazyExprAny, contravariant=True)
|
||||
|
||||
AliasNames: TypeAlias = Callable[[Sequence[str]], Sequence[str]]
|
||||
"""A function aliasing a *sequence* of column names."""
|
||||
|
||||
AliasName: TypeAlias = Callable[[str], str]
|
||||
"""A function aliasing a *single* column name."""
|
||||
|
||||
EvalSeries: TypeAlias = Callable[
|
||||
[CompliantFrameT], Sequence[CompliantSeriesOrNativeExprT]
|
||||
]
|
||||
"""A function from a `Frame` to a sequence of `Series`*.
|
||||
|
||||
See [underwater unicorn magic](https://narwhals-dev.github.io/narwhals/how_it_works/).
|
||||
"""
|
||||
|
||||
EvalNames: TypeAlias = Callable[[CompliantFrameT], Sequence[str]]
|
||||
"""A function from a `Frame` to a sequence of columns names *before* any aliasing takes place."""
|
||||
|
||||
WindowFunction: TypeAlias = (
|
||||
"Callable[[CompliantFrameT, WindowInputs[NativeExprT]], Sequence[NativeExprT]]"
|
||||
)
|
||||
"""A function evaluated with `over(partition_by=..., order_by=...)`."""
|
||||
|
||||
NarwhalsAggregation: TypeAlias = Literal[
|
||||
"sum",
|
||||
"mean",
|
||||
"median",
|
||||
"max",
|
||||
"min",
|
||||
"mode",
|
||||
"std",
|
||||
"var",
|
||||
"len",
|
||||
"n_unique",
|
||||
"count",
|
||||
"quantile",
|
||||
"all",
|
||||
"any",
|
||||
]
|
||||
"""`Expr` methods we aim to support in `DepthTrackingGroupBy`.
|
||||
|
||||
Be sure to update me if you're working on one of these:
|
||||
- https://github.com/narwhals-dev/narwhals/issues/981
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2385
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2484
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2526
|
||||
- https://github.com/narwhals-dev/narwhals/issues/2660
|
||||
"""
|
||||
130
lib/python3.11/site-packages/narwhals/_compliant/when_then.py
Normal file
130
lib/python3.11/site-packages/narwhals/_compliant/when_then.py
Normal file
@ -0,0 +1,130 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Protocol, TypeVar, cast
|
||||
|
||||
from narwhals._compliant.expr import CompliantExpr
|
||||
from narwhals._compliant.typing import (
|
||||
CompliantExprAny,
|
||||
CompliantFrameAny,
|
||||
CompliantSeriesOrNativeExprAny,
|
||||
EagerDataFrameT,
|
||||
EagerExprT,
|
||||
EagerSeriesT,
|
||||
LazyExprAny,
|
||||
NativeSeriesT,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import EvalSeries
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
__all__ = ["CompliantThen", "CompliantWhen", "EagerWhen"]
|
||||
|
||||
ExprT = TypeVar("ExprT", bound=CompliantExprAny)
|
||||
LazyExprT = TypeVar("LazyExprT", bound=LazyExprAny)
|
||||
SeriesT = TypeVar("SeriesT", bound=CompliantSeriesOrNativeExprAny)
|
||||
FrameT = TypeVar("FrameT", bound=CompliantFrameAny)
|
||||
|
||||
Scalar: TypeAlias = Any
|
||||
"""A native literal value."""
|
||||
|
||||
IntoExpr: TypeAlias = "SeriesT | ExprT | NonNestedLiteral | Scalar"
|
||||
"""Anything that is convertible into a `CompliantExpr`."""
|
||||
|
||||
|
||||
class CompliantWhen(Protocol[FrameT, SeriesT, ExprT]):
|
||||
_condition: ExprT
|
||||
_then_value: IntoExpr[SeriesT, ExprT]
|
||||
_otherwise_value: IntoExpr[SeriesT, ExprT] | None
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _then(self) -> type[CompliantThen[FrameT, SeriesT, ExprT, Self]]: ...
|
||||
def __call__(self, compliant_frame: FrameT, /) -> Sequence[SeriesT]: ...
|
||||
def then(
|
||||
self, value: IntoExpr[SeriesT, ExprT], /
|
||||
) -> CompliantThen[FrameT, SeriesT, ExprT, Self]:
|
||||
return self._then.from_when(self, value)
|
||||
|
||||
@classmethod
|
||||
def from_expr(cls, condition: ExprT, /, *, context: _LimitedContext) -> Self:
|
||||
obj = cls.__new__(cls)
|
||||
obj._condition = condition
|
||||
obj._then_value = None
|
||||
obj._otherwise_value = None
|
||||
obj._implementation = context._implementation
|
||||
obj._version = context._version
|
||||
return obj
|
||||
|
||||
|
||||
WhenT_contra = TypeVar(
|
||||
"WhenT_contra", bound=CompliantWhen[Any, Any, Any], contravariant=True
|
||||
)
|
||||
|
||||
|
||||
class CompliantThen(
|
||||
CompliantExpr[FrameT, SeriesT], Protocol[FrameT, SeriesT, ExprT, WhenT_contra]
|
||||
):
|
||||
_call: EvalSeries[FrameT, SeriesT]
|
||||
_when_value: CompliantWhen[FrameT, SeriesT, ExprT]
|
||||
_implementation: Implementation
|
||||
_version: Version
|
||||
|
||||
@classmethod
|
||||
def from_when(cls, when: WhenT_contra, then: IntoExpr[SeriesT, ExprT], /) -> Self:
|
||||
when._then_value = then
|
||||
obj = cls.__new__(cls)
|
||||
obj._call = when
|
||||
obj._when_value = when
|
||||
obj._evaluate_output_names = getattr(
|
||||
then, "_evaluate_output_names", lambda _df: ["literal"]
|
||||
)
|
||||
obj._alias_output_names = getattr(then, "_alias_output_names", None)
|
||||
obj._implementation = when._implementation
|
||||
obj._version = when._version
|
||||
return obj
|
||||
|
||||
def otherwise(self, otherwise: IntoExpr[SeriesT, ExprT], /) -> ExprT:
|
||||
self._when_value._otherwise_value = otherwise
|
||||
return cast("ExprT", self)
|
||||
|
||||
|
||||
class EagerWhen(
|
||||
CompliantWhen[EagerDataFrameT, EagerSeriesT, EagerExprT],
|
||||
Protocol[EagerDataFrameT, EagerSeriesT, EagerExprT, NativeSeriesT],
|
||||
):
|
||||
def _if_then_else(
|
||||
self,
|
||||
when: NativeSeriesT,
|
||||
then: NativeSeriesT,
|
||||
otherwise: NativeSeriesT | NonNestedLiteral | Scalar,
|
||||
/,
|
||||
) -> NativeSeriesT: ...
|
||||
|
||||
def __call__(self, df: EagerDataFrameT, /) -> Sequence[EagerSeriesT]:
|
||||
is_expr = self._condition._is_expr
|
||||
when: EagerSeriesT = self._condition(df)[0]
|
||||
then: EagerSeriesT
|
||||
align = when._align_full_broadcast
|
||||
|
||||
if is_expr(self._then_value):
|
||||
then = self._then_value(df)[0]
|
||||
else:
|
||||
then = when.alias("literal")._from_scalar(self._then_value)
|
||||
then._broadcast = True
|
||||
|
||||
if is_expr(self._otherwise_value):
|
||||
otherwise = self._otherwise_value(df)[0]
|
||||
when, then, otherwise = align(when, then, otherwise)
|
||||
result = self._if_then_else(when.native, then.native, otherwise.native)
|
||||
else:
|
||||
when, then = align(when, then)
|
||||
result = self._if_then_else(when.native, then.native, self._otherwise_value)
|
||||
return [then._with_native(result)]
|
||||
20
lib/python3.11/site-packages/narwhals/_compliant/window.py
Normal file
20
lib/python3.11/site-packages/narwhals/_compliant/window.py
Normal file
@ -0,0 +1,20 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Generic
|
||||
|
||||
from narwhals._compliant.typing import NativeExprT_co
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
__all__ = ["WindowInputs"]
|
||||
|
||||
|
||||
class WindowInputs(Generic[NativeExprT_co]):
|
||||
__slots__ = ("order_by", "partition_by")
|
||||
|
||||
def __init__(
|
||||
self, partition_by: Sequence[str | NativeExprT_co], order_by: Sequence[str]
|
||||
) -> None:
|
||||
self.partition_by = partition_by
|
||||
self.order_by = order_by
|
||||
30
lib/python3.11/site-packages/narwhals/_constants.py
Normal file
30
lib/python3.11/site-packages/narwhals/_constants.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
|
||||
# Temporal (from `polars._utils.constants`)
|
||||
SECONDS_PER_DAY = 86_400
|
||||
SECONDS_PER_MINUTE = 60
|
||||
NS_PER_MINUTE = 60_000_000_000
|
||||
"""Nanoseconds (`[ns]`) per minute."""
|
||||
US_PER_MINUTE = 60_000_000
|
||||
"""Microseconds (`[μs]`) per minute."""
|
||||
MS_PER_MINUTE = 60_000
|
||||
"""Milliseconds (`[ms]`) per minute."""
|
||||
NS_PER_SECOND = 1_000_000_000
|
||||
"""Nanoseconds (`[ns]`) per second (`[s]`)."""
|
||||
US_PER_SECOND = 1_000_000
|
||||
"""Microseconds (`[μs]`) per second (`[s]`)."""
|
||||
MS_PER_SECOND = 1_000
|
||||
"""Milliseconds (`[ms]`) per second (`[s]`)."""
|
||||
NS_PER_MICROSECOND = 1_000
|
||||
"""Nanoseconds (`[ns]`) per microsecond (`[μs]`)."""
|
||||
NS_PER_MILLISECOND = 1_000_000
|
||||
"""Nanoseconds (`[ns]`) per millisecond (`[ms]`).
|
||||
|
||||
From [polars](https://github.com/pola-rs/polars/blob/2c7a3e77f0faa37c86a3745db4ef7707ae50c72e/crates/polars-time/src/chunkedarray/duration.rs#L7).
|
||||
"""
|
||||
EPOCH_YEAR = 1970
|
||||
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""
|
||||
EPOCH = dt.datetime(EPOCH_YEAR, 1, 1).replace(tzinfo=None)
|
||||
"""See [Unix time](https://en.wikipedia.org/wiki/Unix_time)."""
|
||||
502
lib/python3.11/site-packages/narwhals/_dask/dataframe.py
Normal file
502
lib/python3.11/site-packages/narwhals/_dask/dataframe.py
Normal file
@ -0,0 +1,502 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._dask.utils import add_row_index, evaluate_exprs
|
||||
from narwhals._expression_parsing import ExprKind
|
||||
from narwhals._pandas_like.utils import native_to_narwhals_dtype, select_columns_by_name
|
||||
from narwhals._typing_compat import assert_never
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
_remap_full_join_keys,
|
||||
check_column_names_are_unique,
|
||||
check_columns_exist,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.typing import CompliantLazyFrame
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._typing import _EagerAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
Incomplete: TypeAlias = "Any"
|
||||
"""Using `_pandas_like` utils with `_dask`.
|
||||
|
||||
Typing this correctly will complicate the `_pandas_like`-side.
|
||||
Very low priority until `dask` adds typing.
|
||||
"""
|
||||
|
||||
|
||||
class DaskLazyFrame(
|
||||
CompliantLazyFrame["DaskExpr", "dd.DataFrame", "LazyFrame[dd.DataFrame]"],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
_implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: dd.DataFrame,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: dd.DataFrame = native_dataframe
|
||||
self._version = version
|
||||
self._cached_schema: dict[str, DType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: dd.DataFrame | Any) -> TypeIs[dd.DataFrame]:
|
||||
return isinstance(obj, dd.DataFrame)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: dd.DataFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[dd.DataFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.DASK:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected dask, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace:
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
def _with_native(self, df: Any) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def _check_columns_exist(self, subset: Sequence[str]) -> ColumnNotFoundError | None:
|
||||
return check_columns_exist(subset, available=self.columns)
|
||||
|
||||
def _iter_columns(self) -> Iterator[dx.Series]:
|
||||
for _col, ser in self.native.items(): # noqa: PERF102
|
||||
yield ser
|
||||
|
||||
def with_columns(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
return self._with_native(self.native.assign(**dict(new_series)))
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
result = self.native.compute(**kwargs)
|
||||
|
||||
if backend is None or backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
result,
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
pl.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend is Implementation.PYARROW:
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
pa.Table.from_pandas(result),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_schema is not None
|
||||
else self.native.columns.tolist()
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def filter(self, predicate: DaskExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask = predicate(self)[0]
|
||||
return self._with_native(self.native.loc[mask])
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
df: Incomplete = self.native
|
||||
native = select_columns_by_name(df, list(column_names), self._implementation)
|
||||
return self._with_native(native)
|
||||
|
||||
def aggregate(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df = dd.concat([val.rename(name) for name, val in new_series], axis=1)
|
||||
return self._with_native(df)
|
||||
|
||||
def select(self, *exprs: DaskExpr) -> Self:
|
||||
new_series = evaluate_exprs(self, *exprs)
|
||||
df: Incomplete = self.native
|
||||
df = select_columns_by_name(
|
||||
df.assign(**dict(new_series)),
|
||||
[s[0] for s in new_series],
|
||||
self._implementation,
|
||||
)
|
||||
return self._with_native(df)
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
if subset is None:
|
||||
return self._with_native(self.native.dropna())
|
||||
plx = self.__narwhals_namespace__()
|
||||
mask = ~plx.any_horizontal(plx.col(*subset).is_null(), ignore_nulls=True)
|
||||
return self.filter(mask)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_schema is None:
|
||||
native_dtypes = self.native.dtypes
|
||||
self._cached_schema = {
|
||||
col: native_to_narwhals_dtype(
|
||||
native_dtypes[col], self._version, self._implementation
|
||||
)
|
||||
for col in self.native.columns
|
||||
}
|
||||
return self._cached_schema
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
|
||||
return self._with_native(self.native.drop(columns=to_drop))
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
# Implementation is based on the following StackOverflow reply:
|
||||
# https://stackoverflow.com/questions/60831518/in-dask-how-does-one-add-a-range-of-integersauto-increment-to-a-new-column/60852409#60852409
|
||||
if order_by is None:
|
||||
return self._with_native(add_row_index(self.native, name))
|
||||
plx = self.__narwhals_namespace__()
|
||||
columns = self.columns
|
||||
const_expr = plx.lit(value=1, dtype=None).alias(name).broadcast(ExprKind.LITERAL)
|
||||
row_index_expr = (
|
||||
plx.col(name).cum_sum(reverse=False).over(partition_by=[], order_by=order_by)
|
||||
- 1
|
||||
)
|
||||
return self.with_columns(const_expr).select(row_index_expr, plx.col(*columns))
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
return self._with_native(self.native.rename(columns=mapping))
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.head(n=n, compute=False, npartitions=-1))
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
if keep == "none":
|
||||
subset = subset or self.columns
|
||||
token = generate_temporary_column_name(n_bytes=8, columns=subset)
|
||||
ser = self.native.groupby(subset).size().rename(token)
|
||||
ser = ser[ser == 1]
|
||||
unique = ser.reset_index().drop(columns=token)
|
||||
result = self.native.merge(unique, on=subset, how="inner")
|
||||
else:
|
||||
mapped_keep = {"any": "first"}.get(keep, keep)
|
||||
result = self.native.drop_duplicates(subset=subset, keep=mapped_keep)
|
||||
return self._with_native(result)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
ascending: bool | list[bool] = not descending
|
||||
else:
|
||||
ascending = [not d for d in descending]
|
||||
position = "last" if nulls_last else "first"
|
||||
return self._with_native(
|
||||
self.native.sort_values(list(by), ascending=ascending, na_position=position)
|
||||
)
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
df = self.native
|
||||
schema = self.schema
|
||||
by = list(by)
|
||||
if isinstance(reverse, bool) and all(schema[x].is_numeric() for x in by):
|
||||
if reverse:
|
||||
return self._with_native(df.nsmallest(k, by))
|
||||
return self._with_native(df.nlargest(k, by))
|
||||
if isinstance(reverse, bool):
|
||||
reverse = [reverse] * len(by)
|
||||
return self._with_native(
|
||||
df.sort_values(by, ascending=list(reverse)).head(
|
||||
n=k, compute=False, npartitions=-1
|
||||
)
|
||||
)
|
||||
|
||||
def _join_inner(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
return self.native.merge(
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
how="inner",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_left(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
result_native = self.native.merge(
|
||||
other.native,
|
||||
how="left",
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
extra = [
|
||||
right_key if right_key not in self.columns else f"{right_key}{suffix}"
|
||||
for left_key, right_key in zip_strict(left_on, right_on)
|
||||
if right_key != left_key
|
||||
]
|
||||
return result_native.drop(columns=extra)
|
||||
|
||||
def _join_full(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str], suffix: str
|
||||
) -> dd.DataFrame:
|
||||
# dask does not retain keys post-join
|
||||
# we must append the suffix to each key before-hand
|
||||
|
||||
right_on_mapper = _remap_full_join_keys(left_on, right_on, suffix)
|
||||
other_native = other.native.rename(columns=right_on_mapper)
|
||||
check_column_names_are_unique(other_native.columns)
|
||||
right_suffixed = list(right_on_mapper.values())
|
||||
return self.native.merge(
|
||||
other_native,
|
||||
left_on=left_on,
|
||||
right_on=right_suffixed,
|
||||
how="outer",
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
|
||||
def _join_cross(self, other: Self, *, suffix: str) -> dd.DataFrame:
|
||||
key_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
return (
|
||||
self.native.assign(**{key_token: 0})
|
||||
.merge(
|
||||
other.native.assign(**{key_token: 0}),
|
||||
how="inner",
|
||||
left_on=key_token,
|
||||
right_on=key_token,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
.drop(columns=key_token)
|
||||
)
|
||||
|
||||
def _join_semi(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
return self.native.merge(
|
||||
other_native, how="inner", left_on=left_on, right_on=left_on
|
||||
)
|
||||
|
||||
def _join_anti(
|
||||
self, other: Self, *, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> dd.DataFrame:
|
||||
indicator_token = generate_temporary_column_name(
|
||||
n_bytes=8, columns=(*self.columns, *other.columns)
|
||||
)
|
||||
other_native = self._join_filter_rename(
|
||||
other=other,
|
||||
columns_to_select=list(right_on),
|
||||
columns_mapping=dict(zip(right_on, left_on)),
|
||||
)
|
||||
df = self.native.merge(
|
||||
other_native,
|
||||
how="left",
|
||||
indicator=indicator_token, # pyright: ignore[reportArgumentType]
|
||||
left_on=left_on,
|
||||
right_on=left_on,
|
||||
)
|
||||
return df[df[indicator_token] == "left_only"].drop(columns=[indicator_token])
|
||||
|
||||
def _join_filter_rename(
|
||||
self, other: Self, columns_to_select: list[str], columns_mapping: dict[str, str]
|
||||
) -> dd.DataFrame:
|
||||
"""Helper function to avoid creating extra columns and row duplication.
|
||||
|
||||
Used in `"anti"` and `"semi`" join's.
|
||||
|
||||
Notice that a native object is returned.
|
||||
"""
|
||||
other_native: Incomplete = other.native
|
||||
# rename to avoid creating extra columns in join
|
||||
return (
|
||||
select_columns_by_name(other_native, columns_to_select, self._implementation)
|
||||
.rename(columns=columns_mapping)
|
||||
.drop_duplicates()
|
||||
)
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
if how == "cross":
|
||||
result = self._join_cross(other=other, suffix=suffix)
|
||||
|
||||
elif left_on is None or right_on is None: # pragma: no cover
|
||||
raise ValueError(left_on, right_on)
|
||||
|
||||
elif how == "inner":
|
||||
result = self._join_inner(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "anti":
|
||||
result = self._join_anti(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "semi":
|
||||
result = self._join_semi(other=other, left_on=left_on, right_on=right_on)
|
||||
elif how == "left":
|
||||
result = self._join_left(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
elif how == "full":
|
||||
result = self._join_full(
|
||||
other=other, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
else:
|
||||
assert_never(how)
|
||||
return self._with_native(result)
|
||||
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
plx = self.__native_namespace__()
|
||||
return self._with_native(
|
||||
plx.merge_asof(
|
||||
self.native,
|
||||
other.native,
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
left_by=by_left,
|
||||
right_by=by_right,
|
||||
direction=strategy,
|
||||
suffixes=("", suffix),
|
||||
)
|
||||
)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[DaskExpr], *, drop_null_keys: bool
|
||||
) -> DaskLazyGroupBy:
|
||||
from narwhals._dask.group_by import DaskLazyGroupBy
|
||||
|
||||
return DaskLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def tail(self, n: int) -> Self: # pragma: no cover
|
||||
native_frame = self.native
|
||||
n_partitions = native_frame.npartitions
|
||||
|
||||
if n_partitions == 1:
|
||||
return self._with_native(self.native.tail(n=n, compute=False))
|
||||
msg = (
|
||||
"`LazyFrame.tail` is not supported for Dask backend with multiple partitions."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def gather_every(self, n: int, offset: int) -> Self:
|
||||
row_index_token = generate_temporary_column_name(n_bytes=8, columns=self.columns)
|
||||
plx = self.__narwhals_namespace__()
|
||||
return (
|
||||
self.with_row_index(row_index_token, order_by=None)
|
||||
.filter(
|
||||
(plx.col(row_index_token) >= offset)
|
||||
& ((plx.col(row_index_token) - offset) % n == 0)
|
||||
)
|
||||
.drop([row_index_token], strict=False)
|
||||
)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
return self._with_native(
|
||||
self.native.melt(
|
||||
id_vars=index,
|
||||
value_vars=on,
|
||||
var_name=variable_name,
|
||||
value_name=value_name,
|
||||
)
|
||||
)
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
self.native.to_parquet(file)
|
||||
|
||||
explode = not_implemented()
|
||||
701
lib/python3.11/site-packages/narwhals/_dask/expr.py
Normal file
701
lib/python3.11/site-packages/narwhals/_dask/expr.py
Normal file
@ -0,0 +1,701 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._compliant import DepthTrackingExpr, LazyExpr
|
||||
from narwhals._dask.expr_dt import DaskExprDateTimeNamespace
|
||||
from narwhals._dask.expr_str import DaskExprStringNamespace
|
||||
from narwhals._dask.utils import (
|
||||
add_row_index,
|
||||
maybe_evaluate_expr,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, evaluate_output_names_and_aliases
|
||||
from narwhals._pandas_like.utils import get_dtype_backend, native_to_narwhals_dtype
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
)
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
ModeKeepStrategy,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
RollingInterpolationMethod,
|
||||
TemporalLiteral,
|
||||
)
|
||||
|
||||
|
||||
class DaskExpr(
|
||||
LazyExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
|
||||
DepthTrackingExpr["DaskLazyFrame", "dx.Series"], # pyright: ignore[reportInvalidTypeArguments]
|
||||
):
|
||||
_implementation: Implementation = Implementation.DASK
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[DaskLazyFrame, dx.Series], # pyright: ignore[reportInvalidTypeForm]
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[DaskLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self._call(df)
|
||||
|
||||
def __narwhals_namespace__(self) -> DaskNamespace: # pragma: no cover
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
return DaskNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
# result.loc[0][0] is a workaround for dask~<=2024.10.0/dask_expr~<=1.1.16
|
||||
# that raised a KeyError for result[0] during collection.
|
||||
return [result.loc[0][0] for result in self(df)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[DaskLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
try:
|
||||
return [
|
||||
df._native_frame[column_name]
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
return [df.native.iloc[:, i] for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def _with_callable(
|
||||
self,
|
||||
# First argument to `call` should be `dx.Series`
|
||||
call: Callable[..., dx.Series],
|
||||
/,
|
||||
expr_name: str = "",
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
**expressifiable_args: Self | Any,
|
||||
) -> Self:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
native_results: list[dx.Series] = []
|
||||
native_series_list = self._call(df)
|
||||
other_native_series = {
|
||||
key: maybe_evaluate_expr(df, value)
|
||||
for key, value in expressifiable_args.items()
|
||||
}
|
||||
for native_series in native_series_list:
|
||||
result_native = call(native_series, **other_native_series)
|
||||
native_results.append(result_native)
|
||||
return native_results
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=f"{self._function_name}->{expr_name}",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=scalar_kwargs,
|
||||
)
|
||||
|
||||
def _with_alias_output_names(self, func: AliasNames | None, /) -> Self:
|
||||
current_alias_output_names = self._alias_output_names
|
||||
alias_output_names = (
|
||||
None
|
||||
if func is None
|
||||
else func
|
||||
if current_alias_output_names is None
|
||||
else lambda output_names: func(current_alias_output_names(output_names))
|
||||
)
|
||||
return type(self)(
|
||||
call=self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=alias_output_names,
|
||||
version=self._version,
|
||||
scalar_kwargs=self._scalar_kwargs,
|
||||
)
|
||||
|
||||
def _with_binary(
|
||||
self,
|
||||
call: Callable[[dx.Series, Any], dx.Series],
|
||||
name: str,
|
||||
other: Any,
|
||||
*,
|
||||
reverse: bool = False,
|
||||
) -> Self:
|
||||
result = self._with_callable(
|
||||
lambda expr, other: call(expr, other), name, other=other
|
||||
)
|
||||
if reverse:
|
||||
result = result.alias("literal")
|
||||
return result
|
||||
|
||||
def _binary_op(self, op_name: str, other: Any) -> Self:
|
||||
return self._with_binary(
|
||||
lambda expr, other: getattr(expr, op_name)(other), op_name, other
|
||||
)
|
||||
|
||||
def _reverse_binary_op(
|
||||
self, op_name: str, operator_func: Callable[..., dx.Series], other: Any
|
||||
) -> Self:
|
||||
return self._with_binary(
|
||||
lambda expr, other: operator_func(other, expr), op_name, other, reverse=True
|
||||
)
|
||||
|
||||
def __add__(self, other: Any) -> Self:
|
||||
return self._binary_op("__add__", other)
|
||||
|
||||
def __sub__(self, other: Any) -> Self:
|
||||
return self._binary_op("__sub__", other)
|
||||
|
||||
def __mul__(self, other: Any) -> Self:
|
||||
return self._binary_op("__mul__", other)
|
||||
|
||||
def __truediv__(self, other: Any) -> Self:
|
||||
return self._binary_op("__truediv__", other)
|
||||
|
||||
def __floordiv__(self, other: Any) -> Self:
|
||||
return self._binary_op("__floordiv__", other)
|
||||
|
||||
def __pow__(self, other: Any) -> Self:
|
||||
return self._binary_op("__pow__", other)
|
||||
|
||||
def __mod__(self, other: Any) -> Self:
|
||||
return self._binary_op("__mod__", other)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._binary_op("__eq__", other)
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._binary_op("__ne__", other)
|
||||
|
||||
def __ge__(self, other: Any) -> Self:
|
||||
return self._binary_op("__ge__", other)
|
||||
|
||||
def __gt__(self, other: Any) -> Self:
|
||||
return self._binary_op("__gt__", other)
|
||||
|
||||
def __le__(self, other: Any) -> Self:
|
||||
return self._binary_op("__le__", other)
|
||||
|
||||
def __lt__(self, other: Any) -> Self:
|
||||
return self._binary_op("__lt__", other)
|
||||
|
||||
def __and__(self, other: Any) -> Self:
|
||||
return self._binary_op("__and__", other)
|
||||
|
||||
def __or__(self, other: Any) -> Self:
|
||||
return self._binary_op("__or__", other)
|
||||
|
||||
def __rsub__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rsub__", lambda a, b: a - b, other)
|
||||
|
||||
def __rtruediv__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rtruediv__", lambda a, b: a / b, other)
|
||||
|
||||
def __rfloordiv__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rfloordiv__", lambda a, b: a // b, other)
|
||||
|
||||
def __rpow__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rpow__", lambda a, b: a**b, other)
|
||||
|
||||
def __rmod__(self, other: Any) -> Self:
|
||||
return self._reverse_binary_op("__rmod__", lambda a, b: a % b, other)
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.__invert__(), "__invert__")
|
||||
|
||||
def mean(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.mean().to_series(), "mean")
|
||||
|
||||
def median(self) -> Self:
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
def func(s: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(s.dtype, self._version, Implementation.DASK)
|
||||
if not dtype.is_numeric():
|
||||
msg = "`median` operation not supported for non-numeric input type."
|
||||
raise InvalidOperationError(msg)
|
||||
return s.median_approximate().to_series()
|
||||
|
||||
return self._with_callable(func, "median")
|
||||
|
||||
def min(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.min().to_series(), "min")
|
||||
|
||||
def max(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.max().to_series(), "max")
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.std(ddof=ddof).to_series(),
|
||||
"std",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.var(ddof=ddof).to_series(),
|
||||
"var",
|
||||
scalar_kwargs={"ddof": ddof},
|
||||
)
|
||||
|
||||
def skew(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.skew().to_series(), "skew")
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.kurtosis().to_series(), "kurtosis")
|
||||
|
||||
def shift(self, n: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.shift(n), "shift")
|
||||
|
||||
def cum_sum(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
# https://github.com/dask/dask/issues/11802
|
||||
msg = "`cum_sum(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumsum(), "cum_sum")
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_count(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(
|
||||
lambda expr: (~expr.isna()).astype(int).cumsum(), "cum_count"
|
||||
)
|
||||
|
||||
def cum_min(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_min(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummin(), "cum_min")
|
||||
|
||||
def cum_max(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_max(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cummax(), "cum_max")
|
||||
|
||||
def cum_prod(self, *, reverse: bool) -> Self:
|
||||
if reverse: # pragma: no cover
|
||||
msg = "`cum_prod(reverse=True)` is not supported with Dask backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(lambda expr: expr.cumprod(), "cum_prod")
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).sum(),
|
||||
"rolling_sum",
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).mean(),
|
||||
"rolling_mean",
|
||||
)
|
||||
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).var(),
|
||||
"rolling_var",
|
||||
)
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_var`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
if ddof == 1:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.rolling(
|
||||
window=window_size, min_periods=min_samples, center=center
|
||||
).std(),
|
||||
"rolling_std",
|
||||
)
|
||||
msg = "Dask backend only supports `ddof=1` for `rolling_std`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def sum(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.sum().to_series(), "sum")
|
||||
|
||||
def count(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.count().to_series(), "count")
|
||||
|
||||
def round(self, decimals: int) -> Self:
|
||||
return self._with_callable(lambda expr: expr.round(decimals), "round")
|
||||
|
||||
def unique(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.unique(), "unique")
|
||||
|
||||
def drop_nulls(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.dropna(), "drop_nulls")
|
||||
|
||||
def abs(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.abs(), "abs")
|
||||
|
||||
def all(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.all(
|
||||
axis=None, skipna=True, split_every=False, out=None
|
||||
).to_series(),
|
||||
"all",
|
||||
)
|
||||
|
||||
def any(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.any(axis=0, skipna=True, split_every=False).to_series(),
|
||||
"any",
|
||||
)
|
||||
|
||||
def fill_nan(self, value: float | None) -> Self:
|
||||
value_nullable = pd.NA if value is None else value
|
||||
value_numpy = float("nan") if value is None else value
|
||||
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
# If/when pandas exposes an API which distinguishes NaN vs null, use that.
|
||||
mask = cast("dx.Series", expr != expr) # noqa: PLR0124
|
||||
mask = mask.fillna(False)
|
||||
fill = (
|
||||
value_nullable
|
||||
if get_dtype_backend(expr.dtype, self._implementation)
|
||||
else value_numpy
|
||||
)
|
||||
return expr.mask(mask, fill) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return self._with_callable(func, "fill_nan")
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
if value is not None:
|
||||
res_ser = expr.fillna(value)
|
||||
else:
|
||||
res_ser = (
|
||||
expr.ffill(limit=limit)
|
||||
if strategy == "forward"
|
||||
else expr.bfill(limit=limit)
|
||||
)
|
||||
return res_ser
|
||||
|
||||
return self._with_callable(func, "fill_null")
|
||||
|
||||
def clip(
|
||||
self,
|
||||
lower_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
upper_bound: Self | NumericLiteral | TemporalLiteral | None,
|
||||
) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr, lower_bound, upper_bound: expr.clip(
|
||||
lower=lower_bound, upper=upper_bound
|
||||
),
|
||||
"clip",
|
||||
lower_bound=lower_bound,
|
||||
upper_bound=upper_bound,
|
||||
)
|
||||
|
||||
def diff(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.diff(), "diff")
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.nunique(dropna=False).to_series(), "n_unique"
|
||||
)
|
||||
|
||||
def is_null(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isna(), "is_null")
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
expr.dtype, self._version, self._implementation
|
||||
)
|
||||
if dtype.is_numeric():
|
||||
return expr != expr # pyright: ignore[reportReturnType] # noqa: PLR0124
|
||||
msg = f"`.is_nan` only supported for numeric dtypes and not {dtype}, did you mean `.is_null`?"
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return self._with_callable(func, "is_null")
|
||||
|
||||
def len(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.size.to_series(), "len")
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
if interpolation == "linear":
|
||||
|
||||
def func(expr: dx.Series, quantile: float) -> dx.Series:
|
||||
if expr.npartitions > 1:
|
||||
msg = "`Expr.quantile` is not supported for Dask backend with multiple partitions."
|
||||
raise NotImplementedError(msg)
|
||||
return expr.quantile(
|
||||
q=quantile, method="dask"
|
||||
).to_series() # pragma: no cover
|
||||
|
||||
return self._with_callable(func, "quantile", quantile=quantile)
|
||||
msg = "`higher`, `lower`, `midpoint`, `nearest` - interpolation methods are not supported by Dask. Please use `linear` instead."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def is_first_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
first_distinct_index = frame.groupby(_name).agg({col_token: "min"})[col_token]
|
||||
return frame[col_token].isin(first_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_first_distinct")
|
||||
|
||||
def is_last_distinct(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
col_token = generate_temporary_column_name(n_bytes=8, columns=[_name])
|
||||
frame = add_row_index(expr.to_frame(), col_token)
|
||||
last_distinct_index = frame.groupby(_name).agg({col_token: "max"})[col_token]
|
||||
return frame[col_token].isin(last_distinct_index)
|
||||
|
||||
return self._with_callable(func, "is_last_distinct")
|
||||
|
||||
def is_unique(self) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
return (
|
||||
expr.to_frame()
|
||||
.groupby(_name, dropna=False)
|
||||
.transform("size", meta=(_name, int))
|
||||
== 1
|
||||
)
|
||||
|
||||
return self._with_callable(func, "is_unique")
|
||||
|
||||
def is_in(self, other: Any) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isin(other), "is_in")
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.isna().sum().to_series(), "null_count"
|
||||
)
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
# pandas is a required dependency of dask so it's safe to import this
|
||||
from narwhals._pandas_like.group_by import PandasLikeGroupBy
|
||||
|
||||
if not partition_by:
|
||||
assert order_by # noqa: S101
|
||||
|
||||
# This is something like `nw.col('a').cum_sum().order_by(key)`
|
||||
# which we can always easily support, as it doesn't require grouping.
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
return self(df.sort(*order_by, descending=False, nulls_last=False))
|
||||
elif not self._is_elementary(): # pragma: no cover
|
||||
msg = (
|
||||
"Only elementary expressions are supported for `.over` in dask.\n\n"
|
||||
"Please see: "
|
||||
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
elif order_by:
|
||||
# Wrong results https://github.com/dask/dask/issues/11806.
|
||||
msg = "`over` with `order_by` is not yet supported in Dask."
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
function_name = PandasLikeGroupBy._leaf_name(self)
|
||||
try:
|
||||
dask_function_name = PandasLikeGroupBy._REMAP_AGGS[function_name]
|
||||
except KeyError:
|
||||
# window functions are unsupported: https://github.com/dask/dask/issues/11806
|
||||
msg = (
|
||||
f"Unsupported function: {function_name} in `over` context.\n\n"
|
||||
f"Supported functions are {', '.join(PandasLikeGroupBy._REMAP_AGGS)}\n"
|
||||
)
|
||||
raise NotImplementedError(msg) from None
|
||||
|
||||
def func(df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
|
||||
with warnings.catch_warnings():
|
||||
# https://github.com/dask/dask/issues/11804
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*`meta` is not specified",
|
||||
category=UserWarning,
|
||||
)
|
||||
grouped = df.native.groupby(partition_by)
|
||||
if dask_function_name == "size":
|
||||
if len(output_names) != 1: # pragma: no cover
|
||||
msg = "Safety check failed, please report a bug."
|
||||
raise AssertionError(msg)
|
||||
res_native = grouped.transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
).to_frame(output_names[0])
|
||||
else:
|
||||
res_native = grouped[list(output_names)].transform(
|
||||
dask_function_name, **self._scalar_kwargs
|
||||
)
|
||||
result_frame = df._with_native(
|
||||
res_native.rename(columns=dict(zip(output_names, aliases)))
|
||||
).native
|
||||
return [result_frame[name] for name in aliases]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
return expr.astype(native_dtype)
|
||||
|
||||
return self._with_callable(func, "cast")
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.isfinite, "is_finite")
|
||||
|
||||
def log(self, base: float) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
def _log(expr: dx.Series) -> dx.Series:
|
||||
return da.log(expr) / da.log(base)
|
||||
|
||||
return self._with_callable(_log, "log")
|
||||
|
||||
def exp(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.exp, "exp")
|
||||
|
||||
def sqrt(self) -> Self:
|
||||
import dask.array as da
|
||||
|
||||
return self._with_callable(da.sqrt, "sqrt")
|
||||
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
||||
def func(expr: dx.Series) -> dx.Series:
|
||||
_name = expr.name
|
||||
result = expr.to_frame().mode()[_name]
|
||||
return result.head(1) if keep == "any" else result
|
||||
|
||||
return self._with_callable(func, "mode", scalar_kwargs={"keep": keep})
|
||||
|
||||
@property
|
||||
def str(self) -> DaskExprStringNamespace:
|
||||
return DaskExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> DaskExprDateTimeNamespace:
|
||||
return DaskExprDateTimeNamespace(self)
|
||||
|
||||
arg_max: not_implemented = not_implemented()
|
||||
arg_min: not_implemented = not_implemented()
|
||||
arg_true: not_implemented = not_implemented()
|
||||
ewm_mean: not_implemented = not_implemented()
|
||||
gather_every: not_implemented = not_implemented()
|
||||
head: not_implemented = not_implemented()
|
||||
map_batches: not_implemented = not_implemented()
|
||||
sample: not_implemented = not_implemented()
|
||||
rank: not_implemented = not_implemented()
|
||||
replace_strict: not_implemented = not_implemented()
|
||||
sort: not_implemented = not_implemented()
|
||||
tail: not_implemented = not_implemented()
|
||||
|
||||
# namespaces
|
||||
list: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
cat: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
struct: not_implemented = not_implemented() # type: ignore[assignment]
|
||||
175
lib/python3.11/site-packages/narwhals/_dask/expr_dt.py
Normal file
175
lib/python3.11/site-packages/narwhals/_dask/expr_dt.py
Normal file
@ -0,0 +1,175 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import DateTimeNamespace
|
||||
from narwhals._constants import MS_PER_SECOND, NS_PER_SECOND, US_PER_SECOND
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._pandas_like.utils import (
|
||||
ALIAS_DICT,
|
||||
calculate_timestamp_date,
|
||||
calculate_timestamp_datetime,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
|
||||
class DaskExprDateTimeNamespace(
|
||||
LazyExprNamespace["DaskExpr"], DateTimeNamespace["DaskExpr"]
|
||||
):
|
||||
def date(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.date, "date")
|
||||
|
||||
def year(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.year, "year")
|
||||
|
||||
def month(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.month, "month")
|
||||
|
||||
def day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.day, "day")
|
||||
|
||||
def hour(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.hour, "hour")
|
||||
|
||||
def minute(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.minute, "minute")
|
||||
|
||||
def second(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.second, "second")
|
||||
|
||||
def millisecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond // 1000, "millisecond"
|
||||
)
|
||||
|
||||
def microsecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond, "microsecond"
|
||||
)
|
||||
|
||||
def nanosecond(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.microsecond * 1000 + expr.dt.nanosecond, "nanosecond"
|
||||
)
|
||||
|
||||
def ordinal_day(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.dayofyear, "ordinal_day"
|
||||
)
|
||||
|
||||
def weekday(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.weekday + 1, # Dask is 0-6
|
||||
"weekday",
|
||||
)
|
||||
|
||||
def to_string(self, format: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, format: expr.dt.strftime(format.replace("%.f", ".%f")),
|
||||
"strftime",
|
||||
format=format,
|
||||
)
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, time_zone: expr.dt.tz_localize(None).dt.tz_localize(time_zone)
|
||||
if time_zone is not None
|
||||
else expr.dt.tz_localize(None),
|
||||
"tz_localize",
|
||||
time_zone=time_zone,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> DaskExpr:
|
||||
def func(s: dx.Series, time_zone: str) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
if dtype.time_zone is None: # type: ignore[attr-defined]
|
||||
return s.dt.tz_localize("UTC").dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
return s.dt.tz_convert(time_zone) # pyright: ignore[reportAttributeAccessIssue]
|
||||
|
||||
return self.compliant._with_callable(func, "tz_convert", time_zone=time_zone)
|
||||
|
||||
# ignoring coverage due to https://github.com/narwhals-dev/narwhals/issues/2808.
|
||||
def timestamp(self, time_unit: TimeUnit) -> DaskExpr: # pragma: no cover
|
||||
def func(s: dx.Series, time_unit: TimeUnit) -> dx.Series:
|
||||
dtype = native_to_narwhals_dtype(
|
||||
s.dtype, self.compliant._version, Implementation.DASK
|
||||
)
|
||||
is_pyarrow_dtype = "pyarrow" in str(dtype)
|
||||
mask_na = s.isna()
|
||||
dtypes = self.compliant._version.dtypes
|
||||
if dtype == dtypes.Date:
|
||||
# Date is only supported in pandas dtypes if pyarrow-backed
|
||||
s_cast = s.astype("Int32[pyarrow]")
|
||||
result = calculate_timestamp_date(s_cast, time_unit)
|
||||
elif isinstance(dtype, dtypes.Datetime):
|
||||
original_time_unit = dtype.time_unit
|
||||
s_cast = (
|
||||
s.astype("Int64[pyarrow]") if is_pyarrow_dtype else s.astype("int64")
|
||||
)
|
||||
result = calculate_timestamp_datetime(
|
||||
s_cast, original_time_unit, time_unit
|
||||
)
|
||||
else:
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
return result.where(~mask_na) # pyright: ignore[reportReturnType]
|
||||
|
||||
return self.compliant._with_callable(func, "datetime", time_unit=time_unit)
|
||||
|
||||
def total_minutes(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 60, "total_minutes"
|
||||
)
|
||||
|
||||
def total_seconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() // 1, "total_seconds"
|
||||
)
|
||||
|
||||
def total_milliseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * MS_PER_SECOND // 1,
|
||||
"total_milliseconds",
|
||||
)
|
||||
|
||||
def total_microseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * US_PER_SECOND // 1,
|
||||
"total_microseconds",
|
||||
)
|
||||
|
||||
def total_nanoseconds(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.dt.total_seconds() * NS_PER_SECOND // 1, "total_nanoseconds"
|
||||
)
|
||||
|
||||
def truncate(self, every: str) -> DaskExpr:
|
||||
interval = Interval.parse(every)
|
||||
unit = interval.unit
|
||||
if unit in {"mo", "q", "y"}:
|
||||
msg = f"Truncating to {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
freq = f"{interval.multiple}{ALIAS_DICT.get(unit, unit)}"
|
||||
return self.compliant._with_callable(lambda expr: expr.dt.floor(freq), "truncate")
|
||||
|
||||
def offset_by(self, by: str) -> DaskExpr:
|
||||
def func(s: dx.Series, by: str) -> dx.Series:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
unit = interval.unit
|
||||
if unit in {"y", "q", "mo", "d", "ns"}:
|
||||
msg = f"Offsetting by {unit} is not yet supported for dask."
|
||||
raise NotImplementedError(msg)
|
||||
offset = interval.to_timedelta()
|
||||
return s.add(offset)
|
||||
|
||||
return self.compliant._with_callable(func, "offset_by", by=by)
|
||||
121
lib/python3.11/site-packages/narwhals/_dask/expr_str.py
Normal file
121
lib/python3.11/site-packages/narwhals/_dask/expr_str.py
Normal file
@ -0,0 +1,121 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StringNamespace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
|
||||
class DaskExprStringNamespace(LazyExprNamespace["DaskExpr"], StringNamespace["DaskExpr"]):
|
||||
def len_chars(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.str.len(), "len")
|
||||
|
||||
def replace(self, pattern: str, value: str, *, literal: bool, n: int) -> DaskExpr:
|
||||
def _replace(
|
||||
expr: dx.Series, pattern: str, value: str, *, literal: bool, n: int
|
||||
) -> dx.Series:
|
||||
try:
|
||||
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
|
||||
pattern, value, regex=not literal, n=n
|
||||
)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "dask backed `Expr.str.replace` only supports str replacement values"
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
|
||||
return self.compliant._with_callable(
|
||||
_replace, "replace", pattern=pattern, value=value, literal=literal, n=n
|
||||
)
|
||||
|
||||
def replace_all(self, pattern: str, value: str, *, literal: bool) -> DaskExpr:
|
||||
def _replace_all(
|
||||
expr: dx.Series, pattern: str, value: str, *, literal: bool
|
||||
) -> dx.Series:
|
||||
try:
|
||||
return expr.str.replace( # pyright: ignore[reportAttributeAccessIssue]
|
||||
pattern, value, regex=not literal, n=-1
|
||||
)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = "dask backed `Expr.str.replace_all` only supports str replacement values."
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
|
||||
return self.compliant._with_callable(
|
||||
_replace_all, "replace", pattern=pattern, value=value, literal=literal
|
||||
)
|
||||
|
||||
def strip_chars(self, characters: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, characters: expr.str.strip(characters),
|
||||
"strip",
|
||||
characters=characters,
|
||||
)
|
||||
|
||||
def starts_with(self, prefix: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, prefix: expr.str.startswith(prefix), "starts_with", prefix=prefix
|
||||
)
|
||||
|
||||
def ends_with(self, suffix: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, suffix: expr.str.endswith(suffix), "ends_with", suffix=suffix
|
||||
)
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, pattern, literal: expr.str.contains(
|
||||
pat=pattern, regex=not literal
|
||||
),
|
||||
"contains",
|
||||
pattern=pattern,
|
||||
literal=literal,
|
||||
)
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, offset, length: expr.str.slice(
|
||||
start=offset, stop=offset + length if length else None
|
||||
),
|
||||
"slice",
|
||||
offset=offset,
|
||||
length=length,
|
||||
)
|
||||
|
||||
def split(self, by: str) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, by: expr.str.split(pat=by), "split", by=by
|
||||
)
|
||||
|
||||
def to_datetime(self, format: str | None) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, format: dd.to_datetime(expr, format=format),
|
||||
"to_datetime",
|
||||
format=format,
|
||||
)
|
||||
|
||||
def to_uppercase(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.str.upper(), "to_uppercase"
|
||||
)
|
||||
|
||||
def to_lowercase(self) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr: expr.str.lower(), "to_lowercase"
|
||||
)
|
||||
|
||||
def zfill(self, width: int) -> DaskExpr:
|
||||
return self.compliant._with_callable(
|
||||
lambda expr, width: expr.str.zfill(width), "zfill", width=width
|
||||
)
|
||||
|
||||
to_date = not_implemented()
|
||||
147
lib/python3.11/site-packages/narwhals/_dask/group_by.py
Normal file
147
lib/python3.11/site-packages/narwhals/_dask/group_by.py
Normal file
@ -0,0 +1,147 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar
|
||||
|
||||
import dask.dataframe as dd
|
||||
|
||||
from narwhals._compliant import DepthTrackingGroupBy
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from dask.dataframe.api import GroupBy as _DaskGroupBy
|
||||
from pandas.core.groupby import SeriesGroupBy as _PandasSeriesGroupBy
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import NarwhalsAggregation
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
PandasSeriesGroupBy: TypeAlias = _PandasSeriesGroupBy[Any, Any]
|
||||
_AggFn: TypeAlias = Callable[..., Any]
|
||||
|
||||
else:
|
||||
try:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import dask_expr as dx
|
||||
_DaskGroupBy = dx._groupby.GroupBy
|
||||
|
||||
Aggregation: TypeAlias = "str | _AggFn"
|
||||
"""The name of an aggregation function, or the function itself."""
|
||||
|
||||
|
||||
def n_unique() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.nunique(dropna=False)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.sum()
|
||||
|
||||
return dd.Aggregation(name="nunique", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def _all() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.all(skipna=True)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.all(skipna=True)
|
||||
|
||||
return dd.Aggregation(name="all", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def _any() -> dd.Aggregation:
|
||||
def chunk(s: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s.any(skipna=True)
|
||||
|
||||
def agg(s0: PandasSeriesGroupBy) -> pd.Series[Any]:
|
||||
return s0.any(skipna=True)
|
||||
|
||||
return dd.Aggregation(name="any", chunk=chunk, agg=agg)
|
||||
|
||||
|
||||
def var(ddof: int) -> _AggFn:
|
||||
return partial(_DaskGroupBy.var, ddof=ddof)
|
||||
|
||||
|
||||
def std(ddof: int) -> _AggFn:
|
||||
return partial(_DaskGroupBy.std, ddof=ddof)
|
||||
|
||||
|
||||
class DaskLazyGroupBy(DepthTrackingGroupBy["DaskLazyFrame", "DaskExpr", Aggregation]):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, Aggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"std": std,
|
||||
"var": var,
|
||||
"len": "size",
|
||||
"n_unique": n_unique,
|
||||
"count": "count",
|
||||
"quantile": "quantile",
|
||||
"all": _all,
|
||||
"any": _any,
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: DaskLazyFrame,
|
||||
keys: Sequence[DaskExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
|
||||
df, keys=keys
|
||||
)
|
||||
self._grouped = self.compliant.native.groupby(
|
||||
self._keys, dropna=drop_null_keys, observed=True
|
||||
)
|
||||
|
||||
def agg(self, *exprs: DaskExpr) -> DaskLazyFrame:
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
if not exprs:
|
||||
# No aggregation provided
|
||||
return (
|
||||
self.compliant.simple_select(*self._keys)
|
||||
.unique(self._keys, keep="any")
|
||||
.rename(dict(zip(self._keys, self._output_key_names)))
|
||||
)
|
||||
|
||||
self._ensure_all_simple(exprs)
|
||||
# This should be the fastpath, but cuDF is too far behind to use it.
|
||||
# - https://github.com/rapidsai/cudf/issues/15118
|
||||
# - https://github.com/rapidsai/cudf/issues/15084
|
||||
simple_aggregations: dict[str, tuple[str, Aggregation]] = {}
|
||||
exclude = (*self._keys, *self._output_key_names)
|
||||
for expr in exprs:
|
||||
output_names, aliases = evaluate_output_names_and_aliases(
|
||||
expr, self.compliant, exclude
|
||||
)
|
||||
if expr._depth == 0:
|
||||
# e.g. `agg(nw.len())`
|
||||
column = self._keys[0]
|
||||
agg_fn = self._remap_expr_name(expr._function_name)
|
||||
simple_aggregations.update(dict.fromkeys(aliases, (column, agg_fn)))
|
||||
continue
|
||||
|
||||
# e.g. `agg(nw.mean('a'))`
|
||||
agg_fn = self._remap_expr_name(self._leaf_name(expr))
|
||||
# deal with n_unique case in a "lazy" mode to not depend on dask globally
|
||||
agg_fn = agg_fn(**expr._scalar_kwargs) if callable(agg_fn) else agg_fn
|
||||
simple_aggregations.update(
|
||||
(alias, (output_name, agg_fn))
|
||||
for alias, output_name in zip_strict(aliases, output_names)
|
||||
)
|
||||
return DaskLazyFrame(
|
||||
self._grouped.agg(**simple_aggregations).reset_index(),
|
||||
version=self.compliant._version,
|
||||
).rename(dict(zip(self._keys, self._output_key_names)))
|
||||
338
lib/python3.11/site-packages/narwhals/_dask/namespace.py
Normal file
338
lib/python3.11/site-packages/narwhals/_dask/namespace.py
Normal file
@ -0,0 +1,338 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
import dask.dataframe as dd
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._compliant import (
|
||||
CompliantThen,
|
||||
CompliantWhen,
|
||||
DepthTrackingNamespace,
|
||||
LazyNamespace,
|
||||
)
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals._dask.selectors import DaskSelectorNamespace
|
||||
from narwhals._dask.utils import (
|
||||
align_series_full_broadcast,
|
||||
narwhals_to_native_dtype,
|
||||
validate_comparand,
|
||||
)
|
||||
from narwhals._expression_parsing import (
|
||||
ExprKind,
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._utils import Implementation, zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Sequence
|
||||
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class DaskNamespace(
|
||||
LazyNamespace[DaskLazyFrame, DaskExpr, dd.DataFrame],
|
||||
DepthTrackingNamespace[DaskLazyFrame, DaskExpr],
|
||||
):
|
||||
_implementation: Implementation = Implementation.DASK
|
||||
|
||||
@property
|
||||
def selectors(self) -> DaskSelectorNamespace:
|
||||
return DaskSelectorNamespace.from_namespace(self)
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[DaskExpr]:
|
||||
return DaskExpr
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[DaskLazyFrame]:
|
||||
return DaskLazyFrame
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
if dtype is not None:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
native_pd_series = pd.Series([value], dtype=native_dtype, name="literal")
|
||||
else:
|
||||
native_pd_series = pd.Series([value], name="literal")
|
||||
npartitions = df._native_frame.npartitions
|
||||
dask_series = dd.from_pandas(native_pd_series, npartitions=npartitions)
|
||||
return [dask_series[0].to_series()]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="lit",
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def len(self) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
# We don't allow dataframes with 0 columns, so `[0]` is safe.
|
||||
return [df._native_frame[df.columns[0]].size.to_series()]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="len",
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def all_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
|
||||
# Note on `ignore_nulls`: Dask doesn't support storing arbitrary Python
|
||||
# objects in `object` dtype, so we don't need the same check we have for pandas-like.
|
||||
if ignore_nulls:
|
||||
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
|
||||
series = (s if s.dtype == "bool" else s.fillna(True) for s in series)
|
||||
return [reduce(operator.and_, align_series_full_broadcast(df, *series))]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="all_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def any_horizontal(self, *exprs: DaskExpr, ignore_nulls: bool) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series: Iterator[dx.Series] = chain.from_iterable(e(df) for e in exprs)
|
||||
if ignore_nulls:
|
||||
series = (s if s.dtype == "bool" else s.fillna(False) for s in series)
|
||||
return [reduce(operator.or_, align_series_full_broadcast(df, *series))]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="any_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def sum_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
return [dd.concat(series, axis=1).sum(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="sum_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def concat(
|
||||
self, items: Iterable[DaskLazyFrame], *, how: ConcatMethod
|
||||
) -> DaskLazyFrame:
|
||||
if not items:
|
||||
msg = "No items to concatenate" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
dfs = [i._native_frame for i in items]
|
||||
cols_0 = dfs[0].columns
|
||||
if how == "vertical":
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.columns
|
||||
if not (
|
||||
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
|
||||
):
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0.to_list()}\n"
|
||||
f" - dataframe {i}: {cols_current.to_list()}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return DaskLazyFrame(
|
||||
dd.concat(dfs, axis=0, join="inner"), version=self._version
|
||||
)
|
||||
if how == "diagonal":
|
||||
return DaskLazyFrame(
|
||||
dd.concat(dfs, axis=0, join="outer"), version=self._version
|
||||
)
|
||||
|
||||
raise NotImplementedError
|
||||
|
||||
def mean_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
series = align_series_full_broadcast(df, *(s.fillna(0) for s in expr_results))
|
||||
non_na = align_series_full_broadcast(
|
||||
df, *(1 - s.isna() for s in expr_results)
|
||||
)
|
||||
num = reduce(lambda x, y: x + y, series) # pyright: ignore[reportOperatorIssue]
|
||||
den = reduce(lambda x, y: x + y, non_na) # pyright: ignore[reportOperatorIssue]
|
||||
return [cast("dx.Series", num / den)] # pyright: ignore[reportOperatorIssue]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="mean_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def min_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
|
||||
return [dd.concat(series, axis=1).min(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="min_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def max_horizontal(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
|
||||
return [dd.concat(series, axis=1).max(axis=1)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="max_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def when(self, predicate: DaskExpr) -> DaskWhen:
|
||||
return DaskWhen.from_expr(predicate, context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: DaskExpr, separator: str, ignore_nulls: bool
|
||||
) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
series = (
|
||||
s.astype(str) for s in align_series_full_broadcast(df, *expr_results)
|
||||
)
|
||||
null_mask = [s.isna() for s in align_series_full_broadcast(df, *expr_results)]
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = reduce(operator.or_, null_mask)
|
||||
result = reduce(lambda x, y: x + separator + y, series).where(
|
||||
~null_mask_result, None
|
||||
)
|
||||
else:
|
||||
init_value, *values = [
|
||||
s.where(~nm, "") for s, nm in zip_strict(series, null_mask)
|
||||
]
|
||||
|
||||
separators = (
|
||||
nm.map({True: "", False: separator}, meta=str)
|
||||
for nm in null_mask[:-1]
|
||||
)
|
||||
result = reduce(
|
||||
operator.add,
|
||||
(s + v for s, v in zip_strict(separators, values)),
|
||||
init_value,
|
||||
)
|
||||
|
||||
return [result]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="concat_str",
|
||||
evaluate_output_names=getattr(
|
||||
exprs[0], "_evaluate_output_names", lambda _df: ["literal"]
|
||||
),
|
||||
alias_output_names=getattr(exprs[0], "_alias_output_names", None),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def coalesce(self, *exprs: DaskExpr) -> DaskExpr:
|
||||
def func(df: DaskLazyFrame) -> list[dx.Series]:
|
||||
series = align_series_full_broadcast(
|
||||
df, *(s for _expr in exprs for s in _expr(df))
|
||||
)
|
||||
return [reduce(lambda x, y: x.fillna(y), series)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="coalesce",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
|
||||
class DaskWhen(CompliantWhen[DaskLazyFrame, "dx.Series", DaskExpr]): # pyright: ignore[reportInvalidTypeArguments]
|
||||
@property
|
||||
def _then(self) -> type[DaskThen]:
|
||||
return DaskThen
|
||||
|
||||
def __call__(self, df: DaskLazyFrame) -> Sequence[dx.Series]:
|
||||
then_value = (
|
||||
self._then_value(df)[0]
|
||||
if isinstance(self._then_value, DaskExpr)
|
||||
else self._then_value
|
||||
)
|
||||
otherwise_value = (
|
||||
self._otherwise_value(df)[0]
|
||||
if isinstance(self._otherwise_value, DaskExpr)
|
||||
else self._otherwise_value
|
||||
)
|
||||
|
||||
condition = self._condition(df)[0]
|
||||
# re-evaluate DataFrame if the condition aggregates to force
|
||||
# then/otherwise to be evaluated against the aggregated frame
|
||||
assert self._condition._metadata is not None # noqa: S101
|
||||
if self._condition._metadata.is_scalar_like:
|
||||
new_df = df._with_native(condition.to_frame())
|
||||
condition = self._condition.broadcast(ExprKind.AGGREGATION)(df)[0]
|
||||
df = new_df
|
||||
|
||||
if self._otherwise_value is None:
|
||||
(condition, then_series) = align_series_full_broadcast(
|
||||
df, condition, then_value
|
||||
)
|
||||
validate_comparand(condition, then_series)
|
||||
return [then_series.where(condition)] # pyright: ignore[reportArgumentType]
|
||||
(condition, then_series, otherwise_series) = align_series_full_broadcast(
|
||||
df, condition, then_value, otherwise_value
|
||||
)
|
||||
validate_comparand(condition, then_series)
|
||||
validate_comparand(condition, otherwise_series)
|
||||
return [then_series.where(condition, otherwise_series)] # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
class DaskThen(CompliantThen[DaskLazyFrame, "dx.Series", DaskExpr, DaskWhen], DaskExpr): # pyright: ignore[reportInvalidTypeArguments]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "whenthen"
|
||||
34
lib/python3.11/site-packages/narwhals/_dask/selectors.py
Normal file
34
lib/python3.11/site-packages/narwhals/_dask/selectors.py
Normal file
@ -0,0 +1,34 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import dask.dataframe.dask_expr as dx # noqa: F401
|
||||
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._dask.dataframe import DaskLazyFrame # noqa: F401
|
||||
|
||||
|
||||
class DaskSelectorNamespace(LazySelectorNamespace["DaskLazyFrame", "dx.Series"]): # pyright: ignore[reportInvalidTypeArguments]
|
||||
@property
|
||||
def _selector(self) -> type[DaskSelector]:
|
||||
return DaskSelector
|
||||
|
||||
|
||||
class DaskSelector(CompliantSelector["DaskLazyFrame", "dx.Series"], DaskExpr): # type: ignore[misc]
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "selector"
|
||||
|
||||
def _to_expr(self) -> DaskExpr:
|
||||
return DaskExpr(
|
||||
self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
139
lib/python3.11/site-packages/narwhals/_dask/utils.py
Normal file
139
lib/python3.11/site-packages/narwhals/_dask/utils.py
Normal file
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._pandas_like.utils import select_columns_by_name
|
||||
from narwhals._utils import Implementation, Version, isinstance_or_issubclass
|
||||
from narwhals.dependencies import get_pyarrow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
import dask.dataframe as dd
|
||||
import dask.dataframe.dask_expr as dx
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame, Incomplete
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
else:
|
||||
try:
|
||||
import dask.dataframe.dask_expr as dx
|
||||
except ModuleNotFoundError: # pragma: no cover
|
||||
import dask_expr as dx
|
||||
|
||||
|
||||
def maybe_evaluate_expr(df: DaskLazyFrame, obj: DaskExpr | object) -> dx.Series | object:
|
||||
from narwhals._dask.expr import DaskExpr
|
||||
|
||||
if isinstance(obj, DaskExpr):
|
||||
results = obj._call(df)
|
||||
assert len(results) == 1 # debug assertion # noqa: S101
|
||||
return results[0]
|
||||
return obj
|
||||
|
||||
|
||||
def evaluate_exprs(df: DaskLazyFrame, /, *exprs: DaskExpr) -> list[tuple[str, dx.Series]]:
|
||||
native_results: list[tuple[str, dx.Series]] = []
|
||||
for expr in exprs:
|
||||
native_series_list = expr(df)
|
||||
aliases = expr._evaluate_aliases(df)
|
||||
if len(aliases) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got aliases {aliases}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(aliases, native_series_list))
|
||||
return native_results
|
||||
|
||||
|
||||
def align_series_full_broadcast(
|
||||
df: DaskLazyFrame, *series: dx.Series | object
|
||||
) -> Sequence[dx.Series]:
|
||||
return [
|
||||
s if isinstance(s, dx.Series) else df._native_frame.assign(_tmp=s)["_tmp"]
|
||||
for s in series
|
||||
] # pyright: ignore[reportReturnType]
|
||||
|
||||
|
||||
def add_row_index(frame: dd.DataFrame, name: str) -> dd.DataFrame:
|
||||
original_cols = frame.columns
|
||||
df: Incomplete = frame.assign(**{name: 1})
|
||||
return select_columns_by_name(
|
||||
df.assign(**{name: df[name].cumsum(method="blelloch") - 1}),
|
||||
[name, *original_cols],
|
||||
Implementation.DASK,
|
||||
)
|
||||
|
||||
|
||||
def validate_comparand(lhs: dx.Series, rhs: dx.Series) -> None:
|
||||
if not dx.expr.are_co_aligned(lhs._expr, rhs._expr): # pragma: no cover
|
||||
# are_co_aligned is a method which cheaply checks if two Dask expressions
|
||||
# have the same index, and therefore don't require index alignment.
|
||||
# If someone only operates on a Dask DataFrame via expressions, then this
|
||||
# should always be the case: expression outputs (by definition) all come from the
|
||||
# same input dataframe, and Dask Series does not have any operations which
|
||||
# change the index. Nonetheless, we perform this safety check anyway.
|
||||
|
||||
# However, we still need to carefully vet which methods we support for Dask, to
|
||||
# avoid issues where `are_co_aligned` doesn't do what we want it to do:
|
||||
# https://github.com/dask/dask-expr/issues/1112.
|
||||
msg = "Objects are not co-aligned, so this operation is not supported for Dask backend"
|
||||
raise RuntimeError(msg)
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
dtypes_v1 = Version.V1.dtypes
|
||||
NW_TO_DASK_DTYPES: Mapping[type[DType], str] = {
|
||||
dtypes.Float64: "float64",
|
||||
dtypes.Float32: "float32",
|
||||
dtypes.Boolean: "bool",
|
||||
dtypes.Categorical: "category",
|
||||
dtypes.Date: "date32[day][pyarrow]",
|
||||
dtypes.Int8: "int8",
|
||||
dtypes.Int16: "int16",
|
||||
dtypes.Int32: "int32",
|
||||
dtypes.Int64: "int64",
|
||||
dtypes.UInt8: "uint8",
|
||||
dtypes.UInt16: "uint16",
|
||||
dtypes.UInt32: "uint32",
|
||||
dtypes.UInt64: "uint64",
|
||||
dtypes.Datetime: "datetime64[us]",
|
||||
dtypes.Duration: "timedelta64[ns]",
|
||||
dtypes_v1.Datetime: "datetime64[us]",
|
||||
dtypes_v1.Duration: "timedelta64[ns]",
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (
|
||||
dtypes.List,
|
||||
dtypes.Struct,
|
||||
dtypes.Array,
|
||||
dtypes.Time,
|
||||
dtypes.Binary,
|
||||
)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> Any:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if dask_type := NW_TO_DASK_DTYPES.get(base_type):
|
||||
return dask_type
|
||||
if isinstance_or_issubclass(dtype, dtypes.String):
|
||||
if Implementation.PANDAS._backend_version() >= (2, 0, 0):
|
||||
return "string[pyarrow]" if get_pyarrow() else "string[python]"
|
||||
return "object" # pragma: no cover
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
import pandas as pd
|
||||
|
||||
# NOTE: `pandas-stubs.core.dtypes.dtypes.CategoricalDtype.categories` is too narrow
|
||||
# Should be one of the `ListLike*` types
|
||||
# https://github.com/pandas-dev/pandas-stubs/blob/8434bde95460b996323cc8c0fea7b0a8bb00ea26/pandas-stubs/_typing.pyi#L497-L505
|
||||
return pd.CategoricalDtype(dtype.categories, ordered=True) # type: ignore[arg-type]
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES): # pragma: no cover
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Dask."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
542
lib/python3.11/site-packages/narwhals/_duckdb/dataframe.py
Normal file
542
lib/python3.11/site-packages/narwhals/_duckdb/dataframe.py
Normal file
@ -0,0 +1,542 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import reduce
|
||||
from operator import and_
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import duckdb
|
||||
from duckdb import StarExpression
|
||||
|
||||
from narwhals._duckdb.utils import (
|
||||
DeferredTimeZone,
|
||||
F,
|
||||
catch_duckdb_exception,
|
||||
col,
|
||||
evaluate_exprs,
|
||||
join_column_names,
|
||||
lit,
|
||||
native_to_narwhals_dtype,
|
||||
window_expression,
|
||||
)
|
||||
from narwhals._sql.dataframe import SQLLazyFrame
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
Version,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
requires,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.dependencies import get_duckdb
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from duckdb import Expression
|
||||
from duckdb.typing import DuckDBPyType
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals._duckdb.group_by import DuckDBGroupBy
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
from narwhals._duckdb.series import DuckDBInterchangeSeries
|
||||
from narwhals._typing import _EagerAllowedImpl
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.stable.v1 import DataFrame as DataFrameV1
|
||||
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
|
||||
class DuckDBLazyFrame(
|
||||
SQLLazyFrame[
|
||||
"DuckDBExpr",
|
||||
"duckdb.DuckDBPyRelation",
|
||||
"LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]",
|
||||
],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
_implementation = Implementation.DUCKDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: duckdb.DuckDBPyRelation,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: duckdb.DuckDBPyRelation = df
|
||||
self._version = version
|
||||
self._cached_native_schema: dict[str, DuckDBPyType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: duckdb.DuckDBPyRelation | Any) -> TypeIs[duckdb.DuckDBPyRelation]:
|
||||
return isinstance(obj, duckdb.DuckDBPyRelation)
|
||||
|
||||
@classmethod
|
||||
def from_native(
|
||||
cls, data: duckdb.DuckDBPyRelation, /, *, context: _LimitedContext
|
||||
) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def to_narwhals(
|
||||
self, *args: Any, **kwds: Any
|
||||
) -> LazyFrame[duckdb.DuckDBPyRelation] | DataFrameV1[duckdb.DuckDBPyRelation]:
|
||||
if self._version is Version.V1:
|
||||
from narwhals.stable.v1 import DataFrame as DataFrameV1
|
||||
|
||||
return DataFrameV1(self, level="interchange") # type: ignore[no-any-return]
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self: # pragma: no cover
|
||||
# Keep around for backcompat.
|
||||
if self._version is not Version.V1:
|
||||
msg = "__narwhals_dataframe__ is not implemented for DuckDBLazyFrame"
|
||||
raise AttributeError(msg)
|
||||
return self
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
return get_duckdb() # type: ignore[no-any-return]
|
||||
|
||||
def __narwhals_namespace__(self) -> DuckDBNamespace:
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
|
||||
return DuckDBNamespace(version=self._version)
|
||||
|
||||
def get_column(self, name: str) -> DuckDBInterchangeSeries:
|
||||
from narwhals._duckdb.series import DuckDBInterchangeSeries
|
||||
|
||||
return DuckDBInterchangeSeries(self.native.select(name), version=self._version)
|
||||
|
||||
def _iter_columns(self) -> Iterator[Expression]:
|
||||
for name in self.columns:
|
||||
yield col(name)
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is None or backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self.native.arrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.df(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
self.native.pl(), validate_backend_version=True, version=self._version
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.limit(n))
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: DuckDBExpr) -> Self:
|
||||
selection = [val.alias(name) for name, val in evaluate_exprs(self, *exprs)]
|
||||
try:
|
||||
return self._with_native(self.native.aggregate(selection)) # type: ignore[arg-type]
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
def select(self, *exprs: DuckDBExpr) -> Self:
|
||||
selection = (val.alias(name) for name, val in evaluate_exprs(self, *exprs))
|
||||
try:
|
||||
return self._with_native(self.native.select(*selection))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
selection = (name for name in self.columns if name not in columns_to_drop)
|
||||
return self._with_native(self.native.select(*selection))
|
||||
|
||||
def lazy(self, backend: None = None, **_: None) -> Self:
|
||||
# The `backend`` argument has no effect but we keep it here for
|
||||
# backwards compatibility because in `narwhals.stable.v1`
|
||||
# function `.from_native()` will return a DataFrame for DuckDB.
|
||||
|
||||
if backend is not None: # pragma: no cover
|
||||
msg = "`backend` argument is not supported for DuckDB"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
def with_columns(self, *exprs: DuckDBExpr) -> Self:
|
||||
new_columns_map = dict(evaluate_exprs(self, *exprs))
|
||||
result = [
|
||||
new_columns_map.pop(name).alias(name)
|
||||
if name in new_columns_map
|
||||
else col(name)
|
||||
for name in self.columns
|
||||
]
|
||||
result.extend(value.alias(name) for name, value in new_columns_map.items())
|
||||
try:
|
||||
return self._with_native(self.native.select(*result))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
def filter(self, predicate: DuckDBExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask = predicate(self)[0]
|
||||
try:
|
||||
return self._with_native(self.native.filter(mask))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_duckdb_exception(e, self) from None
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_native_schema is None:
|
||||
# Note: prefer `self._cached_native_schema` over `functools.cached_property`
|
||||
# due to Python3.13 failures.
|
||||
self._cached_native_schema = dict(zip(self.columns, self.native.types))
|
||||
|
||||
deferred_time_zone = DeferredTimeZone(self.native)
|
||||
return {
|
||||
column_name: native_to_narwhals_dtype(
|
||||
duckdb_dtype, self._version, deferred_time_zone
|
||||
)
|
||||
for column_name, duckdb_dtype in zip_strict(
|
||||
self.native.columns, self.native.types
|
||||
)
|
||||
}
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_native_schema is not None
|
||||
else self.native.columns
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
# only if version is v1, keep around for backcompat
|
||||
return self.native.df()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
# only if version is v1, keep around for backcompat
|
||||
return self.native.arrow()
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
def _with_native(self, df: duckdb.DuckDBPyRelation) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[DuckDBExpr], *, drop_null_keys: bool
|
||||
) -> DuckDBGroupBy:
|
||||
from narwhals._duckdb.group_by import DuckDBGroupBy
|
||||
|
||||
return DuckDBGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
df = self.native
|
||||
selection = (
|
||||
col(name).alias(mapping[name]) if name in mapping else col(name)
|
||||
for name in df.columns
|
||||
)
|
||||
return self._with_native(self.native.select(*selection))
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
native_how = "outer" if how == "full" else how
|
||||
|
||||
if native_how == "cross":
|
||||
if self._backend_version < (1, 1, 4):
|
||||
msg = f"'duckdb>=1.1.4' is required for cross-join, found version: {self._backend_version}"
|
||||
raise NotImplementedError(msg)
|
||||
rel = self.native.set_alias("lhs").cross(other.native.set_alias("rhs"))
|
||||
else:
|
||||
# help mypy
|
||||
assert left_on is not None # noqa: S101
|
||||
assert right_on is not None # noqa: S101
|
||||
it = (
|
||||
col(f'lhs."{left}"') == col(f'rhs."{right}"')
|
||||
for left, right in zip_strict(left_on, right_on)
|
||||
)
|
||||
condition: Expression = reduce(and_, it)
|
||||
rel = self.native.set_alias("lhs").join(
|
||||
other.native.set_alias("rhs"),
|
||||
# NOTE: Fixed in `--pre` https://github.com/duckdb/duckdb/pull/16933
|
||||
condition=condition, # type: ignore[arg-type, unused-ignore]
|
||||
how=native_how,
|
||||
)
|
||||
|
||||
if native_how in {"inner", "left", "cross", "outer"}:
|
||||
select = [col(f'lhs."{x}"') for x in self.columns]
|
||||
for name in other.columns:
|
||||
col_in_lhs: bool = name in self.columns
|
||||
if native_how == "outer" and not col_in_lhs:
|
||||
select.append(col(f'rhs."{name}"'))
|
||||
elif (native_how == "outer") or (
|
||||
col_in_lhs and (right_on is None or name not in right_on)
|
||||
):
|
||||
select.append(col(f'rhs."{name}"').alias(f"{name}{suffix}"))
|
||||
elif right_on is None or name not in right_on:
|
||||
select.append(col(name))
|
||||
res = rel.select(*select).set_alias(self.native.alias)
|
||||
else: # semi, anti
|
||||
res = rel.select("lhs.*").set_alias(self.native.alias)
|
||||
|
||||
return self._with_native(res)
|
||||
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
lhs = self.native
|
||||
rhs = other.native
|
||||
conditions: list[Expression] = []
|
||||
if by_left is not None and by_right is not None:
|
||||
conditions.extend(
|
||||
col(f'lhs."{left}"') == col(f'rhs."{right}"')
|
||||
for left, right in zip_strict(by_left, by_right)
|
||||
)
|
||||
else:
|
||||
by_left = by_right = []
|
||||
if strategy == "backward":
|
||||
conditions.append(col(f'lhs."{left_on}"') >= col(f'rhs."{right_on}"'))
|
||||
elif strategy == "forward":
|
||||
conditions.append(col(f'lhs."{left_on}"') <= col(f'rhs."{right_on}"'))
|
||||
else:
|
||||
msg = "Only 'backward' and 'forward' strategies are currently supported for DuckDB"
|
||||
raise NotImplementedError(msg)
|
||||
condition: Expression = reduce(and_, conditions)
|
||||
select = ["lhs.*"]
|
||||
for name in rhs.columns:
|
||||
if name in lhs.columns and (
|
||||
right_on is None or name not in {right_on, *by_right}
|
||||
):
|
||||
select.append(f'rhs."{name}" as "{name}{suffix}"')
|
||||
elif right_on is None or name not in {right_on, *by_right}:
|
||||
select.append(str(col(name)))
|
||||
# Replace with Python API call once
|
||||
# https://github.com/duckdb/duckdb/discussions/16947 is addressed.
|
||||
query = f"""
|
||||
SELECT {",".join(select)}
|
||||
FROM lhs
|
||||
ASOF LEFT JOIN rhs
|
||||
ON {condition}
|
||||
""" # noqa: S608
|
||||
return self._with_native(duckdb.sql(query))
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset_ := subset if keep == "any" else (subset or self.columns):
|
||||
# Sanitise input
|
||||
if error := self._check_columns_exist(subset_):
|
||||
raise error
|
||||
idx_name = generate_temporary_column_name(8, self.columns)
|
||||
count_name = generate_temporary_column_name(8, [*self.columns, idx_name])
|
||||
name = count_name if keep == "none" else idx_name
|
||||
idx_expr = window_expression(F("row_number"), subset_).alias(idx_name)
|
||||
count_expr = window_expression(
|
||||
F("count", StarExpression()), subset_, ()
|
||||
).alias(count_name)
|
||||
return self._with_native(
|
||||
self.native.select(StarExpression(), idx_expr, count_expr)
|
||||
.filter(col(name) == lit(1))
|
||||
.select(StarExpression(exclude=[count_name, idx_name]))
|
||||
)
|
||||
return self._with_native(self.native.unique(join_column_names(*self.columns)))
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
descending = [descending] * len(by)
|
||||
if nulls_last:
|
||||
it = (
|
||||
col(name).nulls_last() if not desc else col(name).desc().nulls_last()
|
||||
for name, desc in zip_strict(by, descending)
|
||||
)
|
||||
else:
|
||||
it = (
|
||||
col(name).nulls_first() if not desc else col(name).desc().nulls_first()
|
||||
for name, desc in zip_strict(by, descending)
|
||||
)
|
||||
return self._with_native(self.native.sort(*it))
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
_df = self.native
|
||||
by = list(by)
|
||||
if isinstance(reverse, bool):
|
||||
descending = [not reverse] * len(by)
|
||||
else:
|
||||
descending = [not rev for rev in reverse]
|
||||
expr = window_expression(
|
||||
F("row_number"),
|
||||
order_by=by,
|
||||
descending=descending,
|
||||
nulls_last=[True] * len(by),
|
||||
)
|
||||
condition = expr <= lit(k)
|
||||
query = f"""
|
||||
SELECT *
|
||||
FROM _df
|
||||
QUALIFY {condition}
|
||||
""" # noqa: S608
|
||||
return self._with_native(duckdb.sql(query))
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
subset_ = subset if subset is not None else self.columns
|
||||
keep_condition = reduce(and_, (col(name).isnotnull() for name in subset_))
|
||||
return self._with_native(self.native.filter(keep_condition))
|
||||
|
||||
def explode(self, columns: Sequence[str]) -> Self:
|
||||
dtypes = self._version.dtypes
|
||||
schema = self.collect_schema()
|
||||
for name in columns:
|
||||
dtype = schema[name]
|
||||
if dtype != dtypes.List:
|
||||
msg = (
|
||||
f"`explode` operation not supported for dtype `{dtype}`, "
|
||||
"expected List type"
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
if len(columns) != 1:
|
||||
msg = (
|
||||
"Exploding on multiple columns is not supported with DuckDB backend since "
|
||||
"we cannot guarantee that the exploded columns have matching element counts."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
col_to_explode = col(columns[0])
|
||||
rel = self.native
|
||||
original_columns = self.columns
|
||||
|
||||
not_null_condition = col_to_explode.isnotnull() & F("len", col_to_explode) > lit(
|
||||
0
|
||||
)
|
||||
non_null_rel = rel.filter(not_null_condition).select(
|
||||
*(
|
||||
F("unnest", col_to_explode).alias(name) if name in columns else name
|
||||
for name in original_columns
|
||||
)
|
||||
)
|
||||
|
||||
null_rel = rel.filter(~not_null_condition).select(
|
||||
*(
|
||||
lit(None).alias(name) if name in columns else name
|
||||
for name in original_columns
|
||||
)
|
||||
)
|
||||
|
||||
return self._with_native(non_null_rel.union(null_rel))
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
index_ = [] if index is None else index
|
||||
on_ = [c for c in self.columns if c not in index_] if on is None else on
|
||||
|
||||
if variable_name == "":
|
||||
msg = "`variable_name` cannot be empty string for duckdb backend."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if value_name == "":
|
||||
msg = "`value_name` cannot be empty string for duckdb backend."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
unpivot_on = join_column_names(*on_)
|
||||
rel = self.native # noqa: F841
|
||||
# Replace with Python API once
|
||||
# https://github.com/duckdb/duckdb/discussions/16980 is addressed.
|
||||
query = f"""
|
||||
unpivot rel
|
||||
on {unpivot_on}
|
||||
into
|
||||
name {col(variable_name)}
|
||||
value {col(value_name)}
|
||||
"""
|
||||
return self._with_native(
|
||||
duckdb.sql(query).select(*[*index_, variable_name, value_name])
|
||||
)
|
||||
|
||||
@requires.backend_version((1, 3))
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
|
||||
if order_by is None:
|
||||
msg = "Cannot pass `order_by` to `with_row_index` for DuckDB"
|
||||
raise TypeError(msg)
|
||||
expr = (window_expression(F("row_number"), order_by=order_by) - lit(1)).alias(
|
||||
name
|
||||
)
|
||||
return self._with_native(self.native.select(expr, StarExpression()))
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
df = self.native # noqa: F841
|
||||
query = f"""
|
||||
COPY (SELECT * FROM df)
|
||||
TO '{file}'
|
||||
(FORMAT parquet)
|
||||
""" # noqa: S608
|
||||
duckdb.sql(query)
|
||||
|
||||
gather_every = not_implemented.deprecated(
|
||||
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
|
||||
)
|
||||
tail = not_implemented.deprecated(
|
||||
"`LazyFrame.tail` is deprecated and will be removed in a future version."
|
||||
)
|
||||
303
lib/python3.11/site-packages/narwhals/_duckdb/expr.py
Normal file
303
lib/python3.11/site-packages/narwhals/_duckdb/expr.py
Normal file
@ -0,0 +1,303 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, cast
|
||||
|
||||
from duckdb import CoalesceOperator, StarExpression
|
||||
|
||||
from narwhals._duckdb.expr_dt import DuckDBExprDateTimeNamespace
|
||||
from narwhals._duckdb.expr_list import DuckDBExprListNamespace
|
||||
from narwhals._duckdb.expr_str import DuckDBExprStringNamespace
|
||||
from narwhals._duckdb.expr_struct import DuckDBExprStructNamespace
|
||||
from narwhals._duckdb.utils import (
|
||||
DeferredTimeZone,
|
||||
F,
|
||||
col,
|
||||
lit,
|
||||
narwhals_to_native_dtype,
|
||||
when,
|
||||
window_expression,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
from narwhals._utils import Implementation, Version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant import WindowInputs
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
WindowFunction,
|
||||
)
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.typing import (
|
||||
FillNullStrategy,
|
||||
IntoDType,
|
||||
NonNestedLiteral,
|
||||
RollingInterpolationMethod,
|
||||
)
|
||||
|
||||
DuckDBWindowFunction = WindowFunction[DuckDBLazyFrame, Expression]
|
||||
DuckDBWindowInputs = WindowInputs[Expression]
|
||||
|
||||
|
||||
class DuckDBExpr(SQLExpr["DuckDBLazyFrame", "Expression"]):
|
||||
_implementation = Implementation.DUCKDB
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[DuckDBLazyFrame, Expression],
|
||||
window_function: DuckDBWindowFunction | None = None,
|
||||
*,
|
||||
evaluate_output_names: EvalNames[DuckDBLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
implementation: Implementation = Implementation.DUCKDB,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._metadata: ExprMetadata | None = None
|
||||
self._window_function: DuckDBWindowFunction | None = window_function
|
||||
|
||||
def _count_star(self) -> Expression:
|
||||
return F("count", StarExpression())
|
||||
|
||||
def _window_expression(
|
||||
self,
|
||||
expr: Expression,
|
||||
partition_by: Sequence[str | Expression] = (),
|
||||
order_by: Sequence[str | Expression] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Expression:
|
||||
return window_expression(
|
||||
expr,
|
||||
partition_by,
|
||||
order_by,
|
||||
rows_start,
|
||||
rows_end,
|
||||
descending=descending,
|
||||
nulls_last=nulls_last,
|
||||
)
|
||||
|
||||
def __narwhals_namespace__(self) -> DuckDBNamespace: # pragma: no cover
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
|
||||
return DuckDBNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
if kind is ExprKind.LITERAL:
|
||||
return self
|
||||
if self._backend_version < (1, 3):
|
||||
msg = "At least version 1.3 of DuckDB is required for binary operations between aggregates and columns."
|
||||
raise NotImplementedError(msg)
|
||||
return self.over([lit(1)], [])
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls,
|
||||
evaluate_column_names: EvalNames[DuckDBLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
return [col(name) for name in evaluate_column_names(df)]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
columns = df.columns
|
||||
return [col(columns[i]) for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _alias_native(cls, expr: Expression, name: str) -> Expression:
|
||||
return expr.alias(name)
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
invert = cast("Callable[..., Expression]", operator.invert)
|
||||
return self._with_elementwise(invert)
|
||||
|
||||
def skew(self) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
count = F("count", expr)
|
||||
# Adjust population skewness by correction factor to get sample skewness
|
||||
sample_skewness = (
|
||||
F("skewness", expr)
|
||||
* (count - lit(2))
|
||||
/ F("sqrt", count * (count - lit(1)))
|
||||
)
|
||||
return when(count == lit(0), lit(None)).otherwise(
|
||||
when(count == lit(1), lit(float("nan"))).otherwise(
|
||||
when(count == lit(2), lit(0.0)).otherwise(sample_skewness)
|
||||
)
|
||||
)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(lambda expr: F("kurtosis_pop", expr))
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
if interpolation == "linear":
|
||||
return F("quantile_cont", expr, lit(quantile))
|
||||
msg = "Only linear interpolation methods are supported for DuckDB quantile."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
def func(expr: Expression) -> Expression:
|
||||
# https://stackoverflow.com/a/79338887/4451315
|
||||
return F("array_unique", F("array_agg", expr)) + F(
|
||||
"max", when(expr.isnotnull(), lit(0)).otherwise(lit(1))
|
||||
)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def len(self) -> Self:
|
||||
return self._with_callable(lambda _expr: F("count"))
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
if ddof == 0:
|
||||
return self._with_callable(lambda expr: F("stddev_pop", expr))
|
||||
if ddof == 1:
|
||||
return self._with_callable(lambda expr: F("stddev_samp", expr))
|
||||
|
||||
def _std(expr: Expression) -> Expression:
|
||||
n_samples = F("count", expr)
|
||||
return (
|
||||
F("stddev_pop", expr)
|
||||
* F("sqrt", n_samples)
|
||||
/ (F("sqrt", (n_samples - lit(ddof))))
|
||||
)
|
||||
|
||||
return self._with_callable(_std)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
if ddof == 0:
|
||||
return self._with_callable(lambda expr: F("var_pop", expr))
|
||||
if ddof == 1:
|
||||
return self._with_callable(lambda expr: F("var_samp", expr))
|
||||
|
||||
def _var(expr: Expression) -> Expression:
|
||||
n_samples = F("count", expr)
|
||||
return F("var_pop", expr) * n_samples / (n_samples - lit(ddof))
|
||||
|
||||
return self._with_callable(_var)
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(lambda expr: F("sum", expr.isnull().cast("int")))
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("isnan", expr))
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("isfinite", expr))
|
||||
|
||||
def is_in(self, other: Sequence[Any]) -> Self:
|
||||
return self._with_elementwise(lambda expr: F("contains", lit(other), expr))
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
if strategy is not None:
|
||||
if self._backend_version < (1, 3): # pragma: no cover
|
||||
msg = f"`fill_null` with `strategy={strategy}` is only available in 'duckdb>=1.3.0'."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _fill_with_strategy(
|
||||
df: DuckDBLazyFrame, inputs: DuckDBWindowInputs
|
||||
) -> Sequence[Expression]:
|
||||
fill_func = "last_value" if strategy == "forward" else "first_value"
|
||||
rows_start, rows_end = (
|
||||
(-limit if limit is not None else None, 0)
|
||||
if strategy == "forward"
|
||||
else (0, limit)
|
||||
)
|
||||
return [
|
||||
window_expression(
|
||||
F(fill_func, expr),
|
||||
inputs.partition_by,
|
||||
inputs.order_by,
|
||||
rows_start=rows_start,
|
||||
rows_end=rows_end,
|
||||
ignore_nulls=True,
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(_fill_with_strategy)
|
||||
|
||||
def _fill_constant(expr: Expression, value: Any) -> Expression:
|
||||
return CoalesceOperator(expr, value)
|
||||
|
||||
return self._with_elementwise(_fill_constant, value=value)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [expr.cast(native_dtype) for expr in self(df)]
|
||||
|
||||
def window_f(df: DuckDBLazyFrame, inputs: DuckDBWindowInputs) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [expr.cast(native_dtype) for expr in self.window_function(df, inputs)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
window_f,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
@property
|
||||
def str(self) -> DuckDBExprStringNamespace:
|
||||
return DuckDBExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> DuckDBExprDateTimeNamespace:
|
||||
return DuckDBExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> DuckDBExprListNamespace:
|
||||
return DuckDBExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> DuckDBExprStructNamespace:
|
||||
return DuckDBExprStructNamespace(self)
|
||||
132
lib/python3.11/site-packages/narwhals/_duckdb/expr_dt.py
Normal file
132
lib/python3.11/site-packages/narwhals/_duckdb/expr_dt.py
Normal file
@ -0,0 +1,132 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._constants import (
|
||||
MS_PER_MINUTE,
|
||||
MS_PER_SECOND,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_MINUTE,
|
||||
US_PER_MINUTE,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._duckdb.utils import UNITS_DICT, F, fetch_rel_time_zone, lit
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBExprDateTimeNamespace(SQLExprDateTimeNamesSpace["DuckDBExpr"]):
|
||||
def millisecond(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("millisecond", expr) - F("second", expr) * lit(MS_PER_SECOND)
|
||||
)
|
||||
|
||||
def microsecond(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("microsecond", expr) - F("second", expr) * lit(US_PER_SECOND)
|
||||
)
|
||||
|
||||
def nanosecond(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("nanosecond", expr) - F("second", expr) * lit(NS_PER_SECOND)
|
||||
)
|
||||
|
||||
def to_string(self, format: str) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("strftime", expr, lit(format))
|
||||
)
|
||||
|
||||
def weekday(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(lambda expr: F("isodow", expr))
|
||||
|
||||
def date(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(lambda expr: expr.cast("date"))
|
||||
|
||||
def total_minutes(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("datepart", lit("minute"), expr)
|
||||
)
|
||||
|
||||
def total_seconds(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: lit(SECONDS_PER_MINUTE) * F("datepart", lit("minute"), expr)
|
||||
+ F("datepart", lit("second"), expr)
|
||||
)
|
||||
|
||||
def total_milliseconds(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: lit(MS_PER_MINUTE) * F("datepart", lit("minute"), expr)
|
||||
+ F("datepart", lit("millisecond"), expr)
|
||||
)
|
||||
|
||||
def total_microseconds(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: lit(US_PER_MINUTE) * F("datepart", lit("minute"), expr)
|
||||
+ F("datepart", lit("microsecond"), expr)
|
||||
)
|
||||
|
||||
def truncate(self, every: str) -> DuckDBExpr:
|
||||
interval = Interval.parse(every)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if multiple != 1:
|
||||
# https://github.com/duckdb/duckdb/issues/17554
|
||||
msg = f"Only multiple 1 is currently supported for DuckDB.\nGot {multiple!s}."
|
||||
raise ValueError(msg)
|
||||
if unit == "ns":
|
||||
msg = "Truncating to nanoseconds is not yet supported for DuckDB."
|
||||
raise NotImplementedError(msg)
|
||||
format = lit(UNITS_DICT[unit])
|
||||
|
||||
def _truncate(expr: Expression) -> Expression:
|
||||
return F("date_trunc", format, expr)
|
||||
|
||||
return self.compliant._with_elementwise(_truncate)
|
||||
|
||||
def offset_by(self, by: str) -> DuckDBExpr:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
format = lit(f"{interval.multiple!s} {UNITS_DICT[interval.unit]}")
|
||||
|
||||
def _offset_by(expr: Expression) -> Expression:
|
||||
return F("date_add", format, expr)
|
||||
|
||||
return self.compliant._with_callable(_offset_by)
|
||||
|
||||
def _no_op_time_zone(self, time_zone: str) -> DuckDBExpr:
|
||||
def func(df: DuckDBLazyFrame) -> Sequence[Expression]:
|
||||
native_series_list = self.compliant(df)
|
||||
conn_time_zone = fetch_rel_time_zone(df.native)
|
||||
if conn_time_zone != time_zone:
|
||||
msg = (
|
||||
"DuckDB stores the time zone in the connection, rather than in the "
|
||||
f"data type, so changing the timezone to anything other than {conn_time_zone} "
|
||||
" (the current connection time zone) is not supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return native_series_list
|
||||
|
||||
return self.compliant.__class__(
|
||||
func,
|
||||
evaluate_output_names=self.compliant._evaluate_output_names,
|
||||
alias_output_names=self.compliant._alias_output_names,
|
||||
version=self.compliant._version,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> DuckDBExpr:
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> DuckDBExpr:
|
||||
if time_zone is None:
|
||||
return self.compliant._with_elementwise(lambda expr: expr.cast("timestamp"))
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
total_nanoseconds = not_implemented()
|
||||
timestamp = not_implemented()
|
||||
40
lib/python3.11/site-packages/narwhals/_duckdb/expr_list.py
Normal file
40
lib/python3.11/site-packages/narwhals/_duckdb/expr_list.py
Normal file
@ -0,0 +1,40 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import ListNamespace
|
||||
from narwhals._duckdb.utils import F, lit, when
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from duckdb import Expression
|
||||
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
class DuckDBExprListNamespace(
|
||||
LazyExprNamespace["DuckDBExpr"], ListNamespace["DuckDBExpr"]
|
||||
):
|
||||
def len(self) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(lambda expr: F("len", expr))
|
||||
|
||||
def unique(self) -> DuckDBExpr:
|
||||
def func(expr: Expression) -> Expression:
|
||||
expr_distinct = F("list_distinct", expr)
|
||||
return when(
|
||||
F("array_position", expr, lit(None)).isnotnull(),
|
||||
F("list_append", expr_distinct, lit(None)),
|
||||
).otherwise(expr_distinct)
|
||||
|
||||
return self.compliant._with_callable(func)
|
||||
|
||||
def contains(self, item: NonNestedLiteral) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("list_contains", expr, lit(item))
|
||||
)
|
||||
|
||||
def get(self, index: int) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("list_extract", expr, lit(index + 1))
|
||||
)
|
||||
30
lib/python3.11/site-packages/narwhals/_duckdb/expr_str.py
Normal file
30
lib/python3.11/site-packages/narwhals/_duckdb/expr_str.py
Normal file
@ -0,0 +1,30 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._duckdb.utils import F, lit
|
||||
from narwhals._sql.expr_str import SQLExprStringNamespace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBExprStringNamespace(SQLExprStringNamespace["DuckDBExpr"]):
|
||||
def to_datetime(self, format: str | None) -> DuckDBExpr:
|
||||
if format is None:
|
||||
msg = "Cannot infer format with DuckDB backend, please specify `format` explicitly."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("strptime", expr, lit(format))
|
||||
)
|
||||
|
||||
def to_date(self, format: str | None) -> DuckDBExpr:
|
||||
if format is not None:
|
||||
return self.to_datetime(format=format).dt.date()
|
||||
|
||||
compliant_expr = self.compliant
|
||||
return compliant_expr.cast(compliant_expr._version.dtypes.Date())
|
||||
|
||||
replace = not_implemented()
|
||||
19
lib/python3.11/site-packages/narwhals/_duckdb/expr_struct.py
Normal file
19
lib/python3.11/site-packages/narwhals/_duckdb/expr_struct.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StructNamespace
|
||||
from narwhals._duckdb.utils import F, lit
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBExprStructNamespace(
|
||||
LazyExprNamespace["DuckDBExpr"], StructNamespace["DuckDBExpr"]
|
||||
):
|
||||
def field(self, name: str) -> DuckDBExpr:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F("struct_extract", expr, lit(name))
|
||||
).alias(name)
|
||||
33
lib/python3.11/site-packages/narwhals/_duckdb/group_by.py
Normal file
33
lib/python3.11/site-packages/narwhals/_duckdb/group_by.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._sql.group_by import SQLGroupBy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression # noqa: F401
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
|
||||
class DuckDBGroupBy(SQLGroupBy["DuckDBLazyFrame", "DuckDBExpr", "Expression"]):
|
||||
def __init__(
|
||||
self,
|
||||
df: DuckDBLazyFrame,
|
||||
keys: Sequence[DuckDBExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
|
||||
def agg(self, *exprs: DuckDBExpr) -> DuckDBLazyFrame:
|
||||
agg_columns = list(chain(self._keys, self._evaluate_exprs(exprs)))
|
||||
return self.compliant._with_native(
|
||||
self.compliant.native.aggregate(agg_columns) # type: ignore[arg-type]
|
||||
).rename(dict(zip(self._keys, self._output_key_names)))
|
||||
164
lib/python3.11/site-packages/narwhals/_duckdb/namespace.py
Normal file
164
lib/python3.11/site-packages/narwhals/_duckdb/namespace.py
Normal file
@ -0,0 +1,164 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from duckdb import CoalesceOperator, Expression
|
||||
from duckdb.typing import BIGINT, VARCHAR
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals._duckdb.selectors import DuckDBSelectorNamespace
|
||||
from narwhals._duckdb.utils import (
|
||||
DeferredTimeZone,
|
||||
F,
|
||||
concat_str,
|
||||
function,
|
||||
lit,
|
||||
narwhals_to_native_dtype,
|
||||
when,
|
||||
)
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._sql.namespace import SQLNamespace
|
||||
from narwhals._sql.when_then import SQLThen, SQLWhen
|
||||
from narwhals._utils import Implementation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
|
||||
from duckdb import DuckDBPyRelation # noqa: F401
|
||||
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import ConcatMethod, IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
class DuckDBNamespace(
|
||||
SQLNamespace[DuckDBLazyFrame, DuckDBExpr, "DuckDBPyRelation", Expression]
|
||||
):
|
||||
_implementation: Implementation = Implementation.DUCKDB
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def selectors(self) -> DuckDBSelectorNamespace:
|
||||
return DuckDBSelectorNamespace.from_namespace(self)
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[DuckDBExpr]:
|
||||
return DuckDBExpr
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[DuckDBLazyFrame]:
|
||||
return DuckDBLazyFrame
|
||||
|
||||
def _function(self, name: str, *args: Expression) -> Expression: # type: ignore[override]
|
||||
return function(name, *args)
|
||||
|
||||
def _lit(self, value: Any) -> Expression:
|
||||
return lit(value)
|
||||
|
||||
def _when(
|
||||
self,
|
||||
condition: Expression,
|
||||
value: Expression,
|
||||
otherwise: Expression | None = None,
|
||||
) -> Expression:
|
||||
if otherwise is None:
|
||||
return when(condition, value)
|
||||
return when(condition, value).otherwise(otherwise)
|
||||
|
||||
def _coalesce(self, *exprs: Expression) -> Expression:
|
||||
return CoalesceOperator(*exprs)
|
||||
|
||||
def concat(
|
||||
self, items: Iterable[DuckDBLazyFrame], *, how: ConcatMethod
|
||||
) -> DuckDBLazyFrame:
|
||||
native_items = [item._native_frame for item in items]
|
||||
items = list(items)
|
||||
first = items[0]
|
||||
schema = first.schema
|
||||
if how == "vertical" and not all(x.schema == schema for x in items[1:]):
|
||||
msg = "inputs should all have the same schema"
|
||||
raise TypeError(msg)
|
||||
res = reduce(lambda x, y: x.union(y), native_items)
|
||||
return first._with_native(res)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: DuckDBExpr, separator: str, ignore_nulls: bool
|
||||
) -> DuckDBExpr:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
cols = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
if not ignore_nulls:
|
||||
null_mask_result = reduce(operator.or_, (s.isnull() for s in cols))
|
||||
cols_separated = [
|
||||
y
|
||||
for x in [
|
||||
(col.cast(VARCHAR),)
|
||||
if i == len(cols) - 1
|
||||
else (col.cast(VARCHAR), lit(separator))
|
||||
for i, col in enumerate(cols)
|
||||
]
|
||||
for y in x
|
||||
]
|
||||
return [when(~null_mask_result, concat_str(*cols_separated))]
|
||||
return [concat_str(*cols, separator=separator)]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: DuckDBExpr) -> DuckDBExpr:
|
||||
def func(cols: Iterable[Expression]) -> Expression:
|
||||
cols = list(cols)
|
||||
return reduce(
|
||||
operator.add, (CoalesceOperator(col, lit(0)) for col in cols)
|
||||
) / reduce(operator.add, (col.isnotnull().cast(BIGINT) for col in cols))
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
def when(self, predicate: DuckDBExpr) -> DuckDBWhen:
|
||||
return DuckDBWhen.from_expr(predicate, context=self)
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> DuckDBExpr:
|
||||
def func(df: DuckDBLazyFrame) -> list[Expression]:
|
||||
tz = DeferredTimeZone(df.native)
|
||||
if dtype is not None:
|
||||
target = narwhals_to_native_dtype(dtype, self._version, tz)
|
||||
return [lit(value).cast(target)]
|
||||
return [lit(value)]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def len(self) -> DuckDBExpr:
|
||||
def func(_df: DuckDBLazyFrame) -> list[Expression]:
|
||||
return [F("count")]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
|
||||
class DuckDBWhen(SQLWhen["DuckDBLazyFrame", Expression, DuckDBExpr]):
|
||||
@property
|
||||
def _then(self) -> type[DuckDBThen]:
|
||||
return DuckDBThen
|
||||
|
||||
|
||||
class DuckDBThen(SQLThen["DuckDBLazyFrame", Expression, DuckDBExpr], DuckDBExpr): ...
|
||||
33
lib/python3.11/site-packages/narwhals/_duckdb/selectors.py
Normal file
33
lib/python3.11/site-packages/narwhals/_duckdb/selectors.py
Normal file
@ -0,0 +1,33 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from duckdb import Expression # noqa: F401
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame # noqa: F401
|
||||
from narwhals._duckdb.expr import DuckDBWindowFunction
|
||||
|
||||
|
||||
class DuckDBSelectorNamespace(LazySelectorNamespace["DuckDBLazyFrame", "Expression"]):
|
||||
@property
|
||||
def _selector(self) -> type[DuckDBSelector]:
|
||||
return DuckDBSelector
|
||||
|
||||
|
||||
class DuckDBSelector( # type: ignore[misc]
|
||||
CompliantSelector["DuckDBLazyFrame", "Expression"], DuckDBExpr
|
||||
):
|
||||
_window_function: DuckDBWindowFunction | None = None
|
||||
|
||||
def _to_expr(self) -> DuckDBExpr:
|
||||
return DuckDBExpr(
|
||||
self._call,
|
||||
self._window_function,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
44
lib/python3.11/site-packages/narwhals/_duckdb/series.py
Normal file
44
lib/python3.11/site-packages/narwhals/_duckdb/series.py
Normal file
@ -0,0 +1,44 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._duckdb.utils import DeferredTimeZone, native_to_narwhals_dtype
|
||||
from narwhals.dependencies import get_duckdb
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
import duckdb
|
||||
from typing_extensions import Never, Self
|
||||
|
||||
from narwhals._utils import Version
|
||||
from narwhals.dtypes import DType
|
||||
|
||||
|
||||
class DuckDBInterchangeSeries:
|
||||
def __init__(self, df: duckdb.DuckDBPyRelation, version: Version) -> None:
|
||||
self._native_series = df
|
||||
self._version = version
|
||||
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
return get_duckdb() # type: ignore[no-any-return]
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return native_to_narwhals_dtype(
|
||||
self._native_series.types[0],
|
||||
self._version,
|
||||
DeferredTimeZone(self._native_series),
|
||||
)
|
||||
|
||||
def __getattr__(self, attr: str) -> Never:
|
||||
msg = ( # pragma: no cover
|
||||
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
|
||||
"If you would like to see this kind of object better supported in "
|
||||
"Narwhals, please open a feature request "
|
||||
"at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg) # pragma: no cover
|
||||
18
lib/python3.11/site-packages/narwhals/_duckdb/typing.py
Normal file
18
lib/python3.11/site-packages/narwhals/_duckdb/typing.py
Normal file
@ -0,0 +1,18 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, TypedDict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from duckdb import Expression
|
||||
|
||||
|
||||
class WindowExpressionKwargs(TypedDict, total=False):
|
||||
partition_by: Sequence[str | Expression]
|
||||
order_by: Sequence[str | Expression]
|
||||
rows_start: int | None
|
||||
rows_end: int | None
|
||||
descending: Sequence[bool]
|
||||
nulls_last: Sequence[bool]
|
||||
ignore_nulls: bool
|
||||
370
lib/python3.11/site-packages/narwhals/_duckdb/utils.py
Normal file
370
lib/python3.11/site-packages/narwhals/_duckdb/utils.py
Normal file
@ -0,0 +1,370 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import duckdb
|
||||
import duckdb.typing as duckdb_dtypes
|
||||
from duckdb.typing import DuckDBPyType
|
||||
|
||||
from narwhals._utils import Version, isinstance_or_issubclass, zip_strict
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from duckdb import DuckDBPyRelation, Expression
|
||||
|
||||
from narwhals._compliant.typing import CompliantLazyFrameAny
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
from narwhals._duckdb.expr import DuckDBExpr
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType, TimeUnit
|
||||
|
||||
|
||||
UNITS_DICT = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
DESCENDING_TO_ORDER = {True: "desc", False: "asc"}
|
||||
NULLS_LAST_TO_NULLS_POS = {True: "nulls last", False: "nulls first"}
|
||||
|
||||
col = duckdb.ColumnExpression
|
||||
"""Alias for `duckdb.ColumnExpression`."""
|
||||
|
||||
lit = duckdb.ConstantExpression
|
||||
"""Alias for `duckdb.ConstantExpression`."""
|
||||
|
||||
when = duckdb.CaseExpression
|
||||
"""Alias for `duckdb.CaseExpression`."""
|
||||
|
||||
F = duckdb.FunctionExpression
|
||||
"""Alias for `duckdb.FunctionExpression`."""
|
||||
|
||||
|
||||
def concat_str(*exprs: Expression, separator: str = "") -> Expression:
|
||||
"""Concatenate many strings, NULL inputs are skipped.
|
||||
|
||||
Wraps [concat] and [concat_ws] `FunctionExpression`(s).
|
||||
|
||||
Arguments:
|
||||
exprs: Native columns.
|
||||
separator: String that will be used to separate the values of each column.
|
||||
|
||||
Returns:
|
||||
A new native expression.
|
||||
|
||||
[concat]: https://duckdb.org/docs/stable/sql/functions/char.html#concatstring-
|
||||
[concat_ws]: https://duckdb.org/docs/stable/sql/functions/char.html#concat_wsseparator-string-
|
||||
"""
|
||||
return F("concat_ws", lit(separator), *exprs) if separator else F("concat", *exprs)
|
||||
|
||||
|
||||
def evaluate_exprs(
|
||||
df: DuckDBLazyFrame, /, *exprs: DuckDBExpr
|
||||
) -> list[tuple[str, Expression]]:
|
||||
native_results: list[tuple[str, Expression]] = []
|
||||
for expr in exprs:
|
||||
native_series_list = expr._call(df)
|
||||
output_names = expr._evaluate_output_names(df)
|
||||
if expr._alias_output_names is not None:
|
||||
output_names = expr._alias_output_names(output_names)
|
||||
if len(output_names) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(output_names, native_series_list))
|
||||
return native_results
|
||||
|
||||
|
||||
class DeferredTimeZone:
|
||||
"""Object which gets passed between `native_to_narwhals_dtype` calls.
|
||||
|
||||
DuckDB stores the time zone in the connection, rather than in the dtypes, so
|
||||
this ensures that when calculating the schema of a dataframe with multiple
|
||||
timezone-aware columns, that the connection's time zone is only fetched once.
|
||||
|
||||
Note: we cannot make the time zone a cached `DuckDBLazyFrame` property because
|
||||
the time zone can be modified after `DuckDBLazyFrame` creation:
|
||||
|
||||
```python
|
||||
df = nw.from_native(rel)
|
||||
print(df.collect_schema())
|
||||
rel.query("set timezone = 'Asia/Kolkata'")
|
||||
print(df.collect_schema()) # should change to reflect new time zone
|
||||
```
|
||||
"""
|
||||
|
||||
_cached_time_zone: str | None = None
|
||||
|
||||
def __init__(self, rel: DuckDBPyRelation) -> None:
|
||||
self._rel = rel
|
||||
|
||||
@property
|
||||
def time_zone(self) -> str:
|
||||
"""Fetch relation time zone (if it wasn't calculated already)."""
|
||||
if self._cached_time_zone is None:
|
||||
self._cached_time_zone = fetch_rel_time_zone(self._rel)
|
||||
return self._cached_time_zone
|
||||
|
||||
|
||||
def native_to_narwhals_dtype(
|
||||
duckdb_dtype: DuckDBPyType, version: Version, deferred_time_zone: DeferredTimeZone
|
||||
) -> DType:
|
||||
duckdb_dtype_id = duckdb_dtype.id
|
||||
dtypes = version.dtypes
|
||||
|
||||
# Handle nested data types first
|
||||
if duckdb_dtype_id == "list":
|
||||
return dtypes.List(
|
||||
native_to_narwhals_dtype(duckdb_dtype.child, version, deferred_time_zone)
|
||||
)
|
||||
|
||||
if duckdb_dtype_id == "struct":
|
||||
children = duckdb_dtype.children
|
||||
return dtypes.Struct(
|
||||
[
|
||||
dtypes.Field(
|
||||
name=child[0],
|
||||
dtype=native_to_narwhals_dtype(child[1], version, deferred_time_zone),
|
||||
)
|
||||
for child in children
|
||||
]
|
||||
)
|
||||
|
||||
if duckdb_dtype_id == "array":
|
||||
child, size = duckdb_dtype.children
|
||||
shape: list[int] = [size[1]]
|
||||
|
||||
while child[1].id == "array":
|
||||
child, size = child[1].children
|
||||
shape.insert(0, size[1])
|
||||
|
||||
inner = native_to_narwhals_dtype(child[1], version, deferred_time_zone)
|
||||
return dtypes.Array(inner=inner, shape=tuple(shape))
|
||||
|
||||
if duckdb_dtype_id == "enum":
|
||||
if version is Version.V1:
|
||||
return dtypes.Enum() # type: ignore[call-arg]
|
||||
categories = duckdb_dtype.children[0][1]
|
||||
return dtypes.Enum(categories=categories)
|
||||
|
||||
if duckdb_dtype_id == "timestamp with time zone":
|
||||
return dtypes.Datetime(time_zone=deferred_time_zone.time_zone)
|
||||
|
||||
return _non_nested_native_to_narwhals_dtype(duckdb_dtype_id, version)
|
||||
|
||||
|
||||
def fetch_rel_time_zone(rel: duckdb.DuckDBPyRelation) -> str:
|
||||
result = rel.query(
|
||||
"duckdb_settings()", "select value from duckdb_settings() where name = 'TimeZone'"
|
||||
).fetchone()
|
||||
assert result is not None # noqa: S101
|
||||
return result[0] # type: ignore[no-any-return]
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def _non_nested_native_to_narwhals_dtype(duckdb_dtype_id: str, version: Version) -> DType:
|
||||
dtypes = version.dtypes
|
||||
return {
|
||||
"hugeint": dtypes.Int128(),
|
||||
"bigint": dtypes.Int64(),
|
||||
"integer": dtypes.Int32(),
|
||||
"smallint": dtypes.Int16(),
|
||||
"tinyint": dtypes.Int8(),
|
||||
"uhugeint": dtypes.UInt128(),
|
||||
"ubigint": dtypes.UInt64(),
|
||||
"uinteger": dtypes.UInt32(),
|
||||
"usmallint": dtypes.UInt16(),
|
||||
"utinyint": dtypes.UInt8(),
|
||||
"double": dtypes.Float64(),
|
||||
"float": dtypes.Float32(),
|
||||
"varchar": dtypes.String(),
|
||||
"date": dtypes.Date(),
|
||||
"timestamp_s": dtypes.Datetime("s"),
|
||||
"timestamp_ms": dtypes.Datetime("ms"),
|
||||
"timestamp": dtypes.Datetime(),
|
||||
"timestamp_ns": dtypes.Datetime("ns"),
|
||||
"boolean": dtypes.Boolean(),
|
||||
"interval": dtypes.Duration(),
|
||||
"decimal": dtypes.Decimal(),
|
||||
"time": dtypes.Time(),
|
||||
"blob": dtypes.Binary(),
|
||||
}.get(duckdb_dtype_id, dtypes.Unknown())
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_DUCKDB_DTYPES: Mapping[type[DType], DuckDBPyType] = {
|
||||
dtypes.Float64: duckdb_dtypes.DOUBLE,
|
||||
dtypes.Float32: duckdb_dtypes.FLOAT,
|
||||
dtypes.Binary: duckdb_dtypes.BLOB,
|
||||
dtypes.String: duckdb_dtypes.VARCHAR,
|
||||
dtypes.Boolean: duckdb_dtypes.BOOLEAN,
|
||||
dtypes.Date: duckdb_dtypes.DATE,
|
||||
dtypes.Time: duckdb_dtypes.TIME,
|
||||
dtypes.Int8: duckdb_dtypes.TINYINT,
|
||||
dtypes.Int16: duckdb_dtypes.SMALLINT,
|
||||
dtypes.Int32: duckdb_dtypes.INTEGER,
|
||||
dtypes.Int64: duckdb_dtypes.BIGINT,
|
||||
dtypes.Int128: DuckDBPyType("INT128"),
|
||||
dtypes.UInt8: duckdb_dtypes.UTINYINT,
|
||||
dtypes.UInt16: duckdb_dtypes.USMALLINT,
|
||||
dtypes.UInt32: duckdb_dtypes.UINTEGER,
|
||||
dtypes.UInt64: duckdb_dtypes.UBIGINT,
|
||||
dtypes.UInt128: DuckDBPyType("UINT128"),
|
||||
}
|
||||
TIME_UNIT_TO_TIMESTAMP: Mapping[TimeUnit, DuckDBPyType] = {
|
||||
"s": duckdb_dtypes.TIMESTAMP_S,
|
||||
"ms": duckdb_dtypes.TIMESTAMP_MS,
|
||||
"us": duckdb_dtypes.TIMESTAMP,
|
||||
"ns": duckdb_dtypes.TIMESTAMP_NS,
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal, dtypes.Categorical)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: PLR0912, C901
|
||||
dtype: IntoDType, version: Version, deferred_time_zone: DeferredTimeZone
|
||||
) -> DuckDBPyType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if duckdb_type := NW_TO_DUCKDB_DTYPES.get(base_type):
|
||||
return duckdb_type
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
return DuckDBPyType(f"ENUM{dtype.categories!r}")
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
tu = dtype.time_unit
|
||||
tz = dtype.time_zone
|
||||
if not tz:
|
||||
return TIME_UNIT_TO_TIMESTAMP[tu]
|
||||
if tu != "us":
|
||||
msg = f"Only microsecond precision is supported for timezone-aware `Datetime` in DuckDB, got {tu} precision"
|
||||
raise ValueError(msg)
|
||||
if tz != (rel_tz := deferred_time_zone.time_zone): # pragma: no cover
|
||||
msg = f"Only the connection time zone {rel_tz} is supported, got: {tz}."
|
||||
raise ValueError(msg)
|
||||
# TODO(unassigned): cover once https://github.com/narwhals-dev/narwhals/issues/2742 addressed
|
||||
return duckdb_dtypes.TIMESTAMP_TZ # pragma: no cover
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
if (tu := dtype.time_unit) != "us": # pragma: no cover
|
||||
msg = f"Only microsecond-precision Duration is supported, got {tu} precision"
|
||||
return duckdb_dtypes.INTERVAL
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
inner = narwhals_to_native_dtype(dtype.inner, version, deferred_time_zone)
|
||||
return duckdb.list_type(inner)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
fields = {
|
||||
field.name: narwhals_to_native_dtype(field.dtype, version, deferred_time_zone)
|
||||
for field in dtype.fields
|
||||
}
|
||||
return duckdb.struct_type(fields)
|
||||
if isinstance(dtype, dtypes.Array):
|
||||
nw_inner: IntoDType = dtype
|
||||
while isinstance(nw_inner, dtypes.Array):
|
||||
nw_inner = nw_inner.inner
|
||||
duckdb_inner = narwhals_to_native_dtype(nw_inner, version, deferred_time_zone)
|
||||
duckdb_shape_fmt = "".join(f"[{item}]" for item in dtype.shape)
|
||||
return DuckDBPyType(f"{duckdb_inner}{duckdb_shape_fmt}")
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for DuckDB."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def parse_into_expression(into_expression: str | Expression) -> Expression:
|
||||
return col(into_expression) if isinstance(into_expression, str) else into_expression
|
||||
|
||||
|
||||
def generate_partition_by_sql(*partition_by: str | Expression) -> str:
|
||||
if not partition_by:
|
||||
return ""
|
||||
by_sql = ", ".join([f"{parse_into_expression(x)}" for x in partition_by])
|
||||
return f"partition by {by_sql}"
|
||||
|
||||
|
||||
def join_column_names(*names: str) -> str:
|
||||
return ", ".join(str(col(name)) for name in names)
|
||||
|
||||
|
||||
def generate_order_by_sql(
|
||||
*order_by: str | Expression, descending: Sequence[bool], nulls_last: Sequence[bool]
|
||||
) -> str:
|
||||
if not order_by:
|
||||
return ""
|
||||
by_sql = ",".join(
|
||||
f"{parse_into_expression(x)} {DESCENDING_TO_ORDER[_descending]} {NULLS_LAST_TO_NULLS_POS[_nulls_last]}"
|
||||
for x, _descending, _nulls_last in zip_strict(order_by, descending, nulls_last)
|
||||
)
|
||||
return f"order by {by_sql}"
|
||||
|
||||
|
||||
def window_expression(
|
||||
expr: Expression,
|
||||
partition_by: Sequence[str | Expression] = (),
|
||||
order_by: Sequence[str | Expression] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
ignore_nulls: bool = False,
|
||||
) -> Expression:
|
||||
# TODO(unassigned): Replace with `duckdb.WindowExpression` when they release it.
|
||||
# https://github.com/duckdb/duckdb/discussions/14725#discussioncomment-11200348
|
||||
try:
|
||||
from duckdb import SQLExpression
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
msg = f"DuckDB>=1.3.0 is required for this operation. Found: DuckDB {duckdb.__version__}"
|
||||
raise NotImplementedError(msg) from exc
|
||||
pb = generate_partition_by_sql(*partition_by)
|
||||
descending = descending or [False] * len(order_by)
|
||||
nulls_last = nulls_last or [False] * len(order_by)
|
||||
ob = generate_order_by_sql(*order_by, descending=descending, nulls_last=nulls_last)
|
||||
|
||||
if rows_start is not None and rows_end is not None:
|
||||
rows = f"rows between {-rows_start} preceding and {rows_end} following"
|
||||
elif rows_end is not None:
|
||||
rows = f"rows between unbounded preceding and {rows_end} following"
|
||||
elif rows_start is not None:
|
||||
rows = f"rows between {-rows_start} preceding and unbounded following"
|
||||
else:
|
||||
rows = ""
|
||||
|
||||
func = f"{str(expr).removesuffix(')')} ignore nulls)" if ignore_nulls else str(expr)
|
||||
return SQLExpression(f"{func} over ({pb} {ob} {rows})")
|
||||
|
||||
|
||||
def catch_duckdb_exception(
|
||||
exception: Exception, frame: CompliantLazyFrameAny, /
|
||||
) -> ColumnNotFoundError | Exception:
|
||||
if isinstance(exception, duckdb.BinderException) and any(
|
||||
msg in str(exception)
|
||||
for msg in (
|
||||
"not found in FROM clause",
|
||||
"this column cannot be referenced before it is defined",
|
||||
)
|
||||
):
|
||||
return ColumnNotFoundError.from_available_column_names(
|
||||
available_columns=frame.columns
|
||||
)
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
|
||||
|
||||
def function(name: str, *args: Expression) -> Expression:
|
||||
if name == "isnull":
|
||||
return args[0].isnull()
|
||||
return F(name, *args)
|
||||
94
lib/python3.11/site-packages/narwhals/_duration.py
Normal file
94
lib/python3.11/site-packages/narwhals/_duration.py
Normal file
@ -0,0 +1,94 @@
|
||||
"""Tools for working with the Polars duration string language."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime as dt
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Literal, cast, get_args
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Container, Mapping
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
__all__ = ["IntervalUnit"]
|
||||
|
||||
IntervalUnit: TypeAlias = Literal["ns", "us", "ms", "s", "m", "h", "d", "mo", "q", "y"]
|
||||
"""A Polars duration string interval unit.
|
||||
|
||||
- 'ns': nanosecond.
|
||||
- 'us': microsecond.
|
||||
- 'ms': millisecond.
|
||||
- 's': second.
|
||||
- 'm': minute.
|
||||
- 'h': hour.
|
||||
- 'd': day.
|
||||
- 'mo': month.
|
||||
- 'q': quarter.
|
||||
- 'y': year.
|
||||
"""
|
||||
TimedeltaKwd: TypeAlias = Literal[
|
||||
"days", "hours", "minutes", "seconds", "milliseconds", "microseconds"
|
||||
]
|
||||
|
||||
PATTERN_INTERVAL: re.Pattern[str] = re.compile(
|
||||
r"^(?P<multiple>-?\d+)(?P<unit>ns|us|ms|mo|m|s|h|d|q|y)\Z"
|
||||
)
|
||||
MONTH_MULTIPLES = frozenset([1, 2, 3, 4, 6, 12])
|
||||
QUARTER_MULTIPLES = frozenset([1, 2, 4])
|
||||
UNIT_TO_TIMEDELTA: Mapping[IntervalUnit, TimedeltaKwd] = {
|
||||
"d": "days",
|
||||
"h": "hours",
|
||||
"m": "minutes",
|
||||
"s": "seconds",
|
||||
"ms": "milliseconds",
|
||||
"us": "microseconds",
|
||||
}
|
||||
|
||||
|
||||
class Interval:
|
||||
def __init__(self, multiple: int, unit: IntervalUnit, /) -> None:
|
||||
self.multiple: int = multiple
|
||||
self.unit: IntervalUnit = unit
|
||||
|
||||
def to_timedelta(
|
||||
self, *, unsupported: Container[IntervalUnit] = frozenset(("ns", "mo", "q", "y"))
|
||||
) -> dt.timedelta:
|
||||
if self.unit in unsupported: # pragma: no cover
|
||||
msg = f"Creating timedelta with {self.unit} unit is not supported."
|
||||
raise NotImplementedError(msg)
|
||||
kwd = UNIT_TO_TIMEDELTA[self.unit]
|
||||
# error: Keywords must be strings (bad mypy)
|
||||
return dt.timedelta(**{kwd: self.multiple}) # type: ignore[misc]
|
||||
|
||||
@classmethod
|
||||
def parse(cls, every: str) -> Interval:
|
||||
multiple, unit = cls._parse(every)
|
||||
if unit == "mo" and multiple not in MONTH_MULTIPLES:
|
||||
msg = f"Only the following multiples are supported for 'mo' unit: {MONTH_MULTIPLES}.\nGot: {multiple}."
|
||||
raise ValueError(msg)
|
||||
if unit == "q" and multiple not in QUARTER_MULTIPLES:
|
||||
msg = f"Only the following multiples are supported for 'q' unit: {QUARTER_MULTIPLES}.\nGot: {multiple}."
|
||||
raise ValueError(msg)
|
||||
if unit == "y" and multiple != 1:
|
||||
msg = (
|
||||
f"Only multiple 1 is currently supported for 'y' unit.\nGot: {multiple}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
return cls(multiple, unit)
|
||||
|
||||
@classmethod
|
||||
def parse_no_constraints(cls, every: str) -> Interval:
|
||||
return cls(*cls._parse(every))
|
||||
|
||||
@staticmethod
|
||||
def _parse(every: str) -> tuple[int, IntervalUnit]:
|
||||
if match := PATTERN_INTERVAL.match(every):
|
||||
multiple = int(match["multiple"])
|
||||
unit = cast("IntervalUnit", match["unit"])
|
||||
return multiple, unit
|
||||
msg = (
|
||||
f"Invalid `every` string: {every}. Expected string of kind <number><unit>, "
|
||||
f"where 'unit' is one of: {get_args(IntervalUnit)}."
|
||||
)
|
||||
raise ValueError(msg)
|
||||
42
lib/python3.11/site-packages/narwhals/_enum.py
Normal file
42
lib/python3.11/site-packages/narwhals/_enum.py
Normal file
@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
# ruff: noqa: ARG004
|
||||
from enum import Enum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class NoAutoEnum(Enum):
|
||||
"""Enum base class that prohibits the use of enum.auto() for value assignment.
|
||||
|
||||
This behavior is achieved by overriding the value generation mechanism.
|
||||
|
||||
Examples:
|
||||
>>> from enum import auto
|
||||
>>> from narwhals._enum import NoAutoEnum
|
||||
>>>
|
||||
>>> class Colors(NoAutoEnum):
|
||||
... RED = 1
|
||||
... GREEN = 2
|
||||
>>> Colors.RED
|
||||
<Colors.RED: 1>
|
||||
|
||||
>>> class ColorsWithAuto(NoAutoEnum):
|
||||
... RED = 1
|
||||
... GREEN = auto()
|
||||
Traceback (most recent call last):
|
||||
...
|
||||
ValueError: Creating values with `auto()` is not allowed. Please provide a value manually instead.
|
||||
|
||||
Raises:
|
||||
ValueError: If `auto()` is attempted to be used for any enum member value.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _generate_next_value_(
|
||||
name: str, start: int, count: int, last_values: list[Any]
|
||||
) -> Any:
|
||||
msg = "Creating values with `auto()` is not allowed. Please provide a value manually instead."
|
||||
raise ValueError(msg)
|
||||
|
||||
|
||||
__all__ = ["NoAutoEnum"]
|
||||
60
lib/python3.11/site-packages/narwhals/_exceptions.py
Normal file
60
lib/python3.11/site-packages/narwhals/_exceptions.py
Normal file
@ -0,0 +1,60 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from warnings import warn
|
||||
|
||||
|
||||
def find_stacklevel() -> int:
|
||||
"""Find the first place in the stack that is not inside narwhals.
|
||||
|
||||
Returns:
|
||||
Stacklevel.
|
||||
|
||||
Taken from:
|
||||
https://github.com/pandas-dev/pandas/blob/ab89c53f48df67709a533b6a95ce3d911871a0a8/pandas/util/_exceptions.py#L30-L51
|
||||
"""
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
|
||||
import narwhals as nw
|
||||
|
||||
pkg_dir = str(Path(nw.__file__).parent)
|
||||
|
||||
# https://stackoverflow.com/questions/17407119/python-inspect-stack-is-slow
|
||||
frame = inspect.currentframe()
|
||||
n = 0
|
||||
try:
|
||||
while frame:
|
||||
fname = inspect.getfile(frame)
|
||||
if fname.startswith(pkg_dir) or (
|
||||
(qualname := getattr(frame.f_code, "co_qualname", None))
|
||||
# ignore @singledispatch wrappers
|
||||
and qualname.startswith("singledispatch.")
|
||||
):
|
||||
frame = frame.f_back
|
||||
n += 1
|
||||
else: # pragma: no cover
|
||||
break
|
||||
else: # pragma: no cover
|
||||
pass
|
||||
finally:
|
||||
# https://docs.python.org/3/library/inspect.html
|
||||
# > Though the cycle detector will catch these, destruction of the frames
|
||||
# > (and local variables) can be made deterministic by removing the cycle
|
||||
# > in a finally clause.
|
||||
del frame
|
||||
return n
|
||||
|
||||
|
||||
def issue_deprecation_warning(message: str, _version: str) -> None: # pragma: no cover
|
||||
"""Issue a deprecation warning.
|
||||
|
||||
Arguments:
|
||||
message: The message associated with the warning.
|
||||
_version: Narwhals version when the warning was introduced. Just used for internal
|
||||
bookkeeping.
|
||||
"""
|
||||
warn(message=message, category=DeprecationWarning, stacklevel=find_stacklevel())
|
||||
|
||||
|
||||
def issue_warning(message: str, category: type[Warning]) -> None:
|
||||
warn(message=message, category=category, stacklevel=find_stacklevel())
|
||||
615
lib/python3.11/site-packages/narwhals/_expression_parsing.py
Normal file
615
lib/python3.11/site-packages/narwhals/_expression_parsing.py
Normal file
@ -0,0 +1,615 @@
|
||||
# Utilities for expression parsing
|
||||
# Useful for backends which don't have any concept of expressions, such
|
||||
# and pandas or PyArrow.
|
||||
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI
|
||||
from __future__ import annotations
|
||||
|
||||
from enum import Enum, auto
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
|
||||
|
||||
from narwhals._utils import is_compliant_expr, zip_strict
|
||||
from narwhals.dependencies import is_narwhals_series, is_numpy_array, is_numpy_array_1d
|
||||
from narwhals.exceptions import InvalidOperationError, MultiOutputExpressionError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Never, TypeIs
|
||||
|
||||
from narwhals._compliant import CompliantExpr, CompliantFrameT
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
CompliantExprAny,
|
||||
CompliantFrameAny,
|
||||
CompliantNamespaceAny,
|
||||
EvalNames,
|
||||
)
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import IntoExpr, NonNestedLiteral, _1DArray
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
def is_expr(obj: Any) -> TypeIs[Expr]:
|
||||
"""Check whether `obj` is a Narwhals Expr."""
|
||||
from narwhals.expr import Expr
|
||||
|
||||
return isinstance(obj, Expr)
|
||||
|
||||
|
||||
def is_series(obj: Any) -> TypeIs[Series[Any]]:
|
||||
"""Check whether `obj` is a Narwhals Expr."""
|
||||
from narwhals.series import Series
|
||||
|
||||
return isinstance(obj, Series)
|
||||
|
||||
|
||||
def is_into_expr_eager(obj: Any) -> TypeIs[Expr | Series[Any] | str | _1DArray]:
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.series import Series
|
||||
|
||||
return isinstance(obj, (Series, Expr, str)) or is_numpy_array_1d(obj)
|
||||
|
||||
|
||||
def combine_evaluate_output_names(
|
||||
*exprs: CompliantExpr[CompliantFrameT, Any],
|
||||
) -> EvalNames[CompliantFrameT]:
|
||||
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1, expr2)` takes the
|
||||
# first name of `expr1`.
|
||||
if not is_compliant_expr(exprs[0]): # pragma: no cover
|
||||
msg = f"Safety assertion failed, expected expression, got: {type(exprs[0])}. Please report a bug."
|
||||
raise AssertionError(msg)
|
||||
|
||||
def evaluate_output_names(df: CompliantFrameT) -> Sequence[str]:
|
||||
return exprs[0]._evaluate_output_names(df)[:1]
|
||||
|
||||
return evaluate_output_names
|
||||
|
||||
|
||||
def combine_alias_output_names(*exprs: CompliantExprAny) -> AliasNames | None:
|
||||
# Follow left-hand-rule for naming. E.g. `nw.sum_horizontal(expr1.alias(alias), expr2)` takes the
|
||||
# aliasing function of `expr1` and apply it to the first output name of `expr1`.
|
||||
if exprs[0]._alias_output_names is None:
|
||||
return None
|
||||
|
||||
def alias_output_names(names: Sequence[str]) -> Sequence[str]:
|
||||
return exprs[0]._alias_output_names(names)[:1] # type: ignore[misc]
|
||||
|
||||
return alias_output_names
|
||||
|
||||
|
||||
def evaluate_output_names_and_aliases(
|
||||
expr: CompliantExprAny, df: CompliantFrameAny, exclude: Sequence[str]
|
||||
) -> tuple[Sequence[str], Sequence[str]]:
|
||||
output_names = expr._evaluate_output_names(df)
|
||||
aliases = (
|
||||
output_names
|
||||
if expr._alias_output_names is None
|
||||
else expr._alias_output_names(output_names)
|
||||
)
|
||||
if exclude:
|
||||
assert expr._metadata is not None # noqa: S101
|
||||
if expr._metadata.expansion_kind.is_multi_unnamed():
|
||||
output_names, aliases = zip_strict(
|
||||
*[
|
||||
(x, alias)
|
||||
for x, alias in zip_strict(output_names, aliases)
|
||||
if x not in exclude
|
||||
]
|
||||
)
|
||||
return output_names, aliases
|
||||
|
||||
|
||||
class ExprKind(Enum):
|
||||
"""Describe which kind of expression we are dealing with."""
|
||||
|
||||
LITERAL = auto()
|
||||
"""e.g. `nw.lit(1)`"""
|
||||
|
||||
AGGREGATION = auto()
|
||||
"""Reduces to a single value, not affected by row order, e.g. `nw.col('a').mean()`"""
|
||||
|
||||
ORDERABLE_AGGREGATION = auto()
|
||||
"""Reduces to a single value, affected by row order, e.g. `nw.col('a').arg_max()`"""
|
||||
|
||||
ELEMENTWISE = auto()
|
||||
"""Preserves length, can operate without context for surrounding rows, e.g. `nw.col('a').abs()`."""
|
||||
|
||||
ORDERABLE_WINDOW = auto()
|
||||
"""Depends on the rows around it and on their order, e.g. `diff`."""
|
||||
|
||||
WINDOW = auto()
|
||||
"""Depends on the rows around it and possibly their order, e.g. `rank`."""
|
||||
|
||||
FILTRATION = auto()
|
||||
"""Changes length, not affected by row order, e.g. `drop_nulls`."""
|
||||
|
||||
ORDERABLE_FILTRATION = auto()
|
||||
"""Changes length, affected by row order, e.g. `tail`."""
|
||||
|
||||
OVER = auto()
|
||||
"""Results from calling `.over` on expression."""
|
||||
|
||||
UNKNOWN = auto()
|
||||
"""Based on the information we have, we can't determine the ExprKind."""
|
||||
|
||||
@property
|
||||
def is_scalar_like(self) -> bool:
|
||||
return self in {ExprKind.LITERAL, ExprKind.AGGREGATION}
|
||||
|
||||
@property
|
||||
def is_orderable_window(self) -> bool:
|
||||
return self in {ExprKind.ORDERABLE_WINDOW, ExprKind.ORDERABLE_AGGREGATION}
|
||||
|
||||
@classmethod
|
||||
def from_expr(cls, obj: Expr) -> ExprKind:
|
||||
meta = obj._metadata
|
||||
if meta.is_literal:
|
||||
return ExprKind.LITERAL
|
||||
if meta.is_scalar_like:
|
||||
return ExprKind.AGGREGATION
|
||||
if meta.is_elementwise:
|
||||
return ExprKind.ELEMENTWISE
|
||||
return ExprKind.UNKNOWN
|
||||
|
||||
@classmethod
|
||||
def from_into_expr(
|
||||
cls, obj: IntoExpr | NonNestedLiteral | _1DArray, *, str_as_lit: bool
|
||||
) -> ExprKind:
|
||||
if is_expr(obj):
|
||||
return cls.from_expr(obj)
|
||||
if (
|
||||
is_narwhals_series(obj)
|
||||
or is_numpy_array(obj)
|
||||
or (isinstance(obj, str) and not str_as_lit)
|
||||
):
|
||||
return ExprKind.ELEMENTWISE
|
||||
return ExprKind.LITERAL
|
||||
|
||||
|
||||
def is_scalar_like(
|
||||
obj: ExprKind,
|
||||
) -> TypeIs[Literal[ExprKind.LITERAL, ExprKind.AGGREGATION]]:
|
||||
return obj.is_scalar_like
|
||||
|
||||
|
||||
class ExpansionKind(Enum):
|
||||
"""Describe what kind of expansion the expression performs."""
|
||||
|
||||
SINGLE = auto()
|
||||
"""e.g. `nw.col('a'), nw.sum_horizontal(nw.all())`"""
|
||||
|
||||
MULTI_NAMED = auto()
|
||||
"""e.g. `nw.col('a', 'b')`"""
|
||||
|
||||
MULTI_UNNAMED = auto()
|
||||
"""e.g. `nw.all()`, nw.nth(0, 1)"""
|
||||
|
||||
def is_multi_unnamed(self) -> bool:
|
||||
return self is ExpansionKind.MULTI_UNNAMED
|
||||
|
||||
def is_multi_output(self) -> bool:
|
||||
return self in {ExpansionKind.MULTI_NAMED, ExpansionKind.MULTI_UNNAMED}
|
||||
|
||||
def __and__(self, other: ExpansionKind) -> Literal[ExpansionKind.MULTI_UNNAMED]:
|
||||
if self is ExpansionKind.MULTI_UNNAMED and other is ExpansionKind.MULTI_UNNAMED:
|
||||
# e.g. nw.selectors.all() - nw.selectors.numeric().
|
||||
return ExpansionKind.MULTI_UNNAMED
|
||||
# Don't attempt anything more complex, keep it simple and raise in the face of ambiguity.
|
||||
msg = f"Unsupported ExpansionKind combination, got {self} and {other}, please report a bug." # pragma: no cover
|
||||
raise AssertionError(msg) # pragma: no cover
|
||||
|
||||
|
||||
class ExprMetadata:
|
||||
"""Expression metadata.
|
||||
|
||||
Parameters:
|
||||
expansion_kind: What kind of expansion the expression performs.
|
||||
has_windows: Whether it already contains window functions.
|
||||
is_elementwise: Whether it can operate row-by-row without context
|
||||
of the other rows around it.
|
||||
is_literal: Whether it is just a literal wrapped in an expression.
|
||||
is_scalar_like: Whether it is a literal or an aggregation.
|
||||
last_node: The ExprKind of the last node.
|
||||
n_orderable_ops: The number of order-dependent operations. In the
|
||||
lazy case, this number must be `0` by the time the expression
|
||||
is evaluated.
|
||||
preserves_length: Whether the expression preserves the input length.
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
"expansion_kind",
|
||||
"has_windows",
|
||||
"is_elementwise",
|
||||
"is_literal",
|
||||
"is_scalar_like",
|
||||
"last_node",
|
||||
"n_orderable_ops",
|
||||
"preserves_length",
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
expansion_kind: ExpansionKind,
|
||||
last_node: ExprKind,
|
||||
*,
|
||||
has_windows: bool = False,
|
||||
n_orderable_ops: int = 0,
|
||||
preserves_length: bool = True,
|
||||
is_elementwise: bool = True,
|
||||
is_scalar_like: bool = False,
|
||||
is_literal: bool = False,
|
||||
) -> None:
|
||||
if is_literal:
|
||||
assert is_scalar_like # noqa: S101 # debug assertion
|
||||
if is_elementwise:
|
||||
assert preserves_length # noqa: S101 # debug assertion
|
||||
self.expansion_kind: ExpansionKind = expansion_kind
|
||||
self.last_node: ExprKind = last_node
|
||||
self.has_windows: bool = has_windows
|
||||
self.n_orderable_ops: int = n_orderable_ops
|
||||
self.is_elementwise: bool = is_elementwise
|
||||
self.preserves_length: bool = preserves_length
|
||||
self.is_scalar_like: bool = is_scalar_like
|
||||
self.is_literal: bool = is_literal
|
||||
|
||||
def __init_subclass__(cls, /, *args: Any, **kwds: Any) -> Never: # pragma: no cover
|
||||
msg = f"Cannot subclass {cls.__name__!r}"
|
||||
raise TypeError(msg)
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return (
|
||||
f"ExprMetadata(\n"
|
||||
f" expansion_kind: {self.expansion_kind},\n"
|
||||
f" last_node: {self.last_node},\n"
|
||||
f" has_windows: {self.has_windows},\n"
|
||||
f" n_orderable_ops: {self.n_orderable_ops},\n"
|
||||
f" is_elementwise: {self.is_elementwise},\n"
|
||||
f" preserves_length: {self.preserves_length},\n"
|
||||
f" is_scalar_like: {self.is_scalar_like},\n"
|
||||
f" is_literal: {self.is_literal},\n"
|
||||
")"
|
||||
)
|
||||
|
||||
@property
|
||||
def is_filtration(self) -> bool:
|
||||
return not self.preserves_length and not self.is_scalar_like
|
||||
|
||||
def with_aggregation(self) -> ExprMetadata:
|
||||
if self.is_scalar_like:
|
||||
msg = "Can't apply aggregations to scalar-like expressions."
|
||||
raise InvalidOperationError(msg)
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.AGGREGATION,
|
||||
has_windows=self.has_windows,
|
||||
n_orderable_ops=self.n_orderable_ops,
|
||||
preserves_length=False,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=True,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
def with_orderable_aggregation(self) -> ExprMetadata:
|
||||
# Deprecated, used only in stable.v1.
|
||||
if self.is_scalar_like: # pragma: no cover
|
||||
msg = "Can't apply aggregations to scalar-like expressions."
|
||||
raise InvalidOperationError(msg)
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.ORDERABLE_AGGREGATION,
|
||||
has_windows=self.has_windows,
|
||||
n_orderable_ops=self.n_orderable_ops + 1,
|
||||
preserves_length=False,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=True,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
def with_elementwise_op(self) -> ExprMetadata:
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.ELEMENTWISE,
|
||||
has_windows=self.has_windows,
|
||||
n_orderable_ops=self.n_orderable_ops,
|
||||
preserves_length=self.preserves_length,
|
||||
is_elementwise=self.is_elementwise,
|
||||
is_scalar_like=self.is_scalar_like,
|
||||
is_literal=self.is_literal,
|
||||
)
|
||||
|
||||
def with_window(self) -> ExprMetadata:
|
||||
# Window function which may (but doesn't have to) be used with `over(order_by=...)`.
|
||||
if self.is_scalar_like:
|
||||
msg = "Can't apply window (e.g. `rank`) to scalar-like expression."
|
||||
raise InvalidOperationError(msg)
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.WINDOW,
|
||||
has_windows=self.has_windows,
|
||||
# The function isn't order-dependent (but, users can still use `order_by` if they wish!),
|
||||
# so we don't increment `n_orderable_ops`.
|
||||
n_orderable_ops=self.n_orderable_ops,
|
||||
preserves_length=self.preserves_length,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=False,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
def with_orderable_window(self) -> ExprMetadata:
|
||||
# Window function which must be used with `over(order_by=...)`.
|
||||
if self.is_scalar_like:
|
||||
msg = "Can't apply orderable window (e.g. `diff`, `shift`) to scalar-like expression."
|
||||
raise InvalidOperationError(msg)
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.ORDERABLE_WINDOW,
|
||||
has_windows=self.has_windows,
|
||||
n_orderable_ops=self.n_orderable_ops + 1,
|
||||
preserves_length=self.preserves_length,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=False,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
def with_ordered_over(self) -> ExprMetadata:
|
||||
if self.has_windows:
|
||||
msg = "Cannot nest `over` statements."
|
||||
raise InvalidOperationError(msg)
|
||||
if self.is_elementwise or self.is_filtration:
|
||||
msg = (
|
||||
"Cannot use `over` on expressions which are elementwise\n"
|
||||
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
n_orderable_ops = self.n_orderable_ops
|
||||
if not n_orderable_ops and self.last_node is not ExprKind.WINDOW:
|
||||
msg = (
|
||||
"Cannot use `order_by` in `over` on expression which isn't orderable.\n"
|
||||
"If your expression is orderable, then make sure that `over(order_by=...)`\n"
|
||||
"comes immediately after the order-dependent expression.\n\n"
|
||||
"Hint: instead of\n"
|
||||
" - `(nw.col('price').diff() + 1).over(order_by='date')`\n"
|
||||
"write:\n"
|
||||
" + `nw.col('price').diff().over(order_by='date') + 1`\n"
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
if self.last_node.is_orderable_window:
|
||||
n_orderable_ops -= 1
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.OVER,
|
||||
has_windows=True,
|
||||
n_orderable_ops=n_orderable_ops,
|
||||
preserves_length=True,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=False,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
def with_partitioned_over(self) -> ExprMetadata:
|
||||
if self.has_windows:
|
||||
msg = "Cannot nest `over` statements."
|
||||
raise InvalidOperationError(msg)
|
||||
if self.is_elementwise or self.is_filtration:
|
||||
msg = (
|
||||
"Cannot use `over` on expressions which are elementwise\n"
|
||||
"(e.g. `abs`) or which change length (e.g. `drop_nulls`)."
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.OVER,
|
||||
has_windows=True,
|
||||
n_orderable_ops=self.n_orderable_ops,
|
||||
preserves_length=True,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=False,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
def with_filtration(self) -> ExprMetadata:
|
||||
if self.is_scalar_like:
|
||||
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
|
||||
raise InvalidOperationError(msg)
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.FILTRATION,
|
||||
has_windows=self.has_windows,
|
||||
n_orderable_ops=self.n_orderable_ops,
|
||||
preserves_length=False,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=False,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
def with_orderable_filtration(self) -> ExprMetadata:
|
||||
if self.is_scalar_like:
|
||||
msg = "Can't apply filtration (e.g. `drop_nulls`) to scalar-like expression."
|
||||
raise InvalidOperationError(msg)
|
||||
return ExprMetadata(
|
||||
self.expansion_kind,
|
||||
ExprKind.ORDERABLE_FILTRATION,
|
||||
has_windows=self.has_windows,
|
||||
n_orderable_ops=self.n_orderable_ops + 1,
|
||||
preserves_length=False,
|
||||
is_elementwise=False,
|
||||
is_scalar_like=False,
|
||||
is_literal=False,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def aggregation() -> ExprMetadata:
|
||||
return ExprMetadata(
|
||||
ExpansionKind.SINGLE,
|
||||
ExprKind.AGGREGATION,
|
||||
is_elementwise=False,
|
||||
preserves_length=False,
|
||||
is_scalar_like=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def literal() -> ExprMetadata:
|
||||
return ExprMetadata(
|
||||
ExpansionKind.SINGLE,
|
||||
ExprKind.LITERAL,
|
||||
is_elementwise=False,
|
||||
preserves_length=False,
|
||||
is_literal=True,
|
||||
is_scalar_like=True,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def selector_single() -> ExprMetadata:
|
||||
# e.g. `nw.col('a')`, `nw.nth(0)`
|
||||
return ExprMetadata(ExpansionKind.SINGLE, ExprKind.ELEMENTWISE)
|
||||
|
||||
@staticmethod
|
||||
def selector_multi_named() -> ExprMetadata:
|
||||
# e.g. `nw.col('a', 'b')`
|
||||
return ExprMetadata(ExpansionKind.MULTI_NAMED, ExprKind.ELEMENTWISE)
|
||||
|
||||
@staticmethod
|
||||
def selector_multi_unnamed() -> ExprMetadata:
|
||||
# e.g. `nw.all()`
|
||||
return ExprMetadata(ExpansionKind.MULTI_UNNAMED, ExprKind.ELEMENTWISE)
|
||||
|
||||
@classmethod
|
||||
def from_binary_op(cls, lhs: Expr, rhs: IntoExpr, /) -> ExprMetadata:
|
||||
# We may be able to allow multi-output rhs in the future:
|
||||
# https://github.com/narwhals-dev/narwhals/issues/2244.
|
||||
return combine_metadata(
|
||||
lhs, rhs, str_as_lit=True, allow_multi_output=False, to_single_output=False
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_horizontal_op(cls, *exprs: IntoExpr) -> ExprMetadata:
|
||||
return combine_metadata(
|
||||
*exprs, str_as_lit=False, allow_multi_output=True, to_single_output=True
|
||||
)
|
||||
|
||||
|
||||
def combine_metadata(
|
||||
*args: IntoExpr | object | None,
|
||||
str_as_lit: bool,
|
||||
allow_multi_output: bool,
|
||||
to_single_output: bool,
|
||||
) -> ExprMetadata:
|
||||
"""Combine metadata from `args`.
|
||||
|
||||
Arguments:
|
||||
args: Arguments, maybe expressions, literals, or Series.
|
||||
str_as_lit: Whether to interpret strings as literals or as column names.
|
||||
allow_multi_output: Whether to allow multi-output inputs.
|
||||
to_single_output: Whether the result is always single-output, regardless
|
||||
of the inputs (e.g. `nw.sum_horizontal`).
|
||||
"""
|
||||
n_filtrations = 0
|
||||
result_expansion_kind = ExpansionKind.SINGLE
|
||||
result_has_windows = False
|
||||
result_n_orderable_ops = 0
|
||||
# result preserves length if at least one input does
|
||||
result_preserves_length = False
|
||||
# result is elementwise if all inputs are elementwise
|
||||
result_is_elementwise = True
|
||||
# result is scalar-like if all inputs are scalar-like
|
||||
result_is_scalar_like = True
|
||||
# result is literal if all inputs are literal
|
||||
result_is_literal = True
|
||||
|
||||
for i, arg in enumerate(args):
|
||||
if (isinstance(arg, str) and not str_as_lit) or is_series(arg):
|
||||
result_preserves_length = True
|
||||
result_is_scalar_like = False
|
||||
result_is_literal = False
|
||||
elif is_expr(arg):
|
||||
metadata = arg._metadata
|
||||
if metadata.expansion_kind.is_multi_output():
|
||||
expansion_kind = metadata.expansion_kind
|
||||
if i > 0 and not allow_multi_output:
|
||||
# Left-most argument is always allowed to be multi-output.
|
||||
msg = (
|
||||
"Multi-output expressions (e.g. nw.col('a', 'b'), nw.all()) "
|
||||
"are not supported in this context."
|
||||
)
|
||||
raise MultiOutputExpressionError(msg)
|
||||
if not to_single_output:
|
||||
result_expansion_kind = (
|
||||
result_expansion_kind & expansion_kind
|
||||
if i > 0
|
||||
else expansion_kind
|
||||
)
|
||||
|
||||
result_has_windows |= metadata.has_windows
|
||||
result_n_orderable_ops += metadata.n_orderable_ops
|
||||
result_preserves_length |= metadata.preserves_length
|
||||
result_is_elementwise &= metadata.is_elementwise
|
||||
result_is_scalar_like &= metadata.is_scalar_like
|
||||
result_is_literal &= metadata.is_literal
|
||||
n_filtrations += int(metadata.is_filtration)
|
||||
|
||||
if n_filtrations > 1:
|
||||
msg = "Length-changing expressions can only be used in isolation, or followed by an aggregation"
|
||||
raise InvalidOperationError(msg)
|
||||
if result_preserves_length and n_filtrations:
|
||||
msg = "Cannot combine length-changing expressions with length-preserving ones or aggregations"
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return ExprMetadata(
|
||||
result_expansion_kind,
|
||||
# n-ary operations align positionally, and so the last node is elementwise.
|
||||
ExprKind.ELEMENTWISE,
|
||||
has_windows=result_has_windows,
|
||||
n_orderable_ops=result_n_orderable_ops,
|
||||
preserves_length=result_preserves_length,
|
||||
is_elementwise=result_is_elementwise,
|
||||
is_scalar_like=result_is_scalar_like,
|
||||
is_literal=result_is_literal,
|
||||
)
|
||||
|
||||
|
||||
def check_expressions_preserve_length(*args: IntoExpr, function_name: str) -> None:
|
||||
# Raise if any argument in `args` isn't length-preserving.
|
||||
# For Series input, we don't raise (yet), we let such checks happen later,
|
||||
# as this function works lazily and so can't evaluate lengths.
|
||||
from narwhals.series import Series
|
||||
|
||||
if not all(
|
||||
(is_expr(x) and x._metadata.preserves_length) or isinstance(x, (str, Series))
|
||||
for x in args
|
||||
):
|
||||
msg = f"Expressions which aggregate or change length cannot be passed to '{function_name}'."
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
|
||||
def all_exprs_are_scalar_like(*args: IntoExpr, **kwargs: IntoExpr) -> bool:
|
||||
# Raise if any argument in `args` isn't an aggregation or literal.
|
||||
# For Series input, we don't raise (yet), we let such checks happen later,
|
||||
# as this function works lazily and so can't evaluate lengths.
|
||||
exprs = chain(args, kwargs.values())
|
||||
return all(is_expr(x) and x._metadata.is_scalar_like for x in exprs)
|
||||
|
||||
|
||||
def apply_n_ary_operation(
|
||||
plx: CompliantNamespaceAny,
|
||||
n_ary_function: Callable[..., CompliantExprAny],
|
||||
*comparands: IntoExpr | NonNestedLiteral | _1DArray,
|
||||
str_as_lit: bool,
|
||||
) -> CompliantExprAny:
|
||||
parse = plx.parse_into_expr
|
||||
compliant_exprs = (parse(into, str_as_lit=str_as_lit) for into in comparands)
|
||||
kinds = [
|
||||
ExprKind.from_into_expr(comparand, str_as_lit=str_as_lit)
|
||||
for comparand in comparands
|
||||
]
|
||||
|
||||
broadcast = any(not kind.is_scalar_like for kind in kinds)
|
||||
compliant_exprs = (
|
||||
compliant_expr.broadcast(kind)
|
||||
if broadcast and is_compliant_expr(compliant_expr) and is_scalar_like(kind)
|
||||
else compliant_expr
|
||||
for compliant_expr, kind in zip_strict(compliant_exprs, kinds)
|
||||
)
|
||||
return n_ary_function(*compliant_exprs)
|
||||
432
lib/python3.11/site-packages/narwhals/_ibis/dataframe.py
Normal file
432
lib/python3.11/site-packages/narwhals/_ibis/dataframe.py
Normal file
@ -0,0 +1,432 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
import ibis
|
||||
import ibis.expr.types as ir
|
||||
|
||||
from narwhals._ibis.utils import evaluate_exprs, native_to_narwhals_dtype
|
||||
from narwhals._sql.dataframe import SQLLazyFrame
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
Version,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.exceptions import ColumnNotFoundError, InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from ibis.expr.operations import Binary
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
from narwhals._ibis.group_by import IbisGroupBy
|
||||
from narwhals._ibis.namespace import IbisNamespace
|
||||
from narwhals._ibis.series import IbisInterchangeSeries
|
||||
from narwhals._typing import _EagerAllowedImpl
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.stable.v1 import DataFrame as DataFrameV1
|
||||
from narwhals.typing import AsofJoinStrategy, JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
JoinPredicates: TypeAlias = "Sequence[ir.BooleanColumn] | Sequence[str]"
|
||||
|
||||
|
||||
class IbisLazyFrame(
|
||||
SQLLazyFrame["IbisExpr", "ir.Table", "LazyFrame[ir.Table] | DataFrameV1[ir.Table]"],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
_implementation = Implementation.IBIS
|
||||
|
||||
def __init__(
|
||||
self, df: ir.Table, *, version: Version, validate_backend_version: bool = False
|
||||
) -> None:
|
||||
self._native_frame: ir.Table = df
|
||||
self._version = version
|
||||
self._cached_schema: dict[str, DType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: ir.Table | Any) -> TypeIs[ir.Table]:
|
||||
return isinstance(obj, ir.Table)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: ir.Table, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[ir.Table] | DataFrameV1[ir.Table]:
|
||||
if self._version is Version.V1:
|
||||
from narwhals.stable.v1 import DataFrame
|
||||
|
||||
return DataFrame(self, level="interchange")
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self: # pragma: no cover
|
||||
# Keep around for backcompat.
|
||||
if self._version is not Version.V1:
|
||||
msg = "__narwhals_dataframe__ is not implemented for IbisLazyFrame"
|
||||
raise AttributeError(msg)
|
||||
return self
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
return ibis
|
||||
|
||||
def __narwhals_namespace__(self) -> IbisNamespace:
|
||||
from narwhals._ibis.namespace import IbisNamespace
|
||||
|
||||
return IbisNamespace(version=self._version)
|
||||
|
||||
def get_column(self, name: str) -> IbisInterchangeSeries:
|
||||
from narwhals._ibis.series import IbisInterchangeSeries
|
||||
|
||||
return IbisInterchangeSeries(self.native.select(name), version=self._version)
|
||||
|
||||
def _iter_columns(self) -> Iterator[ir.Expr]:
|
||||
for name in self.columns:
|
||||
yield self.native[name]
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is None or backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self.native.to_pyarrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.to_pandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
self.native.to_polars(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.head(n))
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: IbisExpr) -> Self:
|
||||
selection = [
|
||||
cast("ir.Scalar", val.name(name))
|
||||
for name, val in evaluate_exprs(self, *exprs)
|
||||
]
|
||||
return self._with_native(self.native.aggregate(selection))
|
||||
|
||||
def select(self, *exprs: IbisExpr) -> Self:
|
||||
selection = [val.name(name) for name, val in evaluate_exprs(self, *exprs)]
|
||||
if not selection:
|
||||
msg = "At least one expression must be provided to `select` with the Ibis backend."
|
||||
raise ValueError(msg)
|
||||
|
||||
t = self.native.select(*selection)
|
||||
return self._with_native(t)
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
selection = (col for col in self.columns if col not in columns_to_drop)
|
||||
return self._with_native(self.native.select(*selection))
|
||||
|
||||
def lazy(self, backend: None = None, **_: None) -> Self:
|
||||
# The `backend`` argument has no effect but we keep it here for
|
||||
# backwards compatibility because in `narwhals.stable.v1`
|
||||
# function `.from_native()` will return a DataFrame for Ibis.
|
||||
|
||||
if backend is not None: # pragma: no cover
|
||||
msg = "`backend` argument is not supported for Ibis"
|
||||
raise ValueError(msg)
|
||||
return self
|
||||
|
||||
def with_columns(self, *exprs: IbisExpr) -> Self:
|
||||
new_columns_map = dict(evaluate_exprs(self, *exprs))
|
||||
return self._with_native(self.native.mutate(**new_columns_map))
|
||||
|
||||
def filter(self, predicate: IbisExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
mask = cast("ir.BooleanValue", predicate(self)[0])
|
||||
return self._with_native(self.native.filter(mask))
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_schema is None:
|
||||
# Note: prefer `self._cached_schema` over `functools.cached_property`
|
||||
# due to Python3.13 failures.
|
||||
self._cached_schema = {
|
||||
name: native_to_narwhals_dtype(dtype, self._version)
|
||||
for name, dtype in self.native.schema().fields.items()
|
||||
}
|
||||
return self._cached_schema
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_schema is not None
|
||||
else list(self.native.columns)
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
# only if version is v1, keep around for backcompat
|
||||
return self.native.to_pandas()
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
# only if version is v1, keep around for backcompat
|
||||
return self.native.to_pyarrow()
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
def _with_native(self, df: ir.Table) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[IbisExpr], *, drop_null_keys: bool
|
||||
) -> IbisGroupBy:
|
||||
from narwhals._ibis.group_by import IbisGroupBy
|
||||
|
||||
return IbisGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
def _rename(col: str) -> str:
|
||||
return mapping.get(col, col)
|
||||
|
||||
return self._with_native(self.native.rename(_rename))
|
||||
|
||||
@staticmethod
|
||||
def _join_drop_duplicate_columns(df: ir.Table, columns: Iterable[str], /) -> ir.Table:
|
||||
"""Ibis adds a suffix to the right table col, even when it matches the left during a join."""
|
||||
duplicates = set(df.columns).intersection(columns)
|
||||
return df.drop(*duplicates) if duplicates else df
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
how_native = "outer" if how == "full" else how
|
||||
rname = "{name}" + suffix
|
||||
if other == self:
|
||||
# Ibis does not support self-references unless created as a view
|
||||
other = self._with_native(other.native.view())
|
||||
if how_native == "cross":
|
||||
joined = self.native.join(other.native, how=how_native, rname=rname)
|
||||
return self._with_native(joined)
|
||||
# help mypy
|
||||
assert left_on is not None # noqa: S101
|
||||
assert right_on is not None # noqa: S101
|
||||
predicates = self._convert_predicates(other, left_on, right_on)
|
||||
joined = self.native.join(other.native, predicates, how=how_native, rname=rname)
|
||||
if how_native == "left":
|
||||
right_names = (n + suffix for n in right_on)
|
||||
joined = self._join_drop_duplicate_columns(joined, right_names)
|
||||
it = (cast("Binary", p.op()) for p in predicates if not isinstance(p, str))
|
||||
to_drop = []
|
||||
for pred in it:
|
||||
right = pred.right.name
|
||||
# Mirrors how polars works.
|
||||
if right not in self.columns and pred.left.name != right:
|
||||
to_drop.append(right)
|
||||
if to_drop:
|
||||
joined = joined.drop(*to_drop)
|
||||
return self._with_native(joined)
|
||||
|
||||
def join_asof(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
left_on: str,
|
||||
right_on: str,
|
||||
by_left: Sequence[str] | None,
|
||||
by_right: Sequence[str] | None,
|
||||
strategy: AsofJoinStrategy,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
rname = "{name}" + suffix
|
||||
strategy_op = {"backward": operator.ge, "forward": operator.le}
|
||||
predicates: JoinPredicates = []
|
||||
if op := strategy_op.get(strategy):
|
||||
on: ir.BooleanColumn = op(self.native[left_on], other.native[right_on])
|
||||
else:
|
||||
msg = "Only `backward` and `forward` strategies are currently supported for Ibis"
|
||||
raise NotImplementedError(msg)
|
||||
if by_left is not None and by_right is not None:
|
||||
predicates = self._convert_predicates(other, by_left, by_right)
|
||||
joined = self.native.asof_join(other.native, on, predicates, rname=rname)
|
||||
joined = self._join_drop_duplicate_columns(joined, [right_on + suffix])
|
||||
if by_right is not None:
|
||||
right_names = (n + suffix for n in by_right)
|
||||
joined = self._join_drop_duplicate_columns(joined, right_names)
|
||||
return self._with_native(joined)
|
||||
|
||||
def _convert_predicates(
|
||||
self, other: Self, left_on: Sequence[str], right_on: Sequence[str]
|
||||
) -> JoinPredicates:
|
||||
if left_on == right_on:
|
||||
return left_on
|
||||
return [
|
||||
cast("ir.BooleanColumn", (self.native[left] == other.native[right]))
|
||||
for left, right in zip_strict(left_on, right_on)
|
||||
]
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return {
|
||||
name: native_to_narwhals_dtype(dtype, self._version)
|
||||
for name, dtype in self.native.schema().fields.items()
|
||||
}
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset_ := subset if keep == "any" else (subset or self.columns):
|
||||
# Sanitise input
|
||||
if any(x not in self.columns for x in subset_):
|
||||
msg = f"Columns {set(subset_).difference(self.columns)} not found in {self.columns}."
|
||||
raise ColumnNotFoundError(msg)
|
||||
|
||||
mapped_keep: dict[str, Literal["first"] | None] = {
|
||||
"any": "first",
|
||||
"none": None,
|
||||
}
|
||||
to_keep = mapped_keep[keep]
|
||||
return self._with_native(self.native.distinct(on=subset_, keep=to_keep))
|
||||
return self._with_native(self.native.distinct(on=subset))
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
descending = [descending for _ in range(len(by))]
|
||||
|
||||
sort_cols: list[Any] = []
|
||||
|
||||
for i in range(len(by)):
|
||||
direction_fn = ibis.desc if descending[i] else ibis.asc
|
||||
col = direction_fn(by[i], nulls_first=not nulls_last)
|
||||
sort_cols.append(col)
|
||||
|
||||
return self._with_native(self.native.order_by(*sort_cols))
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
if isinstance(reverse, bool):
|
||||
reverse = [reverse] * len(list(by))
|
||||
sort_cols = []
|
||||
|
||||
for is_reverse, by_col in zip_strict(reverse, by):
|
||||
direction_fn = ibis.asc if is_reverse else ibis.desc
|
||||
col = direction_fn(by_col, nulls_first=False)
|
||||
sort_cols.append(cast("ir.Column", col))
|
||||
|
||||
return self._with_native(self.native.order_by(*sort_cols).head(k))
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
subset_ = subset if subset is not None else self.columns
|
||||
return self._with_native(self.native.drop_null(subset_))
|
||||
|
||||
def explode(self, columns: Sequence[str]) -> Self:
|
||||
dtypes = self._version.dtypes
|
||||
schema = self.collect_schema()
|
||||
for col in columns:
|
||||
dtype = schema[col]
|
||||
|
||||
if dtype != dtypes.List:
|
||||
msg = (
|
||||
f"`explode` operation not supported for dtype `{dtype}`, "
|
||||
"expected List type"
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
if len(columns) != 1:
|
||||
msg = (
|
||||
"Exploding on multiple columns is not supported with Ibis backend since "
|
||||
"we cannot guarantee that the exploded columns have matching element counts."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self._with_native(self.native.unnest(columns[0], keep_empty=True))
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
import ibis.selectors as s
|
||||
|
||||
index_: Sequence[str] = [] if index is None else index
|
||||
on_: Sequence[str] = (
|
||||
[c for c in self.columns if c not in index_] if on is None else on
|
||||
)
|
||||
|
||||
# Discard columns not in the index
|
||||
final_columns = list(dict.fromkeys([*index_, variable_name, value_name]))
|
||||
|
||||
unpivoted = self.native.pivot_longer(
|
||||
s.cols(*on_), names_to=variable_name, values_to=value_name
|
||||
)
|
||||
return self._with_native(unpivoted.select(*final_columns))
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
|
||||
to_select = [
|
||||
ibis.row_number().over(ibis.window(order_by=order_by)).name(name),
|
||||
ibis.selectors.all(),
|
||||
]
|
||||
return self._with_native(self.native.select(*to_select))
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
if isinstance(file, BytesIO): # pragma: no cover
|
||||
msg = "Writing to BytesIO is not supported for Ibis backend."
|
||||
raise NotImplementedError(msg)
|
||||
self.native.to_parquet(file)
|
||||
|
||||
gather_every = not_implemented.deprecated(
|
||||
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
|
||||
)
|
||||
tail = not_implemented.deprecated(
|
||||
"`LazyFrame.tail` is deprecated and will be removed in a future version."
|
||||
)
|
||||
|
||||
# Intentionally not implemented, as Ibis does its own expression rewriting.
|
||||
_evaluate_window_expr = not_implemented()
|
||||
347
lib/python3.11/site-packages/narwhals/_ibis/expr.py
Normal file
347
lib/python3.11/site-packages/narwhals/_ibis/expr.py
Normal file
@ -0,0 +1,347 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar, cast
|
||||
|
||||
import ibis
|
||||
|
||||
from narwhals._ibis.expr_dt import IbisExprDateTimeNamespace
|
||||
from narwhals._ibis.expr_list import IbisExprListNamespace
|
||||
from narwhals._ibis.expr_str import IbisExprStringNamespace
|
||||
from narwhals._ibis.expr_struct import IbisExprStructNamespace
|
||||
from narwhals._ibis.utils import is_floating, lit, narwhals_to_native_dtype
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
from narwhals._utils import Implementation, Version, not_implemented, zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
import ibis.expr.types as ir
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant import WindowInputs
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
WindowFunction,
|
||||
)
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
from narwhals._ibis.namespace import IbisNamespace
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.typing import IntoDType, RankMethod, RollingInterpolationMethod
|
||||
|
||||
ExprT = TypeVar("ExprT", bound=ir.Value)
|
||||
IbisWindowFunction = WindowFunction[IbisLazyFrame, ir.Value]
|
||||
IbisWindowInputs = WindowInputs[ir.Value]
|
||||
|
||||
|
||||
class IbisExpr(SQLExpr["IbisLazyFrame", "ir.Value"]):
|
||||
_implementation = Implementation.IBIS
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[IbisLazyFrame, ir.Value],
|
||||
window_function: IbisWindowFunction | None = None,
|
||||
*,
|
||||
evaluate_output_names: EvalNames[IbisLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
implementation: Implementation = Implementation.IBIS,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._metadata: ExprMetadata | None = None
|
||||
self._window_function: IbisWindowFunction | None = window_function
|
||||
|
||||
@property
|
||||
def window_function(self) -> IbisWindowFunction:
|
||||
def default_window_func(
|
||||
df: IbisLazyFrame, window_inputs: IbisWindowInputs
|
||||
) -> Sequence[ir.Value]:
|
||||
return [
|
||||
expr.over(
|
||||
ibis.window(
|
||||
group_by=window_inputs.partition_by,
|
||||
order_by=self._sort(*window_inputs.order_by),
|
||||
)
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._window_function or default_window_func
|
||||
|
||||
def _window_expression(
|
||||
self,
|
||||
expr: ir.Value,
|
||||
partition_by: Sequence[str | ir.Value] = (),
|
||||
order_by: Sequence[str | ir.Column] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> ir.Value:
|
||||
if rows_start is not None and rows_end is not None:
|
||||
rows_between = {"preceding": -rows_start, "following": rows_end}
|
||||
elif rows_end is not None:
|
||||
rows_between = {"following": rows_end}
|
||||
elif rows_start is not None: # pragma: no cover
|
||||
rows_between = {"preceding": -rows_start}
|
||||
else:
|
||||
rows_between = {}
|
||||
window = ibis.window(
|
||||
group_by=partition_by,
|
||||
order_by=self._sort(*order_by, descending=descending, nulls_last=nulls_last),
|
||||
**rows_between,
|
||||
)
|
||||
return expr.over(window)
|
||||
|
||||
def __narwhals_namespace__(self) -> IbisNamespace: # pragma: no cover
|
||||
from narwhals._ibis.namespace import IbisNamespace
|
||||
|
||||
return IbisNamespace(version=self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
# Ibis does its own broadcasting.
|
||||
return self
|
||||
|
||||
def _sort(
|
||||
self,
|
||||
*cols: ir.Column | str,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Iterator[ir.Column]:
|
||||
descending = descending or [False] * len(cols)
|
||||
nulls_last = nulls_last or [False] * len(cols)
|
||||
mapping = {
|
||||
(False, False): partial(ibis.asc, nulls_first=True),
|
||||
(False, True): partial(ibis.asc, nulls_first=False),
|
||||
(True, False): partial(ibis.desc, nulls_first=True),
|
||||
(True, True): partial(ibis.desc, nulls_first=False),
|
||||
}
|
||||
yield from (
|
||||
cast("ir.Column", mapping[(_desc, _nulls_last)](col))
|
||||
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[IbisLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
def func(df: IbisLazyFrame) -> Sequence[ir.Column]:
|
||||
return [df.native[name] for name in evaluate_column_names(df)]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: IbisLazyFrame) -> Sequence[ir.Column]:
|
||||
return [df.native[i] for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def _with_binary(self, op: Callable[..., ir.Value], other: Self | Any) -> Self:
|
||||
return self._with_callable(op, other=other)
|
||||
|
||||
def _with_elementwise(
|
||||
self, op: Callable[..., ir.Value], /, **expressifiable_args: Self | Any
|
||||
) -> Self:
|
||||
return self._with_callable(op, **expressifiable_args)
|
||||
|
||||
@classmethod
|
||||
def _alias_native(cls, expr: ExprT, name: str, /) -> ExprT:
|
||||
return cast("ExprT", expr.name(name))
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
invert = cast("Callable[..., ir.Value]", operator.invert)
|
||||
return self._with_callable(invert)
|
||||
|
||||
def all(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.all().fill_null(lit(True)))
|
||||
|
||||
def any(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.any().fill_null(lit(False)))
|
||||
|
||||
def quantile(
|
||||
self, quantile: float, interpolation: RollingInterpolationMethod
|
||||
) -> Self:
|
||||
if interpolation != "linear":
|
||||
msg = "Only linear interpolation methods are supported for Ibis quantile."
|
||||
raise NotImplementedError(msg)
|
||||
return self._with_callable(lambda expr: expr.quantile(quantile))
|
||||
|
||||
def clip(self, lower_bound: Any, upper_bound: Any) -> Self:
|
||||
def _clip(
|
||||
expr: ir.NumericValue, lower: Any | None = None, upper: Any | None = None
|
||||
) -> ir.NumericValue:
|
||||
return expr.clip(lower=lower, upper=upper)
|
||||
|
||||
if lower_bound is None:
|
||||
return self._with_callable(_clip, upper=upper_bound)
|
||||
if upper_bound is None:
|
||||
return self._with_callable(_clip, lower=lower_bound)
|
||||
return self._with_callable(_clip, lower=lower_bound, upper=upper_bound)
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.nunique() + expr.isnull().any().cast("int8")
|
||||
)
|
||||
|
||||
def len(self) -> Self:
|
||||
def func(df: IbisLazyFrame) -> Sequence[ir.IntegerScalar]:
|
||||
return [df.native.count() for _ in self._evaluate_output_names(df)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
def _std(expr: ir.NumericColumn, ddof: int) -> ir.Value:
|
||||
if ddof == 0:
|
||||
return expr.std(how="pop")
|
||||
if ddof == 1:
|
||||
return expr.std(how="sample")
|
||||
n_samples = expr.count()
|
||||
std_pop = expr.std(how="pop")
|
||||
ddof_lit = lit(ddof)
|
||||
return std_pop * n_samples.sqrt() / (n_samples - ddof_lit).sqrt()
|
||||
|
||||
return self._with_callable(lambda expr: _std(expr, ddof))
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
def _var(expr: ir.NumericColumn, ddof: int) -> ir.Value:
|
||||
if ddof == 0:
|
||||
return expr.var(how="pop")
|
||||
if ddof == 1:
|
||||
return expr.var(how="sample")
|
||||
n_samples = expr.count()
|
||||
var_pop = expr.var(how="pop")
|
||||
ddof_lit = lit(ddof)
|
||||
return var_pop * n_samples / (n_samples - ddof_lit)
|
||||
|
||||
return self._with_callable(lambda expr: _var(expr, ddof))
|
||||
|
||||
def null_count(self) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isnull().sum())
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
def func(expr: ir.FloatingValue | Any) -> ir.Value:
|
||||
otherwise = expr.isnan() if is_floating(expr.type()) else False
|
||||
return ibis.ifelse(expr.isnull(), None, otherwise)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
return self._with_callable(lambda expr: ~(expr.isinf() | expr.isnan()))
|
||||
|
||||
def is_in(self, other: Sequence[Any]) -> Self:
|
||||
return self._with_callable(lambda expr: expr.isin(other))
|
||||
|
||||
def fill_null(self, value: Self | Any, strategy: Any, limit: int | None) -> Self:
|
||||
# Ibis doesn't yet allow ignoring nulls in first/last with window functions, which makes forward/backward
|
||||
# strategies inconsistent when there are nulls present: https://github.com/ibis-project/ibis/issues/9539
|
||||
if strategy is not None:
|
||||
msg = "`strategy` is not supported for the Ibis backend"
|
||||
raise NotImplementedError(msg)
|
||||
if limit is not None:
|
||||
msg = "`limit` is not supported for the Ibis backend" # pragma: no cover
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def _fill_null(expr: ir.Value, value: ir.Scalar) -> ir.Value:
|
||||
return expr.fill_null(value)
|
||||
|
||||
return self._with_callable(_fill_null, value=value)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def _func(expr: ir.Column) -> ir.Value:
|
||||
native_dtype = narwhals_to_native_dtype(dtype, self._version)
|
||||
# ibis `cast` overloads do not include DataType, only literals
|
||||
return expr.cast(native_dtype) # type: ignore[unused-ignore]
|
||||
|
||||
return self._with_callable(_func)
|
||||
|
||||
def is_unique(self) -> Self:
|
||||
return self._with_callable(
|
||||
lambda expr: expr.isnull().count().over(ibis.window(group_by=(expr))) == 1
|
||||
)
|
||||
|
||||
def rank(self, method: RankMethod, *, descending: bool) -> Self:
|
||||
def _rank(expr: ir.Column) -> ir.Value:
|
||||
order_by = next(self._sort(expr, descending=[descending], nulls_last=[True]))
|
||||
window = ibis.window(order_by=order_by)
|
||||
|
||||
if method == "dense":
|
||||
rank_ = order_by.dense_rank()
|
||||
elif method == "ordinal":
|
||||
rank_ = ibis.row_number().over(window)
|
||||
else:
|
||||
rank_ = order_by.rank()
|
||||
|
||||
# Ibis uses 0-based ranking. Add 1 to match polars 1-based rank.
|
||||
rank_ = rank_ + lit(1)
|
||||
|
||||
# For "max" and "average", adjust using the count of rows in the partition.
|
||||
if method == "max":
|
||||
# Define a window partitioned by expr (i.e. each distinct value)
|
||||
partition = ibis.window(group_by=[expr])
|
||||
cnt = expr.count().over(partition)
|
||||
rank_ = rank_ + cnt - lit(1)
|
||||
elif method == "average":
|
||||
partition = ibis.window(group_by=[expr])
|
||||
cnt = expr.count().over(partition)
|
||||
avg = cast("ir.NumericValue", (cnt - lit(1)) / lit(2.0))
|
||||
rank_ = rank_ + avg
|
||||
|
||||
return ibis.cases((expr.notnull(), rank_))
|
||||
|
||||
return self._with_callable(_rank)
|
||||
|
||||
@property
|
||||
def str(self) -> IbisExprStringNamespace:
|
||||
return IbisExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> IbisExprDateTimeNamespace:
|
||||
return IbisExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> IbisExprListNamespace:
|
||||
return IbisExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> IbisExprStructNamespace:
|
||||
return IbisExprStructNamespace(self)
|
||||
|
||||
# NOTE: https://github.com/ibis-project/ibis/issues/10542
|
||||
cum_prod = not_implemented()
|
||||
|
||||
# NOTE: https://github.com/ibis-project/ibis/issues/11176
|
||||
skew = not_implemented()
|
||||
kurtosis = not_implemented()
|
||||
|
||||
_count_star = not_implemented()
|
||||
|
||||
# Intentionally not implemented, as Ibis does its own expression rewriting.
|
||||
_push_down_window_function = not_implemented()
|
||||
83
lib/python3.11/site-packages/narwhals/_ibis/expr_dt.py
Normal file
83
lib/python3.11/site-packages/narwhals/_ibis/expr_dt.py
Normal file
@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._ibis.utils import (
|
||||
UNITS_DICT_BUCKET,
|
||||
UNITS_DICT_TRUNCATE,
|
||||
timedelta_to_ibis_interval,
|
||||
)
|
||||
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ibis.expr.types as ir
|
||||
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
from narwhals._ibis.utils import BucketUnit, TruncateUnit
|
||||
|
||||
|
||||
class IbisExprDateTimeNamespace(SQLExprDateTimeNamesSpace["IbisExpr"]):
|
||||
def millisecond(self) -> IbisExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.millisecond())
|
||||
|
||||
def microsecond(self) -> IbisExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.microsecond())
|
||||
|
||||
def to_string(self, format: str) -> IbisExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.strftime(format))
|
||||
|
||||
def weekday(self) -> IbisExpr:
|
||||
# Ibis uses 0-6 for Monday-Sunday. Add 1 to match polars.
|
||||
return self.compliant._with_callable(lambda expr: expr.day_of_week.index() + 1)
|
||||
|
||||
def _bucket(self, kwds: dict[BucketUnit, Any], /) -> Callable[..., ir.TimestampValue]:
|
||||
def fn(expr: ir.TimestampValue) -> ir.TimestampValue:
|
||||
return expr.bucket(**kwds)
|
||||
|
||||
return fn
|
||||
|
||||
def _truncate(self, unit: TruncateUnit, /) -> Callable[..., ir.TimestampValue]:
|
||||
def fn(expr: ir.TimestampValue) -> ir.TimestampValue:
|
||||
return expr.truncate(unit)
|
||||
|
||||
return fn
|
||||
|
||||
def truncate(self, every: str) -> IbisExpr:
|
||||
interval = Interval.parse(every)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if unit == "q":
|
||||
multiple, unit = 3 * multiple, "mo"
|
||||
if multiple != 1:
|
||||
if self.compliant._backend_version < (7, 1): # pragma: no cover
|
||||
msg = "Truncating datetimes with multiples of the unit is only supported in Ibis >= 7.1."
|
||||
raise NotImplementedError(msg)
|
||||
fn = self._bucket({UNITS_DICT_BUCKET[unit]: multiple})
|
||||
else:
|
||||
fn = self._truncate(UNITS_DICT_TRUNCATE[unit])
|
||||
return self.compliant._with_callable(fn)
|
||||
|
||||
def offset_by(self, every: str) -> IbisExpr:
|
||||
interval = Interval.parse_no_constraints(every)
|
||||
unit = interval.unit
|
||||
if unit in {"y", "q", "mo", "d", "ns"}:
|
||||
msg = f"Offsetting by {unit} is not yet supported for ibis."
|
||||
raise NotImplementedError(msg)
|
||||
offset = timedelta_to_ibis_interval(interval.to_timedelta())
|
||||
return self.compliant._with_callable(lambda expr: expr.add(offset))
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> IbisExpr:
|
||||
if time_zone is None:
|
||||
return self.compliant._with_callable(lambda expr: expr.cast("timestamp"))
|
||||
msg = "`replace_time_zone` with non-null `time_zone` not yet implemented for Ibis" # pragma: no cover
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
nanosecond = not_implemented()
|
||||
total_minutes = not_implemented()
|
||||
total_seconds = not_implemented()
|
||||
total_milliseconds = not_implemented()
|
||||
total_microseconds = not_implemented()
|
||||
total_nanoseconds = not_implemented()
|
||||
convert_time_zone = not_implemented()
|
||||
timestamp = not_implemented()
|
||||
29
lib/python3.11/site-packages/narwhals/_ibis/expr_list.py
Normal file
29
lib/python3.11/site-packages/narwhals/_ibis/expr_list.py
Normal file
@ -0,0 +1,29 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import ListNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ibis.expr.types as ir
|
||||
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
class IbisExprListNamespace(LazyExprNamespace["IbisExpr"], ListNamespace["IbisExpr"]):
|
||||
def len(self) -> IbisExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.length())
|
||||
|
||||
def unique(self) -> IbisExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.unique())
|
||||
|
||||
def contains(self, item: NonNestedLiteral) -> IbisExpr:
|
||||
return self.compliant._with_callable(lambda expr: expr.contains(item))
|
||||
|
||||
def get(self, index: int) -> IbisExpr:
|
||||
def _get(expr: ir.ArrayColumn) -> ir.Column:
|
||||
return expr[index]
|
||||
|
||||
return self.compliant._with_callable(_get)
|
||||
83
lib/python3.11/site-packages/narwhals/_ibis/expr_str.py
Normal file
83
lib/python3.11/site-packages/narwhals/_ibis/expr_str.py
Normal file
@ -0,0 +1,83 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable
|
||||
|
||||
from ibis.expr.datatypes import Timestamp
|
||||
|
||||
from narwhals._sql.expr_str import SQLExprStringNamespace
|
||||
from narwhals._utils import _is_naive_format, not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ibis.expr.types as ir
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
|
||||
IntoStringValue: TypeAlias = "str | ir.StringValue"
|
||||
|
||||
|
||||
class IbisExprStringNamespace(SQLExprStringNamespace["IbisExpr"]):
|
||||
def strip_chars(self, characters: str | None) -> IbisExpr:
|
||||
if characters is not None:
|
||||
msg = "Ibis does not support `characters` argument in `str.strip_chars`"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return self.compliant._with_callable(lambda expr: expr.strip())
|
||||
|
||||
def _replace_all(
|
||||
self, pattern: IntoStringValue, value: IntoStringValue
|
||||
) -> Callable[..., ir.StringValue]:
|
||||
def fn(expr: ir.StringColumn) -> ir.StringValue:
|
||||
return expr.re_replace(pattern, value)
|
||||
|
||||
return fn
|
||||
|
||||
def _replace_all_literal(
|
||||
self, pattern: IntoStringValue, value: IntoStringValue
|
||||
) -> Callable[..., ir.StringValue]:
|
||||
def fn(expr: ir.StringColumn) -> ir.StringValue:
|
||||
return expr.replace(pattern, value) # pyright: ignore[reportArgumentType]
|
||||
|
||||
return fn
|
||||
|
||||
def replace_all(
|
||||
self, pattern: str, value: str | IbisExpr, *, literal: bool
|
||||
) -> IbisExpr:
|
||||
fn = self._replace_all_literal if literal else self._replace_all
|
||||
if isinstance(value, str):
|
||||
return self.compliant._with_callable(fn(pattern, value))
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr, value: fn(pattern, value)(expr), value=value
|
||||
)
|
||||
|
||||
def _to_datetime(self, format: str) -> Callable[..., ir.TimestampValue]:
|
||||
def fn(expr: ir.StringColumn) -> ir.TimestampValue:
|
||||
return expr.as_timestamp(format)
|
||||
|
||||
return fn
|
||||
|
||||
def _to_datetime_naive(self, format: str) -> Callable[..., ir.TimestampValue]:
|
||||
def fn(expr: ir.StringColumn) -> ir.TimestampValue:
|
||||
dtype: Any = Timestamp(timezone=None)
|
||||
return expr.as_timestamp(format).cast(dtype)
|
||||
|
||||
return fn
|
||||
|
||||
def to_datetime(self, format: str | None) -> IbisExpr:
|
||||
if format is None:
|
||||
msg = "Cannot infer format with Ibis backend"
|
||||
raise NotImplementedError(msg)
|
||||
fn = self._to_datetime_naive if _is_naive_format(format) else self._to_datetime
|
||||
return self.compliant._with_callable(fn(format))
|
||||
|
||||
def to_date(self, format: str | None) -> IbisExpr:
|
||||
if format is None:
|
||||
msg = "Cannot infer format with Ibis backend"
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def fn(expr: ir.StringColumn) -> ir.DateValue:
|
||||
return expr.as_date(format)
|
||||
|
||||
return self.compliant._with_callable(fn)
|
||||
|
||||
replace = not_implemented()
|
||||
19
lib/python3.11/site-packages/narwhals/_ibis/expr_struct.py
Normal file
19
lib/python3.11/site-packages/narwhals/_ibis/expr_struct.py
Normal file
@ -0,0 +1,19 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StructNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ibis.expr.types as ir
|
||||
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
|
||||
|
||||
class IbisExprStructNamespace(LazyExprNamespace["IbisExpr"], StructNamespace["IbisExpr"]):
|
||||
def field(self, name: str) -> IbisExpr:
|
||||
def func(expr: ir.StructColumn) -> ir.Column:
|
||||
return expr[name]
|
||||
|
||||
return self.compliant._with_callable(func).alias(name)
|
||||
32
lib/python3.11/site-packages/narwhals/_ibis/group_by.py
Normal file
32
lib/python3.11/site-packages/narwhals/_ibis/group_by.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._sql.group_by import SQLGroupBy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
import ibis.expr.types as ir # noqa: F401
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
|
||||
|
||||
class IbisGroupBy(SQLGroupBy["IbisLazyFrame", "IbisExpr", "ir.Value"]):
|
||||
def __init__(
|
||||
self,
|
||||
df: IbisLazyFrame,
|
||||
keys: Sequence[str] | Sequence[IbisExpr],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
|
||||
def agg(self, *exprs: IbisExpr) -> IbisLazyFrame:
|
||||
native = self.compliant.native
|
||||
return self.compliant._with_native(
|
||||
native.group_by(self._keys).aggregate(*self._evaluate_exprs(exprs))
|
||||
).rename(dict(zip(self._keys, self._output_key_names)))
|
||||
160
lib/python3.11/site-packages/narwhals/_ibis/namespace.py
Normal file
160
lib/python3.11/site-packages/narwhals/_ibis/namespace.py
Normal file
@ -0,0 +1,160 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import ibis
|
||||
import ibis.expr.types as ir
|
||||
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
from narwhals._ibis.selectors import IbisSelectorNamespace
|
||||
from narwhals._ibis.utils import function, lit, narwhals_to_native_dtype
|
||||
from narwhals._sql.namespace import SQLNamespace
|
||||
from narwhals._sql.when_then import SQLThen, SQLWhen
|
||||
from narwhals._utils import Implementation, requires
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import ConcatMethod, IntoDType, PythonLiteral
|
||||
|
||||
|
||||
class IbisNamespace(SQLNamespace[IbisLazyFrame, IbisExpr, "ir.Table", "ir.Value"]):
|
||||
_implementation: Implementation = Implementation.IBIS
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def selectors(self) -> IbisSelectorNamespace:
|
||||
return IbisSelectorNamespace.from_namespace(self)
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[IbisExpr]:
|
||||
return IbisExpr
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[IbisLazyFrame]:
|
||||
return IbisLazyFrame
|
||||
|
||||
def _function(self, name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
|
||||
return function(name, *args)
|
||||
|
||||
def _lit(self, value: Any) -> ir.Value:
|
||||
return lit(value)
|
||||
|
||||
def _when(
|
||||
self, condition: ir.Value, value: ir.Value, otherwise: ir.Expr | None = None
|
||||
) -> ir.Value:
|
||||
if otherwise is None:
|
||||
return ibis.cases((condition, value))
|
||||
return ibis.cases((condition, value), else_=otherwise) # pragma: no cover
|
||||
|
||||
def _coalesce(self, *exprs: ir.Value) -> ir.Value:
|
||||
return ibis.coalesce(*exprs)
|
||||
|
||||
def concat(
|
||||
self, items: Iterable[IbisLazyFrame], *, how: ConcatMethod
|
||||
) -> IbisLazyFrame:
|
||||
if how == "diagonal":
|
||||
msg = "diagonal concat not supported for Ibis. Please join instead."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
items = list(items)
|
||||
native_items = [item.native for item in items]
|
||||
schema = items[0].schema
|
||||
if not all(x.schema == schema for x in items[1:]):
|
||||
msg = "inputs should all have the same schema"
|
||||
raise TypeError(msg)
|
||||
return self._lazyframe.from_native(ibis.union(*native_items), context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: IbisExpr, separator: str, ignore_nulls: bool
|
||||
) -> IbisExpr:
|
||||
def func(df: IbisLazyFrame) -> list[ir.Value]:
|
||||
cols = list(chain.from_iterable(expr(df) for expr in exprs))
|
||||
cols_casted = [s.cast("string") for s in cols]
|
||||
|
||||
if not ignore_nulls:
|
||||
result = cols_casted[0]
|
||||
for col in cols_casted[1:]:
|
||||
result = result + separator + col
|
||||
else:
|
||||
result = lit(separator).join(cols_casted)
|
||||
|
||||
return [result]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: IbisExpr) -> IbisExpr:
|
||||
def func(cols: Iterable[ir.Value]) -> ir.Value:
|
||||
cols = list(cols)
|
||||
return reduce(operator.add, (col.fill_null(lit(0)) for col in cols)) / reduce(
|
||||
operator.add, (col.isnull().ifelse(lit(0), lit(1)) for col in cols)
|
||||
)
|
||||
|
||||
return self._expr._from_elementwise_horizontal_op(func, *exprs)
|
||||
|
||||
@requires.backend_version((10, 0))
|
||||
def when(self, predicate: IbisExpr) -> IbisWhen:
|
||||
return IbisWhen.from_expr(predicate, context=self)
|
||||
|
||||
def lit(self, value: Any, dtype: IntoDType | None) -> IbisExpr:
|
||||
def func(_df: IbisLazyFrame) -> Sequence[ir.Value]:
|
||||
ibis_dtype = narwhals_to_native_dtype(dtype, self._version) if dtype else None
|
||||
return [lit(value, ibis_dtype)]
|
||||
|
||||
return self._expr(
|
||||
func,
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def len(self) -> IbisExpr:
|
||||
def func(_df: IbisLazyFrame) -> list[ir.Value]:
|
||||
return [_df.native.count()]
|
||||
|
||||
return self._expr(
|
||||
call=func,
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
|
||||
class IbisWhen(SQLWhen["IbisLazyFrame", "ir.Value", IbisExpr]):
|
||||
lit = lit
|
||||
|
||||
@property
|
||||
def _then(self) -> type[IbisThen]:
|
||||
return IbisThen
|
||||
|
||||
def __call__(self, df: IbisLazyFrame) -> Sequence[ir.Value]:
|
||||
is_expr = self._condition._is_expr
|
||||
condition = df._evaluate_expr(self._condition)
|
||||
then_ = self._then_value
|
||||
then = df._evaluate_expr(then_) if is_expr(then_) else lit(then_)
|
||||
other_ = self._otherwise_value
|
||||
if other_ is None:
|
||||
result = ibis.cases((condition, then))
|
||||
else:
|
||||
otherwise = df._evaluate_expr(other_) if is_expr(other_) else lit(other_)
|
||||
result = ibis.cases((condition, then), else_=otherwise)
|
||||
return [result]
|
||||
|
||||
|
||||
class IbisThen(SQLThen["IbisLazyFrame", "ir.Value", IbisExpr], IbisExpr): ...
|
||||
32
lib/python3.11/site-packages/narwhals/_ibis/selectors.py
Normal file
32
lib/python3.11/site-packages/narwhals/_ibis/selectors.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, LazySelectorNamespace
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import ibis.expr.types as ir # noqa: F401
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame # noqa: F401
|
||||
from narwhals._ibis.expr import IbisWindowFunction
|
||||
|
||||
|
||||
class IbisSelectorNamespace(LazySelectorNamespace["IbisLazyFrame", "ir.Value"]):
|
||||
@property
|
||||
def _selector(self) -> type[IbisSelector]:
|
||||
return IbisSelector
|
||||
|
||||
|
||||
class IbisSelector( # type: ignore[misc]
|
||||
CompliantSelector["IbisLazyFrame", "ir.Value"], IbisExpr
|
||||
):
|
||||
_window_function: IbisWindowFunction | None = None
|
||||
|
||||
def _to_expr(self) -> IbisExpr:
|
||||
return IbisExpr(
|
||||
self._call,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
)
|
||||
41
lib/python3.11/site-packages/narwhals/_ibis/series.py
Normal file
41
lib/python3.11/site-packages/narwhals/_ibis/series.py
Normal file
@ -0,0 +1,41 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, NoReturn
|
||||
|
||||
from narwhals._ibis.utils import native_to_narwhals_dtype
|
||||
from narwhals.dependencies import get_ibis
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from types import ModuleType
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._utils import Version
|
||||
from narwhals.dtypes import DType
|
||||
|
||||
|
||||
class IbisInterchangeSeries:
|
||||
def __init__(self, df: Any, version: Version) -> None:
|
||||
self._native_series = df
|
||||
self._version = version
|
||||
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
return get_ibis()
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return native_to_narwhals_dtype(
|
||||
self._native_series.schema().types[0], self._version
|
||||
)
|
||||
|
||||
def __getattr__(self, attr: str) -> NoReturn:
|
||||
msg = (
|
||||
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
|
||||
"If you would like to see this kind of object better supported in "
|
||||
"Narwhals, please open a feature request "
|
||||
"at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
270
lib/python3.11/site-packages/narwhals/_ibis/utils.py
Normal file
270
lib/python3.11/site-packages/narwhals/_ibis/utils.py
Normal file
@ -0,0 +1,270 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
import ibis
|
||||
import ibis.expr.datatypes as ibis_dtypes
|
||||
|
||||
from narwhals._utils import Version, isinstance_or_issubclass
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping
|
||||
from datetime import timedelta
|
||||
|
||||
import ibis.expr.types as ir
|
||||
from ibis.common.temporal import TimestampUnit
|
||||
from ibis.expr.datatypes import DataType as IbisDataType
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
from narwhals._duration import IntervalUnit
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
from narwhals._ibis.expr import IbisExpr
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType, PythonLiteral
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
"""Marker for upstream issues."""
|
||||
|
||||
|
||||
@overload
|
||||
def lit(value: bool, dtype: None = ...) -> ir.BooleanScalar: ... # noqa: FBT001
|
||||
@overload
|
||||
def lit(value: int, dtype: None = ...) -> ir.IntegerScalar: ...
|
||||
@overload
|
||||
def lit(value: float, dtype: None = ...) -> ir.FloatingScalar: ...
|
||||
@overload
|
||||
def lit(value: str, dtype: None = ...) -> ir.StringScalar: ...
|
||||
@overload
|
||||
def lit(value: PythonLiteral | ir.Value, dtype: None = ...) -> ir.Scalar: ...
|
||||
@overload
|
||||
def lit(value: Any, dtype: Any) -> Incomplete: ...
|
||||
def lit(value: Any, dtype: Any | None = None) -> Incomplete:
|
||||
"""Alias for `ibis.literal`."""
|
||||
literal: Incomplete = ibis.literal
|
||||
return literal(value, dtype)
|
||||
|
||||
|
||||
BucketUnit: TypeAlias = Literal[
|
||||
"years",
|
||||
"quarters",
|
||||
"months",
|
||||
"days",
|
||||
"hours",
|
||||
"minutes",
|
||||
"seconds",
|
||||
"milliseconds",
|
||||
"microseconds",
|
||||
"nanoseconds",
|
||||
]
|
||||
TruncateUnit: TypeAlias = Literal[
|
||||
"Y", "Q", "M", "W", "D", "h", "m", "s", "ms", "us", "ns"
|
||||
]
|
||||
|
||||
UNITS_DICT_BUCKET: Mapping[IntervalUnit, BucketUnit] = {
|
||||
"y": "years",
|
||||
"q": "quarters",
|
||||
"mo": "months",
|
||||
"d": "days",
|
||||
"h": "hours",
|
||||
"m": "minutes",
|
||||
"s": "seconds",
|
||||
"ms": "milliseconds",
|
||||
"us": "microseconds",
|
||||
"ns": "nanoseconds",
|
||||
}
|
||||
|
||||
UNITS_DICT_TRUNCATE: Mapping[IntervalUnit, TruncateUnit] = {
|
||||
"y": "Y",
|
||||
"q": "Q",
|
||||
"mo": "M",
|
||||
"d": "D",
|
||||
"h": "h",
|
||||
"m": "m",
|
||||
"s": "s",
|
||||
"ms": "ms",
|
||||
"us": "us",
|
||||
"ns": "ns",
|
||||
}
|
||||
|
||||
FUNCTION_REMAPPING = {
|
||||
"starts_with": "startswith",
|
||||
"ends_with": "endswith",
|
||||
"regexp_matches": "re_search",
|
||||
"str_split": "split",
|
||||
"dayofyear": "day_of_year",
|
||||
"to_date": "date",
|
||||
}
|
||||
|
||||
|
||||
def evaluate_exprs(df: IbisLazyFrame, /, *exprs: IbisExpr) -> list[tuple[str, ir.Value]]:
|
||||
native_results: list[tuple[str, ir.Value]] = []
|
||||
for expr in exprs:
|
||||
native_series_list = expr(df)
|
||||
output_names = expr._evaluate_output_names(df)
|
||||
if expr._alias_output_names is not None:
|
||||
output_names = expr._alias_output_names(output_names)
|
||||
if len(output_names) != len(native_series_list): # pragma: no cover
|
||||
msg = f"Internal error: got output names {output_names}, but only got {len(native_series_list)} results"
|
||||
raise AssertionError(msg)
|
||||
native_results.extend(zip(output_names, native_series_list))
|
||||
return native_results
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype(ibis_dtype: IbisDataType, version: Version) -> DType: # noqa: C901, PLR0912
|
||||
dtypes = version.dtypes
|
||||
if ibis_dtype.is_int64():
|
||||
return dtypes.Int64()
|
||||
if ibis_dtype.is_int32():
|
||||
return dtypes.Int32()
|
||||
if ibis_dtype.is_int16():
|
||||
return dtypes.Int16()
|
||||
if ibis_dtype.is_int8():
|
||||
return dtypes.Int8()
|
||||
if ibis_dtype.is_uint64():
|
||||
return dtypes.UInt64()
|
||||
if ibis_dtype.is_uint32():
|
||||
return dtypes.UInt32()
|
||||
if ibis_dtype.is_uint16():
|
||||
return dtypes.UInt16()
|
||||
if ibis_dtype.is_uint8():
|
||||
return dtypes.UInt8()
|
||||
if ibis_dtype.is_boolean():
|
||||
return dtypes.Boolean()
|
||||
if ibis_dtype.is_float64():
|
||||
return dtypes.Float64()
|
||||
if ibis_dtype.is_float32():
|
||||
return dtypes.Float32()
|
||||
if ibis_dtype.is_string():
|
||||
return dtypes.String()
|
||||
if ibis_dtype.is_date():
|
||||
return dtypes.Date()
|
||||
if is_timestamp(ibis_dtype):
|
||||
_unit = cast("TimestampUnit", ibis_dtype.unit)
|
||||
return dtypes.Datetime(time_unit=_unit.value, time_zone=ibis_dtype.timezone)
|
||||
if is_interval(ibis_dtype):
|
||||
_time_unit = ibis_dtype.unit.value
|
||||
if _time_unit not in {"ns", "us", "ms", "s"}: # pragma: no cover
|
||||
msg = f"Unsupported interval unit: {_time_unit}"
|
||||
raise NotImplementedError(msg)
|
||||
return dtypes.Duration(_time_unit)
|
||||
if is_array(ibis_dtype):
|
||||
if ibis_dtype.length:
|
||||
return dtypes.Array(
|
||||
native_to_narwhals_dtype(ibis_dtype.value_type, version),
|
||||
ibis_dtype.length,
|
||||
)
|
||||
return dtypes.List(native_to_narwhals_dtype(ibis_dtype.value_type, version))
|
||||
if is_struct(ibis_dtype):
|
||||
return dtypes.Struct(
|
||||
[
|
||||
dtypes.Field(name, native_to_narwhals_dtype(dtype, version))
|
||||
for name, dtype in ibis_dtype.items()
|
||||
]
|
||||
)
|
||||
if ibis_dtype.is_decimal(): # pragma: no cover
|
||||
return dtypes.Decimal()
|
||||
if ibis_dtype.is_time():
|
||||
return dtypes.Time()
|
||||
if ibis_dtype.is_binary():
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def is_timestamp(obj: IbisDataType) -> TypeIs[ibis_dtypes.Timestamp]:
|
||||
return obj.is_timestamp()
|
||||
|
||||
|
||||
def is_interval(obj: IbisDataType) -> TypeIs[ibis_dtypes.Interval]:
|
||||
return obj.is_interval()
|
||||
|
||||
|
||||
def is_array(obj: IbisDataType) -> TypeIs[ibis_dtypes.Array[Any]]:
|
||||
return obj.is_array()
|
||||
|
||||
|
||||
def is_struct(obj: IbisDataType) -> TypeIs[ibis_dtypes.Struct]:
|
||||
return obj.is_struct()
|
||||
|
||||
|
||||
def is_floating(obj: IbisDataType) -> TypeIs[ibis_dtypes.Floating]:
|
||||
return obj.is_floating()
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_IBIS_DTYPES: Mapping[type[DType], IbisDataType] = {
|
||||
dtypes.Float64: ibis_dtypes.Float64(),
|
||||
dtypes.Float32: ibis_dtypes.Float32(),
|
||||
dtypes.Binary: ibis_dtypes.Binary(),
|
||||
dtypes.String: ibis_dtypes.String(),
|
||||
dtypes.Boolean: ibis_dtypes.Boolean(),
|
||||
dtypes.Date: ibis_dtypes.Date(),
|
||||
dtypes.Time: ibis_dtypes.Time(),
|
||||
dtypes.Int8: ibis_dtypes.Int8(),
|
||||
dtypes.Int16: ibis_dtypes.Int16(),
|
||||
dtypes.Int32: ibis_dtypes.Int32(),
|
||||
dtypes.Int64: ibis_dtypes.Int64(),
|
||||
dtypes.UInt8: ibis_dtypes.UInt8(),
|
||||
dtypes.UInt16: ibis_dtypes.UInt16(),
|
||||
dtypes.UInt32: ibis_dtypes.UInt32(),
|
||||
dtypes.UInt64: ibis_dtypes.UInt64(),
|
||||
dtypes.Decimal: ibis_dtypes.Decimal(),
|
||||
}
|
||||
# Enum support: https://github.com/ibis-project/ibis/issues/10991
|
||||
UNSUPPORTED_DTYPES = (dtypes.Int128, dtypes.UInt128, dtypes.Categorical, dtypes.Enum)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype(dtype: IntoDType, version: Version) -> IbisDataType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if ibis_type := NW_TO_IBIS_DTYPES.get(base_type):
|
||||
return ibis_type
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
return ibis_dtypes.Timestamp.from_unit(dtype.time_unit, timezone=dtype.time_zone)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return ibis_dtypes.Interval(unit=dtype.time_unit) # pyright: ignore[reportArgumentType]
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
inner = narwhals_to_native_dtype(dtype.inner, version)
|
||||
return ibis_dtypes.Array(value_type=inner)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
fields = [
|
||||
(field.name, narwhals_to_native_dtype(field.dtype, version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
return ibis_dtypes.Struct.from_tuples(fields)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array):
|
||||
inner = narwhals_to_native_dtype(dtype.inner, version)
|
||||
return ibis_dtypes.Array(value_type=inner, length=dtype.size)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Ibis."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def timedelta_to_ibis_interval(td: timedelta) -> ibis.expr.types.temporal.IntervalScalar:
|
||||
return ibis.interval(days=td.days, seconds=td.seconds, microseconds=td.microseconds)
|
||||
|
||||
|
||||
def function(name: str, *args: ir.Value | PythonLiteral) -> ir.Value:
|
||||
# Workaround SQL vs Ibis differences.
|
||||
if name == "row_number":
|
||||
return ibis.row_number() + lit(1)
|
||||
if name == "least":
|
||||
return ibis.least(*args)
|
||||
if name == "greatest":
|
||||
return ibis.greatest(*args)
|
||||
expr = args[0]
|
||||
if name == "var_pop":
|
||||
return cast("ir.NumericColumn", expr).var(how="pop")
|
||||
if name == "var_samp":
|
||||
return cast("ir.NumericColumn", expr).var(how="sample")
|
||||
if name == "stddev_pop":
|
||||
return cast("ir.NumericColumn", expr).std(how="pop")
|
||||
if name == "stddev_samp":
|
||||
return cast("ir.NumericColumn", expr).std(how="sample")
|
||||
if name == "substr":
|
||||
# Ibis is 0-indexed here, SQL is 1-indexed
|
||||
return cast("ir.StringColumn", expr).substr(args[1] - 1, *args[2:]) # type: ignore[operator] # pyright: ignore[reportArgumentType]
|
||||
return getattr(expr, FUNCTION_REMAPPING.get(name, name))(*args[1:])
|
||||
159
lib/python3.11/site-packages/narwhals/_interchange/dataframe.py
Normal file
159
lib/python3.11/site-packages/narwhals/_interchange/dataframe.py
Normal file
@ -0,0 +1,159 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import enum
|
||||
from typing import TYPE_CHECKING, Any, NoReturn
|
||||
|
||||
from narwhals._utils import Version, parse_version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeIs
|
||||
|
||||
from narwhals._interchange.series import InterchangeSeries
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.stable.v1.typing import DataFrameLike
|
||||
|
||||
|
||||
class DtypeKind(enum.IntEnum):
|
||||
# https://data-apis.org/dataframe-protocol/latest/API.html
|
||||
INT = 0
|
||||
UINT = 1
|
||||
FLOAT = 2
|
||||
BOOL = 20
|
||||
STRING = 21 # UTF-8
|
||||
DATETIME = 22
|
||||
CATEGORICAL = 23
|
||||
|
||||
|
||||
def map_interchange_dtype_to_narwhals_dtype( # noqa: C901, PLR0911, PLR0912
|
||||
interchange_dtype: tuple[DtypeKind, int, Any, Any],
|
||||
) -> DType:
|
||||
dtypes = Version.V1.dtypes
|
||||
if interchange_dtype[0] == DtypeKind.INT:
|
||||
if interchange_dtype[1] == 64:
|
||||
return dtypes.Int64()
|
||||
if interchange_dtype[1] == 32:
|
||||
return dtypes.Int32()
|
||||
if interchange_dtype[1] == 16:
|
||||
return dtypes.Int16()
|
||||
if interchange_dtype[1] == 8:
|
||||
return dtypes.Int8()
|
||||
msg = "Invalid bit width for INT" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
if interchange_dtype[0] == DtypeKind.UINT:
|
||||
if interchange_dtype[1] == 64:
|
||||
return dtypes.UInt64()
|
||||
if interchange_dtype[1] == 32:
|
||||
return dtypes.UInt32()
|
||||
if interchange_dtype[1] == 16:
|
||||
return dtypes.UInt16()
|
||||
if interchange_dtype[1] == 8:
|
||||
return dtypes.UInt8()
|
||||
msg = "Invalid bit width for UINT" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
if interchange_dtype[0] == DtypeKind.FLOAT:
|
||||
if interchange_dtype[1] == 64:
|
||||
return dtypes.Float64()
|
||||
if interchange_dtype[1] == 32:
|
||||
return dtypes.Float32()
|
||||
msg = "Invalid bit width for FLOAT" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
if interchange_dtype[0] == DtypeKind.BOOL:
|
||||
return dtypes.Boolean()
|
||||
if interchange_dtype[0] == DtypeKind.STRING:
|
||||
return dtypes.String()
|
||||
if interchange_dtype[0] == DtypeKind.DATETIME:
|
||||
return dtypes.Datetime()
|
||||
if interchange_dtype[0] == DtypeKind.CATEGORICAL: # pragma: no cover
|
||||
# upstream issue: https://github.com/ibis-project/ibis/issues/9570
|
||||
return dtypes.Categorical()
|
||||
msg = f"Invalid dtype, got: {interchange_dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
class InterchangeFrame:
|
||||
_version = Version.V1
|
||||
|
||||
def __init__(self, df: DataFrameLike) -> None:
|
||||
self._interchange_frame = df.__dataframe__()
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> NoReturn:
|
||||
msg = (
|
||||
"Cannot access native namespace for interchange-level dataframes with unknown backend."
|
||||
"If you would like to see this kind of object supported in Narwhals, please "
|
||||
"open a feature request at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def get_column(self, name: str) -> InterchangeSeries:
|
||||
from narwhals._interchange.series import InterchangeSeries
|
||||
|
||||
return InterchangeSeries(self._interchange_frame.get_column_by_name(name))
|
||||
|
||||
def to_pandas(self) -> pd.DataFrame:
|
||||
import pandas as pd # ignore-banned-import()
|
||||
|
||||
if parse_version(pd) < (1, 5, 0): # pragma: no cover
|
||||
msg = (
|
||||
"Conversion to pandas is achieved via interchange protocol which requires"
|
||||
f" 'pandas>=1.5.0' to be installed, found {pd.__version__}"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return pd.api.interchange.from_dataframe(self._interchange_frame)
|
||||
|
||||
def to_arrow(self) -> pa.Table:
|
||||
from pyarrow.interchange.from_dataframe import ( # ignore-banned-import()
|
||||
from_dataframe,
|
||||
)
|
||||
|
||||
return from_dataframe(self._interchange_frame)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
return {
|
||||
column_name: map_interchange_dtype_to_narwhals_dtype(
|
||||
self._interchange_frame.get_column_by_name(column_name).dtype
|
||||
)
|
||||
for column_name in self._interchange_frame.column_names()
|
||||
}
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
return list(self._interchange_frame.column_names())
|
||||
|
||||
def __getattr__(self, attr: str) -> NoReturn:
|
||||
msg = (
|
||||
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
|
||||
"Hint: you probably called `nw.from_native` on an object which isn't fully "
|
||||
"supported by Narwhals, yet implements `__dataframe__`. If you would like to "
|
||||
"see this kind of object supported in Narwhals, please open a feature request "
|
||||
"at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
frame = self._interchange_frame.select_columns_by_name(list(column_names))
|
||||
if not hasattr(frame, "_df"): # pragma: no cover
|
||||
msg = (
|
||||
"Expected interchange object to implement `_df` property to allow for recovering original object.\n"
|
||||
"See https://github.com/data-apis/dataframe-api/issues/360."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return self.__class__(frame._df)
|
||||
|
||||
def select(self, *exprs: str) -> Self: # pragma: no cover
|
||||
msg = (
|
||||
"`select`-ing not by name is not supported for interchange-only level.\n\n"
|
||||
"If you would like to see this kind of object better supported in "
|
||||
"Narwhals, please open a feature request "
|
||||
"at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def supports_dataframe_interchange(obj: Any) -> TypeIs[DataFrameLike]:
|
||||
return hasattr(obj, "__dataframe__")
|
||||
47
lib/python3.11/site-packages/narwhals/_interchange/series.py
Normal file
47
lib/python3.11/site-packages/narwhals/_interchange/series.py
Normal file
@ -0,0 +1,47 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, NoReturn
|
||||
|
||||
from narwhals._interchange.dataframe import map_interchange_dtype_to_narwhals_dtype
|
||||
from narwhals._utils import Version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals.dtypes import DType
|
||||
|
||||
|
||||
class InterchangeSeries:
|
||||
_version = Version.V1
|
||||
|
||||
def __init__(self, df: Any) -> None:
|
||||
self._native_series = df
|
||||
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> NoReturn:
|
||||
msg = (
|
||||
"Cannot access native namespace for interchange-level series with unknown backend. "
|
||||
"If you would like to see this kind of object supported in Narwhals, please "
|
||||
"open a feature request at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return map_interchange_dtype_to_narwhals_dtype(self._native_series.dtype)
|
||||
|
||||
@property
|
||||
def native(self) -> Any:
|
||||
return self._native_series
|
||||
|
||||
def __getattr__(self, attr: str) -> NoReturn:
|
||||
msg = ( # pragma: no cover
|
||||
f"Attribute {attr} is not supported for interchange-level dataframes.\n\n"
|
||||
"Hint: you probably called `nw.from_native` on an object which isn't fully "
|
||||
"supported by Narwhals, yet implements `__dataframe__`. If you would like to "
|
||||
"see this kind of object supported in Narwhals, please open a feature request "
|
||||
"at https://github.com/narwhals-dev/narwhals/issues."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
409
lib/python3.11/site-packages/narwhals/_namespace.py
Normal file
409
lib/python3.11/site-packages/narwhals/_namespace.py
Normal file
@ -0,0 +1,409 @@
|
||||
"""Narwhals-level equivalent of `CompliantNamespace`."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
Callable,
|
||||
Generic,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypeVar,
|
||||
overload,
|
||||
)
|
||||
|
||||
from narwhals._compliant.typing import CompliantNamespaceAny, CompliantNamespaceT_co
|
||||
from narwhals._utils import Implementation, Version
|
||||
from narwhals.dependencies import (
|
||||
get_cudf,
|
||||
get_modin,
|
||||
get_pandas,
|
||||
get_polars,
|
||||
get_pyarrow,
|
||||
is_dask_dataframe,
|
||||
is_duckdb_relation,
|
||||
is_ibis_table,
|
||||
is_pyspark_connect_dataframe,
|
||||
is_pyspark_dataframe,
|
||||
is_sqlframe_dataframe,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Collection, Sized
|
||||
from typing import ClassVar
|
||||
|
||||
import duckdb
|
||||
import pandas as pd
|
||||
import polars as pl
|
||||
import pyarrow as pa
|
||||
import pyspark.sql as pyspark_sql
|
||||
from pyspark.sql.connect.dataframe import DataFrame as PySparkConnectDataFrame
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
from narwhals._ibis.namespace import IbisNamespace
|
||||
from narwhals._pandas_like.namespace import PandasLikeNamespace
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._spark_like.dataframe import SQLFrameDataFrame
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
from narwhals._typing import (
|
||||
Arrow,
|
||||
Backend,
|
||||
Dask,
|
||||
DuckDB,
|
||||
EagerAllowed,
|
||||
Ibis,
|
||||
IntoBackend,
|
||||
PandasLike,
|
||||
Polars,
|
||||
SparkLike,
|
||||
)
|
||||
from narwhals.typing import NativeDataFrame, NativeLazyFrame, NativeSeries
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
_Guard: TypeAlias = "Callable[[Any], TypeIs[T]]"
|
||||
|
||||
EagerAllowedNamespace: TypeAlias = "Namespace[PandasLikeNamespace] | Namespace[ArrowNamespace] | Namespace[PolarsNamespace]"
|
||||
|
||||
class _BasePandasLike(Sized, Protocol):
|
||||
index: Any
|
||||
"""`mypy` doesn't like the asymmetric `property` setter in `pandas`."""
|
||||
|
||||
def __getitem__(self, key: Any, /) -> Any: ...
|
||||
def __mul__(self, other: float | Collection[float] | Self) -> Self: ...
|
||||
def __floordiv__(self, other: float | Collection[float] | Self) -> Self: ...
|
||||
@property
|
||||
def loc(self) -> Any: ...
|
||||
@property
|
||||
def shape(self) -> tuple[int, ...]: ...
|
||||
def set_axis(self, labels: Any, *, axis: Any = ..., copy: bool = ...) -> Self: ...
|
||||
def copy(self, deep: bool = ...) -> Self: ... # noqa: FBT001
|
||||
def rename(self, *args: Any, inplace: Literal[False], **kwds: Any) -> Self:
|
||||
"""`inplace=False` is required to avoid (incorrect?) default overloads."""
|
||||
...
|
||||
|
||||
class _BasePandasLikeFrame(NativeDataFrame, _BasePandasLike, Protocol): ...
|
||||
|
||||
class _BasePandasLikeSeries(NativeSeries, _BasePandasLike, Protocol):
|
||||
def where(self, cond: Any, other: Any = ..., **kwds: Any) -> Any: ...
|
||||
|
||||
class _NativeDask(Protocol):
|
||||
_partition_type: type[pd.DataFrame]
|
||||
|
||||
class _CuDFDataFrame(_BasePandasLikeFrame, Protocol):
|
||||
def to_pylibcudf(self, *args: Any, **kwds: Any) -> Any: ...
|
||||
|
||||
class _CuDFSeries(_BasePandasLikeSeries, Protocol):
|
||||
def to_pylibcudf(self, *args: Any, **kwds: Any) -> Any: ...
|
||||
|
||||
class _NativeIbis(Protocol):
|
||||
def sql(self, *args: Any, **kwds: Any) -> Any: ...
|
||||
def __pyarrow_result__(self, *args: Any, **kwds: Any) -> Any: ...
|
||||
def __pandas_result__(self, *args: Any, **kwds: Any) -> Any: ...
|
||||
def __polars_result__(self, *args: Any, **kwds: Any) -> Any: ...
|
||||
|
||||
class _ModinDataFrame(_BasePandasLikeFrame, Protocol):
|
||||
_pandas_class: type[pd.DataFrame]
|
||||
|
||||
class _ModinSeries(_BasePandasLikeSeries, Protocol):
|
||||
_pandas_class: type[pd.Series[Any]]
|
||||
|
||||
_NativePolars: TypeAlias = "pl.DataFrame | pl.LazyFrame | pl.Series"
|
||||
_NativeArrow: TypeAlias = "pa.Table | pa.ChunkedArray[Any]"
|
||||
_NativeDuckDB: TypeAlias = "duckdb.DuckDBPyRelation"
|
||||
_NativePandas: TypeAlias = "pd.DataFrame | pd.Series[Any]"
|
||||
_NativeModin: TypeAlias = "_ModinDataFrame | _ModinSeries"
|
||||
_NativeCuDF: TypeAlias = "_CuDFDataFrame | _CuDFSeries"
|
||||
_NativePandasLikeSeries: TypeAlias = "pd.Series[Any] | _CuDFSeries | _ModinSeries"
|
||||
_NativePandasLikeDataFrame: TypeAlias = (
|
||||
"pd.DataFrame | _CuDFDataFrame | _ModinDataFrame"
|
||||
)
|
||||
_NativePandasLike: TypeAlias = "_NativePandasLikeDataFrame |_NativePandasLikeSeries"
|
||||
_NativeSQLFrame: TypeAlias = "SQLFrameDataFrame"
|
||||
_NativePySpark: TypeAlias = "pyspark_sql.DataFrame"
|
||||
_NativePySparkConnect: TypeAlias = "PySparkConnectDataFrame"
|
||||
_NativeSparkLike: TypeAlias = (
|
||||
"_NativeSQLFrame | _NativePySpark | _NativePySparkConnect"
|
||||
)
|
||||
|
||||
NativeKnown: TypeAlias = "_NativePolars | _NativeArrow | _NativePandasLike | _NativeSparkLike | _NativeDuckDB | _NativeDask | _NativeIbis"
|
||||
NativeUnknown: TypeAlias = "NativeDataFrame | NativeSeries | NativeLazyFrame"
|
||||
NativeAny: TypeAlias = "NativeKnown | NativeUnknown"
|
||||
|
||||
__all__ = ["Namespace"]
|
||||
|
||||
|
||||
class Namespace(Generic[CompliantNamespaceT_co]):
|
||||
_compliant_namespace: CompliantNamespaceT_co
|
||||
_version: ClassVar[Version] = Version.MAIN
|
||||
|
||||
def __init__(self, namespace: CompliantNamespaceT_co, /) -> None:
|
||||
self._compliant_namespace = namespace
|
||||
|
||||
def __init_subclass__(cls, *args: Any, version: Version, **kwds: Any) -> None:
|
||||
super().__init_subclass__(*args, **kwds)
|
||||
|
||||
if isinstance(version, Version):
|
||||
cls._version = version
|
||||
else:
|
||||
msg = f"Expected {Version} but got {type(version).__name__!r}"
|
||||
raise TypeError(msg)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Namespace[{type(self.compliant).__name__}]"
|
||||
|
||||
@property
|
||||
def compliant(self) -> CompliantNamespaceT_co:
|
||||
return self._compliant_namespace
|
||||
|
||||
@property
|
||||
def implementation(self) -> Implementation:
|
||||
return self.compliant._implementation
|
||||
|
||||
@property
|
||||
def version(self) -> Version:
|
||||
return self._version
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: PandasLike, /) -> Namespace[PandasLikeNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: Polars, /) -> Namespace[PolarsNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: Arrow, /) -> Namespace[ArrowNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: SparkLike, /) -> Namespace[SparkLikeNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: DuckDB, /) -> Namespace[DuckDBNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: Dask, /) -> Namespace[DaskNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: Ibis, /) -> Namespace[IbisNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(cls, backend: EagerAllowed, /) -> EagerAllowedNamespace: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_backend(
|
||||
cls, backend: IntoBackend[Backend], /
|
||||
) -> Namespace[CompliantNamespaceAny]: ...
|
||||
|
||||
@classmethod
|
||||
def from_backend(
|
||||
cls: type[Namespace[Any]], backend: IntoBackend[Backend], /
|
||||
) -> Namespace[Any]:
|
||||
"""Instantiate from native namespace module, string, or Implementation.
|
||||
|
||||
Arguments:
|
||||
backend: native namespace module, string, or Implementation.
|
||||
|
||||
Examples:
|
||||
>>> from narwhals._namespace import Namespace
|
||||
>>> Namespace.from_backend("polars")
|
||||
Namespace[PolarsNamespace]
|
||||
"""
|
||||
impl = Implementation.from_backend(backend)
|
||||
backend_version = impl._backend_version() # noqa: F841
|
||||
version = cls._version
|
||||
ns: CompliantNamespaceAny
|
||||
if impl.is_pandas_like():
|
||||
from narwhals._pandas_like.namespace import PandasLikeNamespace
|
||||
|
||||
ns = PandasLikeNamespace(implementation=impl, version=version)
|
||||
|
||||
elif impl.is_polars():
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
|
||||
ns = PolarsNamespace(version=version)
|
||||
elif impl.is_pyarrow():
|
||||
from narwhals._arrow.namespace import ArrowNamespace
|
||||
|
||||
ns = ArrowNamespace(version=version)
|
||||
elif impl.is_spark_like():
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
|
||||
ns = SparkLikeNamespace(implementation=impl, version=version)
|
||||
elif impl.is_duckdb():
|
||||
from narwhals._duckdb.namespace import DuckDBNamespace
|
||||
|
||||
ns = DuckDBNamespace(version=version)
|
||||
elif impl.is_dask():
|
||||
from narwhals._dask.namespace import DaskNamespace
|
||||
|
||||
ns = DaskNamespace(version=version)
|
||||
elif impl.is_ibis():
|
||||
from narwhals._ibis.namespace import IbisNamespace
|
||||
|
||||
ns = IbisNamespace(version=version)
|
||||
else:
|
||||
msg = "Not supported Implementation" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
return cls(ns)
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: _NativePolars, /
|
||||
) -> Namespace[PolarsNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: _NativePandas, /
|
||||
) -> Namespace[PandasLikeNamespace[pd.DataFrame, pd.Series[Any]]]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(cls, native: _NativeArrow, /) -> Namespace[ArrowNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: _NativeSparkLike, /
|
||||
) -> Namespace[SparkLikeNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: _NativeDuckDB, /
|
||||
) -> Namespace[DuckDBNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(cls, native: _NativeDask, /) -> Namespace[DaskNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(cls, native: _NativeIbis, /) -> Namespace[IbisNamespace]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: _NativeModin, /
|
||||
) -> Namespace[PandasLikeNamespace[_ModinDataFrame, _ModinSeries]]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: _NativeCuDF, /
|
||||
) -> Namespace[PandasLikeNamespace[_CuDFDataFrame, _CuDFSeries]]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: _NativePandasLike, /
|
||||
) -> Namespace[PandasLikeNamespace[Any, Any]]: ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls, native: NativeUnknown, /
|
||||
) -> Namespace[CompliantNamespaceAny]: ...
|
||||
|
||||
@classmethod
|
||||
def from_native_object(
|
||||
cls: type[Namespace[Any]], native: NativeAny, /
|
||||
) -> Namespace[Any]:
|
||||
impl: Backend
|
||||
if is_native_polars(native):
|
||||
impl = Implementation.POLARS
|
||||
elif is_native_pandas(native):
|
||||
impl = Implementation.PANDAS
|
||||
elif is_native_arrow(native):
|
||||
impl = Implementation.PYARROW
|
||||
elif is_native_spark_like(native):
|
||||
impl = (
|
||||
Implementation.SQLFRAME
|
||||
if is_native_sqlframe(native)
|
||||
else Implementation.PYSPARK_CONNECT
|
||||
if is_native_pyspark_connect(native)
|
||||
else Implementation.PYSPARK
|
||||
)
|
||||
elif is_native_dask(native): # pragma: no cover
|
||||
impl = Implementation.DASK
|
||||
elif is_native_duckdb(native):
|
||||
impl = Implementation.DUCKDB
|
||||
elif is_native_cudf(native): # pragma: no cover
|
||||
impl = Implementation.CUDF
|
||||
elif is_native_modin(native): # pragma: no cover
|
||||
impl = Implementation.MODIN
|
||||
elif is_native_ibis(native):
|
||||
impl = Implementation.IBIS
|
||||
else:
|
||||
msg = f"Unsupported type: {type(native).__qualname__!r}"
|
||||
raise TypeError(msg)
|
||||
return cls.from_backend(impl)
|
||||
|
||||
|
||||
def is_native_polars(obj: Any) -> TypeIs[_NativePolars]:
|
||||
return (pl := get_polars()) is not None and isinstance(
|
||||
obj, (pl.DataFrame, pl.Series, pl.LazyFrame)
|
||||
)
|
||||
|
||||
|
||||
def is_native_arrow(obj: Any) -> TypeIs[_NativeArrow]:
|
||||
return (pa := get_pyarrow()) is not None and isinstance(
|
||||
obj, (pa.Table, pa.ChunkedArray)
|
||||
)
|
||||
|
||||
|
||||
def is_native_dask(obj: Any) -> TypeIs[_NativeDask]:
|
||||
return is_dask_dataframe(obj)
|
||||
|
||||
|
||||
is_native_duckdb: _Guard[_NativeDuckDB] = is_duckdb_relation
|
||||
is_native_sqlframe: _Guard[_NativeSQLFrame] = is_sqlframe_dataframe
|
||||
is_native_pyspark: _Guard[_NativePySpark] = is_pyspark_dataframe
|
||||
is_native_pyspark_connect: _Guard[_NativePySparkConnect] = is_pyspark_connect_dataframe
|
||||
|
||||
|
||||
def is_native_pandas(obj: Any) -> TypeIs[_NativePandas]:
|
||||
return (pd := get_pandas()) is not None and isinstance(obj, (pd.DataFrame, pd.Series))
|
||||
|
||||
|
||||
def is_native_modin(obj: Any) -> TypeIs[_NativeModin]:
|
||||
return (mpd := get_modin()) is not None and isinstance(
|
||||
obj, (mpd.DataFrame, mpd.Series)
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
def is_native_cudf(obj: Any) -> TypeIs[_NativeCuDF]:
|
||||
return (cudf := get_cudf()) is not None and isinstance(
|
||||
obj, (cudf.DataFrame, cudf.Series)
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
def is_native_pandas_like(obj: Any) -> TypeIs[_NativePandasLike]:
|
||||
return (
|
||||
is_native_pandas(obj) or is_native_cudf(obj) or is_native_modin(obj)
|
||||
) # pragma: no cover
|
||||
|
||||
|
||||
def is_native_spark_like(obj: Any) -> TypeIs[_NativeSparkLike]:
|
||||
return (
|
||||
is_native_sqlframe(obj)
|
||||
or is_native_pyspark(obj)
|
||||
or is_native_pyspark_connect(obj)
|
||||
)
|
||||
|
||||
|
||||
def is_native_ibis(obj: Any) -> TypeIs[_NativeIbis]:
|
||||
return is_ibis_table(obj)
|
||||
1175
lib/python3.11/site-packages/narwhals/_pandas_like/dataframe.py
Normal file
1175
lib/python3.11/site-packages/narwhals/_pandas_like/dataframe.py
Normal file
File diff suppressed because it is too large
Load Diff
345
lib/python3.11/site-packages/narwhals/_pandas_like/expr.py
Normal file
345
lib/python3.11/site-packages/narwhals/_pandas_like/expr.py
Normal file
@ -0,0 +1,345 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import EagerExpr
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._pandas_like.group_by import PandasLikeGroupBy
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
from narwhals._utils import generate_temporary_column_name
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._compliant.typing import AliasNames, EvalNames, EvalSeries, ScalarKwargs
|
||||
from narwhals._expression_parsing import ExprMetadata
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
from narwhals._pandas_like.namespace import PandasLikeNamespace
|
||||
from narwhals._utils import Implementation, Version, _LimitedContext
|
||||
from narwhals.typing import PythonLiteral
|
||||
|
||||
WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT = {
|
||||
"cum_sum": "cumsum",
|
||||
"cum_min": "cummin",
|
||||
"cum_max": "cummax",
|
||||
"cum_prod": "cumprod",
|
||||
# Pandas cumcount starts counting from 0 while Polars starts from 1
|
||||
# Pandas cumcount counts nulls while Polars does not
|
||||
# So, instead of using "cumcount" we use "cumsum" on notna() to get the same result
|
||||
"cum_count": "cumsum",
|
||||
"rolling_sum": "sum",
|
||||
"rolling_mean": "mean",
|
||||
"rolling_std": "std",
|
||||
"rolling_var": "var",
|
||||
"shift": "shift",
|
||||
"rank": "rank",
|
||||
"diff": "diff",
|
||||
"fill_null": "fillna",
|
||||
"quantile": "quantile",
|
||||
"ewm_mean": "mean",
|
||||
}
|
||||
|
||||
|
||||
def window_kwargs_to_pandas_equivalent(
|
||||
function_name: str, kwargs: ScalarKwargs
|
||||
) -> dict[str, PythonLiteral]:
|
||||
if function_name == "shift":
|
||||
assert "n" in kwargs # noqa: S101
|
||||
pandas_kwargs: dict[str, PythonLiteral] = {"periods": kwargs["n"]}
|
||||
elif function_name == "rank":
|
||||
assert "method" in kwargs # noqa: S101
|
||||
assert "descending" in kwargs # noqa: S101
|
||||
_method = kwargs["method"]
|
||||
pandas_kwargs = {
|
||||
"method": "first" if _method == "ordinal" else _method,
|
||||
"ascending": not kwargs["descending"],
|
||||
"na_option": "keep",
|
||||
"pct": False,
|
||||
}
|
||||
elif function_name.startswith("cum_"): # Cumulative operation
|
||||
pandas_kwargs = {"skipna": True}
|
||||
elif function_name.startswith("rolling_"): # Rolling operation
|
||||
assert "min_samples" in kwargs # noqa: S101
|
||||
assert "window_size" in kwargs # noqa: S101
|
||||
assert "center" in kwargs # noqa: S101
|
||||
pandas_kwargs = {
|
||||
"min_periods": kwargs["min_samples"],
|
||||
"window": kwargs["window_size"],
|
||||
"center": kwargs["center"],
|
||||
}
|
||||
elif function_name in {"std", "var"}:
|
||||
assert "ddof" in kwargs # noqa: S101
|
||||
pandas_kwargs = {"ddof": kwargs["ddof"]}
|
||||
elif function_name == "fill_null":
|
||||
assert "strategy" in kwargs # noqa: S101
|
||||
assert "limit" in kwargs # noqa: S101
|
||||
pandas_kwargs = {"strategy": kwargs["strategy"], "limit": kwargs["limit"]}
|
||||
elif function_name == "quantile":
|
||||
assert "quantile" in kwargs # noqa: S101
|
||||
assert "interpolation" in kwargs # noqa: S101
|
||||
pandas_kwargs = {
|
||||
"q": kwargs["quantile"],
|
||||
"interpolation": kwargs["interpolation"],
|
||||
}
|
||||
elif function_name.startswith("ewm_"):
|
||||
assert "com" in kwargs # noqa: S101
|
||||
assert "span" in kwargs # noqa: S101
|
||||
assert "half_life" in kwargs # noqa: S101
|
||||
assert "alpha" in kwargs # noqa: S101
|
||||
assert "adjust" in kwargs # noqa: S101
|
||||
assert "min_samples" in kwargs # noqa: S101
|
||||
assert "ignore_nulls" in kwargs # noqa: S101
|
||||
|
||||
pandas_kwargs = {
|
||||
"com": kwargs["com"],
|
||||
"span": kwargs["span"],
|
||||
"halflife": kwargs["half_life"],
|
||||
"alpha": kwargs["alpha"],
|
||||
"adjust": kwargs["adjust"],
|
||||
"min_periods": kwargs["min_samples"],
|
||||
"ignore_na": kwargs["ignore_nulls"],
|
||||
}
|
||||
else: # sum, len, ...
|
||||
pandas_kwargs = {}
|
||||
return pandas_kwargs
|
||||
|
||||
|
||||
class PandasLikeExpr(EagerExpr["PandasLikeDataFrame", PandasLikeSeries]):
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[PandasLikeDataFrame, PandasLikeSeries],
|
||||
*,
|
||||
depth: int,
|
||||
function_name: str,
|
||||
evaluate_output_names: EvalNames[PandasLikeDataFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
implementation: Implementation,
|
||||
version: Version,
|
||||
scalar_kwargs: ScalarKwargs | None = None,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._depth = depth
|
||||
self._function_name = function_name
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._implementation = implementation
|
||||
self._version = version
|
||||
self._scalar_kwargs = scalar_kwargs or {}
|
||||
self._metadata: ExprMetadata | None = None
|
||||
|
||||
def __narwhals_namespace__(self) -> PandasLikeNamespace:
|
||||
from narwhals._pandas_like.namespace import PandasLikeNamespace
|
||||
|
||||
return PandasLikeNamespace(self._implementation, version=self._version)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[PandasLikeDataFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
function_name: str = "",
|
||||
) -> Self:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
try:
|
||||
return [
|
||||
PandasLikeSeries(
|
||||
df._native_frame[column_name],
|
||||
implementation=df._implementation,
|
||||
version=df._version,
|
||||
)
|
||||
for column_name in evaluate_column_names(df)
|
||||
]
|
||||
except KeyError as e:
|
||||
if error := df._check_columns_exist(evaluate_column_names(df)):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name=function_name,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
implementation=context._implementation,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
native = df.native
|
||||
return [
|
||||
PandasLikeSeries.from_native(native.iloc[:, i], context=df)
|
||||
for i in column_indices
|
||||
]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
depth=0,
|
||||
function_name="nth",
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
implementation=context._implementation,
|
||||
version=context._version,
|
||||
)
|
||||
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self:
|
||||
return self._reuse_series(
|
||||
"ewm_mean",
|
||||
scalar_kwargs={
|
||||
"com": com,
|
||||
"span": span,
|
||||
"half_life": half_life,
|
||||
"alpha": alpha,
|
||||
"adjust": adjust,
|
||||
"min_samples": min_samples,
|
||||
"ignore_nulls": ignore_nulls,
|
||||
},
|
||||
)
|
||||
|
||||
def over( # noqa: C901, PLR0915
|
||||
self, partition_by: Sequence[str], order_by: Sequence[str]
|
||||
) -> Self:
|
||||
if not partition_by:
|
||||
# e.g. `nw.col('a').cum_sum().order_by(key)`
|
||||
# We can always easily support this as it doesn't require grouping.
|
||||
assert order_by # noqa: S101
|
||||
|
||||
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]:
|
||||
token = generate_temporary_column_name(8, df.columns)
|
||||
df = df.with_row_index(token, order_by=None).sort(
|
||||
*order_by, descending=False, nulls_last=False
|
||||
)
|
||||
results = self(df.drop([token], strict=True))
|
||||
sorting_indices = df.get_column(token)
|
||||
for s in results:
|
||||
s._scatter_in_place(sorting_indices, s)
|
||||
return results
|
||||
elif not self._is_elementary():
|
||||
msg = (
|
||||
"Only elementary expressions are supported for `.over` in pandas-like backends.\n\n"
|
||||
"Please see: "
|
||||
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/"
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
else:
|
||||
function_name = PandasLikeGroupBy._leaf_name(self)
|
||||
pandas_function_name = WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT.get(
|
||||
function_name, PandasLikeGroupBy._REMAP_AGGS.get(function_name)
|
||||
)
|
||||
if pandas_function_name is None:
|
||||
msg = (
|
||||
f"Unsupported function: {function_name} in `over` context.\n\n"
|
||||
f"Supported functions are {', '.join(WINDOW_FUNCTIONS_TO_PANDAS_EQUIVALENT)}\n"
|
||||
f"and {', '.join(PandasLikeGroupBy._REMAP_AGGS)}."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
pandas_kwargs = window_kwargs_to_pandas_equivalent(
|
||||
function_name, self._scalar_kwargs
|
||||
)
|
||||
|
||||
def func(df: PandasLikeDataFrame) -> Sequence[PandasLikeSeries]: # noqa: C901, PLR0912, PLR0914, PLR0915
|
||||
output_names, aliases = evaluate_output_names_and_aliases(self, df, [])
|
||||
if function_name == "cum_count":
|
||||
plx = self.__narwhals_namespace__()
|
||||
df = df.with_columns(~plx.col(*output_names).is_null())
|
||||
|
||||
if function_name.startswith("cum_"):
|
||||
assert "reverse" in self._scalar_kwargs # noqa: S101
|
||||
reverse = self._scalar_kwargs["reverse"]
|
||||
else:
|
||||
assert "reverse" not in self._scalar_kwargs # noqa: S101
|
||||
reverse = False
|
||||
|
||||
if order_by:
|
||||
columns = list(set(partition_by).union(output_names).union(order_by))
|
||||
token = generate_temporary_column_name(8, columns)
|
||||
df = (
|
||||
df.simple_select(*columns)
|
||||
.with_row_index(token, order_by=None)
|
||||
.sort(*order_by, descending=reverse, nulls_last=reverse)
|
||||
)
|
||||
sorting_indices = df.get_column(token)
|
||||
elif reverse:
|
||||
columns = list(set(partition_by).union(output_names))
|
||||
df = df.simple_select(*columns)._gather_slice(slice(None, None, -1))
|
||||
grouped = df._native_frame.groupby(partition_by)
|
||||
if function_name.startswith("rolling"):
|
||||
rolling = grouped[list(output_names)].rolling(**pandas_kwargs)
|
||||
assert pandas_function_name is not None # help mypy # noqa: S101
|
||||
if pandas_function_name in {"std", "var"}:
|
||||
assert "ddof" in self._scalar_kwargs # noqa: S101
|
||||
res_native = getattr(rolling, pandas_function_name)(
|
||||
ddof=self._scalar_kwargs["ddof"]
|
||||
)
|
||||
else:
|
||||
res_native = getattr(rolling, pandas_function_name)()
|
||||
elif function_name.startswith("ewm"):
|
||||
if self._implementation.is_pandas() and (
|
||||
self._implementation._backend_version()
|
||||
) < (1, 2): # pragma: no cover
|
||||
msg = (
|
||||
"Exponentially weighted calculation is not available in over "
|
||||
f"context for pandas versions older than 1.2.0, found {self._implementation._backend_version()}."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
ewm = grouped[list(output_names)].ewm(**pandas_kwargs)
|
||||
assert pandas_function_name is not None # help mypy # noqa: S101
|
||||
res_native = getattr(ewm, pandas_function_name)()
|
||||
elif function_name == "fill_null":
|
||||
assert "strategy" in self._scalar_kwargs # noqa: S101
|
||||
assert "limit" in self._scalar_kwargs # noqa: S101
|
||||
df_grouped = grouped[list(output_names)]
|
||||
if self._scalar_kwargs["strategy"] == "forward":
|
||||
res_native = df_grouped.ffill(limit=self._scalar_kwargs["limit"])
|
||||
elif self._scalar_kwargs["strategy"] == "backward":
|
||||
res_native = df_grouped.bfill(limit=self._scalar_kwargs["limit"])
|
||||
else: # pragma: no cover
|
||||
# This is deprecated in pandas. Indeed, `nw.col('a').fill_null(3).over('b')`
|
||||
# does not seem very useful, and DuckDB doesn't support it either.
|
||||
msg = "`fill_null` with `over` without `strategy` specified is not supported."
|
||||
raise NotImplementedError(msg)
|
||||
elif function_name == "len":
|
||||
if len(output_names) != 1: # pragma: no cover
|
||||
msg = "Safety check failed, please report a bug."
|
||||
raise AssertionError(msg)
|
||||
res_native = grouped.transform("size").to_frame(aliases[0])
|
||||
else:
|
||||
res_native = grouped[list(output_names)].transform(
|
||||
pandas_function_name, **pandas_kwargs
|
||||
)
|
||||
result_frame = df._with_native(res_native).rename(
|
||||
dict(zip(output_names, aliases))
|
||||
)
|
||||
results = [result_frame.get_column(name) for name in aliases]
|
||||
if order_by:
|
||||
for s in results:
|
||||
s._scatter_in_place(sorting_indices, s)
|
||||
return results
|
||||
if reverse:
|
||||
return [s._gather_slice(slice(None, None, -1)) for s in results]
|
||||
return results
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
depth=self._depth + 1,
|
||||
function_name=self._function_name + "->over",
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
implementation=self._implementation,
|
||||
version=self._version,
|
||||
)
|
||||
365
lib/python3.11/site-packages/narwhals/_pandas_like/group_by.py
Normal file
365
lib/python3.11/site-packages/narwhals/_pandas_like/group_by.py
Normal file
@ -0,0 +1,365 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from functools import lru_cache
|
||||
from itertools import chain
|
||||
from operator import methodcaller
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal
|
||||
|
||||
from narwhals._compliant import EagerGroupBy
|
||||
from narwhals._exceptions import issue_warning
|
||||
from narwhals._expression_parsing import evaluate_output_names_and_aliases
|
||||
from narwhals._utils import zip_strict
|
||||
from narwhals.dependencies import is_pandas_like_dataframe
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping, Sequence
|
||||
|
||||
import pandas as pd
|
||||
from pandas.api.typing import DataFrameGroupBy as _NativeGroupBy
|
||||
from typing_extensions import TypeAlias, Unpack
|
||||
|
||||
from narwhals._compliant.typing import NarwhalsAggregation, ScalarKwargs
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
from narwhals._pandas_like.expr import PandasLikeExpr
|
||||
|
||||
NativeGroupBy: TypeAlias = "_NativeGroupBy[tuple[str, ...], Literal[True]]"
|
||||
|
||||
NativeApply: TypeAlias = "Callable[[pd.DataFrame], pd.Series[Any]]"
|
||||
InefficientNativeAggregation: TypeAlias = Literal["cov", "skew"]
|
||||
NativeAggregation: TypeAlias = Literal[
|
||||
"any",
|
||||
"all",
|
||||
"count",
|
||||
"first",
|
||||
"idxmax",
|
||||
"idxmin",
|
||||
"last",
|
||||
"max",
|
||||
"mean",
|
||||
"median",
|
||||
"min",
|
||||
"mode",
|
||||
"nunique",
|
||||
"prod",
|
||||
"quantile",
|
||||
"sem",
|
||||
"size",
|
||||
"std",
|
||||
"sum",
|
||||
"var",
|
||||
InefficientNativeAggregation,
|
||||
]
|
||||
"""https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html#built-in-aggregation-methods"""
|
||||
|
||||
_NativeAgg: TypeAlias = "Callable[[Any], pd.DataFrame | pd.Series[Any]]"
|
||||
"""Equivalent to a partial method call on `DataFrameGroupBy`."""
|
||||
|
||||
|
||||
NonStrHashable: TypeAlias = Any
|
||||
"""Because `pandas` allows *"names"* like that 😭"""
|
||||
|
||||
|
||||
@lru_cache(maxsize=32)
|
||||
def _native_agg(name: NativeAggregation, /, **kwds: Unpack[ScalarKwargs]) -> _NativeAgg:
|
||||
if name == "nunique":
|
||||
return methodcaller(name, dropna=False)
|
||||
if not kwds or kwds.get("ddof") == 1:
|
||||
return methodcaller(name)
|
||||
return methodcaller(name, **kwds)
|
||||
|
||||
|
||||
class AggExpr:
|
||||
"""Wrapper storing the intermediate state per-`PandasLikeExpr`.
|
||||
|
||||
There's a lot of edge cases to handle, so aim to evaluate as little
|
||||
as possible - and store anything that's needed twice.
|
||||
|
||||
Warning:
|
||||
While a `PandasLikeExpr` can be reused - this wrapper is valid **only**
|
||||
in a single `.agg(...)` operation.
|
||||
"""
|
||||
|
||||
expr: PandasLikeExpr
|
||||
output_names: Sequence[str]
|
||||
aliases: Sequence[str]
|
||||
|
||||
def __init__(self, expr: PandasLikeExpr) -> None:
|
||||
self.expr = expr
|
||||
self.output_names = ()
|
||||
self.aliases = ()
|
||||
self._leaf_name: NarwhalsAggregation | Any = ""
|
||||
|
||||
def with_expand_names(self, group_by: PandasLikeGroupBy, /) -> AggExpr:
|
||||
"""**Mutating operation**.
|
||||
|
||||
Stores the results of `evaluate_output_names_and_aliases`.
|
||||
"""
|
||||
df = group_by.compliant
|
||||
exclude = group_by.exclude
|
||||
self.output_names, self.aliases = evaluate_output_names_and_aliases(
|
||||
self.expr, df, exclude
|
||||
)
|
||||
return self
|
||||
|
||||
def _getitem_aggs(
|
||||
self, group_by: PandasLikeGroupBy, /
|
||||
) -> pd.DataFrame | pd.Series[Any]:
|
||||
"""Evaluate the wrapped expression as a group_by operation."""
|
||||
result: pd.DataFrame | pd.Series[Any]
|
||||
names = self.output_names
|
||||
if self.is_len() and self.is_top_level_function():
|
||||
result = group_by._grouped.size()
|
||||
elif self.is_len():
|
||||
result_single = group_by._grouped.size()
|
||||
ns = group_by.compliant.__narwhals_namespace__()
|
||||
result = ns._concat_horizontal(
|
||||
[ns.from_native(result_single).alias(name).native for name in names]
|
||||
)
|
||||
elif self.is_mode():
|
||||
compliant = group_by.compliant
|
||||
if (keep := self.kwargs.get("keep")) != "any": # pragma: no cover
|
||||
msg = (
|
||||
f"`Expr.mode(keep='{keep}')` is not implemented in group by context for "
|
||||
f"backend {compliant._implementation}\n\n"
|
||||
"Hint: Use `nw.col(...).mode(keep='any')` instead."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
cols = list(names)
|
||||
native = compliant.native
|
||||
keys, kwargs = group_by._keys, group_by._kwargs
|
||||
|
||||
# Implementation based on the following suggestion:
|
||||
# https://github.com/pandas-dev/pandas/issues/19254#issuecomment-778661578
|
||||
ns = compliant.__narwhals_namespace__()
|
||||
result = ns._concat_horizontal(
|
||||
[
|
||||
native.groupby([*keys, col], **kwargs)
|
||||
.size()
|
||||
.sort_values(ascending=False)
|
||||
.reset_index(col)
|
||||
.groupby(keys, **kwargs)[col]
|
||||
.head(1)
|
||||
.sort_index()
|
||||
for col in cols
|
||||
]
|
||||
)
|
||||
else:
|
||||
select = names[0] if len(names) == 1 else list(names)
|
||||
result = self.native_agg()(group_by._grouped[select])
|
||||
if is_pandas_like_dataframe(result):
|
||||
result.columns = list(self.aliases)
|
||||
else:
|
||||
result.name = self.aliases[0]
|
||||
return result
|
||||
|
||||
def is_len(self) -> bool:
|
||||
return self.leaf_name == "len"
|
||||
|
||||
def is_mode(self) -> bool:
|
||||
return self.leaf_name == "mode"
|
||||
|
||||
def is_top_level_function(self) -> bool:
|
||||
# e.g. `nw.len()`.
|
||||
return self.expr._depth == 0
|
||||
|
||||
@property
|
||||
def kwargs(self) -> ScalarKwargs:
|
||||
return self.expr._scalar_kwargs
|
||||
|
||||
@property
|
||||
def leaf_name(self) -> NarwhalsAggregation | Any:
|
||||
if name := self._leaf_name:
|
||||
return name
|
||||
self._leaf_name = PandasLikeGroupBy._leaf_name(self.expr)
|
||||
return self._leaf_name
|
||||
|
||||
def native_agg(self) -> _NativeAgg:
|
||||
"""Return a partial `DataFrameGroupBy` method, missing only `self`."""
|
||||
return _native_agg(
|
||||
PandasLikeGroupBy._remap_expr_name(self.leaf_name), **self.kwargs
|
||||
)
|
||||
|
||||
|
||||
class PandasLikeGroupBy(
|
||||
EagerGroupBy["PandasLikeDataFrame", "PandasLikeExpr", NativeAggregation]
|
||||
):
|
||||
_REMAP_AGGS: ClassVar[Mapping[NarwhalsAggregation, NativeAggregation]] = {
|
||||
"sum": "sum",
|
||||
"mean": "mean",
|
||||
"median": "median",
|
||||
"max": "max",
|
||||
"min": "min",
|
||||
"mode": "mode",
|
||||
"std": "std",
|
||||
"var": "var",
|
||||
"len": "size",
|
||||
"n_unique": "nunique",
|
||||
"count": "count",
|
||||
"quantile": "quantile",
|
||||
"all": "all",
|
||||
"any": "any",
|
||||
}
|
||||
_original_columns: tuple[str, ...]
|
||||
"""Column names *prior* to any aliasing in `ParseKeysGroupBy`."""
|
||||
|
||||
_keys: list[str]
|
||||
"""Stores the **aliased** version of group keys from `ParseKeysGroupBy`."""
|
||||
|
||||
_output_key_names: list[str]
|
||||
"""Stores the **original** version of group keys."""
|
||||
|
||||
_kwargs: Mapping[str, bool]
|
||||
"""Stores keyword arguments for `DataFrame.groupby` other than `by`."""
|
||||
|
||||
@property
|
||||
def exclude(self) -> tuple[str, ...]:
|
||||
"""Group keys to ignore when expanding multi-output aggregations."""
|
||||
return self._exclude
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: PandasLikeDataFrame,
|
||||
keys: Sequence[PandasLikeExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._original_columns = tuple(df.columns)
|
||||
self._drop_null_keys = drop_null_keys
|
||||
self._compliant_frame, self._keys, self._output_key_names = self._parse_keys(
|
||||
df, keys
|
||||
)
|
||||
self._exclude: tuple[str, ...] = (*self._keys, *self._output_key_names)
|
||||
# Drop index to avoid potential collisions:
|
||||
# https://github.com/narwhals-dev/narwhals/issues/1907.
|
||||
native = self.compliant.native
|
||||
if set(native.index.names).intersection(self.compliant.columns):
|
||||
native = native.reset_index(drop=True)
|
||||
|
||||
self._kwargs = {
|
||||
"sort": False,
|
||||
"as_index": True,
|
||||
"dropna": drop_null_keys,
|
||||
"observed": True,
|
||||
}
|
||||
self._grouped: NativeGroupBy = native.groupby(self._keys.copy(), **self._kwargs)
|
||||
|
||||
def agg(self, *exprs: PandasLikeExpr) -> PandasLikeDataFrame:
|
||||
all_aggs_are_simple = True
|
||||
agg_exprs: list[AggExpr] = []
|
||||
for expr in exprs:
|
||||
agg_exprs.append(AggExpr(expr).with_expand_names(self))
|
||||
if not self._is_simple(expr):
|
||||
all_aggs_are_simple = False
|
||||
|
||||
if all_aggs_are_simple:
|
||||
result: pd.DataFrame
|
||||
if agg_exprs:
|
||||
ns = self.compliant.__narwhals_namespace__()
|
||||
result = ns._concat_horizontal(self._getitem_aggs(agg_exprs))
|
||||
else:
|
||||
result = self.compliant.__native_namespace__().DataFrame(
|
||||
list(self._grouped.groups), columns=self._keys
|
||||
)
|
||||
elif self.compliant.native.empty:
|
||||
raise empty_results_error()
|
||||
else:
|
||||
result = self._apply_aggs(exprs)
|
||||
# NOTE: Keep `inplace=True` to avoid making a redundant copy.
|
||||
# This may need updating, depending on https://github.com/pandas-dev/pandas/pull/51466/files
|
||||
result.reset_index(inplace=True) # noqa: PD002
|
||||
return self._select_results(result, agg_exprs)
|
||||
|
||||
def _select_results(
|
||||
self, df: pd.DataFrame, /, agg_exprs: Sequence[AggExpr]
|
||||
) -> PandasLikeDataFrame:
|
||||
"""Responsible for remapping temp column names back to original.
|
||||
|
||||
See `ParseKeysGroupBy`.
|
||||
"""
|
||||
new_names = chain.from_iterable(e.aliases for e in agg_exprs)
|
||||
return (
|
||||
self.compliant._with_native(df, validate_column_names=False)
|
||||
.simple_select(*self._keys, *new_names)
|
||||
.rename(dict(zip(self._keys, self._output_key_names)))
|
||||
)
|
||||
|
||||
def _getitem_aggs(
|
||||
self, exprs: Iterable[AggExpr], /
|
||||
) -> list[pd.DataFrame | pd.Series[Any]]:
|
||||
return [e._getitem_aggs(self) for e in exprs]
|
||||
|
||||
def _apply_aggs(self, exprs: Iterable[PandasLikeExpr]) -> pd.DataFrame:
|
||||
"""Stub issue for `include_groups` [pandas-dev/pandas-stubs#1270].
|
||||
|
||||
- [User guide] mentions `include_groups` 4 times without deprecation.
|
||||
- [`DataFrameGroupBy.apply`] doc says the default value of `True` is deprecated since `2.2.0`.
|
||||
- `False` is explicitly the only *non-deprecated* option, but entirely omitted since [pandas-dev/pandas-stubs#1268].
|
||||
|
||||
[pandas-dev/pandas-stubs#1270]: https://github.com/pandas-dev/pandas-stubs/issues/1270
|
||||
[User guide]: https://pandas.pydata.org/pandas-docs/stable/user_guide/groupby.html
|
||||
[`DataFrameGroupBy.apply`]: https://pandas.pydata.org/pandas-docs/stable/reference/api/pandas.core.groupby.DataFrameGroupBy.apply.html
|
||||
[pandas-dev/pandas-stubs#1268]: https://github.com/pandas-dev/pandas-stubs/pull/1268
|
||||
"""
|
||||
warn_complex_group_by()
|
||||
impl = self.compliant._implementation
|
||||
func = self._apply_exprs_function(exprs)
|
||||
apply = self._grouped.apply
|
||||
if impl.is_pandas() and impl._backend_version() >= (2, 2):
|
||||
return apply(func, include_groups=False) # type: ignore[call-overload]
|
||||
return apply(func) # pragma: no cover
|
||||
|
||||
def _apply_exprs_function(self, exprs: Iterable[PandasLikeExpr]) -> NativeApply:
|
||||
ns = self.compliant.__narwhals_namespace__()
|
||||
into_series = ns._series.from_iterable
|
||||
|
||||
def fn(df: pd.DataFrame) -> pd.Series[Any]:
|
||||
compliant = self.compliant._with_native(df)
|
||||
results = (
|
||||
(keys.native.iloc[0], keys.name)
|
||||
for expr in exprs
|
||||
for keys in expr(compliant)
|
||||
)
|
||||
out_group, out_names = zip_strict(*results) if results else ([], [])
|
||||
return into_series(out_group, index=out_names, context=ns).native
|
||||
|
||||
return fn
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[Any, PandasLikeDataFrame]]:
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=".*a length 1 tuple will be returned",
|
||||
category=FutureWarning,
|
||||
)
|
||||
with_native = self.compliant._with_native
|
||||
for key, group in self._grouped:
|
||||
yield (key, with_native(group).simple_select(*self._original_columns))
|
||||
|
||||
|
||||
def empty_results_error() -> ValueError:
|
||||
"""Don't even attempt this, it's way too inconsistent across pandas versions."""
|
||||
msg = (
|
||||
"No results for group-by aggregation.\n\n"
|
||||
"Hint: you were probably trying to apply a non-elementary aggregation with a "
|
||||
"pandas-like API.\n"
|
||||
"Please rewrite your query such that group-by aggregations "
|
||||
"are elementary. For example, instead of:\n\n"
|
||||
" df.group_by('a').agg(nw.col('b').round(2).mean())\n\n"
|
||||
"use:\n\n"
|
||||
" df.with_columns(nw.col('b').round(2)).group_by('a').agg(nw.col('b').mean())\n\n"
|
||||
)
|
||||
return ValueError(msg)
|
||||
|
||||
|
||||
def warn_complex_group_by() -> None:
|
||||
issue_warning(
|
||||
"Found complex group-by expression, which can't be expressed efficiently with the "
|
||||
"pandas API. If you can, please rewrite your query such that group-by aggregations "
|
||||
"are simple (e.g. mean, std, min, max, ...). \n\n"
|
||||
"Please see: "
|
||||
"https://narwhals-dev.github.io/narwhals/concepts/improve_group_by_operation/",
|
||||
UserWarning,
|
||||
)
|
||||
441
lib/python3.11/site-packages/narwhals/_pandas_like/namespace.py
Normal file
441
lib/python3.11/site-packages/narwhals/_pandas_like/namespace.py
Normal file
@ -0,0 +1,441 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
import warnings
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, Literal, Protocol, overload
|
||||
|
||||
from narwhals._compliant import CompliantThen, EagerNamespace, EagerWhen
|
||||
from narwhals._expression_parsing import (
|
||||
combine_alias_output_names,
|
||||
combine_evaluate_output_names,
|
||||
)
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
from narwhals._pandas_like.expr import PandasLikeExpr
|
||||
from narwhals._pandas_like.selectors import PandasSelectorNamespace
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
from narwhals._pandas_like.typing import NativeDataFrameT, NativeSeriesT
|
||||
from narwhals._pandas_like.utils import is_non_nullable_boolean
|
||||
from narwhals._utils import zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._utils import Implementation, Version
|
||||
from narwhals.typing import IntoDType, NonNestedLiteral
|
||||
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
"""Escape hatch, but leaving a trace that this isn't ideal."""
|
||||
|
||||
|
||||
_Vertical: TypeAlias = Literal[0]
|
||||
_Horizontal: TypeAlias = Literal[1]
|
||||
Axis: TypeAlias = Literal[_Vertical, _Horizontal]
|
||||
|
||||
VERTICAL: _Vertical = 0
|
||||
HORIZONTAL: _Horizontal = 1
|
||||
|
||||
|
||||
class PandasLikeNamespace(
|
||||
EagerNamespace[
|
||||
PandasLikeDataFrame,
|
||||
PandasLikeSeries,
|
||||
PandasLikeExpr,
|
||||
NativeDataFrameT,
|
||||
NativeSeriesT,
|
||||
]
|
||||
):
|
||||
@property
|
||||
def _dataframe(self) -> type[PandasLikeDataFrame]:
|
||||
return PandasLikeDataFrame
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[PandasLikeExpr]:
|
||||
return PandasLikeExpr
|
||||
|
||||
@property
|
||||
def _series(self) -> type[PandasLikeSeries]:
|
||||
return PandasLikeSeries
|
||||
|
||||
@property
|
||||
def selectors(self) -> PandasSelectorNamespace:
|
||||
return PandasSelectorNamespace.from_namespace(self)
|
||||
|
||||
def __init__(self, implementation: Implementation, version: Version) -> None:
|
||||
self._implementation = implementation
|
||||
self._version = version
|
||||
|
||||
def coalesce(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(*(s for _expr in exprs for s in _expr(df)))
|
||||
return [
|
||||
reduce(lambda x, y: x.fill_null(y, strategy=None, limit=None), series)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="coalesce",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def lit(self, value: NonNestedLiteral, dtype: IntoDType | None) -> PandasLikeExpr:
|
||||
def _lit_pandas_series(df: PandasLikeDataFrame) -> PandasLikeSeries:
|
||||
pandas_series = self._series.from_iterable(
|
||||
data=[value],
|
||||
name="literal",
|
||||
index=df._native_frame.index[0:1],
|
||||
context=self,
|
||||
)
|
||||
if dtype:
|
||||
return pandas_series.cast(dtype)
|
||||
return pandas_series
|
||||
|
||||
return PandasLikeExpr(
|
||||
lambda df: [_lit_pandas_series(df)],
|
||||
depth=0,
|
||||
function_name="lit",
|
||||
evaluate_output_names=lambda _df: ["literal"],
|
||||
alias_output_names=None,
|
||||
implementation=self._implementation,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
def len(self) -> PandasLikeExpr:
|
||||
return PandasLikeExpr(
|
||||
lambda df: [
|
||||
self._series.from_iterable(
|
||||
[len(df._native_frame)], name="len", index=[0], context=self
|
||||
)
|
||||
],
|
||||
depth=0,
|
||||
function_name="len",
|
||||
evaluate_output_names=lambda _df: ["len"],
|
||||
alias_output_names=None,
|
||||
implementation=self._implementation,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
# --- horizontal ---
|
||||
def sum_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
it = chain.from_iterable(expr(df) for expr in exprs)
|
||||
series = align(*it)
|
||||
native_series = (s.fill_null(0, None, None) for s in series)
|
||||
return [reduce(operator.add, native_series)]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="sum_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def all_horizontal(
|
||||
self, *exprs: PandasLikeExpr, ignore_nulls: bool
|
||||
) -> PandasLikeExpr:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
series = [s for _expr in exprs for s in _expr(df)]
|
||||
if not ignore_nulls and any(
|
||||
s.native.dtype == "object" and s.is_null().any() for s in series
|
||||
):
|
||||
# classical NumPy boolean columns don't support missing values, so
|
||||
# only do the full scan with `is_null` if we have `object` dtype.
|
||||
msg = "Cannot use `ignore_nulls=False` in `all_horizontal` for non-nullable NumPy-backed pandas Series when nulls are present."
|
||||
raise ValueError(msg)
|
||||
it = (
|
||||
(
|
||||
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
|
||||
s if is_non_nullable_boolean(s) else s.fill_null(True, None, None)
|
||||
for s in series
|
||||
)
|
||||
if ignore_nulls
|
||||
else iter(series)
|
||||
)
|
||||
return [reduce(operator.and_, align(*it))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="all_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def any_horizontal(
|
||||
self, *exprs: PandasLikeExpr, ignore_nulls: bool
|
||||
) -> PandasLikeExpr:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
align = self._series._align_full_broadcast
|
||||
series = [s for _expr in exprs for s in _expr(df)]
|
||||
if not ignore_nulls and any(
|
||||
s.native.dtype == "object" and s.is_null().any() for s in series
|
||||
):
|
||||
# classical NumPy boolean columns don't support missing values, so
|
||||
# only do the full scan with `is_null` if we have `object` dtype.
|
||||
msg = "Cannot use `ignore_nulls=False` in `any_horizontal` for non-nullable NumPy-backed pandas Series when nulls are present."
|
||||
raise ValueError(msg)
|
||||
it = (
|
||||
(
|
||||
# NumPy-backed 'bool' dtype can't contain nulls so doesn't need filling.
|
||||
s if is_non_nullable_boolean(s) else s.fill_null(False, None, None)
|
||||
for s in series
|
||||
)
|
||||
if ignore_nulls
|
||||
else iter(series)
|
||||
)
|
||||
return [reduce(operator.or_, align(*it))]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="any_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def mean_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(
|
||||
*(s.fill_null(0, strategy=None, limit=None) for s in expr_results)
|
||||
)
|
||||
non_na = align(*(1 - s.is_null() for s in expr_results))
|
||||
return [reduce(operator.add, series) / reduce(operator.add, non_na)]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="mean_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def min_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
it = chain.from_iterable(expr(df) for expr in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(*it)
|
||||
|
||||
return [
|
||||
PandasLikeSeries(
|
||||
self.concat(
|
||||
(s.to_frame() for s in series), how="horizontal"
|
||||
)._native_frame.min(axis=1),
|
||||
implementation=self._implementation,
|
||||
version=self._version,
|
||||
).alias(series[0].name)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="min_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
def max_horizontal(self, *exprs: PandasLikeExpr) -> PandasLikeExpr:
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
it = chain.from_iterable(expr(df) for expr in exprs)
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(*it)
|
||||
|
||||
return [
|
||||
PandasLikeSeries(
|
||||
self.concat(
|
||||
(s.to_frame() for s in series), how="horizontal"
|
||||
).native.max(axis=1),
|
||||
implementation=self._implementation,
|
||||
version=self._version,
|
||||
).alias(series[0].name)
|
||||
]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="max_horizontal",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
@property
|
||||
def _concat(self) -> _NativeConcat[NativeDataFrameT, NativeSeriesT]:
|
||||
"""Concatenate pandas objects along a particular axis.
|
||||
|
||||
Return the **native** equivalent of `pd.concat`.
|
||||
"""
|
||||
return self._implementation.to_native_namespace().concat
|
||||
|
||||
def _concat_diagonal(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFrameT:
|
||||
if self._implementation.is_pandas() and self._backend_version < (3,):
|
||||
return self._concat(dfs, axis=VERTICAL, copy=False)
|
||||
return self._concat(dfs, axis=VERTICAL)
|
||||
|
||||
def _concat_horizontal(
|
||||
self, dfs: Sequence[NativeDataFrameT | NativeSeriesT], /
|
||||
) -> NativeDataFrameT:
|
||||
if self._implementation.is_cudf():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message="The behavior of array concatenation with empty entries is deprecated",
|
||||
category=FutureWarning,
|
||||
)
|
||||
return self._concat(dfs, axis=HORIZONTAL)
|
||||
elif self._implementation.is_pandas() and self._backend_version < (3,):
|
||||
return self._concat(dfs, axis=HORIZONTAL, copy=False)
|
||||
return self._concat(dfs, axis=HORIZONTAL)
|
||||
|
||||
def _concat_vertical(self, dfs: Sequence[NativeDataFrameT], /) -> NativeDataFrameT:
|
||||
cols_0 = dfs[0].columns
|
||||
for i, df in enumerate(dfs[1:], start=1):
|
||||
cols_current = df.columns
|
||||
if not (
|
||||
(len(cols_current) == len(cols_0)) and (cols_current == cols_0).all()
|
||||
):
|
||||
msg = (
|
||||
"unable to vstack, column names don't match:\n"
|
||||
f" - dataframe 0: {cols_0.to_list()}\n"
|
||||
f" - dataframe {i}: {cols_current.to_list()}\n"
|
||||
)
|
||||
raise TypeError(msg)
|
||||
if self._implementation.is_pandas() and self._backend_version < (3,):
|
||||
return self._concat(dfs, axis=VERTICAL, copy=False)
|
||||
return self._concat(dfs, axis=VERTICAL)
|
||||
|
||||
def when(self, predicate: PandasLikeExpr) -> PandasWhen[NativeSeriesT]:
|
||||
return PandasWhen[NativeSeriesT].from_expr(predicate, context=self)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: PandasLikeExpr, separator: str, ignore_nulls: bool
|
||||
) -> PandasLikeExpr:
|
||||
string = self._version.dtypes.String()
|
||||
|
||||
def func(df: PandasLikeDataFrame) -> list[PandasLikeSeries]:
|
||||
expr_results = [s for _expr in exprs for s in _expr(df)]
|
||||
align = self._series._align_full_broadcast
|
||||
series = align(*(s.cast(string) for s in expr_results))
|
||||
null_mask = align(*(s.is_null() for s in expr_results))
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = reduce(operator.or_, null_mask)
|
||||
result = reduce(lambda x, y: x + separator + y, series).zip_with(
|
||||
~null_mask_result, None
|
||||
)
|
||||
else:
|
||||
# NOTE: Trying to help `mypy` later
|
||||
# error: Cannot determine type of "values" [has-type]
|
||||
values: list[PandasLikeSeries]
|
||||
init_value, *values = [
|
||||
s.zip_with(~nm, "") for s, nm in zip_strict(series, null_mask)
|
||||
]
|
||||
|
||||
sep_array = init_value.from_iterable(
|
||||
data=[separator] * len(init_value),
|
||||
name="sep",
|
||||
index=init_value.native.index,
|
||||
context=self,
|
||||
)
|
||||
separators = (sep_array.zip_with(~nm, "") for nm in null_mask[:-1])
|
||||
result = reduce(
|
||||
operator.add,
|
||||
(s + v for s, v in zip_strict(separators, values)),
|
||||
init_value,
|
||||
)
|
||||
|
||||
return [result]
|
||||
|
||||
return self._expr._from_callable(
|
||||
func=func,
|
||||
depth=max(x._depth for x in exprs) + 1,
|
||||
function_name="concat_str",
|
||||
evaluate_output_names=combine_evaluate_output_names(*exprs),
|
||||
alias_output_names=combine_alias_output_names(*exprs),
|
||||
context=self,
|
||||
)
|
||||
|
||||
|
||||
class _NativeConcat(Protocol[NativeDataFrameT, NativeSeriesT]):
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
objs: Iterable[NativeDataFrameT],
|
||||
*,
|
||||
axis: _Vertical,
|
||||
copy: bool | None = ...,
|
||||
) -> NativeDataFrameT: ...
|
||||
@overload
|
||||
def __call__(
|
||||
self, objs: Iterable[NativeSeriesT], *, axis: _Vertical, copy: bool | None = ...
|
||||
) -> NativeSeriesT: ...
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
objs: Iterable[NativeDataFrameT | NativeSeriesT],
|
||||
*,
|
||||
axis: _Horizontal,
|
||||
copy: bool | None = ...,
|
||||
) -> NativeDataFrameT: ...
|
||||
@overload
|
||||
def __call__(
|
||||
self,
|
||||
objs: Iterable[NativeDataFrameT | NativeSeriesT],
|
||||
*,
|
||||
axis: Axis,
|
||||
copy: bool | None = ...,
|
||||
) -> NativeDataFrameT | NativeSeriesT: ...
|
||||
|
||||
def __call__(
|
||||
self,
|
||||
objs: Iterable[NativeDataFrameT | NativeSeriesT],
|
||||
*,
|
||||
axis: Axis,
|
||||
copy: bool | None = None,
|
||||
) -> NativeDataFrameT | NativeSeriesT: ...
|
||||
|
||||
|
||||
class PandasWhen(
|
||||
EagerWhen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, NativeSeriesT]
|
||||
):
|
||||
@property
|
||||
# Signature of "_then" incompatible with supertype "CompliantWhen"
|
||||
# ArrowWhen seems to follow the same pattern, but no mypy complaint there?
|
||||
def _then(self) -> type[PandasThen]: # type: ignore[override]
|
||||
return PandasThen
|
||||
|
||||
def _if_then_else(
|
||||
self,
|
||||
when: NativeSeriesT,
|
||||
then: NativeSeriesT,
|
||||
otherwise: NativeSeriesT | NonNestedLiteral,
|
||||
) -> NativeSeriesT:
|
||||
where: Incomplete = then.where
|
||||
return where(when) if otherwise is None else where(when, otherwise)
|
||||
|
||||
|
||||
class PandasThen(
|
||||
CompliantThen[PandasLikeDataFrame, PandasLikeSeries, PandasLikeExpr, PandasWhen],
|
||||
PandasLikeExpr,
|
||||
):
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "whenthen"
|
||||
@ -0,0 +1,38 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import CompliantSelector, EagerSelectorNamespace
|
||||
from narwhals._pandas_like.expr import PandasLikeExpr
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._compliant.typing import ScalarKwargs
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame # noqa: F401
|
||||
from narwhals._pandas_like.series import PandasLikeSeries # noqa: F401
|
||||
|
||||
|
||||
class PandasSelectorNamespace(
|
||||
EagerSelectorNamespace["PandasLikeDataFrame", "PandasLikeSeries"]
|
||||
):
|
||||
@property
|
||||
def _selector(self) -> type[PandasSelector]:
|
||||
return PandasSelector
|
||||
|
||||
|
||||
class PandasSelector( # type: ignore[misc]
|
||||
CompliantSelector["PandasLikeDataFrame", "PandasLikeSeries"], PandasLikeExpr
|
||||
):
|
||||
_depth: int = 0
|
||||
_scalar_kwargs: ScalarKwargs = {} # noqa: RUF012
|
||||
_function_name: str = "selector"
|
||||
|
||||
def _to_expr(self) -> PandasLikeExpr:
|
||||
return PandasLikeExpr(
|
||||
self._call,
|
||||
depth=self._depth,
|
||||
function_name=self._function_name,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
implementation=self._implementation,
|
||||
version=self._version,
|
||||
)
|
||||
1160
lib/python3.11/site-packages/narwhals/_pandas_like/series.py
Normal file
1160
lib/python3.11/site-packages/narwhals/_pandas_like/series.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant.any_namespace import CatNamespace
|
||||
from narwhals._pandas_like.utils import PandasLikeSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
|
||||
|
||||
class PandasLikeSeriesCatNamespace(
|
||||
PandasLikeSeriesNamespace, CatNamespace["PandasLikeSeries"]
|
||||
):
|
||||
def get_categories(self) -> PandasLikeSeries:
|
||||
s = self.native
|
||||
return self.with_native(type(s)(s.cat.categories, name=s.name))
|
||||
290
lib/python3.11/site-packages/narwhals/_pandas_like/series_dt.py
Normal file
290
lib/python3.11/site-packages/narwhals/_pandas_like/series_dt.py
Normal file
@ -0,0 +1,290 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._compliant.any_namespace import DateTimeNamespace
|
||||
from narwhals._constants import (
|
||||
EPOCH_YEAR,
|
||||
MS_PER_SECOND,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_DAY,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._pandas_like.utils import (
|
||||
ALIAS_DICT,
|
||||
UNITS_DICT,
|
||||
PandasLikeSeriesNamespace,
|
||||
calculate_timestamp_date,
|
||||
calculate_timestamp_datetime,
|
||||
get_dtype_backend,
|
||||
int_dtype_mapper,
|
||||
is_dtype_pyarrow,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from datetime import timedelta
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
from narwhals.typing import TimeUnit
|
||||
|
||||
|
||||
class PandasLikeSeriesDateTimeNamespace(
|
||||
PandasLikeSeriesNamespace, DateTimeNamespace["PandasLikeSeries"]
|
||||
):
|
||||
def date(self) -> PandasLikeSeries:
|
||||
result = self.with_native(self.native.dt.date)
|
||||
if str(result.dtype).lower() == "object":
|
||||
msg = (
|
||||
"Accessing `date` on the default pandas backend "
|
||||
"will return a Series of type `object`."
|
||||
"\nThis differs from polars API and will prevent `.dt` chaining. "
|
||||
"Please switch to the `pyarrow` backend:"
|
||||
'\ndf.convert_dtypes(dtype_backend="pyarrow")'
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return result
|
||||
|
||||
def year(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.dt.year)
|
||||
|
||||
def month(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.dt.month)
|
||||
|
||||
def day(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.dt.day)
|
||||
|
||||
def hour(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.dt.hour)
|
||||
|
||||
def minute(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.dt.minute)
|
||||
|
||||
def second(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.dt.second)
|
||||
|
||||
def millisecond(self) -> PandasLikeSeries:
|
||||
return self.microsecond() // 1000
|
||||
|
||||
def microsecond(self) -> PandasLikeSeries:
|
||||
if self.backend_version < (3, 0, 0) and self._is_pyarrow():
|
||||
# crazy workaround for https://github.com/pandas-dev/pandas/issues/59154
|
||||
import pyarrow.compute as pc # ignore-banned-import()
|
||||
|
||||
from narwhals._arrow.utils import lit
|
||||
|
||||
arr_ns = self.native.array
|
||||
arr = arr_ns.__arrow_array__()
|
||||
result_arr = pc.add(
|
||||
pc.multiply(pc.millisecond(arr), lit(1_000)), pc.microsecond(arr)
|
||||
)
|
||||
result = type(self.native)(type(arr_ns)(result_arr), name=self.native.name)
|
||||
return self.with_native(result)
|
||||
|
||||
return self.with_native(self.native.dt.microsecond)
|
||||
|
||||
def nanosecond(self) -> PandasLikeSeries:
|
||||
return self.microsecond() * 1_000 + self.native.dt.nanosecond
|
||||
|
||||
def ordinal_day(self) -> PandasLikeSeries:
|
||||
year_start = self.native.dt.year
|
||||
result = (
|
||||
self.native.to_numpy().astype("datetime64[D]")
|
||||
- (year_start.to_numpy() - EPOCH_YEAR).astype("datetime64[Y]")
|
||||
).astype("int32") + 1
|
||||
dtype = "Int64[pyarrow]" if self._is_pyarrow() else "int32"
|
||||
return self.with_native(
|
||||
type(self.native)(result, dtype=dtype, name=year_start.name)
|
||||
)
|
||||
|
||||
def weekday(self) -> PandasLikeSeries:
|
||||
# Pandas is 0-6 while Polars is 1-7
|
||||
return self.with_native(self.native.dt.weekday) + 1
|
||||
|
||||
def _is_pyarrow(self) -> bool:
|
||||
return is_dtype_pyarrow(self.native.dtype)
|
||||
|
||||
def _get_total_seconds(self) -> Any:
|
||||
if hasattr(self.native.dt, "total_seconds"):
|
||||
return self.native.dt.total_seconds()
|
||||
return ( # pragma: no cover
|
||||
self.native.dt.days * SECONDS_PER_DAY
|
||||
+ self.native.dt.seconds
|
||||
+ (self.native.dt.microseconds / US_PER_SECOND)
|
||||
+ (self.native.dt.nanoseconds / NS_PER_SECOND)
|
||||
)
|
||||
|
||||
def total_minutes(self) -> PandasLikeSeries:
|
||||
s = self._get_total_seconds()
|
||||
# this calculates the sign of each series element
|
||||
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
|
||||
s_abs = s.abs() // 60
|
||||
if ~s.isna().any():
|
||||
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
|
||||
return self.with_native(s_abs * s_sign)
|
||||
|
||||
def total_seconds(self) -> PandasLikeSeries:
|
||||
s = self._get_total_seconds()
|
||||
# this calculates the sign of each series element
|
||||
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
|
||||
s_abs = s.abs() // 1
|
||||
if ~s.isna().any():
|
||||
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
|
||||
return self.with_native(s_abs * s_sign)
|
||||
|
||||
def total_milliseconds(self) -> PandasLikeSeries:
|
||||
s = self._get_total_seconds() * MS_PER_SECOND
|
||||
# this calculates the sign of each series element
|
||||
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
|
||||
s_abs = s.abs() // 1
|
||||
if ~s.isna().any():
|
||||
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
|
||||
return self.with_native(s_abs * s_sign)
|
||||
|
||||
def total_microseconds(self) -> PandasLikeSeries:
|
||||
s = self._get_total_seconds() * US_PER_SECOND
|
||||
# this calculates the sign of each series element
|
||||
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
|
||||
s_abs = s.abs() // 1
|
||||
if ~s.isna().any():
|
||||
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
|
||||
return self.with_native(s_abs * s_sign)
|
||||
|
||||
def total_nanoseconds(self) -> PandasLikeSeries:
|
||||
s = self._get_total_seconds() * NS_PER_SECOND
|
||||
# this calculates the sign of each series element
|
||||
s_sign = 2 * (s > 0).astype(int_dtype_mapper(s.dtype)) - 1
|
||||
s_abs = s.abs() // 1
|
||||
if ~s.isna().any():
|
||||
s_abs = s_abs.astype(int_dtype_mapper(s.dtype))
|
||||
return self.with_native(s_abs * s_sign)
|
||||
|
||||
def to_string(self, format: str) -> PandasLikeSeries:
|
||||
# Polars' parser treats `'%.f'` as pandas does `'.%f'`
|
||||
# PyArrow interprets `'%S'` as "seconds, plus fractional seconds"
|
||||
# and doesn't support `%f`
|
||||
if not self._is_pyarrow():
|
||||
format = format.replace("%S%.f", "%S.%f")
|
||||
else:
|
||||
format = format.replace("%S.%f", "%S").replace("%S%.f", "%S")
|
||||
return self.with_native(self.native.dt.strftime(format))
|
||||
|
||||
def replace_time_zone(self, time_zone: str | None) -> PandasLikeSeries:
|
||||
de_zone = self.native.dt.tz_localize(None)
|
||||
result = de_zone.dt.tz_localize(time_zone) if time_zone is not None else de_zone
|
||||
return self.with_native(result)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> PandasLikeSeries:
|
||||
if self.compliant.dtype.time_zone is None: # type: ignore[attr-defined]
|
||||
result = self.native.dt.tz_localize("UTC").dt.tz_convert(time_zone)
|
||||
else:
|
||||
result = self.native.dt.tz_convert(time_zone)
|
||||
return self.with_native(result)
|
||||
|
||||
def timestamp(self, time_unit: TimeUnit) -> PandasLikeSeries:
|
||||
s = self.native
|
||||
dtype = self.compliant.dtype
|
||||
mask_na = s.isna()
|
||||
dtypes = self.version.dtypes
|
||||
if dtype == dtypes.Date:
|
||||
# Date is only supported in pandas dtypes if pyarrow-backed
|
||||
s_cast = s.astype("Int32[pyarrow]")
|
||||
result = calculate_timestamp_date(s_cast, time_unit)
|
||||
elif isinstance(dtype, dtypes.Datetime):
|
||||
fn = (
|
||||
s.view
|
||||
if (self.implementation.is_pandas() and self.backend_version < (2,))
|
||||
else s.astype
|
||||
)
|
||||
s_cast = fn("Int64[pyarrow]") if self._is_pyarrow() else fn("int64")
|
||||
result = calculate_timestamp_datetime(s_cast, dtype.time_unit, time_unit)
|
||||
else:
|
||||
msg = "Input should be either of Date or Datetime type"
|
||||
raise TypeError(msg)
|
||||
result[mask_na] = None
|
||||
return self.with_native(result)
|
||||
|
||||
def truncate(self, every: str) -> PandasLikeSeries:
|
||||
interval = Interval.parse(every)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
native = self.native
|
||||
if self.implementation.is_cudf():
|
||||
if multiple != 1:
|
||||
msg = f"Only multiple `1` is supported for cuDF, got: {multiple}."
|
||||
raise NotImplementedError(msg)
|
||||
return self.with_native(self.native.dt.floor(ALIAS_DICT.get(unit, unit)))
|
||||
dtype_backend = get_dtype_backend(native.dtype, self.compliant._implementation)
|
||||
if unit in {"mo", "q", "y"}:
|
||||
if self.implementation.is_cudf():
|
||||
msg = f"Truncating to {unit} is not supported yet for cuDF."
|
||||
raise NotImplementedError(msg)
|
||||
if dtype_backend == "pyarrow":
|
||||
import pyarrow.compute as pc # ignore-banned-import
|
||||
|
||||
ca = native.array._pa_array
|
||||
result_arr = pc.floor_temporal(ca, multiple, UNITS_DICT[unit])
|
||||
else:
|
||||
if unit == "q":
|
||||
multiple *= 3
|
||||
np_unit = "M"
|
||||
elif unit == "mo":
|
||||
np_unit = "M"
|
||||
else:
|
||||
np_unit = "Y"
|
||||
arr = native.values # noqa: PD011
|
||||
arr_dtype = arr.dtype
|
||||
result_arr = arr.astype(f"datetime64[{multiple}{np_unit}]").astype(
|
||||
arr_dtype
|
||||
)
|
||||
result_native = type(native)(
|
||||
result_arr, dtype=native.dtype, index=native.index, name=native.name
|
||||
)
|
||||
return self.with_native(result_native)
|
||||
return self.with_native(
|
||||
self.native.dt.floor(f"{multiple}{ALIAS_DICT.get(unit, unit)}")
|
||||
)
|
||||
|
||||
def offset_by(self, by: str) -> PandasLikeSeries:
|
||||
native = self.native
|
||||
pdx = self.compliant.__native_namespace__()
|
||||
if self._is_pyarrow():
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
compliant = self.compliant
|
||||
ca = pa.chunked_array([compliant.to_arrow()]) # type: ignore[arg-type]
|
||||
result = (
|
||||
compliant._version.namespace.from_backend("pyarrow")
|
||||
.compliant.from_native(ca)
|
||||
.dt.offset_by(by)
|
||||
.native
|
||||
)
|
||||
result_pd = native.__class__(
|
||||
result, dtype=native.dtype, index=native.index, name=native.name
|
||||
)
|
||||
else:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if unit == "q":
|
||||
multiple *= 3
|
||||
unit = "mo"
|
||||
offset: pd.DateOffset | timedelta
|
||||
if unit == "y":
|
||||
offset = pdx.DateOffset(years=multiple)
|
||||
elif unit == "mo":
|
||||
offset = pdx.DateOffset(months=multiple)
|
||||
elif unit == "ns":
|
||||
offset = pdx.Timedelta(multiple, unit=UNITS_DICT[unit])
|
||||
else:
|
||||
offset = interval.to_timedelta()
|
||||
dtype = self.compliant.dtype
|
||||
datetime_dtype = self.version.dtypes.Datetime
|
||||
if unit == "d" and isinstance(dtype, datetime_dtype) and dtype.time_zone:
|
||||
native_without_timezone = native.dt.tz_localize(None)
|
||||
result_pd = native_without_timezone + offset
|
||||
result_pd = result_pd.dt.tz_localize(dtype.time_zone)
|
||||
else:
|
||||
result_pd = native + offset
|
||||
|
||||
return self.with_native(result_pd)
|
||||
@ -0,0 +1,42 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant.any_namespace import ListNamespace
|
||||
from narwhals._pandas_like.utils import (
|
||||
PandasLikeSeriesNamespace,
|
||||
get_dtype_backend,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
|
||||
|
||||
class PandasLikeSeriesListNamespace(
|
||||
PandasLikeSeriesNamespace, ListNamespace["PandasLikeSeries"]
|
||||
):
|
||||
def len(self) -> PandasLikeSeries:
|
||||
result = self.native.list.len()
|
||||
implementation = self.implementation
|
||||
backend_version = self.backend_version
|
||||
if implementation.is_pandas() and backend_version < (3, 0): # pragma: no cover
|
||||
# `result` is a new object so it's safe to do this inplace.
|
||||
result.index = self.native.index
|
||||
dtype = narwhals_to_native_dtype(
|
||||
self.version.dtypes.UInt32(),
|
||||
get_dtype_backend(result.dtype, implementation),
|
||||
implementation,
|
||||
self.version,
|
||||
)
|
||||
return self.with_native(result.astype(dtype)).alias(self.native.name)
|
||||
|
||||
unique = not_implemented()
|
||||
|
||||
contains = not_implemented()
|
||||
|
||||
def get(self, index: int) -> PandasLikeSeries:
|
||||
result = self.native.list[index]
|
||||
result.name = self.native.name
|
||||
return self.with_native(result)
|
||||
@ -0,0 +1,92 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._compliant.any_namespace import StringNamespace
|
||||
from narwhals._pandas_like.utils import PandasLikeSeriesNamespace, is_dtype_pyarrow
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
|
||||
|
||||
class PandasLikeSeriesStringNamespace(
|
||||
PandasLikeSeriesNamespace, StringNamespace["PandasLikeSeries"]
|
||||
):
|
||||
def len_chars(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.len())
|
||||
|
||||
def replace(
|
||||
self, pattern: str, value: str, *, literal: bool, n: int
|
||||
) -> PandasLikeSeries:
|
||||
try:
|
||||
series = self.native.str.replace(
|
||||
pat=pattern, repl=value, n=n, regex=not literal
|
||||
)
|
||||
except TypeError as e:
|
||||
if not isinstance(value, str):
|
||||
msg = f"{self.compliant._implementation} backed `.str.replace` only supports str replacement values"
|
||||
raise TypeError(msg) from e
|
||||
raise
|
||||
return self.with_native(series)
|
||||
|
||||
def replace_all(self, pattern: str, value: str, *, literal: bool) -> PandasLikeSeries:
|
||||
return self.replace(pattern, value, literal=literal, n=-1)
|
||||
|
||||
def strip_chars(self, characters: str | None) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.strip(characters))
|
||||
|
||||
def starts_with(self, prefix: str) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.startswith(prefix))
|
||||
|
||||
def ends_with(self, suffix: str) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.endswith(suffix))
|
||||
|
||||
def contains(self, pattern: str, *, literal: bool) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.contains(pat=pattern, regex=not literal))
|
||||
|
||||
def slice(self, offset: int, length: int | None) -> PandasLikeSeries:
|
||||
stop = offset + length if length else None
|
||||
return self.with_native(self.native.str.slice(start=offset, stop=stop))
|
||||
|
||||
def split(self, by: str) -> PandasLikeSeries:
|
||||
implementation = self.implementation
|
||||
if not implementation.is_cudf() and not is_dtype_pyarrow(self.native.dtype):
|
||||
msg = (
|
||||
"This operation requires a pyarrow-backed series. "
|
||||
"Please refer to https://narwhals-dev.github.io/narwhals/api-reference/narwhals/#narwhals.maybe_convert_dtypes "
|
||||
"and ensure you are using dtype_backend='pyarrow'. "
|
||||
"Additionally, make sure you have pandas version 1.5+ and pyarrow installed. "
|
||||
)
|
||||
raise TypeError(msg)
|
||||
return self.with_native(self.native.str.split(pat=by))
|
||||
|
||||
def to_datetime(self, format: str | None) -> PandasLikeSeries:
|
||||
# If we know inputs are timezone-aware, we can pass `utc=True` for better performance.
|
||||
if format and any(x in format for x in ("%z", "Z")):
|
||||
return self.with_native(self._to_datetime(format, utc=True))
|
||||
result = self.with_native(self._to_datetime(format, utc=False))
|
||||
if (tz := getattr(result.dtype, "time_zone", None)) and tz != "UTC":
|
||||
return result.dt.convert_time_zone("UTC")
|
||||
return result
|
||||
|
||||
def _to_datetime(self, format: str | None, *, utc: bool) -> Any:
|
||||
result = self.implementation.to_native_namespace().to_datetime(
|
||||
self.native, format=format, utc=utc
|
||||
)
|
||||
return (
|
||||
result.convert_dtypes(dtype_backend="pyarrow")
|
||||
if is_dtype_pyarrow(self.native.dtype)
|
||||
else result
|
||||
)
|
||||
|
||||
def to_date(self, format: str | None) -> PandasLikeSeries:
|
||||
return self.to_datetime(format=format).dt.date()
|
||||
|
||||
def to_uppercase(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.upper())
|
||||
|
||||
def to_lowercase(self) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.lower())
|
||||
|
||||
def zfill(self, width: int) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.str.zfill(width))
|
||||
@ -0,0 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant.any_namespace import StructNamespace
|
||||
from narwhals._pandas_like.utils import PandasLikeSeriesNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
|
||||
|
||||
class PandasLikeSeriesStructNamespace(
|
||||
PandasLikeSeriesNamespace, StructNamespace["PandasLikeSeries"]
|
||||
):
|
||||
def field(self, name: str) -> PandasLikeSeries:
|
||||
return self.with_native(self.native.struct.field(name)).alias(name)
|
||||
43
lib/python3.11/site-packages/narwhals/_pandas_like/typing.py
Normal file
43
lib/python3.11/site-packages/narwhals/_pandas_like/typing.py
Normal file
@ -0,0 +1,43 @@
|
||||
from __future__ import annotations # pragma: no cover
|
||||
|
||||
from typing import TYPE_CHECKING # pragma: no cover
|
||||
|
||||
from narwhals._typing_compat import TypeVar
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from typing import Any
|
||||
|
||||
import pandas as pd
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._namespace import (
|
||||
_CuDFDataFrame,
|
||||
_CuDFSeries,
|
||||
_ModinDataFrame,
|
||||
_ModinSeries,
|
||||
_NativePandasLikeDataFrame,
|
||||
)
|
||||
from narwhals._pandas_like.expr import PandasLikeExpr
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
|
||||
IntoPandasLikeExpr: TypeAlias = "PandasLikeExpr | PandasLikeSeries"
|
||||
|
||||
NativeSeriesT = TypeVar(
|
||||
"NativeSeriesT",
|
||||
"pd.Series[Any]",
|
||||
"_CuDFSeries",
|
||||
"_ModinSeries",
|
||||
default="pd.Series[Any]",
|
||||
)
|
||||
NativeDataFrameT = TypeVar(
|
||||
"NativeDataFrameT", bound="_NativePandasLikeDataFrame", default="pd.DataFrame"
|
||||
)
|
||||
NativeNDFrameT = TypeVar(
|
||||
"NativeNDFrameT",
|
||||
"pd.DataFrame",
|
||||
"pd.Series[Any]",
|
||||
"_CuDFDataFrame",
|
||||
"_CuDFSeries",
|
||||
"_ModinDataFrame",
|
||||
"_ModinSeries",
|
||||
)
|
||||
668
lib/python3.11/site-packages/narwhals/_pandas_like/utils.py
Normal file
668
lib/python3.11/site-packages/narwhals/_pandas_like/utils.py
Normal file
@ -0,0 +1,668 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import functools
|
||||
import operator
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal, TypeVar
|
||||
|
||||
import pandas as pd
|
||||
|
||||
from narwhals._compliant import EagerSeriesNamespace
|
||||
from narwhals._constants import (
|
||||
MS_PER_SECOND,
|
||||
NS_PER_MICROSECOND,
|
||||
NS_PER_MILLISECOND,
|
||||
NS_PER_SECOND,
|
||||
SECONDS_PER_DAY,
|
||||
US_PER_SECOND,
|
||||
)
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
_DeferredIterable,
|
||||
check_columns_exist,
|
||||
isinstance_or_issubclass,
|
||||
)
|
||||
from narwhals.exceptions import ShapeError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping
|
||||
from types import ModuleType
|
||||
|
||||
from pandas._typing import Dtype as PandasDtype
|
||||
from pandas.core.dtypes.dtypes import BaseMaskedDtype
|
||||
from typing_extensions import TypeAlias, TypeIs
|
||||
|
||||
from narwhals._duration import IntervalUnit
|
||||
from narwhals._pandas_like.expr import PandasLikeExpr
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
from narwhals._pandas_like.typing import (
|
||||
NativeDataFrameT,
|
||||
NativeNDFrameT,
|
||||
NativeSeriesT,
|
||||
)
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import DTypeBackend, IntoDType, TimeUnit, _1DArray
|
||||
|
||||
ExprT = TypeVar("ExprT", bound=PandasLikeExpr)
|
||||
UnitCurrent: TypeAlias = TimeUnit
|
||||
UnitTarget: TypeAlias = TimeUnit
|
||||
BinOpBroadcast: TypeAlias = Callable[[Any, int], Any]
|
||||
IntoRhs: TypeAlias = int
|
||||
|
||||
|
||||
PANDAS_LIKE_IMPLEMENTATION = {
|
||||
Implementation.PANDAS,
|
||||
Implementation.CUDF,
|
||||
Implementation.MODIN,
|
||||
}
|
||||
PD_DATETIME_RGX = r"""^
|
||||
datetime64\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
(?:, # Begin non-capturing group for optional timezone
|
||||
\s* # Optional whitespace after comma
|
||||
(?P<time_zone> # Start named group for timezone
|
||||
[a-zA-Z\/]+ # Match timezone name, e.g., UTC, America/New_York
|
||||
(?:[+-]\d{2}:\d{2})? # Optional offset in format +HH:MM or -HH:MM
|
||||
| # OR
|
||||
pytz\.FixedOffset\(\d+\) # Match pytz.FixedOffset with integer offset in parentheses
|
||||
) # End time_zone group
|
||||
)? # End optional timezone group
|
||||
\] # Closing bracket for datetime64
|
||||
$"""
|
||||
PATTERN_PD_DATETIME = re.compile(PD_DATETIME_RGX, re.VERBOSE)
|
||||
PA_DATETIME_RGX = r"""^
|
||||
timestamp\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
(?:, # Begin non-capturing group for optional timezone
|
||||
\s?tz= # Match "tz=" prefix
|
||||
(?P<time_zone> # Start named group for timezone
|
||||
[a-zA-Z\/]* # Match timezone name (e.g., UTC, America/New_York)
|
||||
(?: # Begin optional non-capturing group for offset
|
||||
[+-]\d{2}:\d{2} # Match offset in format +HH:MM or -HH:MM
|
||||
)? # End optional offset group
|
||||
) # End time_zone group
|
||||
)? # End optional timezone group
|
||||
\] # Closing bracket for timestamp
|
||||
\[pyarrow\] # Literal string "[pyarrow]"
|
||||
$"""
|
||||
PATTERN_PA_DATETIME = re.compile(PA_DATETIME_RGX, re.VERBOSE)
|
||||
PD_DURATION_RGX = r"""^
|
||||
timedelta64\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
\] # Closing bracket for timedelta64
|
||||
$"""
|
||||
|
||||
PATTERN_PD_DURATION = re.compile(PD_DURATION_RGX, re.VERBOSE)
|
||||
PA_DURATION_RGX = r"""^
|
||||
duration\[
|
||||
(?P<time_unit>s|ms|us|ns) # Match time unit: s, ms, us, or ns
|
||||
\] # Closing bracket for duration
|
||||
\[pyarrow\] # Literal string "[pyarrow]"
|
||||
$"""
|
||||
PATTERN_PA_DURATION = re.compile(PA_DURATION_RGX, re.VERBOSE)
|
||||
|
||||
NativeIntervalUnit: TypeAlias = Literal[
|
||||
"year",
|
||||
"quarter",
|
||||
"month",
|
||||
"week",
|
||||
"day",
|
||||
"hour",
|
||||
"minute",
|
||||
"second",
|
||||
"millisecond",
|
||||
"microsecond",
|
||||
"nanosecond",
|
||||
]
|
||||
ALIAS_DICT = {"d": "D", "m": "min"}
|
||||
UNITS_DICT: Mapping[IntervalUnit, NativeIntervalUnit] = {
|
||||
"y": "year",
|
||||
"q": "quarter",
|
||||
"mo": "month",
|
||||
"d": "day",
|
||||
"h": "hour",
|
||||
"m": "minute",
|
||||
"s": "second",
|
||||
"ms": "millisecond",
|
||||
"us": "microsecond",
|
||||
"ns": "nanosecond",
|
||||
}
|
||||
|
||||
PANDAS_VERSION = Implementation.PANDAS._backend_version()
|
||||
"""Static backend version for `pandas`.
|
||||
|
||||
Always available if we reached here, due to a module-level import.
|
||||
"""
|
||||
|
||||
|
||||
def is_pandas_or_modin(implementation: Implementation) -> bool:
|
||||
return implementation in {Implementation.PANDAS, Implementation.MODIN}
|
||||
|
||||
|
||||
def align_and_extract_native(
|
||||
lhs: PandasLikeSeries, rhs: PandasLikeSeries | object
|
||||
) -> tuple[pd.Series[Any] | object, pd.Series[Any] | object]:
|
||||
"""Validate RHS of binary operation.
|
||||
|
||||
If the comparison isn't supported, return `NotImplemented` so that the
|
||||
"right-hand-side" operation (e.g. `__radd__`) can be tried.
|
||||
"""
|
||||
from narwhals._pandas_like.series import PandasLikeSeries
|
||||
|
||||
lhs_index = lhs.native.index
|
||||
|
||||
if lhs._broadcast and isinstance(rhs, PandasLikeSeries) and not rhs._broadcast:
|
||||
return lhs.native.iloc[0], rhs.native
|
||||
|
||||
if isinstance(rhs, PandasLikeSeries):
|
||||
if rhs._broadcast:
|
||||
return (lhs.native, rhs.native.iloc[0])
|
||||
if rhs.native.index is not lhs_index:
|
||||
return (
|
||||
lhs.native,
|
||||
set_index(rhs.native, lhs_index, implementation=rhs._implementation),
|
||||
)
|
||||
return (lhs.native, rhs.native)
|
||||
|
||||
if isinstance(rhs, list):
|
||||
msg = "Expected Series or scalar, got list."
|
||||
raise TypeError(msg)
|
||||
# `rhs` must be scalar, so just leave it as-is
|
||||
return lhs.native, rhs
|
||||
|
||||
|
||||
def set_index(
|
||||
obj: NativeNDFrameT, index: Any, *, implementation: Implementation
|
||||
) -> NativeNDFrameT:
|
||||
"""Wrapper around pandas' set_axis to set object index.
|
||||
|
||||
We can set `copy` / `inplace` based on implementation/version.
|
||||
"""
|
||||
if isinstance(index, implementation.to_native_namespace().Index) and (
|
||||
expected_len := len(index)
|
||||
) != (actual_len := len(obj)):
|
||||
msg = f"Expected object of length {expected_len}, got length: {actual_len}"
|
||||
raise ShapeError(msg)
|
||||
if implementation is Implementation.CUDF:
|
||||
obj = obj.copy(deep=False)
|
||||
obj.index = index
|
||||
return obj
|
||||
if implementation is Implementation.PANDAS and (
|
||||
(1, 5) <= implementation._backend_version() < (3,)
|
||||
): # pragma: no cover
|
||||
return obj.set_axis(index, axis=0, copy=False)
|
||||
return obj.set_axis(index, axis=0) # pragma: no cover
|
||||
|
||||
|
||||
def rename(
|
||||
obj: NativeNDFrameT, *args: Any, implementation: Implementation, **kwargs: Any
|
||||
) -> NativeNDFrameT:
|
||||
"""Wrapper around pandas' rename so that we can set `copy` based on implementation/version."""
|
||||
if implementation is Implementation.PANDAS and (
|
||||
implementation._backend_version() >= (3,)
|
||||
): # pragma: no cover
|
||||
return obj.rename(*args, **kwargs, inplace=False)
|
||||
return obj.rename(*args, **kwargs, copy=False, inplace=False)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def non_object_native_to_narwhals_dtype(native_dtype: Any, version: Version) -> DType: # noqa: C901, PLR0912
|
||||
dtype = str(native_dtype)
|
||||
|
||||
dtypes = version.dtypes
|
||||
if dtype in {"int64", "Int64", "Int64[pyarrow]", "int64[pyarrow]"}:
|
||||
return dtypes.Int64()
|
||||
if dtype in {"int32", "Int32", "Int32[pyarrow]", "int32[pyarrow]"}:
|
||||
return dtypes.Int32()
|
||||
if dtype in {"int16", "Int16", "Int16[pyarrow]", "int16[pyarrow]"}:
|
||||
return dtypes.Int16()
|
||||
if dtype in {"int8", "Int8", "Int8[pyarrow]", "int8[pyarrow]"}:
|
||||
return dtypes.Int8()
|
||||
if dtype in {"uint64", "UInt64", "UInt64[pyarrow]", "uint64[pyarrow]"}:
|
||||
return dtypes.UInt64()
|
||||
if dtype in {"uint32", "UInt32", "UInt32[pyarrow]", "uint32[pyarrow]"}:
|
||||
return dtypes.UInt32()
|
||||
if dtype in {"uint16", "UInt16", "UInt16[pyarrow]", "uint16[pyarrow]"}:
|
||||
return dtypes.UInt16()
|
||||
if dtype in {"uint8", "UInt8", "UInt8[pyarrow]", "uint8[pyarrow]"}:
|
||||
return dtypes.UInt8()
|
||||
if dtype in {
|
||||
"float64",
|
||||
"Float64",
|
||||
"Float64[pyarrow]",
|
||||
"float64[pyarrow]",
|
||||
"double[pyarrow]",
|
||||
}:
|
||||
return dtypes.Float64()
|
||||
if dtype in {
|
||||
"float32",
|
||||
"Float32",
|
||||
"Float32[pyarrow]",
|
||||
"float32[pyarrow]",
|
||||
"float[pyarrow]",
|
||||
}:
|
||||
return dtypes.Float32()
|
||||
if dtype in {
|
||||
# "there is no problem which can't be solved by adding an extra string type" pandas
|
||||
"string",
|
||||
"string[python]",
|
||||
"string[pyarrow]",
|
||||
"string[pyarrow_numpy]",
|
||||
"large_string[pyarrow]",
|
||||
"str",
|
||||
}:
|
||||
return dtypes.String()
|
||||
if dtype in {"bool", "boolean", "boolean[pyarrow]", "bool[pyarrow]"}:
|
||||
return dtypes.Boolean()
|
||||
if dtype.startswith("dictionary<"):
|
||||
return dtypes.Categorical()
|
||||
if dtype == "category":
|
||||
return native_categorical_to_narwhals_dtype(native_dtype, version)
|
||||
if (match_ := PATTERN_PD_DATETIME.match(dtype)) or (
|
||||
match_ := PATTERN_PA_DATETIME.match(dtype)
|
||||
):
|
||||
dt_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
|
||||
dt_time_zone: str | None = match_.group("time_zone")
|
||||
return dtypes.Datetime(dt_time_unit, dt_time_zone)
|
||||
if (match_ := PATTERN_PD_DURATION.match(dtype)) or (
|
||||
match_ := PATTERN_PA_DURATION.match(dtype)
|
||||
):
|
||||
du_time_unit: TimeUnit = match_.group("time_unit") # type: ignore[assignment]
|
||||
return dtypes.Duration(du_time_unit)
|
||||
if dtype == "date32[day][pyarrow]":
|
||||
return dtypes.Date()
|
||||
if dtype.startswith("decimal") and dtype.endswith("[pyarrow]"):
|
||||
return dtypes.Decimal()
|
||||
if dtype.startswith("time") and dtype.endswith("[pyarrow]"):
|
||||
return dtypes.Time()
|
||||
if dtype.startswith("binary") and dtype.endswith("[pyarrow]"):
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def object_native_to_narwhals_dtype(
|
||||
series: PandasLikeSeries | None, version: Version, implementation: Implementation
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if implementation is Implementation.CUDF:
|
||||
# Per conversations with their maintainers, they don't support arbitrary
|
||||
# objects, so we can just return String.
|
||||
return dtypes.String()
|
||||
|
||||
infer = pd.api.types.infer_dtype
|
||||
# Arbitrary limit of 100 elements to use to sniff dtype.
|
||||
inferred_dtype = "empty" if series is None else infer(series.head(100), skipna=True)
|
||||
if inferred_dtype == "string":
|
||||
return dtypes.String()
|
||||
if inferred_dtype == "empty" and version is not Version.V1:
|
||||
# Default to String for empty Series.
|
||||
return dtypes.String()
|
||||
if inferred_dtype == "empty":
|
||||
# But preserve returning Object in V1.
|
||||
return dtypes.Object()
|
||||
return dtypes.Object()
|
||||
|
||||
|
||||
def native_categorical_to_narwhals_dtype(
|
||||
native_dtype: pd.CategoricalDtype,
|
||||
version: Version,
|
||||
implementation: Literal[Implementation.CUDF] | None = None,
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if version is Version.V1:
|
||||
return dtypes.Categorical()
|
||||
if native_dtype.ordered:
|
||||
into_iter = (
|
||||
_cudf_categorical_to_list(native_dtype)
|
||||
if implementation is Implementation.CUDF
|
||||
else native_dtype.categories.to_list
|
||||
)
|
||||
return dtypes.Enum(_DeferredIterable(into_iter))
|
||||
return dtypes.Categorical()
|
||||
|
||||
|
||||
def _cudf_categorical_to_list(
|
||||
native_dtype: Any,
|
||||
) -> Callable[[], list[Any]]: # pragma: no cover
|
||||
# NOTE: https://docs.rapids.ai/api/cudf/stable/user_guide/api_docs/api/cudf.core.dtypes.categoricaldtype/#cudf.core.dtypes.CategoricalDtype
|
||||
def fn() -> list[Any]:
|
||||
return native_dtype.categories.to_arrow().to_pylist()
|
||||
|
||||
return fn
|
||||
|
||||
|
||||
def native_to_narwhals_dtype(
|
||||
native_dtype: Any,
|
||||
version: Version,
|
||||
implementation: Implementation,
|
||||
*,
|
||||
allow_object: bool = False,
|
||||
) -> DType:
|
||||
str_dtype = str(native_dtype)
|
||||
|
||||
if str_dtype.startswith(("large_list", "list", "struct", "fixed_size_list")):
|
||||
from narwhals._arrow.utils import (
|
||||
native_to_narwhals_dtype as arrow_native_to_narwhals_dtype,
|
||||
)
|
||||
|
||||
if hasattr(native_dtype, "to_arrow"): # pragma: no cover
|
||||
# cudf, cudf.pandas
|
||||
return arrow_native_to_narwhals_dtype(native_dtype.to_arrow(), version)
|
||||
return arrow_native_to_narwhals_dtype(native_dtype.pyarrow_dtype, version)
|
||||
if str_dtype == "category" and implementation.is_cudf():
|
||||
# https://github.com/rapidsai/cudf/issues/18536
|
||||
# https://github.com/rapidsai/cudf/issues/14027
|
||||
return native_categorical_to_narwhals_dtype(
|
||||
native_dtype, version, Implementation.CUDF
|
||||
)
|
||||
if str_dtype != "object":
|
||||
return non_object_native_to_narwhals_dtype(native_dtype, version)
|
||||
if implementation is Implementation.DASK:
|
||||
# Per conversations with their maintainers, they don't support arbitrary
|
||||
# objects, so we can just return String.
|
||||
return version.dtypes.String()
|
||||
if allow_object:
|
||||
return object_native_to_narwhals_dtype(None, version, implementation)
|
||||
msg = (
|
||||
"Unreachable code, object dtype should be handled separately" # pragma: no cover
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
if Implementation.PANDAS._backend_version() >= (1, 2):
|
||||
|
||||
def is_dtype_numpy_nullable(dtype: Any) -> TypeIs[BaseMaskedDtype]:
|
||||
"""Return `True` if `dtype` is `"numpy_nullable"`."""
|
||||
# NOTE: We need a sentinel as the positive case is `BaseMaskedDtype.base = None`
|
||||
# See https://github.com/narwhals-dev/narwhals/pull/2740#discussion_r2171667055
|
||||
sentinel = object()
|
||||
return (
|
||||
isinstance(dtype, pd.api.extensions.ExtensionDtype)
|
||||
and getattr(dtype, "base", sentinel) is None
|
||||
)
|
||||
else: # pragma: no cover
|
||||
|
||||
def is_dtype_numpy_nullable(dtype: Any) -> TypeIs[BaseMaskedDtype]:
|
||||
# NOTE: `base` attribute was added between 1.1-1.2
|
||||
# Checking by isinstance requires using an import path that is no longer valid
|
||||
# `1.1`: https://github.com/pandas-dev/pandas/blob/b5958ee1999e9aead1938c0bba2b674378807b3d/pandas/core/arrays/masked.py#L37
|
||||
# `1.2`: https://github.com/pandas-dev/pandas/blob/7c48ff4409c622c582c56a5702373f726de08e96/pandas/core/arrays/masked.py#L41
|
||||
# `1.5`: https://github.com/pandas-dev/pandas/blob/35b0d1dcadf9d60722c055ee37442dc76a29e64c/pandas/core/dtypes/dtypes.py#L1609
|
||||
if isinstance(dtype, pd.api.extensions.ExtensionDtype):
|
||||
from pandas.core.arrays.masked import ( # type: ignore[attr-defined]
|
||||
BaseMaskedDtype as OldBaseMaskedDtype, # pyright: ignore[reportAttributeAccessIssue]
|
||||
)
|
||||
|
||||
return isinstance(dtype, OldBaseMaskedDtype)
|
||||
return False
|
||||
|
||||
|
||||
def get_dtype_backend(dtype: Any, implementation: Implementation) -> DTypeBackend:
|
||||
"""Get dtype backend for pandas type.
|
||||
|
||||
Matches pandas' `dtype_backend` argument in `convert_dtypes`.
|
||||
"""
|
||||
if implementation is Implementation.CUDF:
|
||||
return None
|
||||
if is_dtype_pyarrow(dtype):
|
||||
return "pyarrow"
|
||||
return "numpy_nullable" if is_dtype_numpy_nullable(dtype) else None
|
||||
|
||||
|
||||
# NOTE: Use this to avoid annotating inline
|
||||
def iter_dtype_backends(
|
||||
dtypes: Iterable[Any], implementation: Implementation
|
||||
) -> Iterator[DTypeBackend]:
|
||||
"""Yield a `DTypeBackend` per-dtype.
|
||||
|
||||
Matches pandas' `dtype_backend` argument in `convert_dtypes`.
|
||||
"""
|
||||
return (get_dtype_backend(dtype, implementation) for dtype in dtypes)
|
||||
|
||||
|
||||
@functools.lru_cache(maxsize=16)
|
||||
def is_dtype_pyarrow(dtype: Any) -> TypeIs[pd.ArrowDtype]:
|
||||
return hasattr(pd, "ArrowDtype") and isinstance(dtype, pd.ArrowDtype)
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_PD_DTYPES_INVARIANT: Mapping[type[DType], str] = {
|
||||
# TODO(Unassigned): is there no pyarrow-backed categorical?
|
||||
# or at least, convert_dtypes(dtype_backend='pyarrow') doesn't
|
||||
# convert to it?
|
||||
dtypes.Categorical: "category",
|
||||
dtypes.Object: "object",
|
||||
}
|
||||
NW_TO_PD_DTYPES_BACKEND: Mapping[type[DType], Mapping[DTypeBackend, str | type[Any]]] = {
|
||||
dtypes.Float64: {
|
||||
"pyarrow": "Float64[pyarrow]",
|
||||
"numpy_nullable": "Float64",
|
||||
None: "float64",
|
||||
},
|
||||
dtypes.Float32: {
|
||||
"pyarrow": "Float32[pyarrow]",
|
||||
"numpy_nullable": "Float32",
|
||||
None: "float32",
|
||||
},
|
||||
dtypes.Int64: {"pyarrow": "Int64[pyarrow]", "numpy_nullable": "Int64", None: "int64"},
|
||||
dtypes.Int32: {"pyarrow": "Int32[pyarrow]", "numpy_nullable": "Int32", None: "int32"},
|
||||
dtypes.Int16: {"pyarrow": "Int16[pyarrow]", "numpy_nullable": "Int16", None: "int16"},
|
||||
dtypes.Int8: {"pyarrow": "Int8[pyarrow]", "numpy_nullable": "Int8", None: "int8"},
|
||||
dtypes.UInt64: {
|
||||
"pyarrow": "UInt64[pyarrow]",
|
||||
"numpy_nullable": "UInt64",
|
||||
None: "uint64",
|
||||
},
|
||||
dtypes.UInt32: {
|
||||
"pyarrow": "UInt32[pyarrow]",
|
||||
"numpy_nullable": "UInt32",
|
||||
None: "uint32",
|
||||
},
|
||||
dtypes.UInt16: {
|
||||
"pyarrow": "UInt16[pyarrow]",
|
||||
"numpy_nullable": "UInt16",
|
||||
None: "uint16",
|
||||
},
|
||||
dtypes.UInt8: {"pyarrow": "UInt8[pyarrow]", "numpy_nullable": "UInt8", None: "uint8"},
|
||||
dtypes.String: {"pyarrow": "string[pyarrow]", "numpy_nullable": "string", None: str},
|
||||
dtypes.Boolean: {
|
||||
"pyarrow": "boolean[pyarrow]",
|
||||
"numpy_nullable": "boolean",
|
||||
None: "bool",
|
||||
},
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: C901, PLR0912
|
||||
dtype: IntoDType,
|
||||
dtype_backend: DTypeBackend,
|
||||
implementation: Implementation,
|
||||
version: Version,
|
||||
) -> str | PandasDtype:
|
||||
if dtype_backend not in {None, "pyarrow", "numpy_nullable"}:
|
||||
msg = f"Expected one of {{None, 'pyarrow', 'numpy_nullable'}}, got: '{dtype_backend}'"
|
||||
raise ValueError(msg)
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if pd_type := NW_TO_PD_DTYPES_INVARIANT.get(base_type):
|
||||
return pd_type
|
||||
if into_pd_type := NW_TO_PD_DTYPES_BACKEND.get(base_type):
|
||||
return into_pd_type[dtype_backend]
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
# Pandas does not support "ms" or "us" time units before version 2.0
|
||||
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
|
||||
2,
|
||||
): # pragma: no cover
|
||||
dt_time_unit = "ns"
|
||||
else:
|
||||
dt_time_unit = dtype.time_unit
|
||||
|
||||
if dtype_backend == "pyarrow":
|
||||
tz_part = f", tz={tz}" if (tz := dtype.time_zone) else ""
|
||||
return f"timestamp[{dt_time_unit}{tz_part}][pyarrow]"
|
||||
tz_part = f", {tz}" if (tz := dtype.time_zone) else ""
|
||||
return f"datetime64[{dt_time_unit}{tz_part}]"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
if is_pandas_or_modin(implementation) and PANDAS_VERSION < (
|
||||
2,
|
||||
): # pragma: no cover
|
||||
du_time_unit = "ns"
|
||||
else:
|
||||
du_time_unit = dtype.time_unit
|
||||
return (
|
||||
f"duration[{du_time_unit}][pyarrow]"
|
||||
if dtype_backend == "pyarrow"
|
||||
else f"timedelta64[{du_time_unit}]"
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Date):
|
||||
try:
|
||||
import pyarrow as pa # ignore-banned-import # noqa: F401
|
||||
except ModuleNotFoundError as exc: # pragma: no cover
|
||||
# BUG: Never re-raised?
|
||||
msg = "'pyarrow>=13.0.0' is required for `Date` dtype."
|
||||
raise ModuleNotFoundError(msg) from exc
|
||||
return "date32[pyarrow]"
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
ns = implementation.to_native_namespace()
|
||||
return ns.CategoricalDtype(dtype.categories, ordered=True)
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if issubclass(
|
||||
base_type, (dtypes.Struct, dtypes.Array, dtypes.List, dtypes.Time, dtypes.Binary)
|
||||
):
|
||||
return narwhals_to_native_arrow_dtype(dtype, implementation, version)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for {implementation}."
|
||||
raise NotImplementedError(msg)
|
||||
msg = f"Unknown dtype: {dtype}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
def narwhals_to_native_arrow_dtype(
|
||||
dtype: IntoDType, implementation: Implementation, version: Version
|
||||
) -> pd.ArrowDtype:
|
||||
if is_pandas_or_modin(implementation) and PANDAS_VERSION >= (2, 2):
|
||||
try:
|
||||
import pyarrow as pa # ignore-banned-import # noqa: F401
|
||||
except ImportError as exc: # pragma: no cover
|
||||
msg = f"Unable to convert to {dtype} to to the following exception: {exc.msg}"
|
||||
raise ImportError(msg) from exc
|
||||
from narwhals._arrow.utils import narwhals_to_native_dtype as _to_arrow_dtype
|
||||
|
||||
return pd.ArrowDtype(_to_arrow_dtype(dtype, version))
|
||||
msg = ( # pragma: no cover
|
||||
f"Converting to {dtype} dtype is not supported for implementation "
|
||||
f"{implementation} and version {version}."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
|
||||
def int_dtype_mapper(dtype: Any) -> str:
|
||||
if "pyarrow" in str(dtype):
|
||||
return "Int64[pyarrow]"
|
||||
if str(dtype).lower() != str(dtype): # pragma: no cover
|
||||
return "Int64"
|
||||
return "int64"
|
||||
|
||||
|
||||
_TIMESTAMP_DATETIME_OP_FACTOR: Mapping[
|
||||
tuple[UnitCurrent, UnitTarget], tuple[BinOpBroadcast, IntoRhs]
|
||||
] = {
|
||||
("ns", "us"): (operator.floordiv, 1_000),
|
||||
("ns", "ms"): (operator.floordiv, 1_000_000),
|
||||
("us", "ns"): (operator.mul, NS_PER_MICROSECOND),
|
||||
("us", "ms"): (operator.floordiv, 1_000),
|
||||
("ms", "ns"): (operator.mul, NS_PER_MILLISECOND),
|
||||
("ms", "us"): (operator.mul, 1_000),
|
||||
("s", "ns"): (operator.mul, NS_PER_SECOND),
|
||||
("s", "us"): (operator.mul, US_PER_SECOND),
|
||||
("s", "ms"): (operator.mul, MS_PER_SECOND),
|
||||
}
|
||||
|
||||
|
||||
def calculate_timestamp_datetime(
|
||||
s: NativeSeriesT, current: TimeUnit, time_unit: TimeUnit
|
||||
) -> NativeSeriesT:
|
||||
if current == time_unit:
|
||||
return s
|
||||
if item := _TIMESTAMP_DATETIME_OP_FACTOR.get((current, time_unit)):
|
||||
fn, factor = item
|
||||
return fn(s, factor)
|
||||
msg = ( # pragma: no cover
|
||||
f"unexpected time unit {current}, please report an issue at "
|
||||
"https://github.com/narwhals-dev/narwhals"
|
||||
)
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
_TIMESTAMP_DATE_FACTOR: Mapping[TimeUnit, int] = {
|
||||
"ns": NS_PER_SECOND,
|
||||
"us": US_PER_SECOND,
|
||||
"ms": MS_PER_SECOND,
|
||||
"s": 1,
|
||||
}
|
||||
|
||||
|
||||
def calculate_timestamp_date(s: NativeSeriesT, time_unit: TimeUnit) -> NativeSeriesT:
|
||||
return s * SECONDS_PER_DAY * _TIMESTAMP_DATE_FACTOR[time_unit]
|
||||
|
||||
|
||||
def select_columns_by_name(
|
||||
df: NativeDataFrameT,
|
||||
column_names: list[str] | _1DArray, # NOTE: Cannot be a tuple!
|
||||
implementation: Implementation,
|
||||
) -> NativeDataFrameT | Any:
|
||||
"""Select columns by name.
|
||||
|
||||
Prefer this over `df.loc[:, column_names]` as it's
|
||||
generally more performant.
|
||||
"""
|
||||
if len(column_names) == df.shape[1] and (df.columns == column_names).all():
|
||||
return df
|
||||
if (df.columns.dtype.kind == "b") or (
|
||||
implementation is Implementation.PANDAS
|
||||
and implementation._backend_version() < (1, 5)
|
||||
):
|
||||
# See https://github.com/narwhals-dev/narwhals/issues/1349#issuecomment-2470118122
|
||||
# for why we need this
|
||||
if error := check_columns_exist(column_names, available=df.columns.tolist()):
|
||||
raise error
|
||||
return df.loc[:, column_names]
|
||||
try:
|
||||
return df[column_names]
|
||||
except KeyError as e:
|
||||
if error := check_columns_exist(column_names, available=df.columns.tolist()):
|
||||
raise error from e
|
||||
raise
|
||||
|
||||
|
||||
def is_non_nullable_boolean(s: PandasLikeSeries) -> bool:
|
||||
# cuDF booleans are nullable but the native dtype is still 'bool'.
|
||||
return (
|
||||
s._implementation
|
||||
in {Implementation.PANDAS, Implementation.MODIN, Implementation.DASK}
|
||||
and s.native.dtype == "bool"
|
||||
)
|
||||
|
||||
|
||||
def import_array_module(implementation: Implementation, /) -> ModuleType:
|
||||
"""Returns numpy or cupy module depending on the given implementation."""
|
||||
if implementation in {Implementation.PANDAS, Implementation.MODIN}:
|
||||
import numpy as np
|
||||
|
||||
return np
|
||||
if implementation is Implementation.CUDF:
|
||||
import cupy as cp # ignore-banned-import # cuDF dependency.
|
||||
|
||||
return cp
|
||||
msg = f"Expected pandas/modin/cudf, got: {implementation}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
|
||||
class PandasLikeSeriesNamespace(EagerSeriesNamespace["PandasLikeSeries", Any]): ...
|
||||
678
lib/python3.11/site-packages/narwhals/_polars/dataframe.py
Normal file
678
lib/python3.11/site-packages/narwhals/_polars/dataframe.py
Normal file
@ -0,0 +1,678 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator, Mapping, Sequence, Sized
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeVar, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.utils import (
|
||||
catch_polars_exception,
|
||||
extract_args_kwargs,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
_into_arrow_table,
|
||||
convert_str_slice_to_int_slice,
|
||||
is_compliant_series,
|
||||
is_index_selector,
|
||||
is_range,
|
||||
is_sequence_like,
|
||||
is_slice_index,
|
||||
is_slice_none,
|
||||
parse_columns_to_drop,
|
||||
requires,
|
||||
)
|
||||
from narwhals.dependencies import is_numpy_array_1d
|
||||
from narwhals.exceptions import ColumnNotFoundError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable
|
||||
from types import ModuleType
|
||||
from typing import Callable
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny, CompliantLazyFrameAny
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.group_by import PolarsGroupBy, PolarsLazyGroupBy
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._translate import IntoArrowTable
|
||||
from narwhals._typing import _EagerAllowedImpl, _LazyAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import DataFrame, LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import (
|
||||
IntoSchema,
|
||||
JoinStrategy,
|
||||
MultiColSelector,
|
||||
MultiIndexSelector,
|
||||
PivotAgg,
|
||||
SingleIndexSelector,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
R = TypeVar("R")
|
||||
|
||||
Method: TypeAlias = "Callable[..., R]"
|
||||
"""Generic alias representing all methods implemented via `__getattr__`.
|
||||
|
||||
Where `R` is the return type.
|
||||
"""
|
||||
|
||||
# DataFrame methods where PolarsDataFrame just defers to Polars.DataFrame directly.
|
||||
INHERITED_METHODS = frozenset(
|
||||
[
|
||||
"clone",
|
||||
"drop_nulls",
|
||||
"estimated_size",
|
||||
"explode",
|
||||
"filter",
|
||||
"gather_every",
|
||||
"head",
|
||||
"is_unique",
|
||||
"item",
|
||||
"iter_rows",
|
||||
"join_asof",
|
||||
"rename",
|
||||
"row",
|
||||
"rows",
|
||||
"sample",
|
||||
"select",
|
||||
"sink_parquet",
|
||||
"sort",
|
||||
"tail",
|
||||
"to_arrow",
|
||||
"to_pandas",
|
||||
"unique",
|
||||
"with_columns",
|
||||
"write_csv",
|
||||
"write_parquet",
|
||||
]
|
||||
)
|
||||
|
||||
NativePolarsFrame = TypeVar("NativePolarsFrame", pl.DataFrame, pl.LazyFrame)
|
||||
|
||||
|
||||
class PolarsBaseFrame(Generic[NativePolarsFrame]):
|
||||
drop_nulls: Method[Self]
|
||||
explode: Method[Self]
|
||||
filter: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
join_asof: Method[Self]
|
||||
rename: Method[Self]
|
||||
select: Method[Self]
|
||||
sort: Method[Self]
|
||||
tail: Method[Self]
|
||||
unique: Method[Self]
|
||||
with_columns: Method[Self]
|
||||
|
||||
_native_frame: NativePolarsFrame
|
||||
_implementation = Implementation.POLARS
|
||||
_version: Version
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: NativePolarsFrame,
|
||||
*,
|
||||
version: Version,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame = df
|
||||
self._version = version
|
||||
if validate_backend_version:
|
||||
self._validate_backend_version()
|
||||
|
||||
def _validate_backend_version(self) -> None:
|
||||
"""Raise if installed version below `nw._utils.MIN_VERSIONS`.
|
||||
|
||||
**Only use this when moving between backends.**
|
||||
Otherwise, the validation will have taken place already.
|
||||
"""
|
||||
_ = self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def native(self) -> NativePolarsFrame:
|
||||
return self._native_frame
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
return self.native.columns
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.POLARS:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def _with_native(self, df: NativePolarsFrame) -> Self:
|
||||
return self.__class__(df, version=self._version)
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: NativePolarsFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: Any) -> Self:
|
||||
return self.select(*exprs)
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
return self.collect_schema()
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
how_native = (
|
||||
"outer" if (self._backend_version < (0, 20, 29) and how == "full") else how
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.join(
|
||||
other=other.native,
|
||||
how=how_native, # type: ignore[arg-type]
|
||||
left_on=left_on,
|
||||
right_on=right_on,
|
||||
suffix=suffix,
|
||||
)
|
||||
)
|
||||
|
||||
def top_k(
|
||||
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
|
||||
) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(
|
||||
self.native.top_k(
|
||||
k=k,
|
||||
by=by,
|
||||
descending=reverse, # type: ignore[call-arg]
|
||||
)
|
||||
)
|
||||
return self._with_native(self.native.top_k(k=k, by=by, reverse=reverse))
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(
|
||||
self.native.melt(
|
||||
id_vars=index,
|
||||
value_vars=on,
|
||||
variable_name=variable_name,
|
||||
value_name=value_name,
|
||||
)
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.unpivot(
|
||||
on=on, index=index, variable_name=variable_name, value_name=value_name
|
||||
)
|
||||
)
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
df = self.native
|
||||
schema = df.schema if self._backend_version < (1,) else df.collect_schema()
|
||||
return {
|
||||
name: native_to_narwhals_dtype(dtype, self._version)
|
||||
for name, dtype in schema.items()
|
||||
}
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str] | None) -> Self:
|
||||
frame = self.native
|
||||
if order_by is None:
|
||||
result = frame.with_row_index(name)
|
||||
else:
|
||||
end = pl.count() if self._backend_version < (0, 20, 5) else pl.len()
|
||||
result = frame.select(
|
||||
pl.int_range(start=0, end=end).sort_by(order_by).alias(name), pl.all()
|
||||
)
|
||||
|
||||
return self._with_native(result)
|
||||
|
||||
|
||||
class PolarsDataFrame(PolarsBaseFrame[pl.DataFrame]):
|
||||
clone: Method[Self]
|
||||
collect: Method[CompliantDataFrameAny]
|
||||
estimated_size: Method[int | float]
|
||||
gather_every: Method[Self]
|
||||
item: Method[Any]
|
||||
iter_rows: Method[Iterator[tuple[Any, ...]] | Iterator[Mapping[str, Any]]]
|
||||
is_unique: Method[PolarsSeries]
|
||||
row: Method[tuple[Any, ...]]
|
||||
rows: Method[Sequence[tuple[Any, ...]] | Sequence[Mapping[str, Any]]]
|
||||
sample: Method[Self]
|
||||
to_arrow: Method[pa.Table]
|
||||
to_pandas: Method[pd.DataFrame]
|
||||
# NOTE: `write_csv` requires an `@overload` for `str | None`
|
||||
# Can't do that here 😟
|
||||
write_csv: Method[Any]
|
||||
write_parquet: Method[None]
|
||||
|
||||
@classmethod
|
||||
def from_arrow(cls, data: IntoArrowTable, /, *, context: _LimitedContext) -> Self:
|
||||
if context._implementation._backend_version() >= (1, 3):
|
||||
native = pl.DataFrame(data)
|
||||
else: # pragma: no cover
|
||||
native = cast("pl.DataFrame", pl.from_arrow(_into_arrow_table(data, context)))
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@classmethod
|
||||
def from_dict(
|
||||
cls,
|
||||
data: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
schema: IntoSchema | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pl_schema = Schema(schema).to_polars() if schema is not None else schema
|
||||
return cls.from_native(pl.from_dict(data, pl_schema), context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.DataFrame | Any) -> TypeIs[pl.DataFrame]:
|
||||
return isinstance(obj, pl.DataFrame)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(
|
||||
cls,
|
||||
data: _2DArray,
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext, # NOTE: Maybe only `Implementation`?
|
||||
schema: IntoSchema | Sequence[str] | None,
|
||||
) -> Self:
|
||||
from narwhals.schema import Schema
|
||||
|
||||
pl_schema = (
|
||||
Schema(schema).to_polars()
|
||||
if isinstance(schema, (Mapping, Schema))
|
||||
else schema
|
||||
)
|
||||
return cls.from_native(pl.from_numpy(data, pl_schema), context=context)
|
||||
|
||||
def to_narwhals(self) -> DataFrame[pl.DataFrame]:
|
||||
return self._version.dataframe(self, level="full")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsDataFrame"
|
||||
|
||||
def __narwhals_dataframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: pl.Series) -> PolarsSeries: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: pl.DataFrame) -> Self: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, obj: T) -> T: ...
|
||||
|
||||
def _from_native_object(
|
||||
self, obj: pl.Series | pl.DataFrame | T
|
||||
) -> Self | PolarsSeries | T:
|
||||
if isinstance(obj, pl.Series):
|
||||
return PolarsSeries.from_native(obj, context=self)
|
||||
if self._is_native(obj):
|
||||
return self._with_native(obj)
|
||||
# scalar
|
||||
return obj
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS: # pragma: no cover
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
try:
|
||||
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
|
||||
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
|
||||
msg = f"{e!s}\n\nHint: Did you mean one of these columns: {self.columns}?"
|
||||
raise ColumnNotFoundError(msg) from e
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
return func
|
||||
|
||||
def __array__(
|
||||
self, dtype: Any | None = None, *, copy: bool | None = None
|
||||
) -> _2DArray:
|
||||
if self._backend_version < (0, 20, 28) and copy is not None:
|
||||
msg = "`copy` in `__array__` is only supported for 'polars>=0.20.28'"
|
||||
raise NotImplementedError(msg)
|
||||
if self._backend_version < (0, 20, 28):
|
||||
return self.native.__array__(dtype)
|
||||
return self.native.__array__(dtype)
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _2DArray:
|
||||
return self.native.to_numpy()
|
||||
|
||||
@property
|
||||
def shape(self) -> tuple[int, int]:
|
||||
return self.native.shape
|
||||
|
||||
def __getitem__( # noqa: C901, PLR0912
|
||||
self,
|
||||
item: tuple[
|
||||
SingleIndexSelector | MultiIndexSelector[PolarsSeries],
|
||||
MultiColSelector[PolarsSeries],
|
||||
],
|
||||
) -> Any:
|
||||
rows, columns = item
|
||||
if self._backend_version > (0, 20, 30):
|
||||
rows_native = rows.native if is_compliant_series(rows) else rows
|
||||
columns_native = columns.native if is_compliant_series(columns) else columns
|
||||
selector = rows_native, columns_native
|
||||
selected = self.native.__getitem__(selector) # type: ignore[index]
|
||||
return self._from_native_object(selected)
|
||||
else: # pragma: no cover # noqa: RET505
|
||||
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
|
||||
# Polars version we support
|
||||
# This mostly mirrors the logic in `EagerDataFrame.__getitem__`.
|
||||
rows = list(rows) if isinstance(rows, tuple) else rows
|
||||
columns = list(columns) if isinstance(columns, tuple) else columns
|
||||
if is_numpy_array_1d(columns):
|
||||
columns = columns.tolist()
|
||||
|
||||
native = self.native
|
||||
if not is_slice_none(columns):
|
||||
if isinstance(columns, Sized) and len(columns) == 0:
|
||||
return self.select()
|
||||
if is_index_selector(columns):
|
||||
if is_slice_index(columns) or is_range(columns):
|
||||
native = native.select(
|
||||
self.columns[slice(columns.start, columns.stop, columns.step)]
|
||||
)
|
||||
# NOTE: `mypy` loses track of `PolarsSeries` when `is_compliant_series` is used here
|
||||
# `pyright` is fine
|
||||
elif isinstance(columns, PolarsSeries):
|
||||
native = native[:, columns.native.to_list()]
|
||||
else:
|
||||
native = native[:, columns]
|
||||
elif isinstance(columns, slice):
|
||||
native = native.select(
|
||||
self.columns[
|
||||
slice(*convert_str_slice_to_int_slice(columns, self.columns))
|
||||
]
|
||||
)
|
||||
elif is_compliant_series(columns):
|
||||
native = native.select(columns.native.to_list())
|
||||
elif is_sequence_like(columns):
|
||||
native = native.select(columns)
|
||||
else:
|
||||
msg = f"Unreachable code, got unexpected type: {type(columns)}"
|
||||
raise AssertionError(msg)
|
||||
|
||||
if not is_slice_none(rows):
|
||||
if isinstance(rows, int):
|
||||
native = native[[rows], :]
|
||||
elif isinstance(rows, (slice, range)):
|
||||
native = native[rows, :]
|
||||
elif is_compliant_series(rows):
|
||||
native = native[rows.native, :]
|
||||
elif is_sequence_like(rows):
|
||||
native = native[rows, :]
|
||||
else:
|
||||
msg = f"Unreachable code, got unexpected type: {type(rows)}"
|
||||
raise AssertionError(msg)
|
||||
|
||||
return self._with_native(native)
|
||||
|
||||
def get_column(self, name: str) -> PolarsSeries:
|
||||
return PolarsSeries.from_native(self.native.get_column(name), context=self)
|
||||
|
||||
def iter_columns(self) -> Iterator[PolarsSeries]:
|
||||
for series in self.native.iter_columns():
|
||||
yield PolarsSeries.from_native(series, context=self)
|
||||
|
||||
def lazy(
|
||||
self,
|
||||
backend: _LazyAllowedImpl | None = None,
|
||||
*,
|
||||
session: SparkSession | None = None,
|
||||
) -> CompliantLazyFrameAny:
|
||||
if backend is None or backend is Implementation.POLARS:
|
||||
return PolarsLazyFrame.from_native(self.native.lazy(), context=self)
|
||||
if backend is Implementation.DUCKDB:
|
||||
import duckdb # ignore-banned-import
|
||||
|
||||
from narwhals._duckdb.dataframe import DuckDBLazyFrame
|
||||
|
||||
_df = self.native
|
||||
return DuckDBLazyFrame(
|
||||
duckdb.table("_df"), validate_backend_version=True, version=self._version
|
||||
)
|
||||
if backend is Implementation.DASK:
|
||||
import dask.dataframe as dd # ignore-banned-import
|
||||
|
||||
from narwhals._dask.dataframe import DaskLazyFrame
|
||||
|
||||
return DaskLazyFrame(
|
||||
dd.from_pandas(self.native.to_pandas()),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
if backend is Implementation.IBIS:
|
||||
import ibis # ignore-banned-import
|
||||
|
||||
from narwhals._ibis.dataframe import IbisLazyFrame
|
||||
|
||||
return IbisLazyFrame(
|
||||
ibis.memtable(self.native, columns=self.columns),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
if backend.is_spark_like():
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
|
||||
if session is None:
|
||||
msg = "Spark like backends require `session` to be not None."
|
||||
raise ValueError(msg)
|
||||
|
||||
return SparkLikeLazyFrame._from_compliant_dataframe(
|
||||
self, # pyright: ignore[reportArgumentType]
|
||||
session=session,
|
||||
implementation=backend,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
raise AssertionError # pragma: no cover
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[True]) -> dict[str, PolarsSeries]: ...
|
||||
|
||||
@overload
|
||||
def to_dict(self, *, as_series: Literal[False]) -> dict[str, list[Any]]: ...
|
||||
|
||||
def to_dict(
|
||||
self, *, as_series: bool
|
||||
) -> dict[str, PolarsSeries] | dict[str, list[Any]]:
|
||||
if as_series:
|
||||
return {
|
||||
name: PolarsSeries.from_native(col, context=self)
|
||||
for name, col in self.native.to_dict().items()
|
||||
}
|
||||
return self.native.to_dict(as_series=False)
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
|
||||
) -> PolarsGroupBy:
|
||||
from narwhals._polars.group_by import PolarsGroupBy
|
||||
|
||||
return PolarsGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(to_drop))
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def pivot(
|
||||
self,
|
||||
on: Sequence[str],
|
||||
*,
|
||||
index: Sequence[str] | None,
|
||||
values: Sequence[str] | None,
|
||||
aggregate_function: PivotAgg | None,
|
||||
sort_columns: bool,
|
||||
separator: str,
|
||||
) -> Self:
|
||||
try:
|
||||
result = self.native.pivot(
|
||||
on,
|
||||
index=index,
|
||||
values=values,
|
||||
aggregate_function=aggregate_function,
|
||||
sort_columns=sort_columns,
|
||||
separator=separator,
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
return self._from_native_object(result)
|
||||
|
||||
def to_polars(self) -> pl.DataFrame:
|
||||
return self.native
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
*,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
try:
|
||||
return super().join(
|
||||
other=other, how=how, left_on=left_on, right_on=right_on, suffix=suffix
|
||||
)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
def top_k(
|
||||
self, k: int, *, by: str | Iterable[str], reverse: bool | Sequence[bool]
|
||||
) -> Self:
|
||||
try:
|
||||
return super().top_k(k=k, by=by, reverse=reverse)
|
||||
except Exception as e: # noqa: BLE001 # pragma: no cover
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
|
||||
class PolarsLazyFrame(PolarsBaseFrame[pl.LazyFrame]):
|
||||
sink_parquet: Method[None]
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.LazyFrame | Any) -> TypeIs[pl.LazyFrame]:
|
||||
return isinstance(obj, pl.LazyFrame)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[pl.LazyFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsLazyFrame"
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS: # pragma: no cover
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
try:
|
||||
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
|
||||
except pl.exceptions.ColumnNotFoundError as e: # pragma: no cover
|
||||
raise ColumnNotFoundError(str(e)) from e
|
||||
|
||||
return func
|
||||
|
||||
def _iter_columns(self) -> Iterator[PolarsSeries]: # pragma: no cover
|
||||
yield from self.collect(Implementation.POLARS).iter_columns()
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
try:
|
||||
return super().collect_schema()
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
try:
|
||||
result = self.native.collect(**kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
if backend is None or backend is Implementation.POLARS:
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
result.to_pandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
if backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
result.to_arrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=False,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[PolarsExpr], *, drop_null_keys: bool
|
||||
) -> PolarsLazyGroupBy:
|
||||
from narwhals._polars.group_by import PolarsLazyGroupBy
|
||||
|
||||
return PolarsLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
if self._backend_version < (1, 0, 0):
|
||||
return self._with_native(self.native.drop(columns))
|
||||
return self._with_native(self.native.drop(columns, strict=strict))
|
||||
479
lib/python3.11/site-packages/narwhals/_polars/expr.py
Normal file
479
lib/python3.11/site-packages/narwhals/_polars/expr.py
Normal file
@ -0,0 +1,479 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, Callable, Literal
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.utils import (
|
||||
PolarsAnyNamespace,
|
||||
PolarsCatNamespace,
|
||||
PolarsDateTimeNamespace,
|
||||
PolarsListNamespace,
|
||||
PolarsStringNamespace,
|
||||
PolarsStructNamespace,
|
||||
extract_args_kwargs,
|
||||
extract_native,
|
||||
narwhals_to_native_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation, requires
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Mapping, Sequence
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._polars.dataframe import Method
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._utils import Version
|
||||
from narwhals.typing import IntoDType, ModeKeepStrategy, NumericLiteral
|
||||
|
||||
|
||||
class PolarsExpr:
|
||||
# CompliantExpr
|
||||
_implementation: Implementation = Implementation.POLARS
|
||||
_version: Version
|
||||
_native_expr: pl.Expr
|
||||
_metadata: ExprMetadata | None = None
|
||||
_evaluate_output_names: Any
|
||||
_alias_output_names: Any
|
||||
__call__: Any
|
||||
|
||||
# CompliantExpr + builtin descriptor
|
||||
# TODO @dangotbanned: Remove in #2713
|
||||
@classmethod
|
||||
def from_column_names(cls, *_: Any, **__: Any) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *_: Any, **__: Any) -> Self:
|
||||
raise NotImplementedError
|
||||
|
||||
@staticmethod
|
||||
def _eval_names_indices(*_: Any) -> Any:
|
||||
raise NotImplementedError
|
||||
|
||||
def __narwhals_expr__(self) -> Self: # pragma: no cover
|
||||
return self
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace: # pragma: no cover
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __init__(self, expr: pl.Expr, version: Version) -> None:
|
||||
self._native_expr = expr
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Expr:
|
||||
return self._native_expr
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsExpr"
|
||||
|
||||
def _with_native(self, expr: pl.Expr) -> Self:
|
||||
return self.__class__(expr, self._version)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
# Let Polars do its thing.
|
||||
return self
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._with_native(getattr(self.native, attr)(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
def _renamed_min_periods(self, min_samples: int, /) -> dict[str, Any]:
|
||||
name = "min_periods" if self._backend_version < (1, 21, 0) else "min_samples"
|
||||
return {name: min_samples}
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
|
||||
return self._with_native(self.native.cast(dtype_pl))
|
||||
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self:
|
||||
native = self.native.ewm_mean(
|
||||
com=com,
|
||||
span=span,
|
||||
half_life=half_life,
|
||||
alpha=alpha,
|
||||
adjust=adjust,
|
||||
ignore_nulls=ignore_nulls,
|
||||
**self._renamed_min_periods(min_samples),
|
||||
)
|
||||
if self._backend_version < (1,): # pragma: no cover
|
||||
native = pl.when(~self.native.is_null()).then(native).otherwise(None)
|
||||
return self._with_native(native)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
if self._backend_version >= (1, 18):
|
||||
native = self.native.is_nan()
|
||||
else: # pragma: no cover
|
||||
native = pl.when(self.native.is_not_null()).then(self.native.is_nan())
|
||||
return self._with_native(native)
|
||||
|
||||
def over(self, partition_by: Sequence[str], order_by: Sequence[str]) -> Self:
|
||||
if self._backend_version < (1, 9):
|
||||
if order_by:
|
||||
msg = "`order_by` in Polars requires version 1.10 or greater"
|
||||
raise NotImplementedError(msg)
|
||||
native = self.native.over(partition_by or pl.lit(1))
|
||||
else:
|
||||
native = self.native.over(
|
||||
partition_by or pl.lit(1), order_by=order_by or None
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_var(
|
||||
window_size=window_size, center=center, ddof=ddof, **kwds
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_std(
|
||||
window_size=window_size, center=center, ddof=ddof, **kwds
|
||||
)
|
||||
return self._with_native(native)
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_sum(window_size=window_size, center=center, **kwds)
|
||||
return self._with_native(native)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
kwds = self._renamed_min_periods(min_samples)
|
||||
native = self.native.rolling_mean(window_size=window_size, center=center, **kwds)
|
||||
return self._with_native(native)
|
||||
|
||||
def map_batches(
|
||||
self,
|
||||
function: Callable[[Any], Any],
|
||||
return_dtype: IntoDType | None,
|
||||
*,
|
||||
returns_scalar: bool,
|
||||
) -> Self:
|
||||
pl_version = self._backend_version
|
||||
return_dtype_pl = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype is not None
|
||||
else None
|
||||
if pl_version < (1, 32)
|
||||
else pl.self_dtype()
|
||||
)
|
||||
kwargs = {} if pl_version < (0, 20, 31) else {"returns_scalar": returns_scalar}
|
||||
native = self.native.map_batches(function, return_dtype_pl, **kwargs)
|
||||
return self._with_native(native)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self:
|
||||
return_dtype_pl = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype
|
||||
else None
|
||||
)
|
||||
native = self.native.replace_strict(old, new, return_dtype=return_dtype_pl)
|
||||
return self._with_native(native)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__eq__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__ne__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __ge__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__ge__(extract_native(other)))
|
||||
|
||||
def __gt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__gt__(extract_native(other)))
|
||||
|
||||
def __le__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__le__(extract_native(other)))
|
||||
|
||||
def __lt__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__lt__(extract_native(other)))
|
||||
|
||||
def __and__(self, other: PolarsExpr | bool | Any) -> Self:
|
||||
return self._with_native(self.native.__and__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __or__(self, other: PolarsExpr | bool | Any) -> Self:
|
||||
return self._with_native(self.native.__or__(extract_native(other))) # type: ignore[operator]
|
||||
|
||||
def __add__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__add__(extract_native(other)))
|
||||
|
||||
def __sub__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__sub__(extract_native(other)))
|
||||
|
||||
def __mul__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__mul__(extract_native(other)))
|
||||
|
||||
def __pow__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__pow__(extract_native(other)))
|
||||
|
||||
def __truediv__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__truediv__(extract_native(other)))
|
||||
|
||||
def __floordiv__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__floordiv__(extract_native(other)))
|
||||
|
||||
def __mod__(self, other: Any) -> Self:
|
||||
return self._with_native(self.native.__mod__(extract_native(other)))
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
return self._with_native(self.native.__invert__())
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
return self._with_native(self.native.cum_count(reverse=reverse))
|
||||
|
||||
def is_close(
|
||||
self,
|
||||
other: Self | NumericLiteral,
|
||||
*,
|
||||
abs_tol: float,
|
||||
rel_tol: float,
|
||||
nans_equal: bool,
|
||||
) -> Self:
|
||||
left = self.native
|
||||
right = other.native if isinstance(other, PolarsExpr) else pl.lit(other)
|
||||
|
||||
if self._backend_version < (1, 32, 0):
|
||||
lower_bound = right.abs()
|
||||
tolerance = (left.abs().clip(lower_bound) * rel_tol).clip(abs_tol)
|
||||
|
||||
# Values are close if abs_diff <= tolerance, and both finite
|
||||
abs_diff = (left - right).abs()
|
||||
all_ = pl.all_horizontal
|
||||
is_close = all_((abs_diff <= tolerance), left.is_finite(), right.is_finite())
|
||||
|
||||
# Handle infinity cases: infinities are "close" only if they have the same sign
|
||||
is_same_inf = all_(
|
||||
left.is_infinite(), right.is_infinite(), (left.sign() == right.sign())
|
||||
)
|
||||
|
||||
# Handle nan cases:
|
||||
# * nans_equals = True => if both values are NaN, then True
|
||||
# * nans_equals = False => if any value is NaN, then False
|
||||
left_is_nan, right_is_nan = left.is_nan(), right.is_nan()
|
||||
either_nan = left_is_nan | right_is_nan
|
||||
result = (is_close | is_same_inf) & either_nan.not_()
|
||||
|
||||
if nans_equal:
|
||||
result = result | (left_is_nan & right_is_nan)
|
||||
else:
|
||||
result = left.is_close(
|
||||
right, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
|
||||
)
|
||||
return self._with_native(result)
|
||||
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
||||
result = self.native.mode()
|
||||
return self._with_native(result.first() if keep == "any" else result)
|
||||
|
||||
@property
|
||||
def dt(self) -> PolarsExprDateTimeNamespace:
|
||||
return PolarsExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def str(self) -> PolarsExprStringNamespace:
|
||||
return PolarsExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def cat(self) -> PolarsExprCatNamespace:
|
||||
return PolarsExprCatNamespace(self)
|
||||
|
||||
@property
|
||||
def name(self) -> PolarsExprNameNamespace:
|
||||
return PolarsExprNameNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> PolarsExprListNamespace:
|
||||
return PolarsExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> PolarsExprStructNamespace:
|
||||
return PolarsExprStructNamespace(self)
|
||||
|
||||
# Polars
|
||||
abs: Method[Self]
|
||||
all: Method[Self]
|
||||
any: Method[Self]
|
||||
alias: Method[Self]
|
||||
arg_max: Method[Self]
|
||||
arg_min: Method[Self]
|
||||
arg_true: Method[Self]
|
||||
clip: Method[Self]
|
||||
count: Method[Self]
|
||||
cum_max: Method[Self]
|
||||
cum_min: Method[Self]
|
||||
cum_prod: Method[Self]
|
||||
cum_sum: Method[Self]
|
||||
diff: Method[Self]
|
||||
drop_nulls: Method[Self]
|
||||
exp: Method[Self]
|
||||
fill_null: Method[Self]
|
||||
fill_nan: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
is_between: Method[Self]
|
||||
is_duplicated: Method[Self]
|
||||
is_finite: Method[Self]
|
||||
is_first_distinct: Method[Self]
|
||||
is_in: Method[Self]
|
||||
is_last_distinct: Method[Self]
|
||||
is_null: Method[Self]
|
||||
is_unique: Method[Self]
|
||||
kurtosis: Method[Self]
|
||||
len: Method[Self]
|
||||
log: Method[Self]
|
||||
max: Method[Self]
|
||||
mean: Method[Self]
|
||||
median: Method[Self]
|
||||
min: Method[Self]
|
||||
n_unique: Method[Self]
|
||||
null_count: Method[Self]
|
||||
quantile: Method[Self]
|
||||
rank: Method[Self]
|
||||
round: Method[Self]
|
||||
sample: Method[Self]
|
||||
shift: Method[Self]
|
||||
skew: Method[Self]
|
||||
sqrt: Method[Self]
|
||||
std: Method[Self]
|
||||
sum: Method[Self]
|
||||
sort: Method[Self]
|
||||
tail: Method[Self]
|
||||
unique: Method[Self]
|
||||
var: Method[Self]
|
||||
__rfloordiv__: Method[Self]
|
||||
__rsub__: Method[Self]
|
||||
__rmod__: Method[Self]
|
||||
__rpow__: Method[Self]
|
||||
__rtruediv__: Method[Self]
|
||||
|
||||
|
||||
class PolarsExprNamespace(PolarsAnyNamespace[PolarsExpr, pl.Expr]):
|
||||
def __init__(self, expr: PolarsExpr) -> None:
|
||||
self._expr = expr
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsExpr:
|
||||
return self._expr
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Expr:
|
||||
return self._expr.native
|
||||
|
||||
|
||||
class PolarsExprDateTimeNamespace(
|
||||
PolarsExprNamespace, PolarsDateTimeNamespace[PolarsExpr, pl.Expr]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsExprStringNamespace(
|
||||
PolarsExprNamespace, PolarsStringNamespace[PolarsExpr, pl.Expr]
|
||||
):
|
||||
def zfill(self, width: int) -> PolarsExpr:
|
||||
backend_version = self.compliant._backend_version
|
||||
native_result = self.native.str.zfill(width)
|
||||
|
||||
if backend_version < (0, 20, 5): # pragma: no cover
|
||||
# Reason:
|
||||
# `TypeError: argument 'length': 'Expr' object cannot be interpreted as an integer`
|
||||
# in `native_expr.str.slice(1, length)`
|
||||
msg = "`zfill` is only available in 'polars>=0.20.5', found version '0.20.4'."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if backend_version <= (1, 30, 0):
|
||||
length = self.native.str.len_chars()
|
||||
less_than_width = length < width
|
||||
plus = "+"
|
||||
starts_with_plus = self.native.str.starts_with(plus)
|
||||
native_result = (
|
||||
pl.when(starts_with_plus & less_than_width)
|
||||
.then(
|
||||
self.native.str.slice(1, length)
|
||||
.str.zfill(width - 1)
|
||||
.str.pad_start(width, plus)
|
||||
)
|
||||
.otherwise(native_result)
|
||||
)
|
||||
|
||||
return self.compliant._with_native(native_result)
|
||||
|
||||
|
||||
class PolarsExprCatNamespace(
|
||||
PolarsExprNamespace, PolarsCatNamespace[PolarsExpr, pl.Expr]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsExprNameNamespace(PolarsExprNamespace):
|
||||
_accessor = "name"
|
||||
keep: Method[PolarsExpr]
|
||||
map: Method[PolarsExpr]
|
||||
prefix: Method[PolarsExpr]
|
||||
suffix: Method[PolarsExpr]
|
||||
to_lowercase: Method[PolarsExpr]
|
||||
to_uppercase: Method[PolarsExpr]
|
||||
|
||||
|
||||
class PolarsExprListNamespace(
|
||||
PolarsExprNamespace, PolarsListNamespace[PolarsExpr, pl.Expr]
|
||||
):
|
||||
def len(self) -> PolarsExpr:
|
||||
native_expr = self.native
|
||||
native_result = native_expr.list.len()
|
||||
|
||||
if self.compliant._backend_version < (1, 16): # pragma: no cover
|
||||
native_result = (
|
||||
pl.when(~native_expr.is_null()).then(native_result).cast(pl.UInt32())
|
||||
)
|
||||
elif self.compliant._backend_version < (1, 17): # pragma: no cover
|
||||
native_result = native_result.cast(pl.UInt32())
|
||||
|
||||
return self.compliant._with_native(native_result)
|
||||
|
||||
def contains(self, item: Any) -> PolarsExpr:
|
||||
if self.compliant._backend_version < (1, 28):
|
||||
result: pl.Expr = pl.when(self.native.is_not_null()).then(
|
||||
self.native.list.contains(item)
|
||||
)
|
||||
else:
|
||||
result = self.native.list.contains(item)
|
||||
return self.compliant._with_native(result)
|
||||
|
||||
|
||||
class PolarsExprStructNamespace(
|
||||
PolarsExprNamespace, PolarsStructNamespace[PolarsExpr, pl.Expr]
|
||||
): ...
|
||||
76
lib/python3.11/site-packages/narwhals/_polars/group_by.py
Normal file
76
lib/python3.11/site-packages/narwhals/_polars/group_by.py
Normal file
@ -0,0 +1,76 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, cast
|
||||
|
||||
from narwhals._utils import is_sequence_of
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Sequence
|
||||
|
||||
from polars.dataframe.group_by import GroupBy as NativeGroupBy
|
||||
from polars.lazyframe.group_by import LazyGroupBy as NativeLazyGroupBy
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
|
||||
|
||||
class PolarsGroupBy:
|
||||
_compliant_frame: PolarsDataFrame
|
||||
_grouped: NativeGroupBy
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsDataFrame:
|
||||
return self._compliant_frame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: PolarsDataFrame,
|
||||
keys: Sequence[PolarsExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._keys = list(keys)
|
||||
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
||||
self._grouped = (
|
||||
self.compliant.native.group_by(keys)
|
||||
if is_sequence_of(keys, str)
|
||||
else self.compliant.native.group_by(arg.native for arg in keys)
|
||||
)
|
||||
|
||||
def agg(self, *aggs: PolarsExpr) -> PolarsDataFrame:
|
||||
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
||||
return self.compliant._with_native(agg_result)
|
||||
|
||||
def __iter__(self) -> Iterator[tuple[tuple[str, ...], PolarsDataFrame]]:
|
||||
for key, df in self._grouped:
|
||||
yield tuple(cast("str", key)), self.compliant._with_native(df)
|
||||
|
||||
|
||||
class PolarsLazyGroupBy:
|
||||
_compliant_frame: PolarsLazyFrame
|
||||
_grouped: NativeLazyGroupBy
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsLazyFrame:
|
||||
return self._compliant_frame
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
df: PolarsLazyFrame,
|
||||
keys: Sequence[PolarsExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
self._keys = list(keys)
|
||||
self._compliant_frame = df.drop_nulls(keys) if drop_null_keys else df
|
||||
self._grouped = (
|
||||
self.compliant.native.group_by(keys)
|
||||
if is_sequence_of(keys, str)
|
||||
else self.compliant.native.group_by(arg.native for arg in keys)
|
||||
)
|
||||
|
||||
def agg(self, *aggs: PolarsExpr) -> PolarsLazyFrame:
|
||||
agg_result = self._grouped.agg(arg.native for arg in aggs)
|
||||
return self.compliant._with_native(agg_result)
|
||||
281
lib/python3.11/site-packages/narwhals/_polars/namespace.py
Normal file
281
lib/python3.11/site-packages/narwhals/_polars/namespace.py
Normal file
@ -0,0 +1,281 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._expression_parsing import is_expr, is_series
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.utils import extract_args_kwargs, narwhals_to_native_dtype
|
||||
from narwhals._utils import Implementation, requires, zip_strict
|
||||
from narwhals.dependencies import is_numpy_array_2d
|
||||
from narwhals.dtypes import DType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Sequence
|
||||
from datetime import timezone
|
||||
|
||||
from narwhals._compliant import CompliantSelectorNamespace, CompliantWhen
|
||||
from narwhals._polars.dataframe import Method, PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.typing import FrameT
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.expr import Expr
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
IntoSchema,
|
||||
NonNestedLiteral,
|
||||
TimeUnit,
|
||||
_1DArray,
|
||||
_2DArray,
|
||||
)
|
||||
|
||||
|
||||
class PolarsNamespace:
|
||||
all: Method[PolarsExpr]
|
||||
coalesce: Method[PolarsExpr]
|
||||
col: Method[PolarsExpr]
|
||||
exclude: Method[PolarsExpr]
|
||||
sum_horizontal: Method[PolarsExpr]
|
||||
min_horizontal: Method[PolarsExpr]
|
||||
max_horizontal: Method[PolarsExpr]
|
||||
|
||||
when: Method[CompliantWhen[PolarsDataFrame, PolarsSeries, PolarsExpr]]
|
||||
|
||||
_implementation: Implementation = Implementation.POLARS
|
||||
_version: Version
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __init__(self, *, version: Version) -> None:
|
||||
self._version = version
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._expr(getattr(pl, attr)(*pos, **kwds), version=self._version)
|
||||
|
||||
return func
|
||||
|
||||
@property
|
||||
def _dataframe(self) -> type[PolarsDataFrame]:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame
|
||||
|
||||
@property
|
||||
def _lazyframe(self) -> type[PolarsLazyFrame]:
|
||||
from narwhals._polars.dataframe import PolarsLazyFrame
|
||||
|
||||
return PolarsLazyFrame
|
||||
|
||||
@property
|
||||
def _expr(self) -> type[PolarsExpr]:
|
||||
return PolarsExpr
|
||||
|
||||
@property
|
||||
def _series(self) -> type[PolarsSeries]:
|
||||
return PolarsSeries
|
||||
|
||||
def parse_into_expr(
|
||||
self,
|
||||
data: Expr | NonNestedLiteral | Series[pl.Series] | _1DArray,
|
||||
/,
|
||||
*,
|
||||
str_as_lit: bool,
|
||||
) -> PolarsExpr | None:
|
||||
if data is None:
|
||||
# NOTE: To avoid `pl.lit(None)` failing this `None` check
|
||||
# https://github.com/pola-rs/polars/blob/58dd8e5770f16a9bef9009a1c05f00e15a5263c7/py-polars/polars/expr/expr.py#L2870-L2872
|
||||
return data
|
||||
if is_expr(data):
|
||||
expr = data._to_compliant_expr(self)
|
||||
assert isinstance(expr, self._expr) # noqa: S101
|
||||
return expr
|
||||
if isinstance(data, str) and not str_as_lit:
|
||||
return self.col(data)
|
||||
return self.lit(data.to_native() if is_series(data) else data, None)
|
||||
|
||||
@overload
|
||||
def from_native(self, data: pl.DataFrame, /) -> PolarsDataFrame: ...
|
||||
@overload
|
||||
def from_native(self, data: pl.LazyFrame, /) -> PolarsLazyFrame: ...
|
||||
@overload
|
||||
def from_native(self, data: pl.Series, /) -> PolarsSeries: ...
|
||||
def from_native(
|
||||
self, data: pl.DataFrame | pl.LazyFrame | pl.Series | Any, /
|
||||
) -> PolarsDataFrame | PolarsLazyFrame | PolarsSeries:
|
||||
if self._dataframe._is_native(data):
|
||||
return self._dataframe.from_native(data, context=self)
|
||||
if self._series._is_native(data):
|
||||
return self._series.from_native(data, context=self)
|
||||
if self._lazyframe._is_native(data):
|
||||
return self._lazyframe.from_native(data, context=self)
|
||||
msg = f"Unsupported type: {type(data).__name__!r}" # pragma: no cover
|
||||
raise TypeError(msg) # pragma: no cover
|
||||
|
||||
@overload
|
||||
def from_numpy(self, data: Into1DArray, /, schema: None = ...) -> PolarsSeries: ...
|
||||
|
||||
@overload
|
||||
def from_numpy(
|
||||
self, data: _2DArray, /, schema: IntoSchema | Sequence[str] | None
|
||||
) -> PolarsDataFrame: ...
|
||||
|
||||
def from_numpy(
|
||||
self,
|
||||
data: Into1DArray | _2DArray,
|
||||
/,
|
||||
schema: IntoSchema | Sequence[str] | None = None,
|
||||
) -> PolarsDataFrame | PolarsSeries:
|
||||
if is_numpy_array_2d(data):
|
||||
return self._dataframe.from_numpy(data, schema=schema, context=self)
|
||||
return self._series.from_numpy(data, context=self) # pragma: no cover
|
||||
|
||||
@requires.backend_version(
|
||||
(1, 0, 0), "Please use `col` for columns selection instead."
|
||||
)
|
||||
def nth(self, *indices: int) -> PolarsExpr:
|
||||
return self._expr(pl.nth(*indices), version=self._version)
|
||||
|
||||
def len(self) -> PolarsExpr:
|
||||
if self._backend_version < (0, 20, 5):
|
||||
return self._expr(pl.count().alias("len"), self._version)
|
||||
return self._expr(pl.len(), self._version)
|
||||
|
||||
def all_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
|
||||
it = (expr.fill_null(True) for expr in exprs) if ignore_nulls else iter(exprs)
|
||||
return self._expr(pl.all_horizontal(*(expr.native for expr in it)), self._version)
|
||||
|
||||
def any_horizontal(self, *exprs: PolarsExpr, ignore_nulls: bool) -> PolarsExpr:
|
||||
it = (expr.fill_null(False) for expr in exprs) if ignore_nulls else iter(exprs)
|
||||
return self._expr(pl.any_horizontal(*(expr.native for expr in it)), self._version)
|
||||
|
||||
def concat(
|
||||
self,
|
||||
items: Iterable[FrameT],
|
||||
*,
|
||||
how: Literal["vertical", "horizontal", "diagonal"],
|
||||
) -> PolarsDataFrame | PolarsLazyFrame:
|
||||
result = pl.concat((item.native for item in items), how=how)
|
||||
if isinstance(result, pl.DataFrame):
|
||||
return self._dataframe(result, version=self._version)
|
||||
return self._lazyframe.from_native(result, context=self)
|
||||
|
||||
def lit(self, value: Any, dtype: IntoDType | None) -> PolarsExpr:
|
||||
if dtype is not None:
|
||||
return self._expr(
|
||||
pl.lit(value, dtype=narwhals_to_native_dtype(dtype, self._version)),
|
||||
version=self._version,
|
||||
)
|
||||
return self._expr(pl.lit(value), version=self._version)
|
||||
|
||||
def mean_horizontal(self, *exprs: PolarsExpr) -> PolarsExpr:
|
||||
if self._backend_version < (0, 20, 8):
|
||||
return self._expr(
|
||||
pl.sum_horizontal(e._native_expr for e in exprs)
|
||||
/ pl.sum_horizontal(1 - e.is_null()._native_expr for e in exprs),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
return self._expr(
|
||||
pl.mean_horizontal(e._native_expr for e in exprs), version=self._version
|
||||
)
|
||||
|
||||
def concat_str(
|
||||
self, *exprs: PolarsExpr, separator: str, ignore_nulls: bool
|
||||
) -> PolarsExpr:
|
||||
pl_exprs: list[pl.Expr] = [expr._native_expr for expr in exprs]
|
||||
|
||||
if self._backend_version < (0, 20, 6):
|
||||
null_mask = [expr.is_null() for expr in pl_exprs]
|
||||
sep = pl.lit(separator)
|
||||
|
||||
if not ignore_nulls:
|
||||
null_mask_result = pl.any_horizontal(*null_mask)
|
||||
output_expr = pl.reduce(
|
||||
lambda x, y: x.cast(pl.String()) + sep + y.cast(pl.String()), # type: ignore[arg-type,return-value]
|
||||
pl_exprs,
|
||||
)
|
||||
result = pl.when(~null_mask_result).then(output_expr)
|
||||
else:
|
||||
init_value, *values = [
|
||||
pl.when(nm).then(pl.lit("")).otherwise(expr.cast(pl.String()))
|
||||
for expr, nm in zip_strict(pl_exprs, null_mask)
|
||||
]
|
||||
separators = [
|
||||
pl.when(~nm).then(sep).otherwise(pl.lit("")) for nm in null_mask[:-1]
|
||||
]
|
||||
|
||||
result = pl.fold( # type: ignore[assignment]
|
||||
acc=init_value,
|
||||
function=operator.add,
|
||||
exprs=[s + v for s, v in zip_strict(separators, values)],
|
||||
)
|
||||
|
||||
return self._expr(result, version=self._version)
|
||||
|
||||
return self._expr(
|
||||
pl.concat_str(pl_exprs, separator=separator, ignore_nulls=ignore_nulls),
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
# NOTE: Implementation is too different to annotate correctly (vs other `*SelectorNamespace`)
|
||||
# 1. Others have lots of private stuff for code reuse
|
||||
# i. None of that is useful here
|
||||
# 2. We don't have a `PolarsSelector` abstraction, and just use `PolarsExpr`
|
||||
@property
|
||||
def selectors(self) -> CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]:
|
||||
return cast(
|
||||
"CompliantSelectorNamespace[PolarsDataFrame, PolarsSeries]",
|
||||
PolarsSelectorNamespace(self),
|
||||
)
|
||||
|
||||
|
||||
class PolarsSelectorNamespace:
|
||||
_implementation = Implementation.POLARS
|
||||
|
||||
def __init__(self, context: _LimitedContext, /) -> None:
|
||||
self._version = context._version
|
||||
|
||||
def by_dtype(self, dtypes: Iterable[DType]) -> PolarsExpr:
|
||||
native_dtypes = [
|
||||
narwhals_to_native_dtype(dtype, self._version).__class__
|
||||
if isinstance(dtype, type) and issubclass(dtype, DType)
|
||||
else narwhals_to_native_dtype(dtype, self._version)
|
||||
for dtype in dtypes
|
||||
]
|
||||
return PolarsExpr(pl.selectors.by_dtype(native_dtypes), version=self._version)
|
||||
|
||||
def matches(self, pattern: str) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.matches(pattern=pattern), version=self._version)
|
||||
|
||||
def numeric(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.numeric(), version=self._version)
|
||||
|
||||
def boolean(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.boolean(), version=self._version)
|
||||
|
||||
def string(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.string(), version=self._version)
|
||||
|
||||
def categorical(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.categorical(), version=self._version)
|
||||
|
||||
def all(self) -> PolarsExpr:
|
||||
return PolarsExpr(pl.selectors.all(), version=self._version)
|
||||
|
||||
def datetime(
|
||||
self,
|
||||
time_unit: TimeUnit | Iterable[TimeUnit] | None,
|
||||
time_zone: str | timezone | Iterable[str | timezone | None] | None,
|
||||
) -> PolarsExpr:
|
||||
return PolarsExpr(
|
||||
pl.selectors.datetime(time_unit=time_unit, time_zone=time_zone), # type: ignore[arg-type]
|
||||
version=self._version,
|
||||
)
|
||||
795
lib/python3.11/site-packages/narwhals/_polars/series.py
Normal file
795
lib/python3.11/site-packages/narwhals/_polars/series.py
Normal file
@ -0,0 +1,795 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._polars.utils import (
|
||||
BACKEND_VERSION,
|
||||
SERIES_ACCEPTS_PD_INDEX,
|
||||
SERIES_RESPECTS_DTYPE,
|
||||
PolarsAnyNamespace,
|
||||
PolarsCatNamespace,
|
||||
PolarsDateTimeNamespace,
|
||||
PolarsListNamespace,
|
||||
PolarsStringNamespace,
|
||||
PolarsStructNamespace,
|
||||
catch_polars_exception,
|
||||
extract_args_kwargs,
|
||||
extract_native,
|
||||
narwhals_to_native_dtype,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._utils import Implementation, requires
|
||||
from narwhals.dependencies import is_numpy_array_1d, is_pandas_index
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from types import ModuleType
|
||||
from typing import Literal, TypeVar
|
||||
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._polars.dataframe import Method, PolarsDataFrame
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.series import Series
|
||||
from narwhals.typing import (
|
||||
Into1DArray,
|
||||
IntoDType,
|
||||
ModeKeepStrategy,
|
||||
MultiIndexSelector,
|
||||
NonNestedLiteral,
|
||||
NumericLiteral,
|
||||
_1DArray,
|
||||
)
|
||||
|
||||
T = TypeVar("T")
|
||||
IncludeBreakpoint: TypeAlias = Literal[False, True]
|
||||
|
||||
Incomplete: TypeAlias = Any
|
||||
|
||||
# Series methods where PolarsSeries just defers to Polars.Series directly.
|
||||
INHERITED_METHODS = frozenset(
|
||||
[
|
||||
"__add__",
|
||||
"__and__",
|
||||
"__floordiv__",
|
||||
"__invert__",
|
||||
"__iter__",
|
||||
"__mod__",
|
||||
"__mul__",
|
||||
"__or__",
|
||||
"__pow__",
|
||||
"__radd__",
|
||||
"__rand__",
|
||||
"__rfloordiv__",
|
||||
"__rmod__",
|
||||
"__rmul__",
|
||||
"__ror__",
|
||||
"__rsub__",
|
||||
"__rtruediv__",
|
||||
"__sub__",
|
||||
"__truediv__",
|
||||
"abs",
|
||||
"all",
|
||||
"any",
|
||||
"arg_max",
|
||||
"arg_min",
|
||||
"arg_true",
|
||||
"clip",
|
||||
"count",
|
||||
"cum_max",
|
||||
"cum_min",
|
||||
"cum_prod",
|
||||
"cum_sum",
|
||||
"diff",
|
||||
"drop_nulls",
|
||||
"exp",
|
||||
"fill_null",
|
||||
"fill_nan",
|
||||
"filter",
|
||||
"gather_every",
|
||||
"head",
|
||||
"is_between",
|
||||
"is_close",
|
||||
"is_duplicated",
|
||||
"is_empty",
|
||||
"is_finite",
|
||||
"is_first_distinct",
|
||||
"is_in",
|
||||
"is_last_distinct",
|
||||
"is_null",
|
||||
"is_sorted",
|
||||
"is_unique",
|
||||
"item",
|
||||
"kurtosis",
|
||||
"len",
|
||||
"log",
|
||||
"max",
|
||||
"mean",
|
||||
"min",
|
||||
"mode",
|
||||
"n_unique",
|
||||
"null_count",
|
||||
"quantile",
|
||||
"rank",
|
||||
"round",
|
||||
"sample",
|
||||
"shift",
|
||||
"skew",
|
||||
"sqrt",
|
||||
"std",
|
||||
"sum",
|
||||
"tail",
|
||||
"to_arrow",
|
||||
"to_frame",
|
||||
"to_list",
|
||||
"to_pandas",
|
||||
"unique",
|
||||
"var",
|
||||
"zip_with",
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class PolarsSeries:
|
||||
_implementation: Implementation = Implementation.POLARS
|
||||
_native_series: pl.Series
|
||||
_version: Version
|
||||
|
||||
_HIST_EMPTY_SCHEMA: ClassVar[Mapping[IncludeBreakpoint, Sequence[str]]] = {
|
||||
True: ["breakpoint", "count"],
|
||||
False: ["count"],
|
||||
}
|
||||
|
||||
def __init__(self, series: pl.Series, *, version: Version) -> None:
|
||||
self._native_series = series
|
||||
self._version = version
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]:
|
||||
return self._implementation._backend_version()
|
||||
|
||||
def __repr__(self) -> str: # pragma: no cover
|
||||
return "PolarsSeries"
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
from narwhals._polars.namespace import PolarsNamespace
|
||||
|
||||
return PolarsNamespace(version=self._version)
|
||||
|
||||
def __narwhals_series__(self) -> Self:
|
||||
return self
|
||||
|
||||
def __native_namespace__(self) -> ModuleType:
|
||||
if self._implementation is Implementation.POLARS:
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
msg = f"Expected polars, got: {type(self._implementation)}" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(self.native, version=version)
|
||||
|
||||
@classmethod
|
||||
def from_iterable(
|
||||
cls,
|
||||
data: Iterable[Any],
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
name: str = "",
|
||||
dtype: IntoDType | None = None,
|
||||
) -> Self:
|
||||
version = context._version
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, version) if dtype else None
|
||||
values: Incomplete = data
|
||||
if SERIES_RESPECTS_DTYPE:
|
||||
native = pl.Series(name, values, dtype=dtype_pl)
|
||||
else: # pragma: no cover
|
||||
if (not SERIES_ACCEPTS_PD_INDEX) and is_pandas_index(values):
|
||||
values = values.to_series()
|
||||
native = pl.Series(name, values)
|
||||
if dtype_pl:
|
||||
native = native.cast(dtype_pl)
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: pl.Series | Any) -> TypeIs[pl.Series]:
|
||||
return isinstance(obj, pl.Series)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: pl.Series, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version)
|
||||
|
||||
@classmethod
|
||||
def from_numpy(cls, data: Into1DArray, /, *, context: _LimitedContext) -> Self:
|
||||
native = pl.Series(data if is_numpy_array_1d(data) else [data])
|
||||
return cls.from_native(native, context=context)
|
||||
|
||||
def to_narwhals(self) -> Series[pl.Series]:
|
||||
return self._version.series(self, level="full")
|
||||
|
||||
def _with_native(self, series: pl.Series) -> Self:
|
||||
return self.__class__(series, version=self._version)
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: pl.Series) -> Self: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: pl.DataFrame) -> PolarsDataFrame: ...
|
||||
|
||||
@overload
|
||||
def _from_native_object(self, series: T) -> T: ...
|
||||
|
||||
def _from_native_object(
|
||||
self, series: pl.Series | pl.DataFrame | T
|
||||
) -> Self | PolarsDataFrame | T:
|
||||
if self._is_native(series):
|
||||
return self._with_native(series)
|
||||
if isinstance(series, pl.DataFrame):
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame.from_native(series, context=self)
|
||||
# scalar
|
||||
return series
|
||||
|
||||
def __getattr__(self, attr: str) -> Any:
|
||||
if attr not in INHERITED_METHODS:
|
||||
msg = f"{self.__class__.__name__} has not attribute '{attr}'."
|
||||
raise AttributeError(msg)
|
||||
|
||||
def func(*args: Any, **kwargs: Any) -> Any:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
return self._from_native_object(getattr(self.native, attr)(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.native)
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.native.name
|
||||
|
||||
@property
|
||||
def dtype(self) -> DType:
|
||||
return native_to_narwhals_dtype(self.native.dtype, self._version)
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Series:
|
||||
return self._native_series
|
||||
|
||||
def alias(self, name: str) -> Self:
|
||||
return self._from_native_object(self.native.alias(name))
|
||||
|
||||
def __getitem__(self, item: MultiIndexSelector[Self]) -> Any | Self:
|
||||
if isinstance(item, PolarsSeries):
|
||||
return self._from_native_object(self.native.__getitem__(item.native))
|
||||
return self._from_native_object(self.native.__getitem__(item))
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
dtype_pl = narwhals_to_native_dtype(dtype, self._version)
|
||||
return self._with_native(self.native.cast(dtype_pl))
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def replace_strict(
|
||||
self,
|
||||
old: Sequence[Any] | Mapping[Any, Any],
|
||||
new: Sequence[Any],
|
||||
*,
|
||||
return_dtype: IntoDType | None,
|
||||
) -> Self:
|
||||
ser = self.native
|
||||
dtype = (
|
||||
narwhals_to_native_dtype(return_dtype, self._version)
|
||||
if return_dtype
|
||||
else None
|
||||
)
|
||||
return self._with_native(ser.replace_strict(old, new, return_dtype=dtype))
|
||||
|
||||
def to_numpy(self, dtype: Any = None, *, copy: bool | None = None) -> _1DArray:
|
||||
return self.__array__(dtype, copy=copy)
|
||||
|
||||
def __array__(self, dtype: Any, *, copy: bool | None) -> _1DArray:
|
||||
if self._backend_version < (0, 20, 29):
|
||||
return self.native.__array__(dtype=dtype)
|
||||
return self.native.__array__(dtype=dtype, copy=copy)
|
||||
|
||||
def __eq__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__eq__(extract_native(other)))
|
||||
|
||||
def __ne__(self, other: object) -> Self: # type: ignore[override]
|
||||
return self._with_native(self.native.__ne__(extract_native(other)))
|
||||
|
||||
# NOTE: These need to be anything that can't match `PolarsExpr`, due to overload order
|
||||
def __ge__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__ge__(extract_native(other)))
|
||||
|
||||
def __gt__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__gt__(extract_native(other)))
|
||||
|
||||
def __le__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__le__(extract_native(other)))
|
||||
|
||||
def __lt__(self, other: Self) -> Self:
|
||||
return self._with_native(self.native.__lt__(extract_native(other)))
|
||||
|
||||
def __rpow__(self, other: PolarsSeries | Any) -> Self:
|
||||
result = self.native.__rpow__(extract_native(other))
|
||||
if self._backend_version < (1, 16, 1):
|
||||
# Explicitly set alias to work around https://github.com/pola-rs/polars/issues/20071
|
||||
result = result.alias(self.name)
|
||||
return self._with_native(result)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
try:
|
||||
native_is_nan = self.native.is_nan()
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
if self._backend_version < (1, 18): # pragma: no cover
|
||||
select = pl.when(self.native.is_not_null()).then(native_is_nan)
|
||||
return self._with_native(pl.select(select)[self.name])
|
||||
return self._with_native(native_is_nan)
|
||||
|
||||
def median(self) -> Any:
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if not self.dtype.is_numeric():
|
||||
msg = "`median` operation not supported for non-numeric input type."
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
return self.native.median()
|
||||
|
||||
def to_dummies(self, *, separator: str, drop_first: bool) -> PolarsDataFrame:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
if self._backend_version < (0, 20, 15):
|
||||
has_nulls = self.native.is_null().any()
|
||||
result = self.native.to_dummies(separator=separator)
|
||||
output_columns = result.columns
|
||||
if drop_first:
|
||||
_ = output_columns.pop(int(has_nulls))
|
||||
|
||||
result = result.select(output_columns)
|
||||
else:
|
||||
result = self.native.to_dummies(separator=separator, drop_first=drop_first)
|
||||
result = result.with_columns(pl.all().cast(pl.Int8))
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
def ewm_mean(
|
||||
self,
|
||||
*,
|
||||
com: float | None,
|
||||
span: float | None,
|
||||
half_life: float | None,
|
||||
alpha: float | None,
|
||||
adjust: bool,
|
||||
min_samples: int,
|
||||
ignore_nulls: bool,
|
||||
) -> Self:
|
||||
extra_kwargs = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
|
||||
native_result = self.native.ewm_mean(
|
||||
com=com,
|
||||
span=span,
|
||||
half_life=half_life,
|
||||
alpha=alpha,
|
||||
adjust=adjust,
|
||||
ignore_nulls=ignore_nulls,
|
||||
**extra_kwargs,
|
||||
)
|
||||
if self._backend_version < (1,): # pragma: no cover
|
||||
return self._with_native(
|
||||
pl.select(
|
||||
pl.when(~self.native.is_null()).then(native_result).otherwise(None)
|
||||
)[self.native.name]
|
||||
)
|
||||
|
||||
return self._with_native(native_result)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_var(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_var(
|
||||
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
@requires.backend_version((1,))
|
||||
def rolling_std(
|
||||
self, window_size: int, *, min_samples: int, center: bool, ddof: int
|
||||
) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_std(
|
||||
window_size=window_size, center=center, ddof=ddof, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def rolling_sum(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_sum(
|
||||
window_size=window_size, center=center, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def rolling_mean(self, window_size: int, *, min_samples: int, center: bool) -> Self:
|
||||
extra_kwargs: dict[str, Any] = (
|
||||
{"min_periods": min_samples}
|
||||
if self._backend_version < (1, 21, 0)
|
||||
else {"min_samples": min_samples}
|
||||
)
|
||||
return self._with_native(
|
||||
self.native.rolling_mean(
|
||||
window_size=window_size, center=center, **extra_kwargs
|
||||
)
|
||||
)
|
||||
|
||||
def sort(self, *, descending: bool, nulls_last: bool) -> Self:
|
||||
if self._backend_version < (0, 20, 6):
|
||||
result = self.native.sort(descending=descending)
|
||||
|
||||
if nulls_last:
|
||||
is_null = result.is_null()
|
||||
result = pl.concat([result.filter(~is_null), result.filter(is_null)])
|
||||
else:
|
||||
result = self.native.sort(descending=descending, nulls_last=nulls_last)
|
||||
|
||||
return self._with_native(result)
|
||||
|
||||
def scatter(self, indices: int | Sequence[int], values: Any) -> Self:
|
||||
s = self.native.clone().scatter(indices, extract_native(values))
|
||||
return self._with_native(s)
|
||||
|
||||
def value_counts(
|
||||
self, *, sort: bool, parallel: bool, name: str | None, normalize: bool
|
||||
) -> PolarsDataFrame:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
if self._backend_version < (1, 0, 0):
|
||||
value_name_ = name or ("proportion" if normalize else "count")
|
||||
|
||||
result = self.native.value_counts(sort=sort, parallel=parallel).select(
|
||||
**{
|
||||
(self.native.name): pl.col(self.native.name),
|
||||
value_name_: pl.col("count") / pl.sum("count")
|
||||
if normalize
|
||||
else pl.col("count"),
|
||||
}
|
||||
)
|
||||
else:
|
||||
result = self.native.value_counts(
|
||||
sort=sort, parallel=parallel, name=name, normalize=normalize
|
||||
)
|
||||
return PolarsDataFrame.from_native(result, context=self)
|
||||
|
||||
def cum_count(self, *, reverse: bool) -> Self:
|
||||
return self._with_native(self.native.cum_count(reverse=reverse))
|
||||
|
||||
def __contains__(self, other: Any) -> bool:
|
||||
try:
|
||||
return self.native.__contains__(other)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_polars_exception(e) from None
|
||||
|
||||
def is_close(
|
||||
self,
|
||||
other: Self | NumericLiteral,
|
||||
*,
|
||||
abs_tol: float,
|
||||
rel_tol: float,
|
||||
nans_equal: bool,
|
||||
) -> PolarsSeries:
|
||||
if self._backend_version < (1, 32, 0):
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
other_expr = (
|
||||
ns.lit(other.native, None) if isinstance(other, PolarsSeries) else other
|
||||
)
|
||||
expr = ns.col(name).is_close(
|
||||
other_expr, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
|
||||
)
|
||||
return self.to_frame().select(expr).get_column(name)
|
||||
other_series = other.native if isinstance(other, PolarsSeries) else other
|
||||
result = self.native.is_close(
|
||||
other_series, abs_tol=abs_tol, rel_tol=rel_tol, nans_equal=nans_equal
|
||||
)
|
||||
return self._with_native(result)
|
||||
|
||||
def mode(self, *, keep: ModeKeepStrategy) -> Self:
|
||||
result = self.native.mode()
|
||||
return self._with_native(result.head(1) if keep == "any" else result)
|
||||
|
||||
def hist_from_bins(
|
||||
self, bins: list[float], *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
if len(bins) <= 1:
|
||||
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
|
||||
elif self.native.is_empty():
|
||||
if include_breakpoint:
|
||||
native = (
|
||||
pl.Series(bins[1:])
|
||||
.to_frame("breakpoint")
|
||||
.with_columns(count=pl.lit(0, pl.Int64))
|
||||
)
|
||||
else:
|
||||
native = pl.select(count=pl.zeros(len(bins) - 1, pl.Int64))
|
||||
else:
|
||||
return self._hist_from_data(
|
||||
bins=bins, bin_count=None, include_breakpoint=include_breakpoint
|
||||
)
|
||||
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
|
||||
|
||||
def hist_from_bin_count(
|
||||
self, bin_count: int, *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
if bin_count == 0:
|
||||
native = pl.DataFrame(schema=self._HIST_EMPTY_SCHEMA[include_breakpoint])
|
||||
elif self.native.is_empty():
|
||||
if include_breakpoint:
|
||||
native = pl.select(
|
||||
breakpoint=pl.int_range(1, bin_count + 1) / bin_count,
|
||||
count=pl.lit(0, pl.Int64),
|
||||
)
|
||||
else:
|
||||
native = pl.select(count=pl.zeros(bin_count, pl.Int64))
|
||||
else:
|
||||
count: int | None
|
||||
if BACKEND_VERSION < (1, 15): # pragma: no cover
|
||||
count = None
|
||||
bins = self._bins_from_bin_count(bin_count=bin_count)
|
||||
else:
|
||||
count = bin_count
|
||||
bins = None
|
||||
return self._hist_from_data(
|
||||
bins=bins, # type: ignore[arg-type]
|
||||
bin_count=count,
|
||||
include_breakpoint=include_breakpoint,
|
||||
)
|
||||
return self.__narwhals_namespace__()._dataframe.from_native(native, context=self)
|
||||
|
||||
def _bins_from_bin_count(self, bin_count: int) -> pl.Series: # pragma: no cover
|
||||
"""Prepare bins based on backend version compatibility.
|
||||
|
||||
polars <1.15 does not adjust the bins when they have equivalent min/max
|
||||
polars <1.5 with bin_count=...
|
||||
returns bins that range from -inf to +inf and has bin_count + 1 bins.
|
||||
for compat: convert `bin_count=` call to `bins=`
|
||||
"""
|
||||
lower = cast("float", self.native.min())
|
||||
upper = cast("float", self.native.max())
|
||||
|
||||
if lower == upper:
|
||||
lower -= 0.5
|
||||
upper += 0.5
|
||||
|
||||
width = (upper - lower) / bin_count
|
||||
return pl.int_range(0, bin_count + 1, eager=True) * width + lower
|
||||
|
||||
def _hist_from_data(
|
||||
self, bins: list[float] | None, bin_count: int | None, *, include_breakpoint: bool
|
||||
) -> PolarsDataFrame:
|
||||
"""Calculate histogram from non-empty data and post-process the results based on the backend version."""
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
series = self.native
|
||||
|
||||
# Polars inconsistently handles NaN values when computing histograms
|
||||
# against predefined bins: https://github.com/pola-rs/polars/issues/21082
|
||||
if BACKEND_VERSION < (1, 15) or bins is not None:
|
||||
series = series.fill_nan(None)
|
||||
|
||||
df = series.hist(
|
||||
bins,
|
||||
bin_count=bin_count,
|
||||
include_category=False,
|
||||
include_breakpoint=include_breakpoint,
|
||||
)
|
||||
|
||||
# Apply post-processing corrections
|
||||
|
||||
# Handle column naming
|
||||
if not include_breakpoint:
|
||||
col_name = df.columns[0]
|
||||
df = df.select(pl.col(col_name).alias("count"))
|
||||
elif BACKEND_VERSION < (1, 0): # pragma: no cover
|
||||
df = df.rename({"break_point": "breakpoint"})
|
||||
|
||||
if bins is not None: # pragma: no cover
|
||||
# polars<1.6 implicitly adds -inf and inf to either end of bins
|
||||
if BACKEND_VERSION < (1, 6):
|
||||
r = pl.int_range(0, len(df))
|
||||
df = df.filter((r > 0) & (r < len(df) - 1))
|
||||
# polars<1.27 makes the lowest bin a left/right closed interval
|
||||
if BACKEND_VERSION < (1, 27):
|
||||
df = (
|
||||
df.slice(0, 1)
|
||||
.with_columns(pl.col("count") + ((pl.lit(series) == bins[0]).sum()))
|
||||
.vstack(df.slice(1))
|
||||
)
|
||||
|
||||
return PolarsDataFrame.from_native(df, context=self)
|
||||
|
||||
def to_polars(self) -> pl.Series:
|
||||
return self.native
|
||||
|
||||
@property
|
||||
def dt(self) -> PolarsSeriesDateTimeNamespace:
|
||||
return PolarsSeriesDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def str(self) -> PolarsSeriesStringNamespace:
|
||||
return PolarsSeriesStringNamespace(self)
|
||||
|
||||
@property
|
||||
def cat(self) -> PolarsSeriesCatNamespace:
|
||||
return PolarsSeriesCatNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> PolarsSeriesStructNamespace:
|
||||
return PolarsSeriesStructNamespace(self)
|
||||
|
||||
__add__: Method[Self]
|
||||
__and__: Method[Self]
|
||||
__floordiv__: Method[Self]
|
||||
__invert__: Method[Self]
|
||||
__iter__: Method[Iterator[Any]]
|
||||
__mod__: Method[Self]
|
||||
__mul__: Method[Self]
|
||||
__or__: Method[Self]
|
||||
__pow__: Method[Self]
|
||||
__radd__: Method[Self]
|
||||
__rand__: Method[Self]
|
||||
__rfloordiv__: Method[Self]
|
||||
__rmod__: Method[Self]
|
||||
__rmul__: Method[Self]
|
||||
__ror__: Method[Self]
|
||||
__rsub__: Method[Self]
|
||||
__rtruediv__: Method[Self]
|
||||
__sub__: Method[Self]
|
||||
__truediv__: Method[Self]
|
||||
abs: Method[Self]
|
||||
all: Method[bool]
|
||||
any: Method[bool]
|
||||
arg_max: Method[int]
|
||||
arg_min: Method[int]
|
||||
arg_true: Method[Self]
|
||||
clip: Method[Self]
|
||||
count: Method[int]
|
||||
cum_max: Method[Self]
|
||||
cum_min: Method[Self]
|
||||
cum_prod: Method[Self]
|
||||
cum_sum: Method[Self]
|
||||
diff: Method[Self]
|
||||
drop_nulls: Method[Self]
|
||||
exp: Method[Self]
|
||||
fill_null: Method[Self]
|
||||
fill_nan: Method[Self]
|
||||
filter: Method[Self]
|
||||
gather_every: Method[Self]
|
||||
head: Method[Self]
|
||||
is_between: Method[Self]
|
||||
is_duplicated: Method[Self]
|
||||
is_empty: Method[bool]
|
||||
is_finite: Method[Self]
|
||||
is_first_distinct: Method[Self]
|
||||
is_in: Method[Self]
|
||||
is_last_distinct: Method[Self]
|
||||
is_null: Method[Self]
|
||||
is_sorted: Method[bool]
|
||||
is_unique: Method[Self]
|
||||
item: Method[Any]
|
||||
kurtosis: Method[float | None]
|
||||
len: Method[int]
|
||||
log: Method[Self]
|
||||
max: Method[Any]
|
||||
mean: Method[float]
|
||||
min: Method[Any]
|
||||
n_unique: Method[int]
|
||||
null_count: Method[int]
|
||||
quantile: Method[float]
|
||||
rank: Method[Self]
|
||||
round: Method[Self]
|
||||
sample: Method[Self]
|
||||
shift: Method[Self]
|
||||
skew: Method[float | None]
|
||||
sqrt: Method[Self]
|
||||
std: Method[float]
|
||||
sum: Method[float]
|
||||
tail: Method[Self]
|
||||
to_arrow: Method[pa.Array[Any]]
|
||||
to_frame: Method[PolarsDataFrame]
|
||||
to_list: Method[list[Any]]
|
||||
to_pandas: Method[pd.Series[Any]]
|
||||
unique: Method[Self]
|
||||
var: Method[float]
|
||||
zip_with: Method[Self]
|
||||
|
||||
@property
|
||||
def list(self) -> PolarsSeriesListNamespace:
|
||||
return PolarsSeriesListNamespace(self)
|
||||
|
||||
|
||||
class PolarsSeriesNamespace(PolarsAnyNamespace[PolarsSeries, pl.Series]):
|
||||
def __init__(self, series: PolarsSeries) -> None:
|
||||
self._series = series
|
||||
|
||||
@property
|
||||
def compliant(self) -> PolarsSeries:
|
||||
return self._series
|
||||
|
||||
@property
|
||||
def native(self) -> pl.Series:
|
||||
return self._series.native
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.compliant.name
|
||||
|
||||
def __narwhals_namespace__(self) -> PolarsNamespace:
|
||||
return self.compliant.__narwhals_namespace__()
|
||||
|
||||
def to_frame(self) -> PolarsDataFrame:
|
||||
return self.compliant.to_frame()
|
||||
|
||||
|
||||
class PolarsSeriesDateTimeNamespace(
|
||||
PolarsSeriesNamespace, PolarsDateTimeNamespace[PolarsSeries, pl.Series]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsSeriesStringNamespace(
|
||||
PolarsSeriesNamespace, PolarsStringNamespace[PolarsSeries, pl.Series]
|
||||
):
|
||||
def zfill(self, width: int) -> PolarsSeries:
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
return self.to_frame().select(ns.col(name).str.zfill(width)).get_column(name)
|
||||
|
||||
|
||||
class PolarsSeriesCatNamespace(
|
||||
PolarsSeriesNamespace, PolarsCatNamespace[PolarsSeries, pl.Series]
|
||||
): ...
|
||||
|
||||
|
||||
class PolarsSeriesListNamespace(
|
||||
PolarsSeriesNamespace, PolarsListNamespace[PolarsSeries, pl.Series]
|
||||
):
|
||||
def len(self) -> PolarsSeries:
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
return self.to_frame().select(ns.col(name).list.len()).get_column(name)
|
||||
|
||||
def contains(self, item: NonNestedLiteral) -> PolarsSeries:
|
||||
name = self.name
|
||||
ns = self.__narwhals_namespace__()
|
||||
return self.to_frame().select(ns.col(name).list.contains(item)).get_column(name)
|
||||
|
||||
|
||||
class PolarsSeriesStructNamespace(
|
||||
PolarsSeriesNamespace, PolarsStructNamespace[PolarsSeries, pl.Series]
|
||||
): ...
|
||||
25
lib/python3.11/site-packages/narwhals/_polars/typing.py
Normal file
25
lib/python3.11/site-packages/narwhals/_polars/typing.py
Normal file
@ -0,0 +1,25 @@
|
||||
from __future__ import annotations # pragma: no cover
|
||||
|
||||
from typing import (
|
||||
TYPE_CHECKING, # pragma: no cover
|
||||
Union, # pragma: no cover
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import sys
|
||||
from typing import Literal, TypeVar
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
from typing import TypeAlias
|
||||
else:
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
|
||||
IntoPolarsExpr: TypeAlias = Union[PolarsExpr, PolarsSeries]
|
||||
FrameT = TypeVar("FrameT", PolarsDataFrame, PolarsLazyFrame)
|
||||
NativeAccessor: TypeAlias = Literal[
|
||||
"arr", "cat", "dt", "list", "meta", "name", "str", "bin", "struct"
|
||||
]
|
||||
351
lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
351
lib/python3.11/site-packages/narwhals/_polars/utils.py
Normal file
@ -0,0 +1,351 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from functools import lru_cache
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final, Protocol, TypeVar, overload
|
||||
|
||||
import polars as pl
|
||||
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
Version,
|
||||
_DeferredIterable,
|
||||
_StoresCompliant,
|
||||
_StoresNative,
|
||||
deep_getattr,
|
||||
isinstance_or_issubclass,
|
||||
)
|
||||
from narwhals.exceptions import (
|
||||
ColumnNotFoundError,
|
||||
ComputeError,
|
||||
DuplicateError,
|
||||
InvalidOperationError,
|
||||
NarwhalsError,
|
||||
ShapeError,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Callable, Iterable, Iterator, Mapping
|
||||
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from narwhals._polars.dataframe import Method
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
from narwhals._polars.typing import NativeAccessor
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import IntoDType
|
||||
|
||||
T = TypeVar("T")
|
||||
NativeT = TypeVar(
|
||||
"NativeT", bound="pl.DataFrame | pl.LazyFrame | pl.Series | pl.Expr"
|
||||
)
|
||||
|
||||
NativeT_co = TypeVar("NativeT_co", "pl.Series", "pl.Expr", covariant=True)
|
||||
CompliantT_co = TypeVar("CompliantT_co", "PolarsSeries", "PolarsExpr", covariant=True)
|
||||
CompliantT = TypeVar("CompliantT", "PolarsSeries", "PolarsExpr")
|
||||
|
||||
BACKEND_VERSION = Implementation.POLARS._backend_version()
|
||||
"""Static backend version for `polars`."""
|
||||
|
||||
SERIES_RESPECTS_DTYPE: Final[bool] = BACKEND_VERSION >= (0, 20, 26)
|
||||
"""`pl.Series(dtype=...)` fixed in https://github.com/pola-rs/polars/pull/15962
|
||||
|
||||
Includes `SERIES_ACCEPTS_PD_INDEX`.
|
||||
"""
|
||||
|
||||
SERIES_ACCEPTS_PD_INDEX: Final[bool] = BACKEND_VERSION >= (0, 20, 7)
|
||||
"""`pl.Series(values: pd.Index)` fixed in https://github.com/pola-rs/polars/pull/14087"""
|
||||
|
||||
|
||||
@overload
|
||||
def extract_native(obj: _StoresNative[NativeT]) -> NativeT: ...
|
||||
@overload
|
||||
def extract_native(obj: T) -> T: ...
|
||||
def extract_native(obj: _StoresNative[NativeT] | T) -> NativeT | T:
|
||||
return obj.native if _is_compliant_polars(obj) else obj
|
||||
|
||||
|
||||
def _is_compliant_polars(
|
||||
obj: _StoresNative[NativeT] | Any,
|
||||
) -> TypeIs[_StoresNative[NativeT]]:
|
||||
from narwhals._polars.dataframe import PolarsDataFrame, PolarsLazyFrame
|
||||
from narwhals._polars.expr import PolarsExpr
|
||||
from narwhals._polars.series import PolarsSeries
|
||||
|
||||
return isinstance(obj, (PolarsDataFrame, PolarsLazyFrame, PolarsSeries, PolarsExpr))
|
||||
|
||||
|
||||
def extract_args_kwargs(
|
||||
args: Iterable[Any], kwds: Mapping[str, Any], /
|
||||
) -> tuple[Iterator[Any], dict[str, Any]]:
|
||||
it_args = (extract_native(arg) for arg in args)
|
||||
return it_args, {k: extract_native(v) for k, v in kwds.items()}
|
||||
|
||||
|
||||
@lru_cache(maxsize=16)
|
||||
def native_to_narwhals_dtype( # noqa: C901, PLR0912
|
||||
dtype: pl.DataType, version: Version
|
||||
) -> DType:
|
||||
dtypes = version.dtypes
|
||||
if dtype == pl.Float64:
|
||||
return dtypes.Float64()
|
||||
if dtype == pl.Float32:
|
||||
return dtypes.Float32()
|
||||
if hasattr(pl, "Int128") and dtype == pl.Int128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.Int128()
|
||||
if dtype == pl.Int64:
|
||||
return dtypes.Int64()
|
||||
if dtype == pl.Int32:
|
||||
return dtypes.Int32()
|
||||
if dtype == pl.Int16:
|
||||
return dtypes.Int16()
|
||||
if dtype == pl.Int8:
|
||||
return dtypes.Int8()
|
||||
if hasattr(pl, "UInt128") and dtype == pl.UInt128: # pragma: no cover
|
||||
# Not available for Polars pre 1.8.0
|
||||
return dtypes.UInt128()
|
||||
if dtype == pl.UInt64:
|
||||
return dtypes.UInt64()
|
||||
if dtype == pl.UInt32:
|
||||
return dtypes.UInt32()
|
||||
if dtype == pl.UInt16:
|
||||
return dtypes.UInt16()
|
||||
if dtype == pl.UInt8:
|
||||
return dtypes.UInt8()
|
||||
if dtype == pl.String:
|
||||
return dtypes.String()
|
||||
if dtype == pl.Boolean:
|
||||
return dtypes.Boolean()
|
||||
if dtype == pl.Object:
|
||||
return dtypes.Object()
|
||||
if dtype == pl.Categorical:
|
||||
return dtypes.Categorical()
|
||||
if isinstance_or_issubclass(dtype, pl.Enum):
|
||||
if version is Version.V1:
|
||||
return dtypes.Enum() # type: ignore[call-arg]
|
||||
categories = _DeferredIterable(dtype.categories.to_list)
|
||||
return dtypes.Enum(categories)
|
||||
if dtype == pl.Date:
|
||||
return dtypes.Date()
|
||||
if isinstance_or_issubclass(dtype, pl.Datetime):
|
||||
return (
|
||||
dtypes.Datetime()
|
||||
if dtype is pl.Datetime
|
||||
else dtypes.Datetime(dtype.time_unit, dtype.time_zone)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Duration):
|
||||
return (
|
||||
dtypes.Duration()
|
||||
if dtype is pl.Duration
|
||||
else dtypes.Duration(dtype.time_unit)
|
||||
)
|
||||
if isinstance_or_issubclass(dtype, pl.Struct):
|
||||
fields = [
|
||||
dtypes.Field(name, native_to_narwhals_dtype(tp, version))
|
||||
for name, tp in dtype
|
||||
]
|
||||
return dtypes.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, pl.List):
|
||||
return dtypes.List(native_to_narwhals_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, pl.Array):
|
||||
outer_shape = dtype.width if BACKEND_VERSION < (0, 20, 30) else dtype.size
|
||||
return dtypes.Array(native_to_narwhals_dtype(dtype.inner, version), outer_shape)
|
||||
if dtype == pl.Decimal:
|
||||
return dtypes.Decimal()
|
||||
if dtype == pl.Time:
|
||||
return dtypes.Time()
|
||||
if dtype == pl.Binary:
|
||||
return dtypes.Binary()
|
||||
return dtypes.Unknown()
|
||||
|
||||
|
||||
dtypes = Version.MAIN.dtypes
|
||||
NW_TO_PL_DTYPES: Mapping[type[DType], pl.DataType] = {
|
||||
dtypes.Float64: pl.Float64(),
|
||||
dtypes.Float32: pl.Float32(),
|
||||
dtypes.Binary: pl.Binary(),
|
||||
dtypes.String: pl.String(),
|
||||
dtypes.Boolean: pl.Boolean(),
|
||||
dtypes.Categorical: pl.Categorical(),
|
||||
dtypes.Date: pl.Date(),
|
||||
dtypes.Time: pl.Time(),
|
||||
dtypes.Int8: pl.Int8(),
|
||||
dtypes.Int16: pl.Int16(),
|
||||
dtypes.Int32: pl.Int32(),
|
||||
dtypes.Int64: pl.Int64(),
|
||||
dtypes.UInt8: pl.UInt8(),
|
||||
dtypes.UInt16: pl.UInt16(),
|
||||
dtypes.UInt32: pl.UInt32(),
|
||||
dtypes.UInt64: pl.UInt64(),
|
||||
dtypes.Object: pl.Object(),
|
||||
dtypes.Unknown: pl.Unknown(),
|
||||
}
|
||||
UNSUPPORTED_DTYPES = (dtypes.Decimal,)
|
||||
|
||||
|
||||
def narwhals_to_native_dtype( # noqa: C901
|
||||
dtype: IntoDType, version: Version
|
||||
) -> pl.DataType:
|
||||
dtypes = version.dtypes
|
||||
base_type = dtype.base_type()
|
||||
if pl_type := NW_TO_PL_DTYPES.get(base_type):
|
||||
return pl_type
|
||||
if dtype == dtypes.Int128 and hasattr(pl, "Int128"):
|
||||
# Not available for Polars pre 1.8.0
|
||||
return pl.Int128()
|
||||
if isinstance_or_issubclass(dtype, dtypes.Enum):
|
||||
if version is Version.V1:
|
||||
msg = "Converting to Enum is not supported in narwhals.stable.v1"
|
||||
raise NotImplementedError(msg)
|
||||
if isinstance(dtype, dtypes.Enum):
|
||||
return pl.Enum(dtype.categories)
|
||||
msg = "Can not cast / initialize Enum without categories present"
|
||||
raise ValueError(msg)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Datetime):
|
||||
return pl.Datetime(dtype.time_unit, dtype.time_zone) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.Duration):
|
||||
return pl.Duration(dtype.time_unit) # type: ignore[arg-type]
|
||||
if isinstance_or_issubclass(dtype, dtypes.List):
|
||||
return pl.List(narwhals_to_native_dtype(dtype.inner, version))
|
||||
if isinstance_or_issubclass(dtype, dtypes.Struct):
|
||||
fields = [
|
||||
pl.Field(field.name, narwhals_to_native_dtype(field.dtype, version))
|
||||
for field in dtype.fields
|
||||
]
|
||||
return pl.Struct(fields)
|
||||
if isinstance_or_issubclass(dtype, dtypes.Array): # pragma: no cover
|
||||
size = dtype.size
|
||||
kwargs = {"width": size} if BACKEND_VERSION < (0, 20, 30) else {"shape": size}
|
||||
return pl.Array(narwhals_to_native_dtype(dtype.inner, version), **kwargs)
|
||||
if issubclass(base_type, UNSUPPORTED_DTYPES):
|
||||
msg = f"Converting to {base_type.__name__} dtype is not supported for Polars."
|
||||
raise NotImplementedError(msg)
|
||||
return pl.Unknown() # pragma: no cover
|
||||
|
||||
|
||||
def _is_polars_exception(exception: Exception) -> bool:
|
||||
if BACKEND_VERSION >= (1,):
|
||||
# Old versions of Polars didn't have PolarsError.
|
||||
return isinstance(exception, pl.exceptions.PolarsError)
|
||||
# Last attempt, for old Polars versions.
|
||||
return "polars.exceptions" in str(type(exception)) # pragma: no cover
|
||||
|
||||
|
||||
def _is_cudf_exception(exception: Exception) -> bool:
|
||||
# These exceptions are raised when running polars on GPUs via cuDF
|
||||
return str(exception).startswith("CUDF failure")
|
||||
|
||||
|
||||
def catch_polars_exception(exception: Exception) -> NarwhalsError | Exception:
|
||||
if isinstance(exception, pl.exceptions.ColumnNotFoundError):
|
||||
return ColumnNotFoundError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.ShapeError):
|
||||
return ShapeError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.InvalidOperationError):
|
||||
return InvalidOperationError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.DuplicateError):
|
||||
return DuplicateError(str(exception))
|
||||
if isinstance(exception, pl.exceptions.ComputeError):
|
||||
return ComputeError(str(exception))
|
||||
if _is_polars_exception(exception) or _is_cudf_exception(exception):
|
||||
return NarwhalsError(str(exception)) # pragma: no cover
|
||||
# Just return exception as-is.
|
||||
return exception
|
||||
|
||||
|
||||
class PolarsAnyNamespace(
|
||||
_StoresCompliant[CompliantT_co],
|
||||
_StoresNative[NativeT_co],
|
||||
Protocol[CompliantT_co, NativeT_co],
|
||||
):
|
||||
_accessor: ClassVar[NativeAccessor]
|
||||
|
||||
def __getattr__(self, attr: str) -> Callable[..., CompliantT_co]:
|
||||
def func(*args: Any, **kwargs: Any) -> CompliantT_co:
|
||||
pos, kwds = extract_args_kwargs(args, kwargs)
|
||||
method = deep_getattr(self.native, self._accessor, attr)
|
||||
return self.compliant._with_native(method(*pos, **kwds))
|
||||
|
||||
return func
|
||||
|
||||
|
||||
class PolarsDateTimeNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "dt"
|
||||
|
||||
def truncate(self, every: str) -> CompliantT:
|
||||
# Ensure consistent error message is raised.
|
||||
Interval.parse(every)
|
||||
return self.__getattr__("truncate")(every)
|
||||
|
||||
def offset_by(self, by: str) -> CompliantT:
|
||||
# Ensure consistent error message is raised.
|
||||
Interval.parse_no_constraints(by)
|
||||
return self.__getattr__("offset_by")(by)
|
||||
|
||||
to_string: Method[CompliantT]
|
||||
replace_time_zone: Method[CompliantT]
|
||||
convert_time_zone: Method[CompliantT]
|
||||
timestamp: Method[CompliantT]
|
||||
date: Method[CompliantT]
|
||||
year: Method[CompliantT]
|
||||
month: Method[CompliantT]
|
||||
day: Method[CompliantT]
|
||||
hour: Method[CompliantT]
|
||||
minute: Method[CompliantT]
|
||||
second: Method[CompliantT]
|
||||
millisecond: Method[CompliantT]
|
||||
microsecond: Method[CompliantT]
|
||||
nanosecond: Method[CompliantT]
|
||||
ordinal_day: Method[CompliantT]
|
||||
weekday: Method[CompliantT]
|
||||
total_minutes: Method[CompliantT]
|
||||
total_seconds: Method[CompliantT]
|
||||
total_milliseconds: Method[CompliantT]
|
||||
total_microseconds: Method[CompliantT]
|
||||
total_nanoseconds: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsStringNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "str"
|
||||
|
||||
# NOTE: Use `abstractmethod` if we have defs to implement, but also `Method` usage
|
||||
@abc.abstractmethod
|
||||
def zfill(self, width: int) -> CompliantT: ...
|
||||
|
||||
len_chars: Method[CompliantT]
|
||||
replace: Method[CompliantT]
|
||||
replace_all: Method[CompliantT]
|
||||
strip_chars: Method[CompliantT]
|
||||
starts_with: Method[CompliantT]
|
||||
ends_with: Method[CompliantT]
|
||||
contains: Method[CompliantT]
|
||||
slice: Method[CompliantT]
|
||||
split: Method[CompliantT]
|
||||
to_date: Method[CompliantT]
|
||||
to_datetime: Method[CompliantT]
|
||||
to_lowercase: Method[CompliantT]
|
||||
to_uppercase: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsCatNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "cat"
|
||||
get_categories: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsListNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "list"
|
||||
|
||||
@abc.abstractmethod
|
||||
def len(self) -> CompliantT: ...
|
||||
|
||||
get: Method[CompliantT]
|
||||
|
||||
unique: Method[CompliantT]
|
||||
|
||||
|
||||
class PolarsStructNamespace(PolarsAnyNamespace[CompliantT, NativeT_co]):
|
||||
_accessor: ClassVar[NativeAccessor] = "struct"
|
||||
field: Method[CompliantT]
|
||||
@ -0,0 +1 @@
|
||||
# ! Any change to this module will trigger the pyspark and pyspark-connect tests in CI
|
||||
601
lib/python3.11/site-packages/narwhals/_spark_like/dataframe.py
Normal file
601
lib/python3.11/site-packages/narwhals/_spark_like/dataframe.py
Normal file
@ -0,0 +1,601 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import reduce
|
||||
from operator import and_
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from narwhals._exceptions import issue_warning
|
||||
from narwhals._namespace import is_native_spark_like
|
||||
from narwhals._spark_like.utils import (
|
||||
catch_pyspark_connect_exception,
|
||||
catch_pyspark_sql_exception,
|
||||
evaluate_exprs,
|
||||
import_functions,
|
||||
import_native_dtypes,
|
||||
import_window,
|
||||
native_to_narwhals_dtype,
|
||||
)
|
||||
from narwhals._sql.dataframe import SQLLazyFrame
|
||||
from narwhals._utils import (
|
||||
Implementation,
|
||||
ValidateBackendVersion,
|
||||
generate_temporary_column_name,
|
||||
not_implemented,
|
||||
parse_columns_to_drop,
|
||||
zip_strict,
|
||||
)
|
||||
from narwhals.exceptions import InvalidOperationError
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterable, Iterator, Mapping, Sequence
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
import pyarrow as pa
|
||||
from sqlframe.base.column import Column
|
||||
from sqlframe.base.dataframe import BaseDataFrame
|
||||
from sqlframe.base.window import Window
|
||||
from typing_extensions import Self, TypeAlias, TypeIs
|
||||
|
||||
from narwhals._compliant.typing import CompliantDataFrameAny
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
from narwhals._spark_like.utils import SparkSession
|
||||
from narwhals._typing import _EagerAllowedImpl
|
||||
from narwhals._utils import Version, _LimitedContext
|
||||
from narwhals.dataframe import LazyFrame
|
||||
from narwhals.dtypes import DType
|
||||
from narwhals.typing import JoinStrategy, LazyUniqueKeepStrategy
|
||||
|
||||
SQLFrameDataFrame = BaseDataFrame[Any, Any, Any, Any, Any]
|
||||
|
||||
Incomplete: TypeAlias = Any # pragma: no cover
|
||||
"""Marker for working code that fails type checking."""
|
||||
|
||||
|
||||
class SparkLikeLazyFrame(
|
||||
SQLLazyFrame["SparkLikeExpr", "SQLFrameDataFrame", "LazyFrame[SQLFrameDataFrame]"],
|
||||
ValidateBackendVersion,
|
||||
):
|
||||
def __init__(
|
||||
self,
|
||||
native_dataframe: SQLFrameDataFrame,
|
||||
*,
|
||||
version: Version,
|
||||
implementation: Implementation,
|
||||
validate_backend_version: bool = False,
|
||||
) -> None:
|
||||
self._native_frame: SQLFrameDataFrame = native_dataframe
|
||||
self._implementation = implementation
|
||||
self._version = version
|
||||
self._cached_schema: dict[str, DType] | None = None
|
||||
self._cached_columns: list[str] | None = None
|
||||
if validate_backend_version: # pragma: no cover
|
||||
self._validate_backend_version()
|
||||
|
||||
@property
|
||||
def _backend_version(self) -> tuple[int, ...]: # pragma: no cover
|
||||
return self._implementation._backend_version()
|
||||
|
||||
@property
|
||||
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import functions
|
||||
|
||||
return functions
|
||||
return import_functions(self._implementation)
|
||||
|
||||
@property
|
||||
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import types
|
||||
|
||||
return types
|
||||
return import_native_dtypes(self._implementation)
|
||||
|
||||
@property
|
||||
def _Window(self) -> type[Window]:
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.window import Window
|
||||
|
||||
return Window
|
||||
return import_window(self._implementation)
|
||||
|
||||
@staticmethod
|
||||
def _is_native(obj: SQLFrameDataFrame | Any) -> TypeIs[SQLFrameDataFrame]:
|
||||
return is_native_spark_like(obj)
|
||||
|
||||
@classmethod
|
||||
def from_native(cls, data: SQLFrameDataFrame, /, *, context: _LimitedContext) -> Self:
|
||||
return cls(data, version=context._version, implementation=context._implementation)
|
||||
|
||||
def to_narwhals(self) -> LazyFrame[SQLFrameDataFrame]:
|
||||
return self._version.lazyframe(self, level="lazy")
|
||||
|
||||
def __native_namespace__(self) -> ModuleType: # pragma: no cover
|
||||
return self._implementation.to_native_namespace()
|
||||
|
||||
def __narwhals_namespace__(self) -> SparkLikeNamespace:
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
|
||||
return SparkLikeNamespace(
|
||||
version=self._version, implementation=self._implementation
|
||||
)
|
||||
|
||||
def __narwhals_lazyframe__(self) -> Self:
|
||||
return self
|
||||
|
||||
def _with_version(self, version: Version) -> Self:
|
||||
return self.__class__(
|
||||
self.native, version=version, implementation=self._implementation
|
||||
)
|
||||
|
||||
def _with_native(self, df: SQLFrameDataFrame) -> Self:
|
||||
return self.__class__(
|
||||
df, version=self._version, implementation=self._implementation
|
||||
)
|
||||
|
||||
def _to_arrow_schema(self) -> pa.Schema: # pragma: no cover
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
from narwhals._arrow.utils import narwhals_to_native_dtype
|
||||
|
||||
schema: list[tuple[str, pa.DataType]] = []
|
||||
nw_schema = self.collect_schema()
|
||||
native_schema = self.native.schema
|
||||
for key, value in nw_schema.items():
|
||||
try:
|
||||
native_dtype = narwhals_to_native_dtype(value, self._version)
|
||||
except Exception as exc: # noqa: BLE001,PERF203
|
||||
native_spark_dtype = native_schema[key].dataType # type: ignore[index]
|
||||
# If we can't convert the type, just set it to `pa.null`, and warn.
|
||||
# Avoid the warning if we're starting from PySpark's void type.
|
||||
# We can avoid the check when we introduce `nw.Null` dtype.
|
||||
null_type = self._native_dtypes.NullType # pyright: ignore[reportAttributeAccessIssue]
|
||||
if not isinstance(native_spark_dtype, null_type):
|
||||
issue_warning(
|
||||
f"Could not convert dtype {native_spark_dtype} to PyArrow dtype, {exc!r}",
|
||||
UserWarning,
|
||||
)
|
||||
schema.append((key, pa.null()))
|
||||
else:
|
||||
schema.append((key, native_dtype))
|
||||
return pa.schema(schema)
|
||||
|
||||
def _collect_to_arrow(self) -> pa.Table:
|
||||
if self._implementation.is_pyspark() and self._backend_version < (4,):
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
try:
|
||||
return pa.Table.from_batches(self.native._collect_as_arrow())
|
||||
except ValueError as exc:
|
||||
if "at least one RecordBatch" in str(exc):
|
||||
# Empty dataframe
|
||||
|
||||
data: dict[str, list[Any]] = {k: [] for k in self.columns}
|
||||
pa_schema = self._to_arrow_schema()
|
||||
return pa.Table.from_pydict(data, schema=pa_schema)
|
||||
raise # pragma: no cover
|
||||
elif self._implementation.is_pyspark_connect() and self._backend_version < (4,):
|
||||
import pyarrow as pa # ignore-banned-import
|
||||
|
||||
pa_schema = self._to_arrow_schema()
|
||||
return pa.Table.from_pandas(self.native.toPandas(), schema=pa_schema)
|
||||
else:
|
||||
return self.native.toArrow()
|
||||
|
||||
def _iter_columns(self) -> Iterator[Column]:
|
||||
for col in self.columns:
|
||||
yield self._F.col(col)
|
||||
|
||||
@property
|
||||
def columns(self) -> list[str]:
|
||||
if self._cached_columns is None:
|
||||
self._cached_columns = (
|
||||
list(self.schema)
|
||||
if self._cached_schema is not None
|
||||
else self.native.columns
|
||||
)
|
||||
return self._cached_columns
|
||||
|
||||
def _collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if backend is Implementation.PANDAS:
|
||||
from narwhals._pandas_like.dataframe import PandasLikeDataFrame
|
||||
|
||||
return PandasLikeDataFrame(
|
||||
self.native.toPandas(),
|
||||
implementation=Implementation.PANDAS,
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is None or backend is Implementation.PYARROW:
|
||||
from narwhals._arrow.dataframe import ArrowDataFrame
|
||||
|
||||
return ArrowDataFrame(
|
||||
self._collect_to_arrow(),
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
validate_column_names=True,
|
||||
)
|
||||
|
||||
if backend is Implementation.POLARS:
|
||||
import polars as pl # ignore-banned-import
|
||||
|
||||
from narwhals._polars.dataframe import PolarsDataFrame
|
||||
|
||||
return PolarsDataFrame(
|
||||
pl.from_arrow(self._collect_to_arrow()), # type: ignore[arg-type]
|
||||
validate_backend_version=True,
|
||||
version=self._version,
|
||||
)
|
||||
|
||||
msg = f"Unsupported `backend` value: {backend}" # pragma: no cover
|
||||
raise ValueError(msg) # pragma: no cover
|
||||
|
||||
def collect(
|
||||
self, backend: _EagerAllowedImpl | None, **kwargs: Any
|
||||
) -> CompliantDataFrameAny:
|
||||
if self._implementation.is_pyspark_connect():
|
||||
try:
|
||||
return self._collect(backend, **kwargs)
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_connect_exception(e) from None
|
||||
return self._collect(backend, **kwargs)
|
||||
|
||||
def simple_select(self, *column_names: str) -> Self:
|
||||
return self._with_native(self.native.select(*column_names))
|
||||
|
||||
def aggregate(self, *exprs: SparkLikeExpr) -> Self:
|
||||
new_columns = evaluate_exprs(self, *exprs)
|
||||
|
||||
new_columns_list = [col.alias(col_name) for col_name, col in new_columns]
|
||||
if self._implementation.is_pyspark():
|
||||
try:
|
||||
return self._with_native(self.native.agg(*new_columns_list))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
return self._with_native(self.native.agg(*new_columns_list))
|
||||
|
||||
def select(self, *exprs: SparkLikeExpr) -> Self:
|
||||
new_columns = evaluate_exprs(self, *exprs)
|
||||
new_columns_list = [col.alias(col_name) for (col_name, col) in new_columns]
|
||||
if self._implementation.is_pyspark(): # pragma: no cover
|
||||
try:
|
||||
return self._with_native(self.native.select(*new_columns_list))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
return self._with_native(self.native.select(*new_columns_list))
|
||||
|
||||
def with_columns(self, *exprs: SparkLikeExpr) -> Self:
|
||||
new_columns = evaluate_exprs(self, *exprs)
|
||||
if self._implementation.is_pyspark(): # pragma: no cover
|
||||
try:
|
||||
return self._with_native(self.native.withColumns(dict(new_columns)))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
|
||||
return self._with_native(self.native.withColumns(dict(new_columns)))
|
||||
|
||||
def filter(self, predicate: SparkLikeExpr) -> Self:
|
||||
# `[0]` is safe as the predicate's expression only returns a single column
|
||||
condition = predicate._call(self)[0]
|
||||
if self._implementation.is_pyspark():
|
||||
try:
|
||||
return self._with_native(self.native.where(condition))
|
||||
except Exception as e: # noqa: BLE001
|
||||
raise catch_pyspark_sql_exception(e, self) from None
|
||||
return self._with_native(self.native.where(condition))
|
||||
|
||||
@property
|
||||
def schema(self) -> dict[str, DType]:
|
||||
if self._cached_schema is None:
|
||||
self._cached_schema = {
|
||||
field.name: native_to_narwhals_dtype(
|
||||
field.dataType,
|
||||
self._version,
|
||||
self._native_dtypes,
|
||||
self.native.sparkSession,
|
||||
)
|
||||
for field in self.native.schema
|
||||
}
|
||||
return self._cached_schema
|
||||
|
||||
def collect_schema(self) -> dict[str, DType]:
|
||||
return self.schema
|
||||
|
||||
def drop(self, columns: Sequence[str], *, strict: bool) -> Self:
|
||||
columns_to_drop = parse_columns_to_drop(self, columns, strict=strict)
|
||||
return self._with_native(self.native.drop(*columns_to_drop))
|
||||
|
||||
def head(self, n: int) -> Self:
|
||||
return self._with_native(self.native.limit(n))
|
||||
|
||||
def group_by(
|
||||
self, keys: Sequence[str] | Sequence[SparkLikeExpr], *, drop_null_keys: bool
|
||||
) -> SparkLikeLazyGroupBy:
|
||||
from narwhals._spark_like.group_by import SparkLikeLazyGroupBy
|
||||
|
||||
return SparkLikeLazyGroupBy(self, keys, drop_null_keys=drop_null_keys)
|
||||
|
||||
def sort(self, *by: str, descending: bool | Sequence[bool], nulls_last: bool) -> Self:
|
||||
if isinstance(descending, bool):
|
||||
descending = [descending] * len(by)
|
||||
|
||||
if nulls_last:
|
||||
sort_funcs = (
|
||||
self._F.desc_nulls_last if d else self._F.asc_nulls_last
|
||||
for d in descending
|
||||
)
|
||||
else:
|
||||
sort_funcs = (
|
||||
self._F.desc_nulls_first if d else self._F.asc_nulls_first
|
||||
for d in descending
|
||||
)
|
||||
|
||||
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
|
||||
return self._with_native(self.native.sort(*sort_cols))
|
||||
|
||||
def top_k(self, k: int, *, by: Iterable[str], reverse: bool | Sequence[bool]) -> Self:
|
||||
by = list(by)
|
||||
if isinstance(reverse, bool):
|
||||
reverse = [reverse] * len(by)
|
||||
sort_funcs = (
|
||||
self._F.desc_nulls_last if not d else self._F.asc_nulls_last for d in reverse
|
||||
)
|
||||
sort_cols = [sort_f(col) for col, sort_f in zip_strict(by, sort_funcs)]
|
||||
return self._with_native(self.native.sort(*sort_cols).limit(k))
|
||||
|
||||
def drop_nulls(self, subset: Sequence[str] | None) -> Self:
|
||||
subset = list(subset) if subset else None
|
||||
return self._with_native(self.native.dropna(subset=subset))
|
||||
|
||||
def rename(self, mapping: Mapping[str, str]) -> Self:
|
||||
rename_mapping = {
|
||||
colname: mapping.get(colname, colname) for colname in self.columns
|
||||
}
|
||||
return self._with_native(
|
||||
self.native.select(
|
||||
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
|
||||
)
|
||||
)
|
||||
|
||||
def unique(
|
||||
self, subset: Sequence[str] | None, *, keep: LazyUniqueKeepStrategy
|
||||
) -> Self:
|
||||
if subset and (error := self._check_columns_exist(subset)):
|
||||
raise error
|
||||
subset = list(subset) if subset else None
|
||||
if keep == "none":
|
||||
tmp = generate_temporary_column_name(8, self.columns)
|
||||
window = self._Window.partitionBy(subset or self.columns)
|
||||
df = (
|
||||
self.native.withColumn(tmp, self._F.count("*").over(window))
|
||||
.filter(self._F.col(tmp) == self._F.lit(1))
|
||||
.drop(self._F.col(tmp))
|
||||
)
|
||||
return self._with_native(df)
|
||||
return self._with_native(self.native.dropDuplicates(subset=subset))
|
||||
|
||||
def join(
|
||||
self,
|
||||
other: Self,
|
||||
how: JoinStrategy,
|
||||
left_on: Sequence[str] | None,
|
||||
right_on: Sequence[str] | None,
|
||||
suffix: str,
|
||||
) -> Self:
|
||||
left_columns = self.columns
|
||||
right_columns = other.columns
|
||||
|
||||
right_on_: list[str] = list(right_on) if right_on is not None else []
|
||||
left_on_: list[str] = list(left_on) if left_on is not None else []
|
||||
|
||||
# create a mapping for columns on other
|
||||
# `right_on` columns will be renamed as `left_on`
|
||||
# the remaining columns will be either added the suffix or left unchanged.
|
||||
right_cols_to_rename = (
|
||||
[c for c in right_columns if c not in right_on_]
|
||||
if how != "full"
|
||||
else right_columns
|
||||
)
|
||||
|
||||
rename_mapping = {
|
||||
**dict(zip(right_on_, left_on_)),
|
||||
**{
|
||||
colname: f"{colname}{suffix}" if colname in left_columns else colname
|
||||
for colname in right_cols_to_rename
|
||||
},
|
||||
}
|
||||
other_native = other.native.select(
|
||||
[self._F.col(old).alias(new) for old, new in rename_mapping.items()]
|
||||
)
|
||||
|
||||
# If how in {"semi", "anti"}, then resulting columns are same as left columns
|
||||
# Otherwise, we add the right columns with the new mapping, while keeping the
|
||||
# original order of right_columns.
|
||||
col_order = left_columns.copy()
|
||||
|
||||
if how in {"inner", "left", "cross"}:
|
||||
col_order.extend(
|
||||
rename_mapping[colname]
|
||||
for colname in right_columns
|
||||
if colname not in right_on_
|
||||
)
|
||||
elif how == "full":
|
||||
col_order.extend(rename_mapping.values())
|
||||
|
||||
right_on_remapped = [rename_mapping[c] for c in right_on_]
|
||||
on_ = (
|
||||
reduce(
|
||||
and_,
|
||||
(
|
||||
getattr(self.native, left_key) == getattr(other_native, right_key)
|
||||
for left_key, right_key in zip_strict(left_on_, right_on_remapped)
|
||||
),
|
||||
)
|
||||
if how == "full"
|
||||
else None
|
||||
if how == "cross"
|
||||
else left_on_
|
||||
)
|
||||
how_native = "full_outer" if how == "full" else how
|
||||
return self._with_native(
|
||||
self.native.join(other_native, on=on_, how=how_native).select(col_order)
|
||||
)
|
||||
|
||||
def explode(self, columns: Sequence[str]) -> Self:
|
||||
dtypes = self._version.dtypes
|
||||
|
||||
schema = self.collect_schema()
|
||||
for col_to_explode in columns:
|
||||
dtype = schema[col_to_explode]
|
||||
|
||||
if dtype != dtypes.List:
|
||||
msg = (
|
||||
f"`explode` operation not supported for dtype `{dtype}`, "
|
||||
"expected List type"
|
||||
)
|
||||
raise InvalidOperationError(msg)
|
||||
|
||||
column_names = self.columns
|
||||
|
||||
if len(columns) != 1:
|
||||
msg = (
|
||||
"Exploding on multiple columns is not supported with SparkLike backend since "
|
||||
"we cannot guarantee that the exploded columns have matching element counts."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if self._implementation.is_pyspark() or self._implementation.is_pyspark_connect():
|
||||
return self._with_native(
|
||||
self.native.select(
|
||||
*[
|
||||
self._F.col(col_name).alias(col_name)
|
||||
if col_name != columns[0]
|
||||
else self._F.explode_outer(col_name).alias(col_name)
|
||||
for col_name in column_names
|
||||
]
|
||||
)
|
||||
)
|
||||
if self._implementation.is_sqlframe():
|
||||
# Not every sqlframe dialect supports `explode_outer` function
|
||||
# (see https://github.com/eakmanrq/sqlframe/blob/3cb899c515b101ff4c197d84b34fae490d0ed257/sqlframe/base/functions.py#L2288-L2289)
|
||||
# therefore we simply explode the array column which will ignore nulls and
|
||||
# zero sized arrays, and append these specific condition with nulls (to
|
||||
# match polars behavior).
|
||||
|
||||
def null_condition(col_name: str) -> Column:
|
||||
return self._F.isnull(col_name) | (self._F.array_size(col_name) == 0)
|
||||
|
||||
return self._with_native(
|
||||
self.native.select(
|
||||
*[
|
||||
self._F.col(col_name).alias(col_name)
|
||||
if col_name != columns[0]
|
||||
else self._F.explode(col_name).alias(col_name)
|
||||
for col_name in column_names
|
||||
]
|
||||
).union(
|
||||
self.native.filter(null_condition(columns[0])).select(
|
||||
*[
|
||||
self._F.col(col_name).alias(col_name)
|
||||
if col_name != columns[0]
|
||||
else self._F.lit(None).alias(col_name)
|
||||
for col_name in column_names
|
||||
]
|
||||
)
|
||||
)
|
||||
)
|
||||
msg = "Unreachable code, please report an issue at https://github.com/narwhals-dev/narwhals/issues" # pragma: no cover
|
||||
raise AssertionError(msg)
|
||||
|
||||
def unpivot(
|
||||
self,
|
||||
on: Sequence[str] | None,
|
||||
index: Sequence[str] | None,
|
||||
variable_name: str,
|
||||
value_name: str,
|
||||
) -> Self:
|
||||
if self._implementation.is_sqlframe():
|
||||
if variable_name == "":
|
||||
msg = "`variable_name` cannot be empty string for sqlframe backend."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
if value_name == "":
|
||||
msg = "`value_name` cannot be empty string for sqlframe backend."
|
||||
raise NotImplementedError(msg)
|
||||
else: # pragma: no cover
|
||||
pass
|
||||
|
||||
ids = tuple(index) if index else ()
|
||||
values = (
|
||||
tuple(set(self.columns).difference(set(ids))) if on is None else tuple(on)
|
||||
)
|
||||
unpivoted_native_frame = self.native.unpivot(
|
||||
ids=ids,
|
||||
values=values,
|
||||
variableColumnName=variable_name,
|
||||
valueColumnName=value_name,
|
||||
)
|
||||
if index is None:
|
||||
unpivoted_native_frame = unpivoted_native_frame.drop(*ids)
|
||||
return self._with_native(unpivoted_native_frame)
|
||||
|
||||
def with_row_index(self, name: str, order_by: Sequence[str]) -> Self:
|
||||
if order_by is None:
|
||||
msg = "Cannot pass `order_by` to `with_row_index` for PySpark-like"
|
||||
raise TypeError(msg)
|
||||
row_index_expr = (
|
||||
self._F.row_number().over(
|
||||
self._Window.partitionBy(self._F.lit(1)).orderBy(*order_by)
|
||||
)
|
||||
- 1
|
||||
).alias(name)
|
||||
return self._with_native(self.native.select(row_index_expr, *self.columns))
|
||||
|
||||
def sink_parquet(self, file: str | Path | BytesIO) -> None:
|
||||
self.native.write.parquet(file)
|
||||
|
||||
@classmethod
|
||||
def _from_compliant_dataframe(
|
||||
cls,
|
||||
frame: CompliantDataFrameAny,
|
||||
/,
|
||||
*,
|
||||
session: SparkSession,
|
||||
implementation: Implementation,
|
||||
version: Version,
|
||||
) -> SparkLikeLazyFrame:
|
||||
from importlib.util import find_spec
|
||||
|
||||
impl = implementation
|
||||
is_spark_v4 = (not impl.is_sqlframe()) and impl._backend_version() >= (4, 0, 0)
|
||||
if is_spark_v4: # pragma: no cover
|
||||
# pyspark.sql requires pyarrow to be installed from v4.0.0
|
||||
# and since v4.0.0 the input to `createDataFrame` can be a PyArrow Table.
|
||||
data: Any = frame.to_arrow()
|
||||
elif find_spec("pandas"):
|
||||
data = frame.to_pandas()
|
||||
else: # pragma: no cover
|
||||
data = tuple(frame.iter_rows(named=True, buffer_size=512))
|
||||
|
||||
return cls(
|
||||
session.createDataFrame(data),
|
||||
version=version,
|
||||
implementation=implementation,
|
||||
validate_backend_version=True,
|
||||
)
|
||||
|
||||
gather_every = not_implemented.deprecated(
|
||||
"`LazyFrame.gather_every` is deprecated and will be removed in a future version."
|
||||
)
|
||||
join_asof = not_implemented()
|
||||
tail = not_implemented.deprecated(
|
||||
"`LazyFrame.tail` is deprecated and will be removed in a future version."
|
||||
)
|
||||
391
lib/python3.11/site-packages/narwhals/_spark_like/expr.py
Normal file
391
lib/python3.11/site-packages/narwhals/_spark_like/expr.py
Normal file
@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import operator
|
||||
from typing import TYPE_CHECKING, Any, Callable, ClassVar, Literal, cast
|
||||
|
||||
from narwhals._expression_parsing import ExprKind, ExprMetadata
|
||||
from narwhals._spark_like.expr_dt import SparkLikeExprDateTimeNamespace
|
||||
from narwhals._spark_like.expr_list import SparkLikeExprListNamespace
|
||||
from narwhals._spark_like.expr_str import SparkLikeExprStringNamespace
|
||||
from narwhals._spark_like.expr_struct import SparkLikeExprStructNamespace
|
||||
from narwhals._spark_like.utils import (
|
||||
import_functions,
|
||||
import_native_dtypes,
|
||||
import_window,
|
||||
narwhals_to_native_dtype,
|
||||
true_divide,
|
||||
)
|
||||
from narwhals._sql.expr import SQLExpr
|
||||
from narwhals._utils import Implementation, Version, not_implemented, zip_strict
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Iterator, Mapping, Sequence
|
||||
|
||||
from sqlframe.base.column import Column
|
||||
from sqlframe.base.window import Window, WindowSpec
|
||||
from typing_extensions import Self, TypeAlias
|
||||
|
||||
from narwhals._compliant import WindowInputs
|
||||
from narwhals._compliant.typing import (
|
||||
AliasNames,
|
||||
EvalNames,
|
||||
EvalSeries,
|
||||
WindowFunction,
|
||||
)
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
from narwhals._utils import _LimitedContext
|
||||
from narwhals.typing import FillNullStrategy, IntoDType, NonNestedLiteral, RankMethod
|
||||
|
||||
NativeRankMethod: TypeAlias = Literal["rank", "dense_rank", "row_number"]
|
||||
SparkWindowFunction = WindowFunction[SparkLikeLazyFrame, Column]
|
||||
SparkWindowInputs = WindowInputs[Column]
|
||||
|
||||
|
||||
class SparkLikeExpr(SQLExpr["SparkLikeLazyFrame", "Column"]):
|
||||
def __init__(
|
||||
self,
|
||||
call: EvalSeries[SparkLikeLazyFrame, Column],
|
||||
window_function: SparkWindowFunction | None = None,
|
||||
*,
|
||||
evaluate_output_names: EvalNames[SparkLikeLazyFrame],
|
||||
alias_output_names: AliasNames | None,
|
||||
version: Version,
|
||||
implementation: Implementation,
|
||||
) -> None:
|
||||
self._call = call
|
||||
self._evaluate_output_names = evaluate_output_names
|
||||
self._alias_output_names = alias_output_names
|
||||
self._version = version
|
||||
self._implementation = implementation
|
||||
self._metadata: ExprMetadata | None = None
|
||||
self._window_function: SparkWindowFunction | None = window_function
|
||||
|
||||
_REMAP_RANK_METHOD: ClassVar[Mapping[RankMethod, NativeRankMethod]] = {
|
||||
"min": "rank",
|
||||
"max": "rank",
|
||||
"average": "rank",
|
||||
"dense": "dense_rank",
|
||||
"ordinal": "row_number",
|
||||
}
|
||||
|
||||
def _count_star(self) -> Column:
|
||||
return self._F.count("*")
|
||||
|
||||
def _window_expression(
|
||||
self,
|
||||
expr: Column,
|
||||
partition_by: Sequence[str | Column] = (),
|
||||
order_by: Sequence[str | Column] = (),
|
||||
rows_start: int | None = None,
|
||||
rows_end: int | None = None,
|
||||
*,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Column:
|
||||
window = self.partition_by(*partition_by)
|
||||
if order_by:
|
||||
window = window.orderBy(
|
||||
*self._sort(*order_by, descending=descending, nulls_last=nulls_last)
|
||||
)
|
||||
if rows_start is not None and rows_end is not None:
|
||||
window = window.rowsBetween(rows_start, rows_end)
|
||||
elif rows_end is not None:
|
||||
window = window.rowsBetween(self._Window.unboundedPreceding, rows_end)
|
||||
elif rows_start is not None: # pragma: no cover
|
||||
window = window.rowsBetween(rows_start, self._Window.unboundedFollowing)
|
||||
return expr.over(window)
|
||||
|
||||
def broadcast(self, kind: Literal[ExprKind.AGGREGATION, ExprKind.LITERAL]) -> Self:
|
||||
if kind is ExprKind.LITERAL:
|
||||
return self
|
||||
return self.over([self._F.lit(1)], [])
|
||||
|
||||
@property
|
||||
def _F(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import functions
|
||||
|
||||
return functions
|
||||
return import_functions(self._implementation)
|
||||
|
||||
@property
|
||||
def _native_dtypes(self): # type: ignore[no-untyped-def] # noqa: ANN202
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base import types
|
||||
|
||||
return types
|
||||
return import_native_dtypes(self._implementation)
|
||||
|
||||
@property
|
||||
def _Window(self) -> type[Window]:
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.window import Window
|
||||
|
||||
return Window
|
||||
return import_window(self._implementation)
|
||||
|
||||
def _sort(
|
||||
self,
|
||||
*cols: Column | str,
|
||||
descending: Sequence[bool] | None = None,
|
||||
nulls_last: Sequence[bool] | None = None,
|
||||
) -> Iterator[Column]:
|
||||
F = self._F
|
||||
descending = descending or [False] * len(cols)
|
||||
nulls_last = nulls_last or [False] * len(cols)
|
||||
mapping = {
|
||||
(False, False): F.asc_nulls_first,
|
||||
(False, True): F.asc_nulls_last,
|
||||
(True, False): F.desc_nulls_first,
|
||||
(True, True): F.desc_nulls_last,
|
||||
}
|
||||
yield from (
|
||||
mapping[(_desc, _nulls_last)](col)
|
||||
for col, _desc, _nulls_last in zip_strict(cols, descending, nulls_last)
|
||||
)
|
||||
|
||||
def partition_by(self, *cols: Column | str) -> WindowSpec:
|
||||
"""Wraps `Window().partitionBy`, with default and `WindowInputs` handling."""
|
||||
return self._Window.partitionBy(*cols or [self._F.lit(1)])
|
||||
|
||||
def __narwhals_namespace__(self) -> SparkLikeNamespace: # pragma: no cover
|
||||
from narwhals._spark_like.namespace import SparkLikeNamespace
|
||||
|
||||
return SparkLikeNamespace(
|
||||
version=self._version, implementation=self._implementation
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _alias_native(cls, expr: Column, name: str) -> Column:
|
||||
return expr.alias(name)
|
||||
|
||||
@classmethod
|
||||
def from_column_names(
|
||||
cls: type[Self],
|
||||
evaluate_column_names: EvalNames[SparkLikeLazyFrame],
|
||||
/,
|
||||
*,
|
||||
context: _LimitedContext,
|
||||
) -> Self:
|
||||
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
||||
return [df._F.col(col_name) for col_name in evaluate_column_names(df)]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=evaluate_column_names,
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
implementation=context._implementation,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_column_indices(cls, *column_indices: int, context: _LimitedContext) -> Self:
|
||||
def func(df: SparkLikeLazyFrame) -> list[Column]:
|
||||
columns = df.columns
|
||||
return [df._F.col(columns[i]) for i in column_indices]
|
||||
|
||||
return cls(
|
||||
func,
|
||||
evaluate_output_names=cls._eval_names_indices(column_indices),
|
||||
alias_output_names=None,
|
||||
version=context._version,
|
||||
implementation=context._implementation,
|
||||
)
|
||||
|
||||
def __truediv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _truediv(expr: Column, other: Column) -> Column:
|
||||
return true_divide(self._F, expr, other)
|
||||
|
||||
return self._with_binary(_truediv, other)
|
||||
|
||||
def __rtruediv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _rtruediv(expr: Column, other: Column) -> Column:
|
||||
return true_divide(self._F, other, expr)
|
||||
|
||||
return self._with_binary(_rtruediv, other).alias("literal")
|
||||
|
||||
def __floordiv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _floordiv(expr: Column, other: Column) -> Column:
|
||||
return self._F.floor(true_divide(self._F, expr, other))
|
||||
|
||||
return self._with_binary(_floordiv, other)
|
||||
|
||||
def __rfloordiv__(self, other: SparkLikeExpr) -> Self:
|
||||
def _rfloordiv(expr: Column, other: Column) -> Column:
|
||||
return self._F.floor(true_divide(self._F, other, expr))
|
||||
|
||||
return self._with_binary(_rfloordiv, other).alias("literal")
|
||||
|
||||
def __invert__(self) -> Self:
|
||||
invert = cast("Callable[..., Column]", operator.invert)
|
||||
return self._with_elementwise(invert)
|
||||
|
||||
def cast(self, dtype: IntoDType) -> Self:
|
||||
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
|
||||
spark_dtype = narwhals_to_native_dtype(
|
||||
dtype, self._version, self._native_dtypes, df.native.sparkSession
|
||||
)
|
||||
return [expr.cast(spark_dtype) for expr in self(df)]
|
||||
|
||||
def window_f(
|
||||
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
|
||||
) -> Sequence[Column]:
|
||||
spark_dtype = narwhals_to_native_dtype(
|
||||
dtype, self._version, self._native_dtypes, df.native.sparkSession
|
||||
)
|
||||
return [expr.cast(spark_dtype) for expr in self.window_function(df, inputs)]
|
||||
|
||||
return self.__class__(
|
||||
func,
|
||||
window_f,
|
||||
evaluate_output_names=self._evaluate_output_names,
|
||||
alias_output_names=self._alias_output_names,
|
||||
version=self._version,
|
||||
implementation=self._implementation,
|
||||
)
|
||||
|
||||
def median(self) -> Self:
|
||||
def _median(expr: Column) -> Column:
|
||||
if self._implementation in {
|
||||
Implementation.PYSPARK,
|
||||
Implementation.PYSPARK_CONNECT,
|
||||
} and Implementation.PYSPARK._backend_version() < (3, 4): # pragma: no cover
|
||||
# Use percentile_approx with default accuracy parameter (10000)
|
||||
return self._F.percentile_approx(expr.cast("double"), 0.5)
|
||||
|
||||
return self._F.median(expr)
|
||||
|
||||
return self._with_callable(_median)
|
||||
|
||||
def null_count(self) -> Self:
|
||||
def _null_count(expr: Column) -> Column:
|
||||
return self._F.count_if(self._F.isnull(expr))
|
||||
|
||||
return self._with_callable(_null_count)
|
||||
|
||||
def std(self, ddof: int) -> Self:
|
||||
F = self._F
|
||||
if ddof == 0:
|
||||
return self._with_callable(F.stddev_pop)
|
||||
if ddof == 1:
|
||||
return self._with_callable(F.stddev_samp)
|
||||
|
||||
def func(expr: Column) -> Column:
|
||||
n_rows = F.count(expr)
|
||||
return F.stddev_samp(expr) * F.sqrt((n_rows - 1) / (n_rows - ddof))
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def var(self, ddof: int) -> Self:
|
||||
F = self._F
|
||||
if ddof == 0:
|
||||
return self._with_callable(F.var_pop)
|
||||
if ddof == 1:
|
||||
return self._with_callable(F.var_samp)
|
||||
|
||||
def func(expr: Column) -> Column:
|
||||
n_rows = F.count(expr)
|
||||
return F.var_samp(expr) * (n_rows - 1) / (n_rows - ddof)
|
||||
|
||||
return self._with_callable(func)
|
||||
|
||||
def is_finite(self) -> Self:
|
||||
def _is_finite(expr: Column) -> Column:
|
||||
# A value is finite if it's not NaN, and not infinite, while NULLs should be
|
||||
# preserved
|
||||
is_finite_condition = (
|
||||
~self._F.isnan(expr)
|
||||
& (expr != self._F.lit(float("inf")))
|
||||
& (expr != self._F.lit(float("-inf")))
|
||||
)
|
||||
return self._F.when(~self._F.isnull(expr), is_finite_condition).otherwise(
|
||||
None
|
||||
)
|
||||
|
||||
return self._with_elementwise(_is_finite)
|
||||
|
||||
def is_in(self, values: Sequence[Any]) -> Self:
|
||||
def _is_in(expr: Column) -> Column:
|
||||
return expr.isin(values) if values else self._F.lit(False)
|
||||
|
||||
return self._with_elementwise(_is_in)
|
||||
|
||||
def len(self) -> Self:
|
||||
def _len(_expr: Column) -> Column:
|
||||
# Use count(*) to count all rows including nulls
|
||||
return self._F.count("*")
|
||||
|
||||
return self._with_callable(_len)
|
||||
|
||||
def skew(self) -> Self:
|
||||
return self._with_callable(self._F.skewness)
|
||||
|
||||
def kurtosis(self) -> Self:
|
||||
return self._with_callable(self._F.kurtosis)
|
||||
|
||||
def n_unique(self) -> Self:
|
||||
def _n_unique(expr: Column) -> Column:
|
||||
return self._F.count_distinct(expr) + self._F.max(
|
||||
self._F.isnull(expr).cast(self._native_dtypes.IntegerType())
|
||||
)
|
||||
|
||||
return self._with_callable(_n_unique)
|
||||
|
||||
def is_nan(self) -> Self:
|
||||
def _is_nan(expr: Column) -> Column:
|
||||
return self._F.when(self._F.isnull(expr), None).otherwise(self._F.isnan(expr))
|
||||
|
||||
return self._with_elementwise(_is_nan)
|
||||
|
||||
def fill_null(
|
||||
self,
|
||||
value: Self | NonNestedLiteral,
|
||||
strategy: FillNullStrategy | None,
|
||||
limit: int | None,
|
||||
) -> Self:
|
||||
if strategy is not None:
|
||||
|
||||
def _fill_with_strategy(
|
||||
df: SparkLikeLazyFrame, inputs: SparkWindowInputs
|
||||
) -> Sequence[Column]:
|
||||
fn = self._F.last_value if strategy == "forward" else self._F.first_value
|
||||
if strategy == "forward":
|
||||
start = self._Window.unboundedPreceding if limit is None else -limit
|
||||
end = self._Window.currentRow
|
||||
else:
|
||||
start = self._Window.currentRow
|
||||
end = self._Window.unboundedFollowing if limit is None else limit
|
||||
return [
|
||||
fn(expr, ignoreNulls=True).over(
|
||||
self.partition_by(*inputs.partition_by)
|
||||
.orderBy(*self._sort(*inputs.order_by))
|
||||
.rowsBetween(start, end)
|
||||
)
|
||||
for expr in self(df)
|
||||
]
|
||||
|
||||
return self._with_window_function(_fill_with_strategy)
|
||||
|
||||
def _fill_constant(expr: Column, value: Column) -> Column:
|
||||
return self._F.ifnull(expr, value)
|
||||
|
||||
return self._with_elementwise(_fill_constant, value=value)
|
||||
|
||||
@property
|
||||
def str(self) -> SparkLikeExprStringNamespace:
|
||||
return SparkLikeExprStringNamespace(self)
|
||||
|
||||
@property
|
||||
def dt(self) -> SparkLikeExprDateTimeNamespace:
|
||||
return SparkLikeExprDateTimeNamespace(self)
|
||||
|
||||
@property
|
||||
def list(self) -> SparkLikeExprListNamespace:
|
||||
return SparkLikeExprListNamespace(self)
|
||||
|
||||
@property
|
||||
def struct(self) -> SparkLikeExprStructNamespace:
|
||||
return SparkLikeExprStructNamespace(self)
|
||||
|
||||
quantile = not_implemented()
|
||||
192
lib/python3.11/site-packages/narwhals/_spark_like/expr_dt.py
Normal file
192
lib/python3.11/site-packages/narwhals/_spark_like/expr_dt.py
Normal file
@ -0,0 +1,192 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._constants import US_PER_SECOND
|
||||
from narwhals._duration import Interval
|
||||
from narwhals._spark_like.utils import (
|
||||
UNITS_DICT,
|
||||
fetch_session_time_zone,
|
||||
strptime_to_pyspark_format,
|
||||
)
|
||||
from narwhals._sql.expr_dt import SQLExprDateTimeNamesSpace
|
||||
from narwhals._utils import not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlframe.base.column import Column
|
||||
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeExprDateTimeNamespace(SQLExprDateTimeNamesSpace["SparkLikeExpr"]):
|
||||
def _weekday(self, expr: Column) -> Column:
|
||||
# PySpark's dayofweek returns 1-7 for Sunday-Saturday
|
||||
return (self.compliant._F.dayofweek(expr) + 6) % 7
|
||||
|
||||
def to_string(self, format: str) -> SparkLikeExpr:
|
||||
F = self.compliant._F
|
||||
|
||||
def _to_string(expr: Column) -> Column:
|
||||
# Handle special formats
|
||||
if format == "%G-W%V":
|
||||
return self._format_iso_week(expr)
|
||||
if format == "%G-W%V-%u":
|
||||
return self._format_iso_week_with_day(expr)
|
||||
|
||||
format_, suffix = self._format_microseconds(expr, format)
|
||||
|
||||
# Convert Python format to PySpark format
|
||||
pyspark_fmt = strptime_to_pyspark_format(format_)
|
||||
|
||||
result = F.date_format(expr, pyspark_fmt)
|
||||
if "T" in format_:
|
||||
# `strptime_to_pyspark_format` replaces "T" with " " since pyspark
|
||||
# does not support the literal "T" in `date_format`.
|
||||
# If no other spaces are in the given format, then we can revert this
|
||||
# operation, otherwise we raise an exception.
|
||||
if " " not in format_:
|
||||
result = F.replace(result, F.lit(" "), F.lit("T"))
|
||||
else: # pragma: no cover
|
||||
msg = (
|
||||
"`dt.to_string` with a format that contains both spaces and "
|
||||
" the literal 'T' is not supported for spark-like backends."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
return F.concat(result, *suffix)
|
||||
|
||||
return self.compliant._with_elementwise(_to_string)
|
||||
|
||||
def millisecond(self) -> SparkLikeExpr:
|
||||
def _millisecond(expr: Column) -> Column:
|
||||
return self.compliant._F.floor(
|
||||
(self.compliant._F.unix_micros(expr) % US_PER_SECOND) / 1000
|
||||
)
|
||||
|
||||
return self.compliant._with_elementwise(_millisecond)
|
||||
|
||||
def microsecond(self) -> SparkLikeExpr:
|
||||
def _microsecond(expr: Column) -> Column:
|
||||
return self.compliant._F.unix_micros(expr) % US_PER_SECOND
|
||||
|
||||
return self.compliant._with_elementwise(_microsecond)
|
||||
|
||||
def nanosecond(self) -> SparkLikeExpr:
|
||||
def _nanosecond(expr: Column) -> Column:
|
||||
return (self.compliant._F.unix_micros(expr) % US_PER_SECOND) * 1000
|
||||
|
||||
return self.compliant._with_elementwise(_nanosecond)
|
||||
|
||||
def weekday(self) -> SparkLikeExpr:
|
||||
return self.compliant._with_elementwise(self._weekday)
|
||||
|
||||
def truncate(self, every: str) -> SparkLikeExpr:
|
||||
interval = Interval.parse(every)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if multiple != 1:
|
||||
msg = f"Only multiple 1 is currently supported for Spark-like.\nGot {multiple!s}."
|
||||
raise ValueError(msg)
|
||||
if unit == "ns":
|
||||
msg = "Truncating to nanoseconds is not yet supported for Spark-like."
|
||||
raise NotImplementedError(msg)
|
||||
format = UNITS_DICT[unit]
|
||||
|
||||
def _truncate(expr: Column) -> Column:
|
||||
return self.compliant._F.date_trunc(format, expr)
|
||||
|
||||
return self.compliant._with_elementwise(_truncate)
|
||||
|
||||
def offset_by(self, by: str) -> SparkLikeExpr:
|
||||
interval = Interval.parse_no_constraints(by)
|
||||
multiple, unit = interval.multiple, interval.unit
|
||||
if unit == "ns": # pragma: no cover
|
||||
msg = "Offsetting by nanoseconds is not yet supported for Spark-like."
|
||||
raise NotImplementedError(msg)
|
||||
|
||||
F = self.compliant._F
|
||||
|
||||
def _offset_by(expr: Column) -> Column:
|
||||
# https://github.com/eakmanrq/sqlframe/issues/441
|
||||
return F.timestamp_add( # pyright: ignore[reportAttributeAccessIssue]
|
||||
UNITS_DICT[unit], F.lit(multiple), expr
|
||||
)
|
||||
|
||||
return self.compliant._with_callable(_offset_by)
|
||||
|
||||
def _no_op_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
|
||||
def func(df: SparkLikeLazyFrame) -> Sequence[Column]:
|
||||
native_series_list = self.compliant(df)
|
||||
conn_time_zone = fetch_session_time_zone(df.native.sparkSession)
|
||||
if conn_time_zone != time_zone:
|
||||
msg = (
|
||||
"PySpark stores the time zone in the session, rather than in the "
|
||||
f"data type, so changing the timezone to anything other than {conn_time_zone} "
|
||||
" (the current session time zone) is not supported."
|
||||
)
|
||||
raise NotImplementedError(msg)
|
||||
return native_series_list
|
||||
|
||||
return self.compliant.__class__(
|
||||
func,
|
||||
evaluate_output_names=self.compliant._evaluate_output_names,
|
||||
alias_output_names=self.compliant._alias_output_names,
|
||||
version=self.compliant._version,
|
||||
implementation=self.compliant._implementation,
|
||||
)
|
||||
|
||||
def convert_time_zone(self, time_zone: str) -> SparkLikeExpr: # pragma: no cover
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
def replace_time_zone(
|
||||
self, time_zone: str | None
|
||||
) -> SparkLikeExpr: # pragma: no cover
|
||||
if time_zone is None:
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: expr.cast("timestamp_ntz")
|
||||
)
|
||||
return self._no_op_time_zone(time_zone)
|
||||
|
||||
def _format_iso_week_with_day(self, expr: Column) -> Column:
|
||||
"""Format datetime as ISO week string with day."""
|
||||
F = self.compliant._F
|
||||
|
||||
year = F.date_format(expr, "yyyy")
|
||||
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
|
||||
day = self._weekday(expr)
|
||||
return F.concat(year, F.lit("-W"), week, F.lit("-"), day.cast("string"))
|
||||
|
||||
def _format_iso_week(self, expr: Column) -> Column:
|
||||
"""Format datetime as ISO week string."""
|
||||
F = self.compliant._F
|
||||
|
||||
year = F.date_format(expr, "yyyy")
|
||||
week = F.lpad(F.weekofyear(expr).cast("string"), 2, "0")
|
||||
return F.concat(year, F.lit("-W"), week)
|
||||
|
||||
def _format_microseconds(
|
||||
self, expr: Column, format: str
|
||||
) -> tuple[str, tuple[Column, ...]]:
|
||||
"""Format microseconds if present in format, else it's a no-op."""
|
||||
F = self.compliant._F
|
||||
|
||||
suffix: tuple[Column, ...]
|
||||
if format.endswith((".%f", "%.f")):
|
||||
import re
|
||||
|
||||
micros = F.unix_micros(expr) % US_PER_SECOND
|
||||
micros_str = F.lpad(micros.cast("string"), 6, "0")
|
||||
suffix = (F.lit("."), micros_str)
|
||||
format_ = re.sub(r"(.%|%.)f$", "", format)
|
||||
return format_, suffix
|
||||
|
||||
return format, ()
|
||||
|
||||
timestamp = not_implemented()
|
||||
total_seconds = not_implemented()
|
||||
total_minutes = not_implemented()
|
||||
total_milliseconds = not_implemented()
|
||||
total_microseconds = not_implemented()
|
||||
total_nanoseconds = not_implemented()
|
||||
@ -0,0 +1,35 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import ListNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.column import Column
|
||||
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
from narwhals.typing import NonNestedLiteral
|
||||
|
||||
|
||||
class SparkLikeExprListNamespace(
|
||||
LazyExprNamespace["SparkLikeExpr"], ListNamespace["SparkLikeExpr"]
|
||||
):
|
||||
def len(self) -> SparkLikeExpr:
|
||||
return self.compliant._with_elementwise(self.compliant._F.array_size)
|
||||
|
||||
def unique(self) -> SparkLikeExpr:
|
||||
return self.compliant._with_elementwise(self.compliant._F.array_distinct)
|
||||
|
||||
def contains(self, item: NonNestedLiteral) -> SparkLikeExpr:
|
||||
def func(expr: Column) -> Column:
|
||||
F = self.compliant._F
|
||||
return F.array_contains(expr, F.lit(item))
|
||||
|
||||
return self.compliant._with_elementwise(func)
|
||||
|
||||
def get(self, index: int) -> SparkLikeExpr:
|
||||
def _get(expr: Column) -> Column:
|
||||
return expr.getItem(index)
|
||||
|
||||
return self.compliant._with_elementwise(_get)
|
||||
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from functools import partial
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._spark_like.utils import strptime_to_pyspark_format
|
||||
from narwhals._sql.expr_str import SQLExprStringNamespace
|
||||
from narwhals._utils import _is_naive_format, not_implemented
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeExprStringNamespace(SQLExprStringNamespace["SparkLikeExpr"]):
|
||||
def to_datetime(self, format: str | None) -> SparkLikeExpr:
|
||||
F = self.compliant._F
|
||||
if not format:
|
||||
function = F.to_timestamp
|
||||
elif _is_naive_format(format):
|
||||
function = partial(
|
||||
F.to_timestamp_ntz, format=F.lit(strptime_to_pyspark_format(format))
|
||||
)
|
||||
else:
|
||||
format = strptime_to_pyspark_format(format)
|
||||
function = partial(F.to_timestamp, format=format)
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: function(F.replace(expr, F.lit("T"), F.lit(" ")))
|
||||
)
|
||||
|
||||
def to_date(self, format: str | None) -> SparkLikeExpr:
|
||||
F = self.compliant._F
|
||||
return self.compliant._with_elementwise(
|
||||
lambda expr: F.to_date(expr, format=strptime_to_pyspark_format(format))
|
||||
)
|
||||
|
||||
replace = not_implemented()
|
||||
@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._compliant import LazyExprNamespace
|
||||
from narwhals._compliant.any_namespace import StructNamespace
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from sqlframe.base.column import Column
|
||||
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeExprStructNamespace(
|
||||
LazyExprNamespace["SparkLikeExpr"], StructNamespace["SparkLikeExpr"]
|
||||
):
|
||||
def field(self, name: str) -> SparkLikeExpr:
|
||||
def func(expr: Column) -> Column:
|
||||
return expr.getField(name)
|
||||
|
||||
return self.compliant._with_elementwise(func).alias(name)
|
||||
@ -0,0 +1,37 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from narwhals._sql.group_by import SQLGroupBy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlframe.base.column import Column # noqa: F401
|
||||
|
||||
from narwhals._spark_like.dataframe import SparkLikeLazyFrame
|
||||
from narwhals._spark_like.expr import SparkLikeExpr
|
||||
|
||||
|
||||
class SparkLikeLazyGroupBy(SQLGroupBy["SparkLikeLazyFrame", "SparkLikeExpr", "Column"]):
|
||||
def __init__(
|
||||
self,
|
||||
df: SparkLikeLazyFrame,
|
||||
keys: Sequence[SparkLikeExpr] | Sequence[str],
|
||||
/,
|
||||
*,
|
||||
drop_null_keys: bool,
|
||||
) -> None:
|
||||
frame, self._keys, self._output_key_names = self._parse_keys(df, keys=keys)
|
||||
self._compliant_frame = frame.drop_nulls(self._keys) if drop_null_keys else frame
|
||||
|
||||
def agg(self, *exprs: SparkLikeExpr) -> SparkLikeLazyFrame:
|
||||
result = (
|
||||
self.compliant.native.groupBy(*self._keys).agg(*agg_columns)
|
||||
if (agg_columns := list(self._evaluate_exprs(exprs)))
|
||||
else self.compliant.native.select(*self._keys).dropDuplicates()
|
||||
)
|
||||
|
||||
return self.compliant._with_native(result).rename(
|
||||
dict(zip(self._keys, self._output_key_names))
|
||||
)
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user