diff --git a/Cargo.toml b/Cargo.toml index a2eef74..9a2ffd6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -20,6 +20,7 @@ ndarray-stats = "0.5" thiserror = "1.0" serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" +rand = "0.8" [dev-dependencies] approx = "0.5" diff --git a/python/eo_processor/__init__.py b/python/eo_processor/__init__.py index 9d9fc3e..14cbe89 100644 --- a/python/eo_processor/__init__.py +++ b/python/eo_processor/__init__.py @@ -50,6 +50,7 @@ detect_breakpoints as _detect_breakpoints, complex_classification as _complex_classification, random_forest_predict as _random_forest_predict, + random_forest_train as _random_forest_train, ) from ._core import texture_entropy as _texture_entropy import logging @@ -132,9 +133,19 @@ "complex_classification", "texture_entropy", "random_forest_predict", + "random_forest_train", ] +def random_forest_train(features, labels, n_estimators=100, min_samples_split=2, max_depth=None, max_features=None): + """ + Train a random forest model. + """ + if max_features is None: + max_features = int(np.sqrt(features.shape[1])) + return _random_forest_train(features, labels, n_estimators, min_samples_split, max_depth, max_features) + + def random_forest_predict(model_json, features): """ Predict using a random forest model. diff --git a/src/classification.rs b/src/classification.rs index 628454e..5c61421 100644 --- a/src/classification.rs +++ b/src/classification.rs @@ -1,7 +1,11 @@ +use rand::seq::SliceRandom; +use rand::thread_rng; use rayon::prelude::*; use serde::{Deserialize, Serialize}; use std::collections::HashMap; +type SplitData = (Vec>, Vec, Vec>, Vec); + #[derive(Serialize, Deserialize, Debug)] pub enum DecisionNode { Leaf { @@ -17,12 +21,211 @@ pub enum DecisionNode { #[derive(Serialize, Deserialize, Debug)] pub struct DecisionTree { - root: DecisionNode, + root: Option, + max_depth: Option, + min_samples_split: i32, } impl DecisionTree { + pub fn new(max_depth: Option, min_samples_split: i32) -> Self { + DecisionTree { + root: None, + max_depth, + min_samples_split, + } + } + + pub fn fit( + &mut self, + features: &[Vec], + labels: &[f64], + n_features_to_consider: usize, + ) { + self.root = Some(self.build_tree(features, labels, 0, n_features_to_consider)); + } + + fn build_tree( + &self, + features: &[Vec], + labels: &[f64], + depth: i32, + n_features_to_consider: usize, + ) -> DecisionNode { + if let Some(max_depth) = self.max_depth { + if depth >= max_depth { + return DecisionNode::Leaf { + class_prediction: self.calculate_leaf_value(labels), + }; + } + } + + if labels.len() < self.min_samples_split as usize { + return DecisionNode::Leaf { + class_prediction: self.calculate_leaf_value(labels), + }; + } + + if labels.iter().all(|&l| l == labels[0]) { + return DecisionNode::Leaf { + class_prediction: labels[0], + }; + } + + let best_split = self.find_best_split(features, labels, n_features_to_consider); + + if let Some((feature_index, threshold)) = best_split { + let (left_features, left_labels, right_features, right_labels) = + self.split_data(features, labels, feature_index, threshold); + + if left_labels.is_empty() || right_labels.is_empty() { + return DecisionNode::Leaf { + class_prediction: self.calculate_leaf_value(labels), + }; + } + + let left_child = self.build_tree( + &left_features, + &left_labels, + depth + 1, + n_features_to_consider, + ); + let right_child = self.build_tree( + &right_features, + &right_labels, + depth + 1, + n_features_to_consider, + ); + + DecisionNode::Node { + feature_index, + threshold, + left: Box::new(left_child), + right: Box::new(right_child), + } + } else { + DecisionNode::Leaf { + class_prediction: self.calculate_leaf_value(labels), + } + } + } + + fn find_best_split( + &self, + features: &[Vec], + labels: &[f64], + n_features_to_consider: usize, + ) -> Option<(usize, f64)> { + let mut best_gain = -1.0; + let mut best_split: Option<(usize, f64)> = None; + let n_features = features[0].len(); + let current_gini = self.gini_impurity(labels); + + let mut feature_indices: Vec = (0..n_features).collect(); + feature_indices.shuffle(&mut thread_rng()); + let feature_subset = &feature_indices[..n_features_to_consider.min(n_features)]; + + for &feature_index in feature_subset { + let mut unique_thresholds = features + .iter() + .map(|row| row[feature_index]) + .collect::>(); + unique_thresholds.sort_by(|a, b| a.partial_cmp(b).unwrap()); + unique_thresholds.dedup(); + + for &threshold in &unique_thresholds { + let (left_labels, right_labels) = + self.split_labels(labels, features, feature_index, threshold); + + if left_labels.is_empty() || right_labels.is_empty() { + continue; + } + + let p_left = left_labels.len() as f64 / labels.len() as f64; + let p_right = right_labels.len() as f64 / labels.len() as f64; + let gain = current_gini + - p_left * self.gini_impurity(&left_labels) + - p_right * self.gini_impurity(&right_labels); + + if gain > best_gain { + best_gain = gain; + best_split = Some((feature_index, threshold)); + } + } + } + best_split + } + + fn split_data( + &self, + features: &[Vec], + labels: &[f64], + feature_index: usize, + threshold: f64, + ) -> SplitData { + let mut left_features = Vec::new(); + let mut left_labels = Vec::new(); + let mut right_features = Vec::new(); + let mut right_labels = Vec::new(); + + for (i, row) in features.iter().enumerate() { + if row[feature_index] <= threshold { + left_features.push(row.clone()); + left_labels.push(labels[i]); + } else { + right_features.push(row.clone()); + right_labels.push(labels[i]); + } + } + (left_features, left_labels, right_features, right_labels) + } + + fn split_labels( + &self, + labels: &[f64], + features: &[Vec], + feature_index: usize, + threshold: f64, + ) -> (Vec, Vec) { + let mut left_labels = Vec::new(); + let mut right_labels = Vec::new(); + for (i, label) in labels.iter().enumerate() { + if features[i][feature_index] <= threshold { + left_labels.push(*label); + } else { + right_labels.push(*label); + } + } + (left_labels, right_labels) + } + + fn gini_impurity(&self, labels: &[f64]) -> f64 { + let mut counts = HashMap::new(); + for &label in labels { + *counts.entry(label as i64).or_insert(0) += 1; + } + + let mut impurity = 1.0; + for &count in counts.values() { + let prob = count as f64 / labels.len() as f64; + impurity -= prob.powi(2); + } + impurity + } + + fn calculate_leaf_value(&self, labels: &[f64]) -> f64 { + let mut counts = HashMap::new(); + for &label in labels { + *counts.entry(label as i64).or_insert(0) += 1; + } + counts + .into_iter() + .max_by_key(|&(_, count)| count) + .map(|(val, _)| val as f64) + .unwrap_or(0.0) + } + pub fn predict(&self, features: &[f64]) -> f64 { - let mut current_node = &self.root; + let mut current_node = self.root.as_ref().expect("Tree is not trained yet."); loop { match current_node { DecisionNode::Leaf { class_prediction } => { @@ -45,12 +248,66 @@ impl DecisionTree { } } +use rand::Rng; + #[derive(Serialize, Deserialize, Debug)] pub struct RandomForest { trees: Vec, + n_estimators: i32, + max_depth: Option, + min_samples_split: i32, + max_features: Option, } impl RandomForest { + pub fn new( + n_estimators: i32, + max_depth: Option, + min_samples_split: i32, + max_features: Option, + ) -> Self { + RandomForest { + trees: Vec::with_capacity(n_estimators as usize), + n_estimators, + max_depth, + min_samples_split, + max_features, + } + } + + pub fn fit(&mut self, features: &[Vec], labels: &[f64]) { + let n_samples = features.len(); + let n_features = features[0].len(); + let n_features_to_consider = self + .max_features + .unwrap_or_else(|| (n_features as f64).sqrt() as usize); + + self.trees = (0..self.n_estimators) + .into_par_iter() + .map(|_| { + let mut rng = rand::thread_rng(); + let sample_indices: Vec = (0..n_samples) + .map(|_| rng.gen_range(0..n_samples)) + .collect(); + + let bootstrapped_features: Vec> = sample_indices + .iter() + .map(|&i| features[i].clone()) + .collect(); + let bootstrapped_labels: Vec = + sample_indices.iter().map(|&i| labels[i]).collect(); + + let mut tree = DecisionTree::new(self.max_depth, self.min_samples_split); + tree.fit( + &bootstrapped_features, + &bootstrapped_labels, + n_features_to_consider, + ); + tree + }) + .collect(); + } + pub fn predict(&self, features: &[f64]) -> Option { if self.trees.is_empty() { return None; @@ -74,17 +331,48 @@ impl RandomForest { } } -use numpy::{PyReadonlyArray2, PyArray1}; +use numpy::{PyArray1, PyReadonlyArray1, PyReadonlyArray2}; use pyo3::prelude::*; +#[pyfunction] +pub fn random_forest_train( + _py: Python, + features: PyReadonlyArray2, + labels: PyReadonlyArray1, + n_estimators: i32, + min_samples_split: i32, + max_depth: Option, + max_features: Option, +) -> PyResult { + let features_array = features.as_array(); + let labels_array = labels.as_array(); + + let features_vec: Vec> = features_array + .outer_iter() + .map(|row| row.to_vec()) + .collect(); + let labels_vec: Vec = labels_array.to_vec(); + + let mut forest = RandomForest::new(n_estimators, max_depth, min_samples_split, max_features); + forest.fit(&features_vec, &labels_vec); + + serde_json::to_string(&forest).map_err(|e| { + PyErr::new::(format!("Failed to serialize model: {}", e)) + }) +} + #[pyfunction] pub fn random_forest_predict<'py>( py: Python<'py>, model_json: &str, features: PyReadonlyArray2, ) -> PyResult<&'py PyArray1> { - let forest: RandomForest = serde_json::from_str(model_json) - .map_err(|e| PyErr::new::(format!("Failed to deserialize model: {}", e)))?; + let forest: RandomForest = serde_json::from_str(model_json).map_err(|e| { + PyErr::new::(format!( + "Failed to deserialize model: {}", + e + )) + })?; let features_array = features.as_array(); let n_samples = features_array.shape()[0]; diff --git a/src/lib.rs b/src/lib.rs index 9b068fb..8e35638 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -5,10 +5,10 @@ pub mod morphology; pub mod processes; pub mod spatial; pub mod temporal; +pub mod texture; pub mod trends; pub mod workflows; pub mod zonal; -pub mod texture; use pyo3::exceptions::PyValueError; use pyo3::prelude::*; @@ -103,6 +103,7 @@ fn _core(_py: Python, m: &PyModule) -> PyResult<()> { // --- Classification --- m.add_function(wrap_pyfunction!(classification::random_forest_predict, m)?)?; + m.add_function(wrap_pyfunction!(classification::random_forest_train, m)?)?; Ok(()) } diff --git a/src/workflows.rs b/src/workflows.rs index a0a14e6..9e71d5c 100644 --- a/src/workflows.rs +++ b/src/workflows.rs @@ -141,18 +141,16 @@ pub fn complex_classification( let mut out = ndarray::ArrayD::::zeros(blue_arr.raw_dim()); - out.indexed_iter_mut() - .par_bridge() - .for_each(|(idx, res)| { - let b = blue_arr[&idx]; - let g = green_arr[&idx]; - let r = red_arr[&idx]; - let n = nir_arr[&idx]; - let s1 = swir1_arr[&idx]; - let s2 = swir2_arr[&idx]; - let t = temp_arr[&idx]; - *res = classify_pixel(b, g, r, n, s1, s2, t); - }); + out.indexed_iter_mut().par_bridge().for_each(|(idx, res)| { + let b = blue_arr[&idx]; + let g = green_arr[&idx]; + let r = red_arr[&idx]; + let n = nir_arr[&idx]; + let s1 = swir1_arr[&idx]; + let s2 = swir2_arr[&idx]; + let t = temp_arr[&idx]; + *res = classify_pixel(b, g, r, n, s1, s2, t); + }); Ok(out.into_pyarray(py).to_owned()) } @@ -216,4 +214,3 @@ fn classify_pixel(b: f64, g: f64, r: f64, n: f64, s1: f64, s2: f64, t: f64) -> u UNCLASSIFIED } - diff --git a/tests/test_classification.py b/tests/test_classification.py index 21e810c..2ee2476 100644 --- a/tests/test_classification.py +++ b/tests/test_classification.py @@ -1,7 +1,9 @@ import numpy as np from sklearn.ensemble import RandomForestClassifier from sklearn.datasets import make_classification -from eo_processor import random_forest_predict +from sklearn.model_selection import train_test_split +from sklearn.metrics import accuracy_score +from eo_processor import random_forest_predict, random_forest_train from .utils import sklearn_to_json def test_random_forest_predict(): @@ -30,3 +32,36 @@ def test_random_forest_predict(): sklearn_predictions = clf.predict(X) assert np.array_equal(predictions, sklearn_predictions) + +def test_random_forest_train_and_predict(): + """Test the full train-and-predict cycle.""" + # Generate synthetic data + X, y = make_classification( + n_samples=200, + n_features=15, + n_informative=10, + n_redundant=2, + n_classes=2, + random_state=42, + shuffle=True, + ) + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert labels to float64 + y_train = y_train.astype(np.float64) + + # Train the model using the Rust implementation + model_json = random_forest_train( + X_train, + y_train, + n_estimators=100, + min_samples_split=3, + max_depth=10, + ) + + # Perform inference + predictions = random_forest_predict(model_json, X_test) + + # Check accuracy + accuracy = accuracy_score(y_test, predictions) + assert accuracy >= 0.75, f"Accuracy of {accuracy:.2f} is below the threshold of 0.75" diff --git a/tests/utils.py b/tests/utils.py index f478822..b30fab3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -23,6 +23,16 @@ def build_tree(node_id): } } - trees.append({"root": build_tree(0)}) + trees.append({ + "root": build_tree(0), + "max_depth": model.max_depth, + "min_samples_split": model.min_samples_split, + }) - return json.dumps({"trees": trees}) + return json.dumps({ + "trees": trees, + "n_estimators": model.n_estimators, + "max_depth": model.max_depth, + "min_samples_split": model.min_samples_split, + "max_features": model.max_features if isinstance(model.max_features, int) else None, + })