Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions docs/source/user-guide/common-operations/udf-and-udfa.rst
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,17 @@ also see how the inputs to ``update`` and ``merge`` differ.

df.aggregate([], [my_udaf(col("a"), col("b")).alias("col_diff")])

FAQ
^^^

**How do I return a list from a UDAF?**
Use a list-valued scalar and declare list types for both the return and state
definitions. Returning a ``pyarrow.Array`` from ``evaluate`` is not supported
unless you convert it to a list scalar. For example, in ``evaluate`` you can
return ``pa.scalar([...], type=pa.list_(pa.timestamp("ms")))`` and register the
UDAF with ``return_type=pa.list_(pa.timestamp("ms"))`` and
``state_type=[pa.list_(pa.timestamp("ms"))]``.

Window Functions
----------------

Expand Down
16 changes: 15 additions & 1 deletion python/datafusion/user_defined.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,7 +310,21 @@ def merge(self, states: list[pa.Array]) -> None:

@abstractmethod
def evaluate(self) -> pa.Scalar:
"""Return the resultant value."""
"""Return the resultant value.

If you need to return a list, wrap it in a scalar with the correct
list type, for example::

import pyarrow as pa

return pa.scalar(
[pa.scalar("2024-01-01T00:00:00Z")],
type=pa.list_(pa.timestamp("ms")),
)

Returning a ``pyarrow.Array`` from ``evaluate`` is not supported unless
you explicitly convert it to a list-valued scalar.
"""


class AggregateUDFExportable(Protocol):
Expand Down
48 changes: 48 additions & 0 deletions python/tests/test_udaf.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

from __future__ import annotations

from datetime import datetime, timezone

import pyarrow as pa
import pyarrow.compute as pc
import pytest
Expand Down Expand Up @@ -58,6 +60,25 @@ def state(self) -> list[pa.Scalar]:
return [self._sum]


class CollectTimestamps(Accumulator):
def __init__(self):
self._values: list[datetime] = []

def state(self) -> list[pa.Scalar]:
return [pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))]

def update(self, values: pa.Array) -> None:
self._values.extend(values.to_pylist())

def merge(self, states: list[pa.Array]) -> None:
for state in states[0].to_pylist():
if state is not None:
self._values.extend(state)

def evaluate(self) -> pa.Scalar:
return pa.scalar(self._values, type=pa.list_(pa.timestamp("ns")))


@pytest.fixture
def df(ctx):
# create a RecordBatch and a new DataFrame from it
Expand Down Expand Up @@ -217,3 +238,30 @@ def test_register_udaf(ctx, df) -> None:
df_result = ctx.sql("select summarize(b) from test_table")

assert df_result.collect()[0][0][0].as_py() == 14.0


def test_udaf_list_timestamp_return(ctx) -> None:
timestamps = [
datetime(2024, 1, 1, tzinfo=timezone.utc),
datetime(2024, 1, 2, tzinfo=timezone.utc),
]
batch = pa.RecordBatch.from_arrays(
[pa.array(timestamps, type=pa.timestamp("ns"))],
names=["ts"],
)
df = ctx.create_dataframe([[batch]], name="timestamp_table")

collect = udaf(
CollectTimestamps,
pa.timestamp("ns"),
pa.list_(pa.timestamp("ns")),
[pa.list_(pa.timestamp("ns"))],
volatility="immutable",
)

result = df.aggregate([], [collect(column("ts"))]).collect()[0]

assert result.column(0) == pa.array(
[timestamps],
type=pa.list_(pa.timestamp("ns")),
)
42 changes: 24 additions & 18 deletions src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@

use std::sync::Arc;

use datafusion::arrow::array::{Array, ArrayRef};
use datafusion::arrow::array::ArrayRef;
use datafusion::arrow::datatypes::DataType;
use datafusion::arrow::pyarrow::{PyArrowType, ToPyArrow};
use datafusion::common::ScalarValue;
Expand All @@ -32,7 +32,7 @@ use pyo3::types::{PyCapsule, PyTuple};
use crate::common::data_type::PyScalarValue;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
use crate::expr::PyExpr;
use crate::utils::{parse_volatility, validate_pycapsule};
use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule};

