Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions python/datafusion/expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -562,8 +562,6 @@ def literal(value: Any) -> Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
if not isinstance(value, pa.Scalar):
value = pa.scalar(value)
return Expr(expr_internal.RawExpr.literal(value))

@staticmethod
Expand All @@ -576,7 +574,6 @@ def literal_with_metadata(value: Any, metadata: dict[str, str]) -> Expr:
"""
if isinstance(value, str):
value = pa.scalar(value, type=pa.string_view())
value = value if isinstance(value, pa.Scalar) else pa.scalar(value)

return Expr(expr_internal.RawExpr.literal_with_metadata(value, metadata))

Expand Down
30 changes: 30 additions & 0 deletions python/tests/test_expr.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from datetime import date, datetime, time, timezone
from decimal import Decimal

import arro3.core
import nanoarrow
import pyarrow as pa
import pytest
from datafusion import (
Expand Down Expand Up @@ -980,6 +982,34 @@ def test_literal_metadata(ctx):
assert expected_field.metadata == actual_field.metadata


def test_scalar_conversion() -> None:
expected_value = lit(1)
assert str(expected_value) == "Expr(Int64(1))"

# Test pyarrow imports
assert expected_value == lit(pa.scalar(1))
assert expected_value == lit(pa.scalar(1, type=pa.int32()))

# Test nanoarrow
na_scalar = nanoarrow.Array([1], nanoarrow.int32())[0]
assert expected_value == lit(na_scalar)

# Test pyo3
arro3_scalar = arro3.core.Scalar(1, type=arro3.core.DataType.int32())
assert expected_value == lit(arro3_scalar)

expected_value = lit([1, 2, 3])
assert str(expected_value) == "Expr(List([1, 2, 3]))"

assert expected_value == lit(pa.scalar([1, 2, 3]))

na_array = nanoarrow.Array([1, 2, 3], nanoarrow.int32())
assert expected_value == lit(na_array)

arro3_array = arro3.core.Array([1, 2, 3], type=arro3.core.DataType.int32())
assert expected_value == lit(arro3_array)


def test_ensure_expr():
e = col("a")
assert ensure_expr(e) is e.expr
Expand Down
6 changes: 3 additions & 3 deletions src/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ use parking_lot::RwLock;
use pyo3::prelude::*;
use pyo3::types::*;

use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionResult;
use crate::utils::py_obj_to_scalar_value;
#[pyclass(name = "Config", module = "datafusion", subclass, frozen)]
#[derive(Clone)]
pub(crate) struct PyConfig {
Expand Down Expand Up @@ -65,9 +65,9 @@ impl PyConfig {

/// Set a configuration option
pub fn set(&self, key: &str, value: Py<PyAny>, py: Python) -> PyDataFusionResult<()> {
let scalar_value = py_obj_to_scalar_value(py, value)?;
let scalar_value: PyScalarValue = value.extract(py)?;
let mut options = self.config.write();
options.set(key, scalar_value.to_string().as_str())?;
options.set(key, scalar_value.0.to_string().as_str())?;
Ok(())
}

Expand Down
9 changes: 4 additions & 5 deletions src/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,16 +48,15 @@ use pyo3::pybacked::PyBackedStr;
use pyo3::types::{PyCapsule, PyList, PyTuple, PyTupleMethods};
use pyo3::PyErr;

use crate::common::data_type::PyScalarValue;
use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::expr::sort_expr::{to_sort_expressions, PySortExpr};
use crate::expr::PyExpr;
use crate::physical_plan::PyExecutionPlan;
use crate::record_batch::{poll_next_batch, PyRecordBatchStream};
use crate::sql::logical::PyLogicalPlan;
use crate::table::{PyTable, TempViewTable};
use crate::utils::{
is_ipython_env, py_obj_to_scalar_value, spawn_future, validate_pycapsule, wait_for_future,
};
use crate::utils::{is_ipython_env, spawn_future, validate_pycapsule, wait_for_future};

/// File-level static CStr for the Arrow array stream capsule name.
static ARROW_ARRAY_STREAM_NAME: &CStr = cstr!("arrow_array_stream");
Expand Down Expand Up @@ -1191,14 +1190,14 @@ impl PyDataFrame {
columns: Option<Vec<PyBackedStr>>,
py: Python,
) -> PyDataFusionResult<Self> {
let scalar_value = py_obj_to_scalar_value(py, value)?;
let scalar_value: PyScalarValue = value.extract(py)?;

let cols = match columns {
Some(col_names) => col_names.iter().map(|c| c.to_string()).collect(),
None => Vec::new(), // Empty vector means fill null for all columns
};

let df = self.df.as_ref().clone().fill_null(scalar_value, cols)?;
let df = self.df.as_ref().clone().fill_null(scalar_value.0, cols)?;
Ok(Self::new(df))
}
}
Expand Down
117 changes: 107 additions & 10 deletions src/pyarrow_util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,30 +17,127 @@

//! Conversions between PyArrow and DataFusion types

use arrow::array::{Array, ArrayData};
use std::sync::Arc;

use arrow::array::{make_array, Array, ArrayData, ListArray};
use arrow::buffer::OffsetBuffer;
use arrow::datatypes::Field;
use arrow::pyarrow::{FromPyArrow, ToPyArrow};
use datafusion::common::exec_err;
use datafusion::scalar::ScalarValue;
use pyo3::types::{PyAnyMethods, PyList};
use pyo3::{Bound, FromPyObject, PyAny, PyResult, Python};

use crate::common::data_type::PyScalarValue;
use crate::errors::PyDataFusionError;

fn pyobj_extract_scalar_via_capsule(
value: &Bound<'_, PyAny>,
as_list_array: bool,
) -> PyResult<PyScalarValue> {
let array_data = ArrayData::from_pyarrow_bound(value)?;
let array = make_array(array_data);

if as_list_array {
let field = Arc::new(Field::new_list_field(
array.data_type().clone(),
array.nulls().is_some(),
));
let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
let list_array = ListArray::new(field, offsets, array, None);
Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))))
} else {
let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
Ok(PyScalarValue(scalar))
}
}

impl FromPyArrow for PyScalarValue {
fn from_pyarrow_bound(value: &Bound<'_, PyAny>) -> PyResult<Self> {
let py = value.py();
let typ = value.getattr("type")?;
let pyarrow_mod = py.import("pyarrow");

// construct pyarrow array from the python value and pyarrow type
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, [value])?;
let array = factory.call1((args, typ))?;
// Is it a PyArrow object?
if let Ok(pa) = pyarrow_mod.as_ref() {
let scalar_type = pa.getattr("Scalar")?;
if value.is_instance(&scalar_type)? {
let typ = value.getattr("type")?;

// convert the pyarrow array to rust array using C data interface
let array = arrow::array::make_array(ArrayData::from_pyarrow_bound(&array)?);
let scalar = ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
// construct pyarrow array from the python value and pyarrow type
let factory = py.import("pyarrow")?.getattr("array")?;
let args = PyList::new(py, [value])?;
let array = factory.call1((args, typ))?;

Ok(PyScalarValue(scalar))
return pyobj_extract_scalar_via_capsule(&array, false);
}

let array_type = pa.getattr("Array")?;
if value.is_instance(&array_type)? {
return pyobj_extract_scalar_via_capsule(value, true);
}
}

// Is it a NanoArrow scalar?
if let Ok(na) = py.import("nanoarrow") {
let type_name = value.get_type().repr()?;
if type_name.contains("nanoarrow")? && type_name.contains("Scalar")? {
return pyobj_extract_scalar_via_capsule(value, false);
}
let array_type = na.getattr("Array")?;
if value.is_instance(&array_type)? {
return pyobj_extract_scalar_via_capsule(value, true);
}
}

// Is it a arro3 scalar?
if let Ok(arro3) = py.import("arro3").and_then(|arro3| arro3.getattr("core")) {
let scalar_type = arro3.getattr("Scalar")?;
if value.is_instance(&scalar_type)? {
return pyobj_extract_scalar_via_capsule(value, false);
}
let array_type = arro3.getattr("Array")?;
if value.is_instance(&array_type)? {
return pyobj_extract_scalar_via_capsule(value, true);
}
}

// Does it have a PyCapsule interface but isn't one of our known libraries?
// If so do our "best guess". Try checking type name, and if that fails
// return a single value if the length is 1 and return a List value otherwise
if value.hasattr("__arrow_c_array__")? {
let type_name = value.get_type().repr()?;
if type_name.contains("Scalar")? {
return pyobj_extract_scalar_via_capsule(value, false);
}
if type_name.contains("Array")? {
return pyobj_extract_scalar_via_capsule(value, true);
}

let array_data = ArrayData::from_pyarrow_bound(value)?;
let array = make_array(array_data);
if array.len() == 1 {
let scalar =
ScalarValue::try_from_array(&array, 0).map_err(PyDataFusionError::from)?;
return Ok(PyScalarValue(scalar));
} else {
let field = Arc::new(Field::new_list_field(
array.data_type().clone(),
array.nulls().is_some(),
));
let offsets = OffsetBuffer::from_lengths(vec![array.len()]);
let list_array = ListArray::new(field, offsets, array, None);
return Ok(PyScalarValue(ScalarValue::List(Arc::new(list_array))));
}
}

// Last attempt - try to create a PyArrow scalar from a plain Python object
if let Ok(pa) = pyarrow_mod.as_ref() {
let scalar = pa.call_method1("scalar", (value,))?;

PyScalarValue::from_pyarrow_bound(&scalar)
} else {
exec_err!("Unable to import scalar value").map_err(PyDataFusionError::from)?
}
}
}

Expand Down
12 changes: 3 additions & 9 deletions src/udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use pyo3::types::{PyCapsule, PyTuple};
use crate::common::data_type::PyScalarValue;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionResult};
use crate::expr::PyExpr;
use crate::utils::{parse_volatility, py_obj_to_scalar_value, validate_pycapsule};
use crate::utils::{parse_volatility, validate_pycapsule};

#[derive(Debug)]
struct RustAccumulator {
Expand All @@ -52,10 +52,7 @@ impl Accumulator for RustAccumulator {
let mut scalars = Vec::new();
for item in values.try_iter()? {
let item: Bound<'_, PyAny> = item?;
let scalar = match item.extract::<PyScalarValue>() {
Ok(py_scalar) => py_scalar.0,
Err(_) => py_obj_to_scalar_value(py, item.unbind())?,
};
let scalar = item.extract::<PyScalarValue>()?.0;
scalars.push(scalar);
}
Ok(scalars)
Expand All @@ -66,10 +63,7 @@ impl Accumulator for RustAccumulator {
fn evaluate(&mut self) -> Result<ScalarValue> {
Python::attach(|py| -> PyResult<ScalarValue> {
let value = self.accum.bind(py).call_method0("evaluate")?;
match value.extract::<PyScalarValue>() {
Ok(py_scalar) => Ok(py_scalar.0),
Err(_) => py_obj_to_scalar_value(py, value.unbind()),
}
value.extract::<PyScalarValue>().map(|v| v.0)
})
.map_err(|e| DataFusionError::Execution(format!("{e}")))
}
Expand Down
1 change: 0 additions & 1 deletion src/udwf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,6 @@ impl PartitionEvaluator for RustPartitionEvaluator {
}

fn evaluate_all(&mut self, values: &[ArrayRef], num_rows: usize) -> Result<ArrayRef> {
println!("evaluate all called with number of values {}", values.len());
Python::attach(|py| {
let py_values = PyList::new(
py,
Expand Down
57 changes: 0 additions & 57 deletions src/utils.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,6 @@ use std::future::Future;
use std::sync::{Arc, OnceLock};
use std::time::Duration;

use datafusion::arrow::array::{make_array, ArrayData, ListArray};
use datafusion::arrow::buffer::{OffsetBuffer, ScalarBuffer};
use datafusion::arrow::datatypes::Field;
use datafusion::arrow::pyarrow::FromPyArrow;
use datafusion::common::ScalarValue;
use datafusion::datasource::TableProvider;
use datafusion::execution::context::SessionContext;
use datafusion::logical_expr::Volatility;
Expand All @@ -37,7 +32,6 @@ use tokio::runtime::Runtime;
use tokio::task::JoinHandle;
use tokio::time::sleep;

use crate::common::data_type::PyScalarValue;
use crate::context::PySessionContext;
use crate::errors::{py_datafusion_err, to_datafusion_err, PyDataFusionError, PyDataFusionResult};
use crate::TokioRuntime;
Expand Down Expand Up @@ -203,57 +197,6 @@ pub(crate) fn table_provider_from_pycapsule<'py>(
}
}

pub(crate) fn py_obj_to_scalar_value(py: Python, obj: Py<PyAny>) -> PyResult<ScalarValue> {
// convert Python object to PyScalarValue to ScalarValue

let pa = py.import("pyarrow")?;
let scalar_attr = pa.getattr("Scalar")?;
let scalar_type = scalar_attr.downcast::<PyType>()?;
let array_attr = pa.getattr("Array")?;
let array_type = array_attr.downcast::<PyType>()?;
let chunked_array_attr = pa.getattr("ChunkedArray")?;
let chunked_array_type = chunked_array_attr.downcast::<PyType>()?;

let obj_ref = obj.bind(py);

if obj_ref.is_instance(scalar_type)? {
let py_scalar = PyScalarValue::extract_bound(obj_ref)
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;
return Ok(py_scalar.into());
}

if obj_ref.is_instance(array_type)? || obj_ref.is_instance(chunked_array_type)? {
let array_obj = if obj_ref.is_instance(chunked_array_type)? {
obj_ref.call_method0("combine_chunks")?.unbind()
} else {
obj_ref.clone().unbind()
};
let array_bound = array_obj.bind(py);
let array_data = ArrayData::from_pyarrow_bound(array_bound)
.map_err(|e| PyValueError::new_err(format!("Failed to extract pyarrow array: {e}")))?;
let array = make_array(array_data);
let offsets = OffsetBuffer::new(ScalarBuffer::from(vec![0, array.len() as i32]));
let list_array = Arc::new(ListArray::new(
Arc::new(Field::new_list_field(array.data_type().clone(), true)),
offsets,
array,
None,
));

return Ok(ScalarValue::List(list_array));
}

// Convert Python object to PyArrow scalar
let scalar = pa.call_method1("scalar", (obj,))?;

// Convert PyArrow scalar to PyScalarValue
let py_scalar = PyScalarValue::extract_bound(scalar.as_ref())
.map_err(|e| PyValueError::new_err(format!("Failed to extract PyScalarValue: {e}")))?;

// Convert PyScalarValue to ScalarValue
Ok(py_scalar.into())
}

pub(crate) fn extract_logical_extension_codec(
py: Python,
obj: Option<Bound<PyAny>>,
Expand Down