diff --git a/src/pathpyG/io/pandas.py b/src/pathpyG/io/pandas.py index 643769a4..2eb9c667 100644 --- a/src/pathpyG/io/pandas.py +++ b/src/pathpyG/io/pandas.py @@ -1,4 +1,4 @@ -""""Functions to read and write graphs from and to pandas DataFrames.""" +"""Functions to read and write graphs from and to pandas DataFrames.""" import ast import logging @@ -25,9 +25,10 @@ def _parse_timestamp(df: pd.DataFrame, timestamp_format: str = "%Y-%m-%d %H:%M:%S", time_rescale: int = 1) -> None: - """Parse time stamps in a DataFrame. + """Parse time stamps in a [DataFrame][pandas.DataFrame]. - Parses the time stamps in the DataFrame and rescales using the given time rescale factor. + Parses the time stamps in the [DataFrame][pandas.DataFrame] and returns them as seconds since the minimum timestamp. + The time stamps are then rescaled using the given time rescale factor. The time stamps are expected to be in a column named `t`. If the column is of type `object`, it is assumed to contain time stamps in the specified format. @@ -38,17 +39,17 @@ def _parse_timestamp(df: pd.DataFrame, timestamp_format: str = "%Y-%m-%d %H:%M:% """ # optionally parse time stamps if df["t"].dtype == "object" and isinstance(df["t"].values[0], str): - # convert time stamps to seconds since epoch df["t"] = pd.to_datetime(df["t"], format=timestamp_format) - # rescale time stamps - df["t"] = df["t"].astype("int64") // time_rescale - df["t"] = df["t"] - df["t"].min() # rescale to start at 0 + # convert time stamps to seconds since first timestamp + df["t"] = (df["t"] - df["t"].min()).dt.total_seconds().astype(int) # rescale to start at 0 + # rescale time stamps to seconds with optional rescaling + df["t"] = df["t"] // time_rescale elif df["t"].dtype == "int64" or df["t"].dtype == "float64": # rescale time stamps df["t"] = df["t"] // time_rescale elif pd.api.types.is_datetime64_any_dtype(df["t"]): - df["t"] = df["t"].astype("int64") // time_rescale - df["t"] = df["t"] - df["t"].min() # rescale to start at 0 + df["t"] = (df["t"] - df["t"].min()).dt.total_seconds().astype(int) # rescale to start at 0 + df["t"] = df["t"] // time_rescale else: raise ValueError( "Column `t` must be of type `object`, `int64`, `float64`, or a datetime type. " @@ -398,9 +399,9 @@ def df_to_temporal_graph( def graph_to_df(graph: Graph, node_indices: Optional[bool] = False) -> pd.DataFrame: """Return a [pandas.DataFrame][] for a given [graph][pathpyG.Graph]. - Contains all edges including edge attributes. Node and network-level - attributes are not included. To facilitate the import into network analysis - tools that only support integer node identifiers, node uids can be replaced + Contains all edges including edge attributes. Node and network-level + attributes are not included. To facilitate the import into network analysis + tools that only support integer node identifiers, node uids can be replaced by a consecutive, zero-based index. Args: @@ -422,7 +423,7 @@ def graph_to_df(graph: Graph, node_indices: Optional[bool] = False) -> pd.DataFr else: vs = graph.mapping.to_ids(to_numpy(graph.data.edge_index[0])) ws = graph.mapping.to_ids(to_numpy(graph.data.edge_index[1])) - df = pd.DataFrame({**{"v": vs, "w": ws}, **{a: graph.data[a].tolist() for a in graph.edge_attrs()}}) + df = pd.DataFrame({**{"v": vs, "w": ws}, **{a: to_numpy(graph.data[a]) for a in graph.edge_attrs()}}) return df @@ -430,9 +431,9 @@ def graph_to_df(graph: Graph, node_indices: Optional[bool] = False) -> pd.DataFr def temporal_graph_to_df(graph: TemporalGraph, node_indices: Optional[bool] = False) -> pd.DataFrame: """Return a [pandas.DataFrame][] for a given [temporal graph][pathpyG.TemporalGraph]. - Contains all edges including edge attributes. Node and network-level - attributes are not included. To facilitate the import into network analysis - tools that only support integer node identifiers, node uids can be replaced + Contains all edges including edge attributes. Node and network-level + attributes are not included. To facilitate the import into network analysis + tools that only support integer node identifiers, node uids can be replaced by a consecutive, zero-based index. facilitate the import into network analysis tools that only support integer diff --git a/src/pathpyG/io/sql.py b/src/pathpyG/io/sql.py new file mode 100644 index 00000000..de9293c0 --- /dev/null +++ b/src/pathpyG/io/sql.py @@ -0,0 +1,115 @@ +"""Module for database I/O operations.""" + +import logging +import sqlite3 +from pathlib import Path + +import pandas as pd + +from pathpyG.core.graph import Graph +from pathpyG.core.temporal_graph import TemporalGraph +from pathpyG.io.pandas import add_node_attributes, df_to_graph, df_to_temporal_graph, graph_to_df, temporal_graph_to_df + +logger = logging.getLogger("root") + + +def read_sql( + db_path: str, + edge_table: str = "edges", + node_table: str | None = None, + source_name: str = "source", + target_name: str = "target", + time_name: str | None = None, + node_name: str = "node_id", + timestamp_format: str = "%Y-%m-%d %H:%M:%S", + time_rescale: int = 1, +) -> Graph: + """Read a graph from an SQL database file. + + The function reads edge and node data from specified tables in the database + and constructs a [graph][pathpyG.Graph] or a [temporal graph][pathpyG.TemporalGraph] + if a time column is provided. + + By default, it looks for an "edges" table for edge data + and creates a graph from the edges with optional edge attributes corresponding to other + columns (except source, target, and time) in the table. + Additionally, if a node table is specified, node attributes are read from that table + and added to the graph. + + Args: + db_path: Path to the SQL database file. + edge_table: Name of the table containing edges and optional edge attributes. Defaults to "edges". + node_table: Name of the table containing nodes and optional node attributes. If None, nodes are inferred from edges. Defaults to None. + source_name: Name of the column representing source nodes in the edge table. Defaults to "source". + target_name: Name of the column representing target nodes in the edge table. Defaults to "target". + time_name: Name of the column representing timestamps in the edge table. If None, edges are considered static. Defaults to None. + node_name: Name of the column representing node IDs in the node table. Defaults to "node_id". + timestamp_format: Format of the timestamps if time_name is provided. Defaults to "%Y-%m-%d %H:%M:%S". + time_rescale: Factor to rescale time values (e.g., to convert microseconds to seconds). Defaults to 1. + + Returns: + Graph: The [graph][pathpyG.Graph] read from the database or the [temporal graph][pathpyG.TemporalGraph] if time_name is provided. + """ + conn = sqlite3.connect(db_path) + + # Read edges + edge_query = f"SELECT * FROM {edge_table}" + edges_df = pd.read_sql_query(edge_query, conn).rename(columns={source_name: "v", target_name: "w"}) + + # Create graph + g: Graph + if time_name and time_name in edges_df.columns: + edges_df = edges_df.rename(columns={time_name: "t"}) + g = df_to_temporal_graph(df=edges_df, timestamp_format=timestamp_format, time_rescale=time_rescale) + else: + if time_name: + logger.warning(f"Column '{time_name}' not found in edge table. Reading as static graph.") + g = df_to_graph(df=edges_df) + + # Read and add node attributes if node_table is provided + if node_table: + node_query = f"SELECT * FROM {node_table}" + nodes_df = pd.read_sql_query(node_query, conn).rename(columns={node_name: "v"}) + add_node_attributes(df=nodes_df, g=g) + + conn.close() + + return g + + +def write_sql(g: Graph, db_path: str | Path, edge_table: str = "edges", node_table: str = "nodes") -> None: + """Write a graph to an SQL database file. + + The function writes edge and node data from a [graph][pathpyG.Graph] or a [temporal graph][pathpyG.TemporalGraph] + to specified tables in the database. By default, it writes to "edges" and "nodes" tables, + storing edges and edge attributes, as well as nodes and node attributes, respectively. + For a [temporal graph][pathpyG.TemporalGraph], the time attribute is also included in the edges table. + + Args: + g: The [graph][pathpyG.Graph] to write to the database. + db_path: Path to the SQL database file. + edge_table: Name of the table to store edges and edge attributes. Defaults to "edges". + node_table: Name of the table to store nodes and node attributes. Defaults to "nodes". + """ + if isinstance(db_path, str): + db_path = Path(db_path) + if db_path.exists(): + logger.warning(f"Database file {db_path} already exists and will be overwritten.") + + conn = sqlite3.connect(db_path) + + # Write edges + if isinstance(g, TemporalGraph): + edges_df = temporal_graph_to_df(g) + else: + edges_df = graph_to_df(g) + edges_df.to_sql(edge_table, conn, if_exists="replace", index=False) + + # Write nodes + nodes_df = pd.DataFrame({"v": list(g.nodes)}) + for attr_name in g.node_attrs(): + nodes_df[attr_name] = g.data[attr_name] + nodes_df.to_sql(node_table, conn, if_exists="replace", index=False) + + conn.commit() + conn.close() diff --git a/tests/io/test_pandas.py b/tests/io/test_pandas.py index 860da64f..a5693b63 100644 --- a/tests/io/test_pandas.py +++ b/tests/io/test_pandas.py @@ -93,9 +93,9 @@ def test_parse_timestamp_datetime64(): def test_parse_timestamp_rescale(): df = pd.DataFrame({"t": ["2023-01-01 12:00:00", "2023-01-01 13:00:00"]}) - _parse_timestamp(df, time_rescale=10**9) # convert to seconds + _parse_timestamp(df, time_rescale=60) # convert to minutes # Should be seconds since epoch - assert np.all(df["t"].diff().dropna() == 3600) + assert np.all(df["t"].diff().dropna() == 60) def test_parse_timestamp_invalid_type(): diff --git a/tests/io/test_sql.py b/tests/io/test_sql.py new file mode 100644 index 00000000..70e14201 --- /dev/null +++ b/tests/io/test_sql.py @@ -0,0 +1,219 @@ +"""Unit tests for sql.py module.""" + +import sqlite3 +from pathlib import Path +from unittest.mock import patch + +import pandas as pd + +from pathpyG.core.graph import Graph +from pathpyG.io.sql import read_sql, write_sql + + +class TestReadSql: + """Tests for read_sql function.""" + + def test_read_sql_static_graph(self, tmp_path): + """Test reading a static graph from SQL database.""" + db_path = tmp_path / "test.db" + conn = sqlite3.connect(db_path) + + # Create test edge table + edges_data = {"source": ["A", "B", "C"], "target": ["B", "C", "A"], "weight": [1.0, 2.0, 3.0]} + edges_df = pd.DataFrame(edges_data) + edges_df.to_sql("edges", conn, if_exists="replace", index=False) + conn.close() + + # Read graph + g = read_sql(str(db_path)) + + assert isinstance(g, Graph) + assert len(g.nodes) == 3 + assert len(g.edges) == 3 + assert g.data["edge_weight"].tolist() == [1.0, 2.0, 3.0] + + def test_read_sql_temporal_graph(self, tmp_path): + """Test reading a temporal graph from SQL database.""" + db_path = tmp_path / "test.db" + conn = sqlite3.connect(db_path) + + # Create test edge table with timestamps + edges_data = { + "source": ["A", "B", "C"], + "target": ["B", "C", "A"], + "timestamp": ["2023-01-01 10:00:00", "2023-01-01 10:00:10", "2023-01-01 10:00:25"], + } + edges_df = pd.DataFrame(edges_data) + edges_df.to_sql("edges", conn, if_exists="replace", index=False) + conn.close() + + # Read temporal graph + g = read_sql(str(db_path), time_name="timestamp") + + assert len(g.nodes) == 3 + assert len(g.edges) == 3 + assert "time" in g.data + assert g.data["time"].tolist() == [0, 10, 25] + + def test_read_sql_with_node_table(self, tmp_path): + """Test reading graph with node attributes from separate node table.""" + db_path = tmp_path / "test.db" + conn = sqlite3.connect(db_path) + + # Create edge table + edges_data = {"source": ["A", "B"], "target": ["B", "C"]} + pd.DataFrame(edges_data).to_sql("edges", conn, if_exists="replace", index=False) + + # Create node table with attributes + nodes_data = { + "node_id": ["A", "B", "C"], + "label": ["Node A", "Node B", "Node C"], + "color": ["red", "blue", "green"], + } + pd.DataFrame(nodes_data).to_sql("nodes", conn, if_exists="replace", index=False) + conn.close() + + # Read graph + g = read_sql(str(db_path), node_table="nodes") + + assert g.nodes == ["A", "B", "C"] + assert g.data["node_label"].tolist() == ["Node A", "Node B", "Node C"] + assert g.data["node_color"].tolist() == ["red", "blue", "green"] + assert g.data.edge_index.shape[1] == 2 + + def test_read_sql_custom_column_names(self, tmp_path): + """Test reading graph with custom column names.""" + db_path = tmp_path / "test.db" + conn = sqlite3.connect(db_path) + + # Create edge table with custom column names + edges_data = {"from_node": ["X", "Y"], "to_node": ["Y", "Z"]} + pd.DataFrame(edges_data).to_sql("my_edges", conn, if_exists="replace", index=False) + conn.close() + + # Read graph with custom names + g = read_sql(str(db_path), edge_table="my_edges", source_name="from_node", target_name="to_node") + + assert len(g.nodes) == 3 + assert len(g.edges) == 2 + + def test_read_sql_missing_time_column(self, tmp_path): + """Test reading graph when time_name column doesn't exist.""" + db_path = tmp_path / "test.db" + conn = sqlite3.connect(db_path) + + edges_data = {"source": ["A", "B"], "target": ["B", "C"]} + pd.DataFrame(edges_data).to_sql("edges", conn, if_exists="replace", index=False) + conn.close() + + # Should return static graph even if time_name is specified but not present and log a warning + with patch("pathpyG.io.sql.logger") as mock_logger: + g = read_sql(str(db_path), time_name="nonexistent_time") + mock_logger.warning.assert_called_once() + + assert isinstance(g, Graph) + assert len(g.edges) == 2 + + +class TestWriteSql: + """Tests for write_sql function.""" + + def test_write_sql_creates_database(self, tmp_path): + """Test that write_sql creates a database file.""" + db_path = tmp_path / "output.db" + + # Create a simple graph + g = Graph.from_edge_list([("A", "B"), ("B", "C")]) + + # Write to database + write_sql(g, db_path) + + assert db_path.exists() + + def test_write_sql_creates_tables(self, tmp_path): + """Test that write_sql creates edge and node tables.""" + db_path = tmp_path / "output.db" + + g = Graph.from_edge_list([("A", "B"), ("B", "C")]) + write_sql(g, db_path) + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + + # Check tables exist + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + assert "edges" in tables + assert "nodes" in tables + conn.close() + + def test_write_sql_custom_table_names(self, tmp_path): + """Test write_sql with custom table names.""" + db_path = tmp_path / "output.db" + + g = Graph.from_edge_list([("A", "B")]) + + write_sql(g, db_path, edge_table="my_edges", node_table="my_nodes") + + conn = sqlite3.connect(db_path) + cursor = conn.cursor() + cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") + tables = [row[0] for row in cursor.fetchall()] + + assert "my_edges" in tables + assert "my_nodes" in tables + conn.close() + + def test_write_sql_overwrite_warning(self, tmp_path): + """Test that write_sql logs warning when overwriting existing database.""" + db_path = tmp_path / "output.db" + + g = Graph.from_edge_list([("A", "B")]) + + # Write once + write_sql(g, db_path) + + # Write again and check for warning + with patch("pathpyG.io.sql.logger") as mock_logger: + write_sql(g, db_path) + mock_logger.warning.assert_called_once() + + def test_write_read_roundtrip(self, tmp_path): + """Test that graph can be written and read back.""" + db_path = tmp_path / "roundtrip.db" + + # Create and write graph + g_original = Graph.from_edge_list([("A", "B"), ("B", "C")]) + g_original.data["edge_weight"] = [1.0, 2.0] + g_original.data["node_label"] = ["Node A", "Node B", "Node C"] + + write_sql(g_original, db_path) + + # Read back + g_read = read_sql(str(db_path), node_table="nodes") + + assert len(g_read.nodes) == len(g_original.nodes) + assert len(g_read.edges) == len(g_original.edges) + assert g_read.data["edge_weight"].tolist() == g_original.data["edge_weight"] + assert g_read.data["node_label"].tolist() == g_original.data["node_label"] + + +class TestReadWriteIntegration: + """Integration tests for read and write operations.""" + + def test_connection_closed_after_read(self, tmp_path): + """Test that database connection is properly closed after read.""" + db_path = tmp_path / "test.db" + conn = sqlite3.connect(db_path) + + edges_data = {"source": ["A"], "target": ["B"]} + pd.DataFrame(edges_data).to_sql("edges", conn, if_exists="replace", index=False) + conn.close() + + # Read graph + read_sql(str(db_path)) + + # Should be able to delete file if connection is closed + Path(db_path).unlink() + assert not Path(db_path).exists()