diff --git a/python/datafusion/expr.py b/python/datafusion/expr.py index 695fe7c49..9df58f52a 100644 --- a/python/datafusion/expr.py +++ b/python/datafusion/expr.py @@ -562,8 +562,6 @@ def literal(value: Any) -> Expr: """ if isinstance(value, str): value = pa.scalar(value, type=pa.string_view()) - if not isinstance(value, pa.Scalar): - value = pa.scalar(value) return Expr(expr_internal.RawExpr.literal(value)) @staticmethod @@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr: """ if isinstance(value, str): value = pa.scalar(value, type=pa.string_view()) - value = value if isinstance(value, pa.Scalar) else pa.scalar(value) return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata)) diff --git a/python/tests/test_expr.py b/python/tests/test_expr.py index 39e48f7c3..6ff3f4004 100644 --- a/python/tests/test_expr.py +++ b/python/tests/test_expr.py @@ -20,6 +20,8 @@ from datetime import date, datetime, time, timezone from decimal import Decimal +import arro3.core +import nanoarrow import pyarrow as pa import pytest from datafusion import ( @@ -980,6 +982,34 @@ def test_literal_metadata(ctx): assert expected_field.metadata == actual_field.metadata +def test_scalar_conversion() -> None: + expected_value = lit(1) + assert str(expected_value) == "Expr(Int64(1))" + + # Test pyarrow imports + assert expected_value == lit(pa.scalar(1)) + assert expected_value == lit(pa.scalar(1, type=pa.int32())) + + # Test nanoarrow + na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0] + assert expected_value == lit(na_scalar) + + # Test pyo3 + arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32()) + assert expected_value == lit(arro3_scalar) + + expected_value = lit([1, 2, 3]) + assert str(expected_value) == "Expr(List([1, 2, 3]))" + + assert expected_value == lit(pa.scalar([1, 2, 3])) + + na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32()) + assert expected_value == lit(na_array) + + arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32()) + assert expected_value == lit(arro3_array) + + def test_ensure_expr(): e = col("a") assert ensure_expr(e) is e.expr diff --git a/src/config.rs b/src/config.rs index 583dea7ef..38936e6c5 100644 --- a/src/config.rs +++ b/src/config.rs @@ -22,8 +22,8 @@ use parking_lot::RwLock; use pyo3::prelude::*; use pyo3::types::*; +use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionResult; -use crate::utils::py_obj_to_scalar_value; #[pyclass(name = "Config", module = "datafusion", subclass, frozen)] #[derive(Clone)] pub(crate) struct PyConfig { @@ -65,9 +65,9 @@ impl PyConfig { /// Set a configuration option pub fn set(&self, key: &str, value: Py, py: Python) -> PyDataFusionResult<()> { - let scalar_value = py_obj_to_scalar_value(py, value)?; + let scalar_value: PyScalarValue = value.extract(py)?; let mut options = self.config.write(); - options.set(key, scalar_value.to_string().as_str())?; + options.set(key, scalar_value.0.to_string().as_str())?; Ok(()) } diff --git a/src/dataframe.rs b/src/dataframe.rs index 94105d7ea..0b6eaf2a0 100644 --- a/src/dataframe.rs +++ b/src/dataframe.rs @@ -48,6 +48,7 @@ use pyo3::pybacked::PyBackedStr; use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods}; use pyo3::PyErr; +use crate::common::data_type::PyScalarValue; use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::expr::sort_expr::{to_sort_expressions, PySortExpr}; use crate::expr::PyExpr; @@ -55,9 +56,7 @@ use crate::physical_plan::PyExecutionPlan; use crate::record_batch::{poll_next_batch, PyRecordBatchStream}; use crate::sql::logical::PyLogicalPlan; use crate::table::{PyTable, TempViewTable}; -use crate::utils::{ - is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, wait_for_future, -}; +use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future}; /// File-level static CStr for the Arrow array stream capsule name. static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream"); @@ -1191,14 +1190,14 @@ impl PyDataFrame { columns: Option>, py: Python, ) -> PyDataFusionResult { - let scalar_value = py_obj_to_scalar_value(py, value)?; + let scalar_value: PyScalarValue = value.extract(py)?; let cols = match columns { Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(), None => Vec::new(), // Empty vector means fill null for all columns }; - let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?; + let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?; Ok(Self::new(df)) } } diff --git a/src/pyarrow_util.rs b/src/pyarrow_util.rs index 264cfd342..6221be1ad 100644 --- a/src/pyarrow_util.rs +++ b/src/pyarrow_util.rs @@ -17,8 +17,13 @@ //! Conversions between PyArrow and DataFusion types -use arrow::array::{Array, ArrayData}; +use std::sync::Arc; + +use arrow::array::{make_array, Array, ArrayData, ListArray}; +use arrow::buffer::OffsetBuffer; +use arrow::datatypes::Field; use arrow::pyarrow::{FromPyArrow, ToPyArrow}; +use datafusion::common::exec_err; use datafusion::scalar::ScalarValue; use pyo3::types::{PyAnyMethods, PyList}; use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; @@ -26,21 +31,113 @@ use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python}; use crate::common::data_type::PyScalarValue; use crate::errors::PyDataFusionError; +fn pyobj_extract_scalar_via_capsule( + value: &Bound<'_, PyAny>, + as_list_array: bool, +) -> PyResult { + let array_data = ArrayData::from_pyarrow_bound(value)?; + let array = make_array(array_data); + + if as_list_array { + let field = Arc::new(Field::new_list_field( + array.data_type().clone(), + array.nulls().is_some(), + )); + let offsets = OffsetBuffer::from_lengths(vec![array.len()]); + let list_array = ListArray::new(field, offsets, array, None); + Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array)))) + } else { + let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + Ok(PyScalarValue(scalar)) + } +} + impl FromPyArrow for PyScalarValue { fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult { let py = value.py(); - let typ = value.getattr("type")?; + let pyarrow_mod = py.import("pyarrow"); - // construct pyarrow array from the python value and pyarrow type - let factory = py.import("pyarrow")?.getattr("array")?; - let args = PyList::new(py, [value])?; - let array = factory.call1((args, typ))?; + // Is it a PyArrow object? + if let Ok(pa) = pyarrow_mod.as_ref() { + let scalar_type = pa.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { + let typ = value.getattr("type")?; - // convert the pyarrow array to rust array using C data interface - let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?); - let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + // construct pyarrow array from the python value and pyarrow type + let factory = py.import("pyarrow")?.getattr("array")?; + let args = PyList::new(py, [value])?; + let array = factory.call1((args, typ))?; - Ok(PyScalarValue(scalar)) + return pyobj_extract_scalar_via_capsule(&array, false); + } + + let array_type = pa.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Is it a NanoArrow scalar? + if let Ok(na) = py.import("nanoarrow") { + let type_name = value.get_type().repr()?; + if type_name.contains("nanoarrow")? && type_name.contains("Scalar")? { + return pyobj_extract_scalar_via_capsule(value, false); + } + let array_type = na.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Is it a arro3 scalar? + if let Ok(arro3) = py.import("arro3").and_then(|arro3| arro3.getattr("core")) { + let scalar_type = arro3.getattr("Scalar")?; + if value.is_instance(&scalar_type)? { + return pyobj_extract_scalar_via_capsule(value, false); + } + let array_type = arro3.getattr("Array")?; + if value.is_instance(&array_type)? { + return pyobj_extract_scalar_via_capsule(value, true); + } + } + + // Does it have a PyCapsule interface but isn't one of our known libraries? + // If so do our "best guess". Try checking type name, and if that fails + // return a single value if the length is 1 and return a List value otherwise + if value.hasattr("__arrow_c_array__")? { + let type_name = value.get_type().repr()?; + if type_name.contains("Scalar")? { + return pyobj_extract_scalar_via_capsule(value, false); + } + if type_name.contains("Array")? { + return pyobj_extract_scalar_via_capsule(value, true); + } + + let array_data = ArrayData::from_pyarrow_bound(value)?; + let array = make_array(array_data); + if array.len() == 1 { + let scalar = + ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?; + return Ok(PyScalarValue(scalar)); + } else { + let field = Arc::new(Field::new_list_field( + array.data_type().clone(), + array.nulls().is_some(), + )); + let offsets = OffsetBuffer::from_lengths(vec![array.len()]); + let list_array = ListArray::new(field, offsets, array, None); + return Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array)))); + } + } + + // Last attempt - try to create a PyArrow scalar from a plain Python object + if let Ok(pa) = pyarrow_mod.as_ref() { + let scalar = pa.call_method1("scalar", (value,))?; + + PyScalarValue::from_pyarrow_bound(&scalar) + } else { + exec_err!("Unable to import scalar value").map_err(PyDataFusionError::from)? + } } } diff --git a/src/udaf.rs b/src/udaf.rs index 883170adf..24ef1f6d3 100644 --- a/src/udaf.rs +++ b/src/udaf.rs @@ -32,7 +32,7 @@ use pyo3::types::{PyCapsule, PyTuple}; use crate::common::data_type::PyScalarValue; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult}; use crate::expr::PyExpr; -use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule}; +use crate::utils::{parse_volatility, validate_pycapsule}; #[derive(Debug)] struct RustAccumulator { @@ -52,10 +52,7 @@ impl Accumulator for RustAccumulator { let mut scalars = Vec::new(); for item in values.try_iter()? { let item: Bound<'_, PyAny> = item?; - let scalar = match item.extract::() { - Ok(py_scalar) => py_scalar.0, - Err(_) => py_obj_to_scalar_value(py, item.unbind())?, - }; + let scalar = item.extract::()?.0; scalars.push(scalar); } Ok(scalars) @@ -66,10 +63,7 @@ impl Accumulator for RustAccumulator { fn evaluate(&mut self) -> Result { Python::attach(|py| -> PyResult { let value = self.accum.bind(py).call_method0("evaluate")?; - match value.extract::() { - Ok(py_scalar) => Ok(py_scalar.0), - Err(_) => py_obj_to_scalar_value(py, value.unbind()), - } + value.extract::().map(|v| v.0) }) .map_err(|e| DataFusionError::Execution(format!("{e}"))) } diff --git a/src/udwf.rs b/src/udwf.rs index 86310609c..6b4f07c36 100644 --- a/src/udwf.rs +++ b/src/udwf.rs @@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator { } fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result { - println!("evaluate all called with number of values {}", values.len()); Python::attach(|py| { let py_values = PyList::new( py, diff --git a/src/utils.rs b/src/utils.rs index 3b97ffb88..4b45f29bf 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -19,11 +19,6 @@ use std::future::Future; use std::sync::{Arc, OnceLock}; use std::time::Duration; -use datafusion::arrow::array::{make_array, ArrayData, ListArray}; -use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer}; -use datafusion::arrow::datatypes::Field; -use datafusion::arrow::pyarrow::FromPyArrow; -use datafusion::common::ScalarValue; use datafusion::datasource::TableProvider; use datafusion::execution::context::SessionContext; use datafusion::logical_expr::Volatility; @@ -37,7 +32,6 @@ use tokio::runtime::Runtime; use tokio::task::JoinHandle; use tokio::time::sleep; -use crate::common::data_type::PyScalarValue; use crate::context::PySessionContext; use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::TokioRuntime; @@ -203,57 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>( } } -pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py) -> PyResult { - // convert Python object to PyScalarValue to ScalarValue - - let pa = py.import("pyarrow")?; - let scalar_attr = pa.getattr("Scalar")?; - let scalar_type = scalar_attr.downcast::()?; - let array_attr = pa.getattr("Array")?; - let array_type = array_attr.downcast::()?; - let chunked_array_attr = pa.getattr("ChunkedArray")?; - let chunked_array_type = chunked_array_attr.downcast::()?; - - let obj_ref = obj.bind(py); - - if obj_ref.is_instance(scalar_type)? { - let py_scalar = PyScalarValue::extract_bound(obj_ref) - .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?; - return Ok(py_scalar.into()); - } - - if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? { - let array_obj = if obj_ref.is_instance(chunked_array_type)? { - obj_ref.call_method0("combine_chunks")?.unbind() - } else { - obj_ref.clone().unbind() - }; - let array_bound = array_obj.bind(py); - let array_data = ArrayData::from_pyarrow_bound(array_bound) - .map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?; - let array = make_array(array_data); - let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32])); - let list_array = Arc::new(ListArray::new( - Arc::new(Field::new_list_field(array.data_type().clone(), true)), - offsets, - array, - None, - )); - - return Ok(ScalarValue::List(list_array)); - } - - // Convert Python object to PyArrow scalar - let scalar = pa.call_method1("scalar", (obj,))?; - - // Convert PyArrow scalar to PyScalarValue - let py_scalar = PyScalarValue::extract_bound(scalar.as_ref()) - .map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?; - - // Convert PyScalarValue to ScalarValue - Ok(py_scalar.into()) -} - pub(crate) fn extract_logical_extension_codec( py: Python, obj: Option>,