#[derive(Debug)]
struct RustAccumulator {
Expand All @@ -47,24 +47,30 @@ impl RustAccumulator {

impl Accumulator for RustAccumulator {
fn state(&mut self) -> Result<Vec<ScalarValue>> {
Python::attach(|py| {
self.accum
.bind(py)
.call_method0("state")?
.extract::<Vec<PyScalarValue>>()
Python::attach(|py| -> PyResult<Vec<ScalarValue>> {
let values = self.accum.bind(py).call_method0("state")?;
let mut scalars = Vec::new();
for item in values.try_iter()? {
let item: Bound<'_, PyAny> = item?;
let scalar = match item.extract::<PyScalarValue>() {
Ok(py_scalar) => py_scalar.0,
Err(_) => py_obj_to_scalar_value(py, item.unbind())?,
};
scalars.push(scalar);
}
Ok(scalars)
})
.map(|v| v.into_iter().map(|x| x.0).collect())
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}

fn evaluate(&mut self) -> Result<ScalarValue> {
Python::attach(|py| {
self.accum
.bind(py)
.call_method0("evaluate")?
.extract::<PyScalarValue>()
Python::attach(|py| -> PyResult<ScalarValue> {
let value = self.accum.bind(py).call_method0("evaluate")?;
match value.extract::<PyScalarValue>() {
Ok(py_scalar) => Ok(py_scalar.0),
Err(_) => py_obj_to_scalar_value(py, value.unbind()),
}
})
.map(|v| v.0)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}

Expand All @@ -73,7 +79,7 @@ impl Accumulator for RustAccumulator {
// 1. cast args to Pyarrow array
let py_args = values
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.map(|arg| arg.to_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;

Expand All @@ -94,7 +100,7 @@ impl Accumulator for RustAccumulator {
.iter()
.map(|state| {
state
.into_data()
.to_data()
.to_pyarrow(py)
.map_err(|e| DataFusionError::Execution(format!("{e}")))
})
Expand All @@ -119,7 +125,7 @@ impl Accumulator for RustAccumulator {
// 1. cast args to Pyarrow array
let py_args = values
.iter()
.map(|arg| arg.into_data().to_pyarrow(py).unwrap())
.map(|arg| arg.to_data().to_pyarrow(py).unwrap())
.collect::<Vec<_>>();
let py_args = PyTuple::new(py, py_args).map_err(to_datafusion_err)?;

Expand All @@ -144,7 +150,7 @@ impl Accumulator for RustAccumulator {
}

pub fn to_rust_accumulator(accum: Py<PyAny>) -> AccumulatorFactoryFunction {
Arc::new(move |_| -> Result<Box<dyn Accumulator>> {
Arc::new(move |_args| -> Result<Box<dyn Accumulator>> {
let accum = Python::attach(|py| {
accum
.call0(py)
Expand Down
39 changes: 39 additions & 0 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,10 @@ use std::future::Future;
use std::sync::{Arc, OnceLock};
use std::time::Duration;

use datafusion::arrow::array::{make_array, ArrayData, ListArray};
use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
use datafusion::arrow::datatypes::Field;
use datafusion::arrow::pyarrow::FromPyArrow;
use datafusion::common::ScalarValue;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionContext;
Expand Down Expand Up @@ -203,6 +207,41 @@ pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<Sca
// convert Python object to PyScalarValue to ScalarValue

let pa = py.import("pyarrow")?;
let scalar_attr = pa.getattr("Scalar")?;
let scalar_type = scalar_attr.downcast::<PyType>()?;
let array_attr = pa.getattr("Array")?;
let array_type = array_attr.downcast::<PyType>()?;
let chunked_array_attr = pa.getattr("ChunkedArray")?;
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;

let obj_ref = obj.bind(py);

if obj_ref.is_instance(scalar_type)? {
let py_scalar = PyScalarValue::extract_bound(obj_ref)
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
return Ok(py_scalar.into());
}

if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? {
let array_obj = if obj_ref.is_instance(chunked_array_type)? {
obj_ref.call_method0("combine_chunks")?.unbind()
} else {
obj_ref.clone().unbind()
};
let array_bound = array_obj.bind(py);
let array_data = ArrayData::from_pyarrow_bound(array_bound)
.map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?;
let array = make_array(array_data);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32]));
let list_array = Arc::new(ListArray::new(
Arc::new(Field::new_list_field(array.data_type().clone(), true)),
offsets,
array,
None,
));

return Ok(ScalarValue::List(list_array));
}

// Convert Python object to PyArrow scalar
let scalar = pa.call_method1("scalar", (obj,))?;
Expand Down