diff --git a/docs/source/user-guide/io/csv.rst b/docs/source/user-guide/io/csv.rst index 144b6615c..9c23c291b 100644 --- a/docs/source/user-guide/io/csv.rst +++ b/docs/source/user-guide/io/csv.rst @@ -36,3 +36,25 @@ An alternative is to use :py:func:`~datafusion.context.SessionContext.register_c ctx.register_csv("file", "file.csv") df = ctx.table("file") + +If you require additional control over how to read the CSV file, you can use +:py:class:`~datafusion.options.CsvReadOptions` to set a variety of options. + +.. code-block:: python + + from datafusion import CsvReadOptions + options = ( + CsvReadOptions() + .with_has_header(True) # File contains a header row + .with_delimiter(";") # Use ; as the delimiter instead of , + .with_comment("#") # Skip lines starting with # + .with_escape("\\") # Escape character + .with_null_regex(r"^(null|NULL|N/A)$") # Treat these as NULL + .with_truncated_rows(True) # Allow rows to have incomplete columns + .with_file_compression_type("gzip") # Read gzipped CSV + .with_file_extension(".gz") # File extension other than .csv + ) + df = ctx.read_csv("data.csv.gz", options=options) + +Details for all CSV reading options can be found on the +`DataFusion documentation site `_. diff --git a/examples/csv-read-options.py b/examples/csv-read-options.py new file mode 100644 index 000000000..a5952d950 --- /dev/null +++ b/examples/csv-read-options.py @@ -0,0 +1,96 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Example demonstrating CsvReadOptions usage.""" + +from datafusion import CsvReadOptions, SessionContext + +# Create a SessionContext +ctx = SessionContext() + +# Example 1: Using CsvReadOptions with default values +print("Example 1: Default CsvReadOptions") +options = CsvReadOptions() +df = ctx.read_csv("data.csv", options=options) + +# Example 2: Using CsvReadOptions with custom parameters +print("\nExample 2: Custom CsvReadOptions") +options = CsvReadOptions( + has_header=True, + delimiter=",", + quote='"', + schema_infer_max_records=1000, + file_extension=".csv", +) +df = ctx.read_csv("data.csv", options=options) + +# Example 3: Using the builder pattern (recommended for readability) +print("\nExample 3: Builder pattern") +options = ( + CsvReadOptions() + .with_has_header(True) # noqa: FBT003 + .with_delimiter("|") + .with_quote("'") + .with_schema_infer_max_records(500) + .with_truncated_rows(False) # noqa: FBT003 + .with_newlines_in_values(True) # noqa: FBT003 +) +df = ctx.read_csv("data.csv", options=options) + +# Example 4: Advanced options +print("\nExample 4: Advanced options") +options = ( + CsvReadOptions() + .with_has_header(True) # noqa: FBT003 + .with_delimiter(",") + .with_comment("#") # Skip lines starting with # + .with_escape("\\") # Escape character + .with_null_regex(r"^(null|NULL|N/A)$") # Treat these as NULL + .with_truncated_rows(True) # noqa: FBT003 + .with_file_compression_type("gzip") # Read gzipped CSV + .with_file_extension(".gz") +) +df = ctx.read_csv("data.csv.gz", options=options) + +# Example 5: Register CSV table with options +print("\nExample 5: Register CSV table") +options = CsvReadOptions().with_has_header(True).with_delimiter(",") # noqa: FBT003 +ctx.register_csv("my_table", "data.csv", options=options) +df = ctx.sql("SELECT * FROM my_table") + +# Example 6: Backward compatibility (without options) +print("\nExample 6: Backward compatibility") +# Still works the old way! +df = ctx.read_csv("data.csv", has_header=True, delimiter=",") + +print("\nAll examples completed!") +print("\nFor all available options, see the CsvReadOptions documentation:") +print(" - has_header: bool") +print(" - delimiter: str") +print(" - quote: str") +print(" - terminator: str | None") +print(" - escape: str | None") +print(" - comment: str | None") +print(" - newlines_in_values: bool") +print(" - schema: pa.Schema | None") +print(" - schema_infer_max_records: int") +print(" - file_extension: str") +print(" - table_partition_cols: list[tuple[str, pa.DataType]]") +print(" - file_compression_type: str") +print(" - file_sort_order: list[list[SortExpr]]") +print(" - null_regex: str | None") +print(" - truncated_rows: bool") diff --git a/python/datafusion/__init__.py b/python/datafusion/__init__.py index 784d4ccc6..2e6f81166 100644 --- a/python/datafusion/__init__.py +++ b/python/datafusion/__init__.py @@ -54,6 +54,7 @@ from .dataframe_formatter import configure_formatter from .expr import Expr, WindowFrame from .io import read_avro, read_csv, read_json, read_parquet +from .options import CsvReadOptions from .plan import ExecutionPlan, LogicalPlan from .record_batch import RecordBatch, RecordBatchStream from .user_defined import ( @@ -75,6 +76,7 @@ "AggregateUDF", "Catalog", "Config", + "CsvReadOptions", "DFSchema", "DataFrame", "DataFrameWriteOptions", @@ -106,6 +108,7 @@ "lit", "literal", "object_store", + "options", "read_avro", "read_csv", "read_json", diff --git a/python/datafusion/context.py b/python/datafusion/context.py index be647feff..7b92c082b 100644 --- a/python/datafusion/context.py +++ b/python/datafusion/context.py @@ -34,6 +34,11 @@ from datafusion.catalog import Catalog from datafusion.dataframe import DataFrame from datafusion.expr import sort_list_to_raw_sort_list +from datafusion.options import ( + DEFAULT_MAX_INFER_SCHEMA, + CsvReadOptions, + _convert_table_partition_cols, +) from datafusion.record_batch import RecordBatchStream from ._internal import RuntimeEnvBuilder as RuntimeEnvBuilderInternal @@ -584,7 +589,7 @@ def register_listing_table( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_listing_table( name, str(path), @@ -905,7 +910,7 @@ def register_parquet( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_parquet( name, str(path), @@ -924,9 +929,10 @@ def register_csv( schema: pa.Schema | None = None, has_header: bool = True, delimiter: str = ",", - schema_infer_max_records: int = 1000, + schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA, file_extension: str = ".csv", file_compression_type: str | None = None, + options: CsvReadOptions | None = None, ) -> None: """Register a CSV file as a table. @@ -946,18 +952,46 @@ def register_csv( file_extension: File extension; only files with this extension are selected for data input. file_compression_type: File compression type. + options: Set advanced options for CSV reading. This cannot be + combined with any of the other options in this method. """ - path = [str(p) for p in path] if isinstance(path, list) else str(path) + path_arg = [str(p) for p in path] if isinstance(path, list) else str(path) + + if options is not None and ( + schema is not None + or not has_header + or delimiter != "," + or schema_infer_max_records != DEFAULT_MAX_INFER_SCHEMA + or file_extension != ".csv" + or file_compression_type is not None + ): + message = ( + "Combining CsvReadOptions parameter with additional options " + "is not supported. Use CsvReadOptions to set parameters." + ) + warnings.warn( + message, + category=UserWarning, + stacklevel=2, + ) + + options = ( + options + if options is not None + else CsvReadOptions( + schema=schema, + has_header=has_header, + delimiter=delimiter, + schema_infer_max_records=schema_infer_max_records, + file_extension=file_extension, + file_compression_type=file_compression_type, + ) + ) self.ctx.register_csv( name, - path, - schema, - has_header, - delimiter, - schema_infer_max_records, - file_extension, - file_compression_type, + path_arg, + options.to_inner(), ) def register_json( @@ -988,7 +1022,7 @@ def register_json( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_json( name, str(path), @@ -1021,7 +1055,7 @@ def register_avro( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) self.ctx.register_avro( name, str(path), schema, file_extension, table_partition_cols ) @@ -1101,7 +1135,7 @@ def read_json( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) return DataFrame( self.ctx.read_json( str(path), @@ -1119,10 +1153,11 @@ def read_csv( schema: pa.Schema | None = None, has_header: bool = True, delimiter: str = ",", - schema_infer_max_records: int = 1000, + schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA, file_extension: str = ".csv", table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, + options: CsvReadOptions | None = None, ) -> DataFrame: """Read a CSV data source. @@ -1140,26 +1175,51 @@ def read_csv( selected for data input. table_partition_cols: Partition columns. file_compression_type: File compression type. + options: Set advanced options for CSV reading. This cannot be + combined with any of the other options in this method. Returns: DataFrame representation of the read CSV files """ - if table_partition_cols is None: - table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + path_arg = [str(p) for p in path] if isinstance(path, list) else str(path) + + if options is not None and ( + schema is not None + or not has_header + or delimiter != "," + or schema_infer_max_records != DEFAULT_MAX_INFER_SCHEMA + or file_extension != ".csv" + or table_partition_cols is not None + or file_compression_type is not None + ): + message = ( + "Combining CsvReadOptions parameter with additional options " + "is not supported. Use CsvReadOptions to set parameters." + ) + warnings.warn( + message, + category=UserWarning, + stacklevel=2, + ) - path = [str(p) for p in path] if isinstance(path, list) else str(path) + options = ( + options + if options is not None + else CsvReadOptions( + schema=schema, + has_header=has_header, + delimiter=delimiter, + schema_infer_max_records=schema_infer_max_records, + file_extension=file_extension, + table_partition_cols=table_partition_cols, + file_compression_type=file_compression_type, + ) + ) return DataFrame( self.ctx.read_csv( - path, - schema, - has_header, - delimiter, - schema_infer_max_records, - file_extension, - table_partition_cols, - file_compression_type, + path_arg, + options.to_inner(), ) ) @@ -1197,7 +1257,7 @@ def read_parquet( """ if table_partition_cols is None: table_partition_cols = [] - table_partition_cols = self._convert_table_partition_cols(table_partition_cols) + table_partition_cols = _convert_table_partition_cols(table_partition_cols) file_sort_order = self._convert_file_sort_order(file_sort_order) return DataFrame( self.ctx.read_parquet( @@ -1231,7 +1291,7 @@ def read_avro( """ if file_partition_cols is None: file_partition_cols = [] - file_partition_cols = self._convert_table_partition_cols(file_partition_cols) + file_partition_cols = _convert_table_partition_cols(file_partition_cols) return DataFrame( self.ctx.read_avro(str(path), schema, file_partition_cols, file_extension) ) diff --git a/python/datafusion/io.py b/python/datafusion/io.py index 67dbc730f..4f9c3c516 100644 --- a/python/datafusion/io.py +++ b/python/datafusion/io.py @@ -31,6 +31,8 @@ from datafusion.dataframe import DataFrame from datafusion.expr import Expr + from .options import CsvReadOptions + def read_parquet( path: str | pathlib.Path, @@ -126,6 +128,7 @@ def read_csv( file_extension: str = ".csv", table_partition_cols: list[tuple[str, str | pa.DataType]] | None = None, file_compression_type: str | None = None, + options: CsvReadOptions | None = None, ) -> DataFrame: """Read a CSV data source. @@ -147,15 +150,12 @@ def read_csv( selected for data input. table_partition_cols: Partition columns. file_compression_type: File compression type. + options: Set advanced options for CSV reading. This cannot be + combined with any of the other options in this method. Returns: DataFrame representation of the read CSV files """ - if table_partition_cols is None: - table_partition_cols = [] - - path = [str(p) for p in path] if isinstance(path, list) else str(path) - return SessionContext.global_ctx().read_csv( path, schema, @@ -165,6 +165,7 @@ def read_csv( file_extension, table_partition_cols, file_compression_type, + options, ) diff --git a/python/datafusion/options.py b/python/datafusion/options.py new file mode 100644 index 000000000..ec19f37d0 --- /dev/null +++ b/python/datafusion/options.py @@ -0,0 +1,284 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Options for reading various file formats.""" + +from __future__ import annotations + +import warnings +from typing import TYPE_CHECKING + +import pyarrow as pa + +from datafusion.expr import sort_list_to_raw_sort_list + +if TYPE_CHECKING: + from datafusion.expr import SortExpr + +from ._internal import options + +__all__ = ["CsvReadOptions"] + +DEFAULT_MAX_INFER_SCHEMA = 1000 + + +class CsvReadOptions: + """Options for reading CSV files. + + This class provides a builder pattern for configuring CSV reading options. + All methods starting with ``with_`` return ``self`` to allow method chaining. + """ + + def __init__( + self, + *, + has_header: bool = True, + delimiter: str = ",", + quote: str = '"', + terminator: str | None = None, + escape: str | None = None, + comment: str | None = None, + newlines_in_values: bool = False, + schema: pa.Schema | None = None, + schema_infer_max_records: int = DEFAULT_MAX_INFER_SCHEMA, + file_extension: str = ".csv", + table_partition_cols: list[tuple[str, pa.DataType]] | None = None, + file_compression_type: str = "", + file_sort_order: list[list[SortExpr]] | None = None, + null_regex: str | None = None, + truncated_rows: bool = False, + ) -> None: + """Initialize CsvReadOptions. + + Args: + has_header: Does the CSV file have a header row? If schema inference + is run on a file with no headers, default column names are created. + delimiter: Column delimiter character. Must be a single ASCII character. + quote: Quote character for fields containing delimiters or newlines. + Must be a single ASCII character. + terminator: Optional line terminator character. If ``None``, uses CRLF. + Must be a single ASCII character. + escape: Optional escape character for quotes. Must be a single ASCII + character. + comment: If specified, lines beginning with this character are ignored. + Must be a single ASCII character. + newlines_in_values: Whether newlines in quoted values are supported. + Parsing newlines in quoted values may be affected by execution + behavior such as parallel file scanning. Setting this to ``True`` + ensures that newlines in values are parsed successfully, which may + reduce performance. + schema: Optional PyArrow schema representing the CSV files. If ``None``, + the CSV reader will try to infer it based on data in the file. + schema_infer_max_records: Maximum number of rows to read from CSV files + for schema inference if needed. + file_extension: File extension; only files with this extension are + selected for data input. + table_partition_cols: Partition columns as a list of tuples of + (column_name, data_type). + file_compression_type: File compression type. Supported values are + ``"gzip"``, ``"bz2"``, ``"xz"``, ``"zstd"``, or empty string for + uncompressed. + file_sort_order: Optional sort order of the files as a list of sort + expressions per file. + null_regex: Optional regex pattern to match null values in the CSV. + truncated_rows: Whether to allow truncated rows when parsing. By default + this is ``False`` and will error if the CSV rows have different + lengths. When set to ``True``, it will allow records with less than + the expected number of columns and fill the missing columns with + nulls. If the record's schema is not nullable, it will still return + an error. + """ + validate_single_character("delimiter", delimiter) + validate_single_character("quote", quote) + validate_single_character("terminator", terminator) + validate_single_character("escape", escape) + validate_single_character("comment", comment) + + self.has_header = has_header + self.delimiter = delimiter + self.quote = quote + self.terminator = terminator + self.escape = escape + self.comment = comment + self.newlines_in_values = newlines_in_values + self.schema = schema + self.schema_infer_max_records = schema_infer_max_records + self.file_extension = file_extension + self.table_partition_cols = table_partition_cols or [] + self.file_compression_type = file_compression_type + self.file_sort_order = file_sort_order or [] + self.null_regex = null_regex + self.truncated_rows = truncated_rows + + def with_has_header(self, has_header: bool) -> CsvReadOptions: + """Configure whether the CSV has a header row.""" + self.has_header = has_header + return self + + def with_delimiter(self, delimiter: str) -> CsvReadOptions: + """Configure the column delimiter.""" + self.delimiter = delimiter + return self + + def with_quote(self, quote: str) -> CsvReadOptions: + """Configure the quote character.""" + self.quote = quote + return self + + def with_terminator(self, terminator: str | None) -> CsvReadOptions: + """Configure the line terminator character.""" + self.terminator = terminator + return self + + def with_escape(self, escape: str | None) -> CsvReadOptions: + """Configure the escape character.""" + self.escape = escape + return self + + def with_comment(self, comment: str | None) -> CsvReadOptions: + """Configure the comment character.""" + self.comment = comment + return self + + def with_newlines_in_values(self, newlines_in_values: bool) -> CsvReadOptions: + """Configure whether newlines in values are supported.""" + self.newlines_in_values = newlines_in_values + return self + + def with_schema(self, schema: pa.Schema | None) -> CsvReadOptions: + """Configure the schema.""" + self.schema = schema + return self + + def with_schema_infer_max_records( + self, schema_infer_max_records: int + ) -> CsvReadOptions: + """Configure maximum records for schema inference.""" + self.schema_infer_max_records = schema_infer_max_records + return self + + def with_file_extension(self, file_extension: str) -> CsvReadOptions: + """Configure the file extension filter.""" + self.file_extension = file_extension + return self + + def with_table_partition_cols( + self, table_partition_cols: list[tuple[str, pa.DataType]] + ) -> CsvReadOptions: + """Configure table partition columns.""" + self.table_partition_cols = table_partition_cols + return self + + def with_file_compression_type(self, file_compression_type: str) -> CsvReadOptions: + """Configure file compression type.""" + self.file_compression_type = file_compression_type + return self + + def with_file_sort_order( + self, file_sort_order: list[list[SortExpr]] + ) -> CsvReadOptions: + """Configure file sort order.""" + self.file_sort_order = file_sort_order + return self + + def with_null_regex(self, null_regex: str | None) -> CsvReadOptions: + """Configure null value regex pattern.""" + self.null_regex = null_regex + return self + + def with_truncated_rows(self, truncated_rows: bool) -> CsvReadOptions: + """Configure whether to allow truncated rows.""" + self.truncated_rows = truncated_rows + return self + + def to_inner(self) -> options.CsvReadOptions: + """Convert this object into the underlying Rust structure. + + This is intended for internal use only. + """ + file_sort_order = ( + [] + if self.file_sort_order is None + else [ + sort_list_to_raw_sort_list(sort_list) + for sort_list in self.file_sort_order + ] + ) + + return options.CsvReadOptions( + has_header=self.has_header, + delimiter=ord(self.delimiter[0]) if self.delimiter else ord(","), + quote=ord(self.quote[0]) if self.quote else ord('"'), + terminator=ord(self.terminator[0]) if self.terminator else None, + escape=ord(self.escape[0]) if self.escape else None, + comment=ord(self.comment[0]) if self.comment else None, + newlines_in_values=self.newlines_in_values, + schema=self.schema, + schema_infer_max_records=self.schema_infer_max_records, + file_extension=self.file_extension, + table_partition_cols=_convert_table_partition_cols( + self.table_partition_cols + ), + file_compression_type=self.file_compression_type or "", + file_sort_order=file_sort_order, + null_regex=self.null_regex, + truncated_rows=self.truncated_rows, + ) + + +def validate_single_character(name: str, value: str | None) -> None: + if value is not None and len(value) != 1: + message = f"{name} must be a single character" + raise ValueError(message) + + +def _convert_table_partition_cols( + table_partition_cols: list[tuple[str, str | pa.DataType]], +) -> list[tuple[str, pa.DataType]]: + warn = False + converted_table_partition_cols = [] + + for col, data_type in table_partition_cols: + if isinstance(data_type, str): + warn = True + if data_type == "string": + converted_data_type = pa.string() + elif data_type == "int": + converted_data_type = pa.int32() + else: + message = ( + f"Unsupported literal data type '{data_type}' for partition " + "column. Supported types are 'string' and 'int'" + ) + raise ValueError(message) + else: + converted_data_type = data_type + + converted_table_partition_cols.append((col, converted_data_type)) + + if warn: + message = ( + "using literals for table_partition_cols data types is deprecated," + "use pyarrow types instead" + ) + warnings.warn( + message, + category=DeprecationWarning, + stacklevel=2, + ) + + return converted_table_partition_cols diff --git a/python/tests/test_context.py b/python/tests/test_context.py index bd65305ed..5853f9feb 100644 --- a/python/tests/test_context.py +++ b/python/tests/test_context.py @@ -22,6 +22,7 @@ import pyarrow.dataset as ds import pytest from datafusion import ( + CsvReadOptions, DataFrame, RuntimeEnvBuilder, SessionConfig, @@ -626,6 +627,8 @@ def test_read_csv_list(ctx): def test_read_csv_compressed(ctx, tmp_path): test_data_path = pathlib.Path("testing/data/csv/aggregate_test_100.csv") + expected = ctx.read_csv(test_data_path).collect() + # File compression type gzip_path = tmp_path / "aggregate_test_100.csv.gz" @@ -636,7 +639,13 @@ def test_read_csv_compressed(ctx, tmp_path): gzipped_file.writelines(csv_file) csv_df = ctx.read_csv(gzip_path, file_extension=".gz", file_compression_type="gz") - csv_df.select(column("c1")).show() + assert csv_df.collect() == expected + + csv_df = ctx.read_csv( + gzip_path, + options=CsvReadOptions(file_extension=".gz", file_compression_type="gz"), + ) + assert csv_df.collect() == expected def test_read_parquet(ctx): @@ -710,3 +719,154 @@ def test_create_dataframe_with_global_ctx(batch): result = df.collect()[0].column(0) assert result == pa.array([4, 5, 6]) + + +def test_csv_read_options_builder_pattern(): + """Test CsvReadOptions builder pattern.""" + from datafusion import CsvReadOptions + + options = ( + CsvReadOptions() + .with_has_header(False) # noqa: FBT003 + .with_delimiter("|") + .with_quote("'") + .with_schema_infer_max_records(2000) + .with_truncated_rows(True) # noqa: FBT003 + .with_newlines_in_values(True) # noqa: FBT003 + .with_file_extension(".tsv") + ) + assert options.has_header is False + assert options.delimiter == "|" + assert options.quote == "'" + assert options.schema_infer_max_records == 2000 + assert options.truncated_rows is True + assert options.newlines_in_values is True + assert options.file_extension == ".tsv" + + +def read_csv_with_options_inner( + tmp_path: pathlib.Path, + csv_content: str, + options: CsvReadOptions, + expected: pa.RecordBatch, + as_read: bool, + global_ctx: bool, +) -> None: + from datafusion import SessionContext + + # Create a test CSV file + group_dir = tmp_path / "group=a" + group_dir.mkdir(exist_ok=True) + + csv_path = group_dir / "test.csv" + csv_path.write_text(csv_content) + + ctx = SessionContext() + + if as_read: + if global_ctx: + from datafusion.io import read_csv + + df = read_csv(str(tmp_path), options=options) + else: + df = ctx.read_csv(str(tmp_path), options=options) + else: + ctx.register_csv("test_table", str(tmp_path), options=options) + df = ctx.sql("SELECT * FROM test_table") + df.show() + + # Verify the data + result = df.collect() + assert len(result) == 1 + assert result[0] == expected + + +@pytest.mark.parametrize( + ("as_read", "global_ctx"), + [ + (True, True), + (True, False), + (False, False), + ], +) +def test_read_csv_with_options(tmp_path, as_read, global_ctx): + """Test reading CSV with CsvReadOptions.""" + + csv_content = "Alice;30;|New York; NY|\nBob;25\n#Charlie;35;Paris\nPhil;75;Detroit' MI\nKarin;50;|Stockholm\nSweden|" # noqa: E501 + + # Some of the read options are difficult to test in combination + # such as schema and schema_infer_max_records so run multiple tests + # file_sort_order doesn't impact reading, but included here to ensure + # all options parse correctly + options = CsvReadOptions( + has_header=False, + delimiter=";", + quote="|", + terminator="\n", + escape="\\", + comment="#", + newlines_in_values=True, + schema_infer_max_records=1, + null_regex="[pP]+aris", + truncated_rows=True, + file_sort_order=[[column("column_1").sort(), column("column_2")], ["column_3"]], + ) + + expected = pa.RecordBatch.from_arrays( + [ + pa.array(["Alice", "Bob", "Phil", "Karin"]), + pa.array([30, 25, 75, 50]), + pa.array(["New York; NY", None, "Detroit' MI", "Stockholm\nSweden"]), + ], + names=["column_1", "column_2", "column_3"], + ) + + read_csv_with_options_inner( + tmp_path, csv_content, options, expected, as_read, global_ctx + ) + + schema = pa.schema( + [ + pa.field("name", pa.string(), nullable=False), + pa.field("age", pa.float32(), nullable=False), + pa.field("location", pa.string(), nullable=True), + ] + ) + options.with_schema(schema) + + expected = pa.RecordBatch.from_arrays( + [ + pa.array(["Alice", "Bob", "Phil", "Karin"]), + pa.array([30.0, 25.0, 75.0, 50.0]), + pa.array(["New York; NY", None, "Detroit' MI", "Stockholm\nSweden"]), + ], + schema=schema, + ) + + read_csv_with_options_inner( + tmp_path, csv_content, options, expected, as_read, global_ctx + ) + + csv_content = "name,age\nAlice,30\nBob,25\nCharlie,35\nDiego,40\nEmily,15" + + expected = pa.RecordBatch.from_arrays( + [ + pa.array(["Alice", "Bob", "Charlie", "Diego", "Emily"]), + pa.array([30, 25, 35, 40, 15]), + pa.array(["a", "a", "a", "a", "a"]), + ], + schema=pa.schema( + [ + pa.field("name", pa.string(), nullable=True), + pa.field("age", pa.int64(), nullable=True), + pa.field("group", pa.string(), nullable=False), + ] + ), + ) + options = CsvReadOptions( + table_partition_cols=[("group", pa.string())], + ) + + read_csv_with_options_inner( + tmp_path, csv_content, options, expected, as_read, global_ctx + ) diff --git a/python/tests/test_sql.py b/python/tests/test_sql.py index 85afd021f..48c374660 100644 --- a/python/tests/test_sql.py +++ b/python/tests/test_sql.py @@ -92,7 +92,7 @@ def test_register_csv(ctx, tmp_path): result = pa.Table.from_batches(result) assert result.schema == alternative_schema - with pytest.raises(ValueError, match="Delimiter must be a single character"): + with pytest.raises(ValueError, match="delimiter must be a single character"): ctx.register_csv("csv4", path, delimiter="wrong") with pytest.raises( diff --git a/src/context.rs b/src/context.rs index 1cd04ac2f..f28c5982c 100644 --- a/src/context.rs +++ b/src/context.rs @@ -64,9 +64,9 @@ use crate::dataframe::PyDataFrame; use crate::dataset::Dataset; use crate::errors::{py_datafusion_err, PyDataFusionError, PyDataFusionResult}; use crate::expr::sort_expr::PySortExpr; +use crate::options::PyCsvReadOptions; use crate::physical_plan::PyExecutionPlan; use crate::record_batch::PyRecordBatchStream; -use crate::sql::exceptions::py_value_err; use crate::sql::logical::PyLogicalPlan; use crate::sql::util::replace_placeholders_with_strings; use crate::store::StorageContexts; @@ -724,41 +724,20 @@ impl PySessionContext { Ok(()) } - #[allow(clippy::too_many_arguments)] #[pyo3(signature = (name, path, - schema=None, - has_header=true, - delimiter=",", - schema_infer_max_records=1000, - file_extension=".csv", - file_compression_type=None))] + options=None))] pub fn register_csv( &self, name: &str, path: &Bound<'_, PyAny>, - schema: Option>, - has_header: bool, - delimiter: &str, - schema_infer_max_records: usize, - file_extension: &str, - file_compression_type: Option, + options: Option<&PyCsvReadOptions>, py: Python, ) -> PyDataFusionResult<()> { - let delimiter = delimiter.as_bytes(); - if delimiter.len() != 1 { - return Err(PyDataFusionError::PythonError(py_value_err( - "Delimiter must be a single character", - ))); - } - - let mut options = CsvReadOptions::new() - .has_header(has_header) - .delimiter(delimiter[0]) - .schema_infer_max_records(schema_infer_max_records) - .file_extension(file_extension) - .file_compression_type(parse_file_compression_type(file_compression_type)?); - options.schema = schema.as_ref().map(|x| &x.0); + let options = options + .map(|opts| opts.try_into()) + .transpose()? + .unwrap_or_default(); if path.is_instance_of::() { let paths = path.extract::>()?; @@ -978,48 +957,19 @@ impl PySessionContext { Ok(PyDataFrame::new(df)) } - #[allow(clippy::too_many_arguments)] #[pyo3(signature = ( path, - schema=None, - has_header=true, - delimiter=",", - schema_infer_max_records=1000, - file_extension=".csv", - table_partition_cols=vec![], - file_compression_type=None))] + options=None))] pub fn read_csv( &self, path: &Bound<'_, PyAny>, - schema: Option>, - has_header: bool, - delimiter: &str, - schema_infer_max_records: usize, - file_extension: &str, - table_partition_cols: Vec<(String, PyArrowType)>, - file_compression_type: Option, + options: Option<&PyCsvReadOptions>, py: Python, ) -> PyDataFusionResult { - let delimiter = delimiter.as_bytes(); - if delimiter.len() != 1 { - return Err(PyDataFusionError::PythonError(py_value_err( - "Delimiter must be a single character", - ))); - }; - - let mut options = CsvReadOptions::new() - .has_header(has_header) - .delimiter(delimiter[0]) - .schema_infer_max_records(schema_infer_max_records) - .file_extension(file_extension) - .table_partition_cols( - table_partition_cols - .into_iter() - .map(|(name, ty)| (name, ty.0)) - .collect::>(), - ) - .file_compression_type(parse_file_compression_type(file_compression_type)?); - options.schema = schema.as_ref().map(|x| &x.0); + let options = options + .map(|opts| opts.try_into()) + .transpose()? + .unwrap_or_default(); if path.is_instance_of::() { let paths = path.extract::>()?; diff --git a/src/lib.rs b/src/lib.rs index eda50fe10..081366b20 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -43,6 +43,7 @@ pub mod errors; pub mod expr; #[allow(clippy::borrow_deref_ref)] mod functions; +mod options; pub mod physical_plan; mod pyarrow_filter_expression; pub mod pyarrow_util; @@ -126,6 +127,10 @@ fn _internal(py: Python, m: Bound<'_, PyModule>) -> PyResult<()> { store::init_module(&store)?; m.add_submodule(&store)?; + let options = PyModule::new(py, "options")?; + options::init_module(&options)?; + m.add_submodule(&options)?; + // Register substrait as a submodule #[cfg(feature = "substrait")] setup_substrait_module(py, &m)?; diff --git a/src/options.rs b/src/options.rs new file mode 100644 index 000000000..a37664b2e --- /dev/null +++ b/src/options.rs @@ -0,0 +1,142 @@ +use arrow::datatypes::{DataType, Schema}; +use arrow::pyarrow::PyArrowType; +use datafusion::prelude::CsvReadOptions; +use pyo3::prelude::{PyModule, PyModuleMethods}; +use pyo3::{pyclass, pymethods, Bound, PyResult}; + +use crate::context::parse_file_compression_type; +use crate::errors::PyDataFusionError; +use crate::expr::sort_expr::PySortExpr; + +/// Options for reading CSV files +#[pyclass(name = "CsvReadOptions", module = "datafusion.options", frozen)] +pub struct PyCsvReadOptions { + pub has_header: bool, + pub delimiter: u8, + pub quote: u8, + pub terminator: Option, + pub escape: Option, + pub comment: Option, + pub newlines_in_values: bool, + pub schema: Option>, + pub schema_infer_max_records: usize, + pub file_extension: String, + pub table_partition_cols: Vec<(String, PyArrowType)>, + pub file_compression_type: String, + pub file_sort_order: Vec>, + pub null_regex: Option, + pub truncated_rows: bool, +} + +#[pymethods] +impl PyCsvReadOptions { + #[allow(clippy::too_many_arguments)] + #[pyo3(signature = ( + has_header=true, + delimiter=b',', + quote=b'"', + terminator=None, + escape=None, + comment=None, + newlines_in_values=false, + schema=None, + schema_infer_max_records=1000, + file_extension=".csv".to_string(), + table_partition_cols=vec![], + file_compression_type="".to_string(), + file_sort_order=vec![], + null_regex=None, + truncated_rows=false + ))] + #[new] + fn new( + has_header: bool, + delimiter: u8, + quote: u8, + terminator: Option, + escape: Option, + comment: Option, + newlines_in_values: bool, + schema: Option>, + schema_infer_max_records: usize, + file_extension: String, + table_partition_cols: Vec<(String, PyArrowType)>, + file_compression_type: String, + file_sort_order: Vec>, + null_regex: Option, + truncated_rows: bool, + ) -> Self { + Self { + has_header, + delimiter, + quote, + terminator, + escape, + comment, + newlines_in_values, + schema, + schema_infer_max_records, + file_extension, + table_partition_cols, + file_compression_type, + file_sort_order, + null_regex, + truncated_rows, + } + } +} + +impl<'a> TryFrom<&'a PyCsvReadOptions> for CsvReadOptions<'a> { + type Error = PyDataFusionError; + + fn try_from(value: &'a PyCsvReadOptions) -> Result, Self::Error> { + let partition_cols: Vec<(String, DataType)> = value + .table_partition_cols + .iter() + .map(|(name, dtype)| (name.clone(), dtype.0.clone())) + .collect(); + + let compression = parse_file_compression_type(Some(value.file_compression_type.clone()))?; + + let sort_order: Vec> = value + .file_sort_order + .iter() + .map(|inner| { + inner + .iter() + .map(|sort_expr| sort_expr.sort.clone()) + .collect() + }) + .collect(); + + // Explicit struct initialization to catch upstream changes + let mut options = CsvReadOptions { + has_header: value.has_header, + delimiter: value.delimiter, + quote: value.quote, + terminator: value.terminator, + escape: value.escape, + comment: value.comment, + newlines_in_values: value.newlines_in_values, + schema: None, // Will be set separately due to lifetime constraints + schema_infer_max_records: value.schema_infer_max_records, + file_extension: value.file_extension.as_str(), + table_partition_cols: partition_cols, + file_compression_type: compression, + file_sort_order: sort_order, + null_regex: value.null_regex.clone(), + truncated_rows: value.truncated_rows, + }; + + // Set schema separately to handle the lifetime + options.schema = value.schema.as_ref().map(|s| &s.0); + + Ok(options) + } +} + +pub(crate) fn init_module(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + + Ok(()) +}