diff --git a/Cargo.toml b/Cargo.toml index 012dc2a..a022415 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -21,6 +21,7 @@ thiserror = "1.0" serde = { version = "1.0.228", features = ["derive"] } serde_json = "1.0.145" rand = "0.8" +nalgebra = "0.32" [dev-dependencies] approx = "0.5" diff --git a/python/eo_processor/__init__.py b/python/eo_processor/__init__.py index 29aa3e2..249a562 100644 --- a/python/eo_processor/__init__.py +++ b/python/eo_processor/__init__.py @@ -165,13 +165,25 @@ def random_forest_predict(model_json, features): def bfast_monitor( - stack, dates, history_start_date, monitor_start_date, level + stack, + dates, + history_start_date, + monitor_start_date, + order=1, + h=0.25, + alpha=0.05, ): """ - Scaffold for the BFAST Monitor workflow. + BFAST Monitor workflow for change detection. """ return _bfast_monitor( - stack, dates, history_start_date, monitor_start_date, level + stack, + dates, + history_start_date, + monitor_start_date, + order, + h, + alpha, ) diff --git a/python/eo_processor/__init__.pyi b/python/eo_processor/__init__.pyi index 41687e4..26640ed 100644 --- a/python/eo_processor/__init__.pyi +++ b/python/eo_processor/__init__.pyi @@ -161,4 +161,15 @@ def binary_closing( input: NDArray[np.uint8], kernel_size: int = ... ) -> NDArray[np.uint8]: ... +# Workflows +def bfast_monitor( + stack: NumericArray, + dates: Sequence[int], + history_start_date: int, + monitor_start_date: int, + order: int = ..., + h: float = ..., + alpha: float = ..., +) -> NDArray[np.float64]: ... + # Raises ValueError if p < 1.0 diff --git a/src/lib.rs b/src/lib.rs index 1a1e309..9184b3f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -20,6 +20,8 @@ pub enum CoreError { InvalidArgument(String), #[error("Computation error: {0}")] ComputationError(String), + #[error("Not enough data: {0}")] + NotEnoughData(String), } impl From for PyErr { @@ -27,6 +29,7 @@ impl From for PyErr { match err { CoreError::InvalidArgument(msg) => PyValueError::new_err(msg), CoreError::ComputationError(msg) => PyValueError::new_err(msg), + CoreError::NotEnoughData(msg) => PyValueError::new_err(msg), } } } diff --git a/src/workflows.rs b/src/workflows.rs index e46ea61..251b87d 100644 --- a/src/workflows.rs +++ b/src/workflows.rs @@ -1,70 +1,172 @@ use crate::CoreError; +use nalgebra::{DMatrix, DVector}; use ndarray::{Axis, IxDyn, Zip}; use numpy::{IntoPyArray, PyArrayDyn, PyReadonlyArrayDyn}; use pyo3::prelude::*; use rayon::prelude::*; +const TWO_PI: f64 = 2.0 * std::f64::consts::PI; + // --- 1. BFAST Monitor Workflow --- -// Placeholder struct for model parameters +/// Represents the fitted harmonic model. struct HarmonicModel { - mean: f64, + coefficients: DVector, + sigma: f64, } -// Placeholder for fitting a harmonic model to the stable history period. -// In a real implementation, this would involve solving for harmonic coefficients. -fn fit_harmonic_model(y: &[f64]) -> HarmonicModel { - if y.is_empty() { - return HarmonicModel { mean: 0.0 }; +/// Constructs the design matrix for a harmonic model. +/// +/// # Arguments +/// +/// * `dates` - A slice of fractional years. +/// * `order` - The order of the harmonic model (e.g., 1 for one sine/cosine pair). +/// +/// # Returns +/// +/// A 2D array representing the design matrix `X`. +fn build_design_matrix(dates: &[f64], order: usize) -> DMatrix { + let n = dates.len(); + let num_coeffs = 2 * order + 2; // intercept, trend, and sin/cos pairs + let mut x = DMatrix::::zeros(n, num_coeffs); + + for i in 0..n { + let t = dates[i]; + x[(i, 0)] = 1.0; // Intercept + x[(i, 1)] = t; // Trend + for j in 1..=order { + let freq = TWO_PI * j as f64 * t; + x[(i, 2 * j)] = freq.cos(); + x[(i, 2 * j + 1)] = freq.sin(); + } } - let sum: f64 = y.iter().sum(); - HarmonicModel { - mean: sum / y.len() as f64, + x +} + +/// Fits a harmonic model to the stable history period using Ordinary Least Squares (OLS). +fn fit_harmonic_model(y: &[f64], dates: &[f64], order: usize) -> Result { + if y.len() < (2 * order + 2) { + return Err(CoreError::NotEnoughData( + "Not enough historical data to fit model".to_string(), + )); } + + let y_vec = DVector::from_vec(y.to_vec()); + let x = build_design_matrix(dates, order); + + let decomp = x.clone().svd(true, true); + let coeffs = decomp.solve(&y_vec, 1e-10).map_err(|e| { + CoreError::ComputationError(format!("Failed to solve OLS with nalgebra: {}", e)) + })?; + + let y_pred = &x * &coeffs; + let residuals = &y_vec - &y_pred; + let sum_sq_err = residuals.iter().map(|&r| r * r).sum::(); + let df = (y.len() - (2 * order + 2)) as f64; + if df <= 0.0 { + return Err(CoreError::ComputationError( + "Degrees of freedom is non-positive".to_string(), + )); + } + let sigma = (sum_sq_err / df).sqrt(); + + Ok(HarmonicModel { + coefficients: coeffs, + sigma, + }) } -// Placeholder for predicting values based on the fitted model. -fn predict_harmonic_model(_model: &HarmonicModel, dates: &[i64]) -> Vec { - // For now, just return a constant prediction (the historical mean) - vec![_model.mean; dates.len()] +/// Predicts values for the monitoring period based on the fitted model. +fn predict_harmonic_model(model: &HarmonicModel, dates: &[f64], order: usize) -> DVector { + let x_mon = build_design_matrix(dates, order); + &x_mon * &model.coefficients } -// Placeholder for the MOSUM process to detect a break. -// Returns (break_date, magnitude) +/// Detects a break using the OLS-MOSUM process. fn detect_mosum_break( y_monitor: &[f64], - y_pred: &[f64], - monitor_dates: &[i64], - level: f64, + y_pred: &DVector, + monitor_dates: &[f64], + hist_len: usize, + sigma: f64, + h: f64, + alpha: f64, ) -> (f64, f64) { - if y_monitor.is_empty() || y_monitor.len() != y_pred.len() { + if y_monitor.is_empty() { return (0.0, 0.0); } + let n_hist = hist_len as f64; + let window_size = (h * n_hist).floor() as usize; + let residuals: Vec = y_monitor .iter() .zip(y_pred.iter()) .map(|(obs, pred)| obs - pred) .collect(); - let mean_residual: f64 = residuals.iter().sum::() / residuals.len() as f64; + let mut cusum = vec![0.0; residuals.len() + 1]; + for i in 0..residuals.len() { + cusum[i + 1] = cusum[i] + residuals[i]; + } + + // We can only start calculating MOSUM after `window_size` observations + if residuals.len() < window_size { + return (0.0, 0.0); + } + + let mosum_process: Vec = (window_size..residuals.len()) + .map(|i| cusum[i] - cusum[i - window_size]) + .collect(); + + let standardizer = sigma * n_hist.sqrt(); + let standardized_mosum: Vec = mosum_process + .iter() + .map(|&m| (m / standardizer).abs()) + .collect(); + + // Simplified critical boundary based on a lookup for alpha=0.05 and h=0.25 + // A full implementation would use a precomputed table or a more complex calculation. + let critical_value = if alpha <= 0.05 { 1.36 } else { 1.63 }; // Approximations - // Simplified break detection: if the average residual in the monitoring period - // exceeds the significance level, flag the start of the period as a break. - if mean_residual.abs() > level { - (monitor_dates[0] as f64, mean_residual.abs()) - } else { - (0.0, 0.0) // No break detected + for (i, &mosum_val) in standardized_mosum.iter().enumerate() { + // The index k starts from 1 for the monitoring period + let k = (i + 1) as f64; + let boundary = critical_value * (1.0 + k / n_hist).sqrt(); + + if mosum_val > boundary { + let break_idx = i + window_size; + let magnitude = (y_monitor[break_idx] - y_pred[break_idx]).abs(); + return (monitor_dates[break_idx], magnitude); + } } + + (0.0, 0.0) // No break detected +} + +/// Converts integer dates (YYYYMMDD) to fractional years. +fn dates_to_frac_years(dates: &[i64]) -> Vec { + dates + .iter() + .map(|&date| { + let year = (date / 10000) as f64; + let month = ((date % 10000) / 100) as f64; + let day = (date % 100) as f64; + // Simple approximation + year + (month - 1.0) / 12.0 + (day - 1.0) / 365.25 + }) + .collect() } // This is the main logic function that runs for each pixel. fn run_bfast_monitor_per_pixel( pixel_ts: &[f64], - dates: &[i64], - history_start: i64, - monitor_start: i64, - level: f64, + dates: &[f64], + history_start: f64, + monitor_start: f64, + order: usize, + h: f64, // h parameter for MOSUM window size + alpha: f64, // Significance level ) -> (f64, f64) { // 1. Find the indices for the history and monitoring periods let history_indices: Vec = dates @@ -82,32 +184,48 @@ fn run_bfast_monitor_per_pixel( .collect(); if history_indices.is_empty() || monitor_indices.is_empty() { - return (0.0, 0.0); // Not enough data + return (0.0, 0.0); } // 2. Extract the data for these periods let history_ts: Vec = history_indices.iter().map(|&i| pixel_ts[i]).collect(); + let history_dates: Vec = history_indices.iter().map(|&i| dates[i]).collect(); let monitor_ts: Vec = monitor_indices.iter().map(|&i| pixel_ts[i]).collect(); - let monitor_dates: Vec = monitor_indices.iter().map(|&i| dates[i]).collect(); + let monitor_dates: Vec = monitor_indices.iter().map(|&i| dates[i]).collect(); // 3. Fit model on the historical period - let model = fit_harmonic_model(&history_ts); + let model_result = fit_harmonic_model(&history_ts, &history_dates, order); + let model = match model_result { + Ok(m) => m, + Err(_) => return (0.0, 0.0), // Return no-break if model fails + }; // 4. Predict for the monitoring period - let predicted_ts = predict_harmonic_model(&model, &monitor_dates); + let predicted_ts = predict_harmonic_model(&model, &monitor_dates, order); // 5. Detect break using MOSUM process on residuals - detect_mosum_break(&monitor_ts, &predicted_ts, &monitor_dates, level) + detect_mosum_break( + &monitor_ts, + &predicted_ts, + &monitor_dates, + history_ts.len(), + model.sigma, + h, + alpha, + ) } #[pyfunction] +#[allow(clippy::too_many_arguments)] pub fn bfast_monitor( py: Python, stack: PyReadonlyArrayDyn, dates: Vec, history_start_date: i64, monitor_start_date: i64, - level: f64, // Significance level + order: usize, + h: f64, + alpha: f64, ) -> PyResult>> { let stack_arr = stack.as_array(); @@ -132,6 +250,11 @@ pub fn bfast_monitor( .into()); } + // Convert integer dates to fractional years for modeling + let frac_dates = dates_to_frac_years(&dates); + let history_start_frac = dates_to_frac_years(&[history_start_date])[0]; + let monitor_start_frac = dates_to_frac_years(&[monitor_start_date])[0]; + // Output channels: [break_date, magnitude] let mut out_array = ndarray::ArrayD::::zeros(IxDyn(&[2, height, width])); @@ -158,10 +281,12 @@ pub fn bfast_monitor( .par_for_each(|break_date, magnitude, pixel_ts| { let (bk_date, mag) = run_bfast_monitor_per_pixel( pixel_ts.as_slice().unwrap(), - &dates, - history_start_date, - monitor_start_date, - level, + &frac_dates, + history_start_frac, + monitor_start_frac, + order, + h, + alpha, ); *break_date = bk_date; *magnitude = mag; @@ -194,16 +319,18 @@ pub fn complex_classification( let mut out = ndarray::ArrayD::::zeros(blue_arr.raw_dim()); - out.indexed_iter_mut().par_bridge().for_each(|(idx, res)| { - let b = blue_arr[&idx]; - let g = green_arr[&idx]; - let r = red_arr[&idx]; - let n = nir_arr[&idx]; - let s1 = swir1_arr[&idx]; - let s2 = swir2_arr[&idx]; - let t = temp_arr[&idx]; - *res = classify_pixel(b, g, r, n, s1, s2, t); - }); + out.indexed_iter_mut() + .par_bridge() + .for_each(|(idx, res)| { + let b = blue_arr[&idx]; + let g = green_arr[&idx]; + let r = red_arr[&idx]; + let n = nir_arr[&idx]; + let s1 = swir1_arr[&idx]; + let s2 = swir2_arr[&idx]; + let t = temp_arr[&idx]; + *res = classify_pixel(b, g, r, n, s1, s2, t); + }); Ok(out.into_pyarray(py).to_owned()) } diff --git a/tests/test_workflows.py b/tests/test_workflows.py index 5e83990..8e976cf 100644 --- a/tests/test_workflows.py +++ b/tests/test_workflows.py @@ -1,78 +1,83 @@ import numpy as np +import pandas as pd from eo_processor import bfast_monitor, complex_classification -def test_bfast_monitor_break_detected(): +def test_bfast_monitor_logic(): """ - Test the bfast_monitor function with a synthetic time series - where a breakpoint is expected. + Test the bfast_monitor function with synthetic time series + for both break and no-break scenarios. """ - time = 100 - history_len = 50 - monitor_len = 50 + # --- 1. Generate common data --- + # Create a date range + history_dates = pd.to_datetime(pd.date_range(start="2010-01-01", end="2014-12-31", freq="16D")) + monitor_dates = pd.to_datetime(pd.date_range(start="2015-01-01", end="2017-12-31", freq="16D")) + all_dates = history_dates.union(monitor_dates) - np.random.seed(42) - noise = np.random.normal(0, 0.1, time) - # Stable history period, then a sudden drop in the monitoring period - y = ( - np.concatenate( - [np.linspace(10, 10, history_len), np.linspace(5, 5, monitor_len)] - ) - + noise - ) - - # Create a 3D stack (time, y, x) - stack = np.zeros((time, 1, 1)) - stack[:, 0, 0] = y - - # Create corresponding dates (as simple integers for this test) - dates = np.arange(time, dtype=np.int64) - history_start_date = 0 - monitor_start_date = 50 - - # Run the bfast_monitor detection - # Level is set low to ensure the break is detected - result = bfast_monitor( - stack, dates.tolist(), history_start_date, monitor_start_date, level=0.5 - ) - - # Extract the results - break_date = result[0, 0, 0] - magnitude = result[1, 0, 0] + # Convert dates to fractional years for generating the signal + time_frac = all_dates.year + all_dates.dayofyear / 365.25 - # Assert that a breakpoint was detected at the start of the monitoring period - assert break_date == monitor_start_date - assert magnitude > 0 + # Convert dates to integer format YYYYMMDD for the function input + dates_int = (all_dates.year * 10000 + all_dates.month * 100 + all_dates.day).to_numpy(dtype=np.int64) + history_start_date = 20100101 + monitor_start_date = 20150101 -def test_bfast_monitor_no_break(): - """ - Test the bfast_monitor function with a stable time series - where no breakpoint is expected. - """ - time = 100 + # Generate a base harmonic signal np.random.seed(42) - noise = np.random.normal(0, 0.1, time) - # Stable time series with no break - y = np.linspace(10, 10, time) + noise - - stack = np.zeros((time, 1, 1)) - stack[:, 0, 0] = y - dates = np.arange(time, dtype=np.int64) - history_start_date = 0 - monitor_start_date = 50 - - # Level is set high enough that noise shouldn't trigger a break - result = bfast_monitor( - stack, dates.tolist(), history_start_date, monitor_start_date, level=1.0 + noise = np.random.normal(0, 0.05, len(all_dates)) + signal = 0.5 + 0.2 * np.cos(2 * np.pi * time_frac) + 0.1 * np.sin(4 * np.pi * time_frac) + noise + + # --- 2. Test break detection scenario --- + + # Introduce a sudden drop in the monitoring period + break_signal = signal.values.copy() + monitor_start_index = len(history_dates) + break_signal[monitor_start_index:] -= 0.4 + + # Create a 3D stack (Time, Y, X) + stack_break = np.zeros((len(all_dates), 1, 1)) + stack_break[:, 0, 0] = break_signal + + # Run bfast_monitor for the break scenario + result_break = bfast_monitor( + stack_break, + dates_int.tolist(), + history_start_date=history_start_date, + monitor_start_date=monitor_start_date, + order=1, + h=0.25, + alpha=0.05, ) - break_date = result[0, 0, 0] - magnitude = result[1, 0, 0] + break_date_frac = result_break[0, 0, 0] + magnitude = result_break[1, 0, 0] + + # Assert that a breakpoint was detected near the start of the monitoring period + # The exact date depends on the MOSUM window, so we check a range + assert 2015.0 < break_date_frac < 2016.5 + assert magnitude > 0.3 # Should be around 0.4 + + # --- 3. Test no-break scenario --- + + # Use the original stable signal + stack_stable = np.zeros((len(all_dates), 1, 1)) + stack_stable[:, 0, 0] = signal + + # Run bfast_monitor for the stable scenario + result_stable = bfast_monitor( + stack_stable, + dates_int.tolist(), + history_start_date=history_start_date, + monitor_start_date=monitor_start_date, + order=1, + h=0.25, + alpha=0.05, + ) # Assert that no breakpoint was detected - assert break_date == 0.0 - assert magnitude == 0.0 + assert result_stable[0, 0, 0] == 0.0 + assert result_stable[1, 0, 0] == 0.0 def test_complex_classification(): diff --git a/tox.ini b/tox.ini index 8ca2083..4e42d16 100644 --- a/tox.ini +++ b/tox.ini @@ -12,6 +12,7 @@ deps = maturin>=1.9.6 pytest>=7.0 numpy>=1.20.0 + pandas pillow>=9.0.0 scikit-learn>=1.0 scikit-image>=0.18.0 @@ -42,6 +43,7 @@ deps = pytest>=7.0 pytest-cov>=4.0.0 numpy>=1.20.0 + pandas pillow>=9.0.0 scikit-learn>=1.0 scikit-image>=0.18.0