From f226c87a2a3e6a87882f206a44af168d6d1e116e Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sat, 28 Mar 2026 21:29:29 -0400 Subject: [PATCH 01/16] initial pass on discrete transform --- rust/src/algorithms/transform.rs | 87 ++++++++++++++++++++++++++++++++ 1 file changed, 87 insertions(+) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index a65a844..a15efff 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -226,6 +226,93 @@ pub fn mixture_discrete( Ok(mix_rv) } +pub struct Transformation +{ + pub mapping: Box Number>, + pub support: (Number, Number), + +} + + +/// Computes a transformation of a discrete random variable +/// +/// # Arugments +/// * `random_variable`- the random variable to transform +/// * `transformation` - the transformation to apply to the random variable +/// +/// # Returns +/// * `transformed_rv` - the transformed random variable +pub fn transform_discrete ( + random_variable: &RandomVariable, + transformations: &[Transformation], +) -> Result +{ + let pdf_random_variable = random_variable.to_pdf()?; + let support = pdf_random_variable.support; + let function = pdf_random_variable.function; + + // Compute the transformed + let mut raw_transformed_support = Vec::new(); + for &s in support.iter() { + for transformation in transformations { + let mapping = &transformation.mapping; + let (trans_min, trans_max) = &transformation.support; + if s >= *trans_min && s <= *trans_max { + let raw_transformed_support_value = mapping(s); + raw_transformed_support.push(raw_transformed_support_value); + // Break to avoid inserting multiple transformed entries for + // any given support value + break + } + } + } + + // Sort the transformed support and functions + let mut raw_transformed_pairs: Vec<(Number, Number)> = raw_transformed_support + .into_iter() + .zip(function) + .collect(); + + raw_transformed_pairs.sort_by(|a, b| { + let first_value = a.0.to_f64(); + let second_value = b.0.to_f64(); + first_value.total_cmp(&second_value) + }); + + let (sorted_support, sorted_function): (Vec, Vec) = + raw_transformed_pairs.into_iter().unzip(); + + // De-duplicate the support. If a a vlue appears multiple times in the + // support, combine the probabilities + let mut transformed_function = Vec::new(); + let mut transformed_support = Vec::new(); + for (&s, &probability) in sorted_support.iter().zip(sorted_function.iter()) { + let support_index = transformed_support + .iter() + .position(|&x: &Number| x.to_f64() == s.to_f64()); + + match support_index { + Some(index) => { + transformed_function[index] += probability; + } + None => { + transformed_support.push(s); + transformed_function.push(probability); + } + } + } + + let transformed_rv = RandomVariable { + function: transformed_function, + support: transformed_support, + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + Ok(transformed_rv) +} + + #[cfg(test)] mod tests { use super::*; From 5a7f7caf430fefdf3117286dac4eb91495db5b5f Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sat, 28 Mar 2026 21:51:22 -0400 Subject: [PATCH 02/16] add validation --- rust/src/algorithms/transform.rs | 53 ++++++++++++++++++++++++++++++++ 1 file changed, 53 insertions(+) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index a15efff..3a9d445 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -251,6 +251,59 @@ pub fn transform_discrete ( let support = pdf_random_variable.support; let function = pdf_random_variable.function; + // Validate that the transformations are increasing and overlapping + for window in transformations.windows(2) { + let current_transformation = &window[0]; + let next_transformation = &window[1]; + + let (trans_min, trans_max) = ¤t_transformation.support; + let (next_trans_min, next_trans_max) = &next_transformation.support; + + if *trans_min >= *trans_max { + return Err( + "the max range of the transformation must exceeed the min range" + .to_string() + ); + } + + if *next_trans_min >= *next_trans_max { + return Err( + "the max range of the transformation must exceeed the min range" + .to_string() + ); + } + + if trans_max != next_trans_min { + return Err( + "the transformation ranges must be overlapping" + .to_string() + ); + } + } + + // Validate that the transformations cover the support + let lowest_support = support.first().expect("unable to extract lowest support"); + let lowest_transform = transformations.first() + .expect("unable to extract lowest transform") + .support.0; + if *lowest_support > lowest_transform { + return Err( + "the minimum transformation support is higher than the minimum rv support" + .to_string() + ); + } + + let highest_support = support.last().expect("unable to extract highest support"); + let highest_transform = transformations.last() + .expect("unable to extract highest transform") + .support.1; + if *highest_support > highest_transform { + return Err( + "the maxium transformation support is lower than the maximum rv support" + .to_string() + ); + } + // Compute the transformed let mut raw_transformed_support = Vec::new(); for &s in support.iter() { From c9b93e237248a45a8f3e29594236ddab41841233 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sat, 28 Mar 2026 22:01:05 -0400 Subject: [PATCH 03/16] validation and tests for tranform --- rust/src/algorithms/transform.rs | 267 +++++++++++++++++++++++++------ 1 file changed, 220 insertions(+), 47 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 3a9d445..a553291 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -226,14 +226,11 @@ pub fn mixture_discrete( Ok(mix_rv) } -pub struct Transformation -{ +pub struct Transformation { pub mapping: Box Number>, pub support: (Number, Number), - } - /// Computes a transformation of a discrete random variable /// /// # Arugments @@ -242,90 +239,141 @@ pub struct Transformation /// /// # Returns /// * `transformed_rv` - the transformed random variable -pub fn transform_discrete ( +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use applpy_rust::algorithms::transform::{transform_discrete, Transformation}; +/// use num_rational::Rational64; +/// +/// let rv = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 10)), +/// Number::Rational(Rational64::new(2, 10)), +/// Number::Rational(Rational64::new(3, 10)), +/// Number::Rational(Rational64::new(4, 10)), +/// ], +/// support: vec![ +/// Number::Integer(1), +/// Number::Integer(2), +/// Number::Integer(3), +/// Number::Integer(4), +/// ], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let transformed = transform_discrete( +/// &rv, +/// &[ +/// Transformation { +/// mapping: Box::new(|x| x * Number::Integer(2)), +/// support: (Number::Integer(1), Number::Integer(3)), +/// }, +/// Transformation { +/// mapping: Box::new(|x| x + Number::Integer(10)), +/// support: (Number::Integer(3), Number::Integer(5)), +/// }, +/// ], +/// ) +/// .unwrap(); +/// +/// assert_eq!( +/// transformed.support, +/// vec![ +/// Number::Integer(2), +/// Number::Integer(4), +/// Number::Integer(6), +/// Number::Integer(14), +/// ] +/// ); +/// assert_eq!( +/// transformed.function, +/// vec![ +/// Number::Rational(Rational64::new(1, 10)), +/// Number::Rational(Rational64::new(2, 10)), +/// Number::Rational(Rational64::new(3, 10)), +/// Number::Rational(Rational64::new(4, 10)), +/// ] +/// ); +/// ``` +pub fn transform_discrete( random_variable: &RandomVariable, transformations: &[Transformation], -) -> Result -{ +) -> Result { let pdf_random_variable = random_variable.to_pdf()?; let support = pdf_random_variable.support; let function = pdf_random_variable.function; - // Validate that the transformations are increasing and overlapping - for window in transformations.windows(2) { - let current_transformation = &window[0]; - let next_transformation = &window[1]; - - let (trans_min, trans_max) = ¤t_transformation.support; - let (next_trans_min, next_trans_max) = &next_transformation.support; + if transformations.is_empty() { + return Err("at least one transformation is required".to_string()); + } + // Validate that each transformation range is increasing. + for transformation in transformations { + let (trans_min, trans_max) = &transformation.support; if *trans_min >= *trans_max { return Err( - "the max range of the transformation must exceeed the min range" - .to_string() + "the max range of the transformation must exceeed the min range".to_string(), ); } + } - if *next_trans_min >= *next_trans_max { - return Err( - "the max range of the transformation must exceeed the min range" - .to_string() - ); - } + // Validate that the transformations are adjacent. + for window in transformations.windows(2) { + let current_transformation = &window[0]; + let next_transformation = &window[1]; + + let (_, trans_max) = ¤t_transformation.support; + let (next_trans_min, _) = &next_transformation.support; if trans_max != next_trans_min { - return Err( - "the transformation ranges must be overlapping" - .to_string() - ); + return Err("the transformation ranges must be adjacent".to_string()); } } // Validate that the transformations cover the support let lowest_support = support.first().expect("unable to extract lowest support"); - let lowest_transform = transformations.first() + let lowest_transform = transformations + .first() .expect("unable to extract lowest transform") - .support.0; - if *lowest_support > lowest_transform { + .support + .0; + if *lowest_support < lowest_transform { return Err( - "the minimum transformation support is higher than the minimum rv support" - .to_string() + "the minimum transformation support is higher than the minimum rv support".to_string(), ); } let highest_support = support.last().expect("unable to extract highest support"); - let highest_transform = transformations.last() + let highest_transform = transformations + .last() .expect("unable to extract highest transform") - .support.1; + .support + .1; if *highest_support > highest_transform { return Err( - "the maxium transformation support is lower than the maximum rv support" - .to_string() + "the maxium transformation support is lower than the maximum rv support".to_string(), ); } - // Compute the transformed - let mut raw_transformed_support = Vec::new(); - for &s in support.iter() { + // Compute the transformed support and preserve the original probabilities. + let mut raw_transformed_pairs = Vec::new(); + for (&s, &probability) in support.iter().zip(function.iter()) { for transformation in transformations { let mapping = &transformation.mapping; let (trans_min, trans_max) = &transformation.support; if s >= *trans_min && s <= *trans_max { let raw_transformed_support_value = mapping(s); - raw_transformed_support.push(raw_transformed_support_value); + raw_transformed_pairs.push((raw_transformed_support_value, probability)); // Break to avoid inserting multiple transformed entries for // any given support value - break + break; } } } - // Sort the transformed support and functions - let mut raw_transformed_pairs: Vec<(Number, Number)> = raw_transformed_support - .into_iter() - .zip(function) - .collect(); - raw_transformed_pairs.sort_by(|a, b| { let first_value = a.0.to_f64(); let second_value = b.0.to_f64(); @@ -365,7 +413,6 @@ pub fn transform_discrete ( Ok(transformed_rv) } - #[cfg(test)] mod tests { use super::*; @@ -495,4 +542,130 @@ mod tests { Err(msg) if msg == "the mix weights must sum to one" )); } + + #[test] + fn transform_discrete_uses_first_range_for_shared_boundary() { + let rv = sample_discrete_rv(); + + let transformed = transform_discrete( + &rv, + &[ + Transformation { + mapping: Box::new(|x| x * Number::Integer(2)), + support: (Number::Integer(1), Number::Integer(3)), + }, + Transformation { + mapping: Box::new(|x| x + Number::Integer(10)), + support: (Number::Integer(3), Number::Integer(5)), + }, + ], + ) + .unwrap(); + + assert_eq!( + transformed.support, + vec![ + Number::Integer(2), + Number::Integer(4), + Number::Integer(6), + Number::Integer(14), + ] + ); + assert_eq!( + transformed.function, + vec![ + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(2, 10)), + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(4, 10)), + ] + ); + assert!(matches!(transformed.functional_form, FunctionalForm::Pdf)); + assert!(matches!(transformed.domain_type, DomainType::Discrete)); + } + + #[test] + fn transform_discrete_returns_error_for_empty_transformations() { + let rv = sample_discrete_rv(); + let result = transform_discrete(&rv, &[]); + + assert!(matches!( + result, + Err(msg) if msg == "at least one transformation is required" + )); + } + + #[test] + fn transform_discrete_returns_error_for_single_invalid_range() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[Transformation { + mapping: Box::new(|x| x), + support: (Number::Integer(1), Number::Integer(1)), + }], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the max range of the transformation must exceeed the min range" + )); + } + + #[test] + fn transform_discrete_returns_error_when_ranges_are_not_adjacent() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[ + Transformation { + mapping: Box::new(|x| x), + support: (Number::Integer(1), Number::Integer(2)), + }, + Transformation { + mapping: Box::new(|x| x), + support: (Number::Integer(3), Number::Integer(5)), + }, + ], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the transformation ranges must be adjacent" + )); + } + + #[test] + fn transform_discrete_returns_error_when_transformations_do_not_cover_min_support() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[Transformation { + mapping: Box::new(|x| x), + support: (Number::Integer(2), Number::Integer(5)), + }], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the minimum transformation support is higher than the minimum rv support" + )); + } + + #[test] + fn transform_discrete_returns_error_when_transformations_do_not_cover_max_support() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[Transformation { + mapping: Box::new(|x| x), + support: (Number::Integer(1), Number::Integer(3)), + }], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the maxium transformation support is lower than the maximum rv support" + )); + } } From dd0a81b08a077b6c8369c3d9bf0c69f882742d05 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sun, 29 Mar 2026 21:36:07 -0400 Subject: [PATCH 04/16] remove box for the transformation struct --- rust/src/algorithms/transform.rs | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index a553291..1d15564 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -227,7 +227,7 @@ pub fn mixture_discrete( } pub struct Transformation { - pub mapping: Box Number>, + pub mapping: fn(Number) -> Number, pub support: (Number, Number), } @@ -268,11 +268,11 @@ pub struct Transformation { /// &rv, /// &[ /// Transformation { -/// mapping: Box::new(|x| x * Number::Integer(2)), +/// mapping: |x| x * Number::Integer(2), /// support: (Number::Integer(1), Number::Integer(3)), /// }, /// Transformation { -/// mapping: Box::new(|x| x + Number::Integer(10)), +/// mapping: |x| x + Number::Integer(10), /// support: (Number::Integer(3), Number::Integer(5)), /// }, /// ], @@ -551,11 +551,11 @@ mod tests { &rv, &[ Transformation { - mapping: Box::new(|x| x * Number::Integer(2)), + mapping: |x| x * Number::Integer(2), support: (Number::Integer(1), Number::Integer(3)), }, Transformation { - mapping: Box::new(|x| x + Number::Integer(10)), + mapping: |x| x + Number::Integer(10), support: (Number::Integer(3), Number::Integer(5)), }, ], @@ -601,7 +601,7 @@ mod tests { let result = transform_discrete( &rv, &[Transformation { - mapping: Box::new(|x| x), + mapping: |x| x, support: (Number::Integer(1), Number::Integer(1)), }], ); @@ -619,11 +619,11 @@ mod tests { &rv, &[ Transformation { - mapping: Box::new(|x| x), + mapping: |x| x, support: (Number::Integer(1), Number::Integer(2)), }, Transformation { - mapping: Box::new(|x| x), + mapping: |x| x, support: (Number::Integer(3), Number::Integer(5)), }, ], @@ -641,7 +641,7 @@ mod tests { let result = transform_discrete( &rv, &[Transformation { - mapping: Box::new(|x| x), + mapping: |x| x, support: (Number::Integer(2), Number::Integer(5)), }], ); @@ -658,7 +658,7 @@ mod tests { let result = transform_discrete( &rv, &[Transformation { - mapping: Box::new(|x| x), + mapping: |x| x, support: (Number::Integer(1), Number::Integer(3)), }], ); From a99fe94f2298a4f33cc6b7670d46a37829a24d80 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sun, 29 Mar 2026 21:48:49 -0400 Subject: [PATCH 05/16] add sub for random variable --- rust/src/algorithms/rv.rs | 21 ++++++++++++++++++++- 1 file changed, 20 insertions(+), 1 deletion(-) diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index f5a1f31..ae0287b 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use std::fmt; -use std::ops::{Add, AddAssign, Mul}; +use std::ops::{Add, AddAssign, Mul, Sub}; use num_rational::Rational64; use num_traits::cast::ToPrimitive; @@ -10,6 +10,7 @@ use crate::algorithms::algebra; use crate::algorithms::conversion; use crate::algorithms::moments; use crate::algorithms::number::Number; +use crate::algorithms::transform; #[derive(Debug, Clone, PartialEq)] pub enum FunctionalForm { @@ -79,6 +80,24 @@ impl AddAssign for RandomVariable { } } +impl Sub for RandomVariable { + type Output = Result; + + fn sub(self, rhs: Self) -> Self::Output { + let min_support = self.support.first() + .expect("failed to extract the first number"); + let max_support = self.support.last() + .expect("failed to extract the last number"); + let transformation = transform::Transformation { + mapping: |x| x * Number::Integer(-1), + support: (*min_support, *max_support), + }; + let negative_rhs = transform::transform_discrete(&rhs, &[transformation])?; + let sub_rv = algebra::convolution_discrete(&self, &negative_rhs)?; + Ok(sub_rv) + } +} + impl Mul for RandomVariable { type Output = Result; From 1c3eec1bcdb352b5a1b14abafa51f5e911bc2ea8 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sun, 29 Mar 2026 21:59:59 -0400 Subject: [PATCH 06/16] add tests for sub and sub assign --- rust/src/algorithms/rv.rs | 187 +++++++++++++++++++++++++++++++++++++- 1 file changed, 184 insertions(+), 3 deletions(-) diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index ae0287b..f78528a 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use std::fmt; -use std::ops::{Add, AddAssign, Mul, Sub}; +use std::ops::{Add, AddAssign, Mul, Sub, SubAssign}; use num_rational::Rational64; use num_traits::cast::ToPrimitive; @@ -80,13 +80,61 @@ impl AddAssign for RandomVariable { } } +/// Computes the difference of two independent discrete random variables. +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use num_rational::Rational64; +/// +/// let rv1 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(1), Number::Integer(2)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let rv2 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(2), Number::Integer(3)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let difference = (rv1 - rv2).unwrap(); +/// +/// assert_eq!( +/// difference.support, +/// vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] +/// ); +/// assert_eq!( +/// difference.function, +/// vec![ +/// Number::Rational(Rational64::new(1, 4)), +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 4)), +/// ] +/// ); +/// assert!(difference.verify_pdf(None).unwrap()); +/// ``` impl Sub for RandomVariable { type Output = Result; fn sub(self, rhs: Self) -> Self::Output { - let min_support = self.support.first() + let min_support = rhs + .support + .first() .expect("failed to extract the first number"); - let max_support = self.support.last() + let max_support = rhs + .support + .last() .expect("failed to extract the last number"); let transformation = transform::Transformation { mapping: |x| x * Number::Integer(-1), @@ -98,6 +146,61 @@ impl Sub for RandomVariable { } } +/// Updates a random variable in place with the difference of two +/// independent discrete random variables. +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use num_rational::Rational64; +/// +/// let mut rv1 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(1), Number::Integer(2)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let rv2 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(2), Number::Integer(3)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// rv1 -= rv2; +/// +/// assert_eq!( +/// rv1.support, +/// vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] +/// ); +/// assert_eq!( +/// rv1.function, +/// vec![ +/// Number::Rational(Rational64::new(1, 4)), +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 4)), +/// ] +/// ); +/// assert!(rv1.verify_pdf(None).unwrap()); +/// ``` +impl SubAssign for RandomVariable { + fn sub_assign(&mut self, rhs: Self) { + let sub_rv = self + .clone() + .sub(rhs) + .expect("failed to subtract the random variables"); + *self = sub_rv.clone(); + } +} + impl Mul for RandomVariable { type Output = Result; @@ -421,6 +524,84 @@ mod tests { use super::*; use num_rational::Rational64; + #[test] + fn sub_returns_difference_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs - rhs).unwrap(); + + assert_eq!( + result.support, + vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + + #[test] + fn sub_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs -= rhs; + + assert_eq!( + lhs.support, + vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + #[test] fn verify_pdf_returns_err_for_non_pdf_functional_form() { let rv = RandomVariable { From e8d27cbfc4a3b7fbf0195837f427a5a75cd336a9 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sun, 29 Mar 2026 22:02:34 -0400 Subject: [PATCH 07/16] split transformation support into two attributes --- rust/src/algorithms/rv.rs | 3 +- rust/src/algorithms/transform.rs | 47 +++++++++++++++++--------------- 2 files changed, 27 insertions(+), 23 deletions(-) diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index f78528a..dba9f70 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -138,7 +138,8 @@ impl Sub for RandomVariable { .expect("failed to extract the last number"); let transformation = transform::Transformation { mapping: |x| x * Number::Integer(-1), - support: (*min_support, *max_support), + min_support: *min_support, + max_support: *max_support, }; let negative_rhs = transform::transform_discrete(&rhs, &[transformation])?; let sub_rv = algebra::convolution_discrete(&self, &negative_rhs)?; diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 1d15564..391429c 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -228,7 +228,8 @@ pub fn mixture_discrete( pub struct Transformation { pub mapping: fn(Number) -> Number, - pub support: (Number, Number), + pub min_support: Number, + pub max_support: Number, } /// Computes a transformation of a discrete random variable @@ -269,11 +270,13 @@ pub struct Transformation { /// &[ /// Transformation { /// mapping: |x| x * Number::Integer(2), -/// support: (Number::Integer(1), Number::Integer(3)), +/// min_support: Number::Integer(1), +/// max_support: Number::Integer(3), /// }, /// Transformation { /// mapping: |x| x + Number::Integer(10), -/// support: (Number::Integer(3), Number::Integer(5)), +/// min_support: Number::Integer(3), +/// max_support: Number::Integer(5), /// }, /// ], /// ) @@ -312,8 +315,7 @@ pub fn transform_discrete( // Validate that each transformation range is increasing. for transformation in transformations { - let (trans_min, trans_max) = &transformation.support; - if *trans_min >= *trans_max { + if transformation.min_support >= transformation.max_support { return Err( "the max range of the transformation must exceeed the min range".to_string(), ); @@ -325,10 +327,7 @@ pub fn transform_discrete( let current_transformation = &window[0]; let next_transformation = &window[1]; - let (_, trans_max) = ¤t_transformation.support; - let (next_trans_min, _) = &next_transformation.support; - - if trans_max != next_trans_min { + if current_transformation.max_support != next_transformation.min_support { return Err("the transformation ranges must be adjacent".to_string()); } } @@ -338,8 +337,7 @@ pub fn transform_discrete( let lowest_transform = transformations .first() .expect("unable to extract lowest transform") - .support - .0; + .min_support; if *lowest_support < lowest_transform { return Err( "the minimum transformation support is higher than the minimum rv support".to_string(), @@ -350,8 +348,7 @@ pub fn transform_discrete( let highest_transform = transformations .last() .expect("unable to extract highest transform") - .support - .1; + .max_support; if *highest_support > highest_transform { return Err( "the maxium transformation support is lower than the maximum rv support".to_string(), @@ -363,8 +360,7 @@ pub fn transform_discrete( for (&s, &probability) in support.iter().zip(function.iter()) { for transformation in transformations { let mapping = &transformation.mapping; - let (trans_min, trans_max) = &transformation.support; - if s >= *trans_min && s <= *trans_max { + if s >= transformation.min_support && s <= transformation.max_support { let raw_transformed_support_value = mapping(s); raw_transformed_pairs.push((raw_transformed_support_value, probability)); // Break to avoid inserting multiple transformed entries for @@ -552,11 +548,13 @@ mod tests { &[ Transformation { mapping: |x| x * Number::Integer(2), - support: (Number::Integer(1), Number::Integer(3)), + min_support: Number::Integer(1), + max_support: Number::Integer(3), }, Transformation { mapping: |x| x + Number::Integer(10), - support: (Number::Integer(3), Number::Integer(5)), + min_support: Number::Integer(3), + max_support: Number::Integer(5), }, ], ) @@ -602,7 +600,8 @@ mod tests { &rv, &[Transformation { mapping: |x| x, - support: (Number::Integer(1), Number::Integer(1)), + min_support: Number::Integer(1), + max_support: Number::Integer(1), }], ); @@ -620,11 +619,13 @@ mod tests { &[ Transformation { mapping: |x| x, - support: (Number::Integer(1), Number::Integer(2)), + min_support: Number::Integer(1), + max_support: Number::Integer(2), }, Transformation { mapping: |x| x, - support: (Number::Integer(3), Number::Integer(5)), + min_support: Number::Integer(3), + max_support: Number::Integer(5), }, ], ); @@ -642,7 +643,8 @@ mod tests { &rv, &[Transformation { mapping: |x| x, - support: (Number::Integer(2), Number::Integer(5)), + min_support: Number::Integer(2), + max_support: Number::Integer(5), }], ); @@ -659,7 +661,8 @@ mod tests { &rv, &[Transformation { mapping: |x| x, - support: (Number::Integer(1), Number::Integer(3)), + min_support: Number::Integer(1), + max_support: Number::Integer(3), }], ); From c2e0e5ca467241f3ed98d1bf3839be188b2673ca Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sun, 29 Mar 2026 22:06:11 -0400 Subject: [PATCH 08/16] add Div for RandomVariable --- rust/src/algorithms/rv.rs | 25 ++++++++++++++++++++++++- 1 file changed, 24 insertions(+), 1 deletion(-) diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index dba9f70..e0baeef 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use std::fmt; -use std::ops::{Add, AddAssign, Mul, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign}; use num_rational::Rational64; use num_traits::cast::ToPrimitive; @@ -80,6 +80,29 @@ impl AddAssign for RandomVariable { } } +impl Div for RandomVariable { + type Output = Result; + + fn div(self, rhs: Self) -> Self::Output { + let min_support = rhs + .support + .first() + .expect("failed to extract the first number"); + let max_support = rhs + .support + .last() + .expect("failed to extract the last number"); + let transformation = transform::Transformation { + mapping: |x| Number::Integer(1) / x, + min_support: *min_support, + max_support: *max_support, + }; + let inverse_rhs = transform::transform_discrete(&rhs, &[transformation])?; + let div_rv = algebra::product_discrete(&self, &inverse_rhs)?; + Ok(div_rv) + } +} + /// Computes the difference of two independent discrete random variables. /// /// # Examples From adb65a617e064906dfd255e511aef68e239259af Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sun, 29 Mar 2026 22:10:24 -0400 Subject: [PATCH 09/16] added tests for RandomVariable operators --- rust/src/algorithms/rv.rs | 169 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 169 insertions(+) diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index e0baeef..8346ed7 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -548,6 +548,130 @@ mod tests { use super::*; use num_rational::Rational64; + #[test] + fn add_returns_sum_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs + rhs).unwrap(); + + assert_eq!( + result.support, + vec![Number::Integer(3), Number::Integer(4), Number::Integer(5)] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + + #[test] + fn add_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs += rhs; + + assert_eq!( + lhs.support, + vec![Number::Integer(3), Number::Integer(4), Number::Integer(5)] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + + #[test] + fn div_returns_quotient_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(4)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![ + Number::Rational(Rational64::new(2, 1)), + Number::Rational(Rational64::new(4, 1)), + ], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs / rhs).unwrap(); + + assert_eq!( + result.support, + vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 1)), + Number::Rational(Rational64::new(2, 1)), + ] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + #[test] fn sub_returns_difference_distribution() { let lhs = RandomVariable { @@ -626,6 +750,51 @@ mod tests { assert!(matches!(lhs.domain_type, DomainType::Discrete)); } + #[test] + fn mul_returns_product_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs * rhs).unwrap(); + + assert_eq!( + result.support, + vec![ + Number::Integer(2), + Number::Integer(3), + Number::Integer(4), + Number::Integer(6), + ] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + #[test] fn verify_pdf_returns_err_for_non_pdf_functional_form() { let rv = RandomVariable { From ab666a131219d58e52fdf068412be0dda8bb1350 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Sun, 29 Mar 2026 22:12:33 -0400 Subject: [PATCH 10/16] add MulAssign and DivAssign --- rust/src/algorithms/rv.rs | 113 +++++++++++++++++++++++++++++++++++++- 1 file changed, 112 insertions(+), 1 deletion(-) diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index 8346ed7..f6538f3 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use std::fmt; -use std::ops::{Add, AddAssign, Div, Mul, Sub, SubAssign}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use num_rational::Rational64; use num_traits::cast::ToPrimitive; @@ -103,6 +103,16 @@ impl Div for RandomVariable { } } +impl DivAssign for RandomVariable { + fn div_assign(&mut self, rhs: Self) { + let div_rv = self + .clone() + .div(rhs) + .expect("failed to divide the random variables"); + *self = div_rv; + } +} + /// Computes the difference of two independent discrete random variables. /// /// # Examples @@ -234,6 +244,16 @@ impl Mul for RandomVariable { } } +impl MulAssign for RandomVariable { + fn mul_assign(&mut self, rhs: Self) { + let product_rv = self + .clone() + .mul(rhs) + .expect("failed to multiply the random variables"); + *self = product_rv; + } +} + impl RandomVariable { pub fn verify_pdf(&self, tolerance: Option) -> Result { if self.functional_form != FunctionalForm::Pdf { @@ -672,6 +692,52 @@ mod tests { assert!(matches!(result.domain_type, DomainType::Discrete)); } + #[test] + fn div_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(4)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![ + Number::Rational(Rational64::new(2, 1)), + Number::Rational(Rational64::new(4, 1)), + ], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs /= rhs; + + assert_eq!( + lhs.support, + vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 1)), + Number::Rational(Rational64::new(2, 1)), + ] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + #[test] fn sub_returns_difference_distribution() { let lhs = RandomVariable { @@ -795,6 +861,51 @@ mod tests { assert!(matches!(result.domain_type, DomainType::Discrete)); } + #[test] + fn mul_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs *= rhs; + + assert_eq!( + lhs.support, + vec![ + Number::Integer(2), + Number::Integer(3), + Number::Integer(4), + Number::Integer(6), + ] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + #[test] fn verify_pdf_returns_err_for_non_pdf_functional_form() { let rv = RandomVariable { From 0e35e7f9c28bdde9e984b7889cb99c485e2f8daa Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Mon, 30 Mar 2026 14:39:02 -0400 Subject: [PATCH 11/16] add algebra special methods to FastRV --- rust/src/python/api.rs | 96 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 96 insertions(+) diff --git a/rust/src/python/api.rs b/rust/src/python/api.rs index 77b050d..19ed581 100644 --- a/rust/src/python/api.rs +++ b/rust/src/python/api.rs @@ -291,6 +291,54 @@ impl FastRV { } } + pub fn __radd__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let sum_rv = lhs_rv + self_rv; + + match sum_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __sub__(&self, rhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let rhs_rv = rhs.inner.clone(); + + let difference_rv = self_rv - rhs_rv; + + match difference_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __rsub__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let difference_rv = lhs_rv - self_rv; + + match difference_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + pub fn __mul__(&self, rhs: FastRV) -> PyResult { let self_rv = self.inner.clone(); let rhs_rv = rhs.inner.clone(); @@ -307,6 +355,54 @@ impl FastRV { } } + pub fn __rmul__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let product_rv = lhs_rv * self_rv; + + match product_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __truediv__(&self, rhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let rhs_rv = rhs.inner.clone(); + + let quotient_rv = self_rv / rhs_rv; + + match quotient_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __rtruediv__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let quotient_rv = lhs_rv / self_rv; + + match quotient_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + #[getter] pub fn function(&self) -> Vec { self.inner.function.clone() From 8fe37ed19d248b1c8de27e27a60f74672d3fd018 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Mon, 30 Mar 2026 14:50:45 -0400 Subject: [PATCH 12/16] use fast rv for algebra --- applpy/rv.py | 54 ++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 46 insertions(+), 8 deletions(-) diff --git a/applpy/rv.py b/applpy/rv.py index 3b4549c..1613a1a 100644 --- a/applpy/rv.py +++ b/applpy/rv.py @@ -330,13 +330,16 @@ def __add__(self, other): from .algebra import convolution from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv + other_fast_rv) + try: return convolution(self, other) except Exception: return convolution(other, self) - else: - raise RVError("Could not compute the convolution") # If the random variable is added to a constant, shift # the random variable if isinstance(other, (float, int)): @@ -357,6 +360,11 @@ def __radd__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__radd__(other_fast_rv)) return self.__add__(other) def __sub__(self, other): @@ -375,7 +383,12 @@ def __sub__(self, other): from .algebra import convolution from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv - other_fast_rv) + gX = [[-x], [-oo, oo]] random_variable = transform(other, gX) return convolution(self, random_variable) @@ -396,6 +409,12 @@ def __rsub__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__rsub__(other_fast_rv)) + # Perform an negative transformation of the random variable neg_self = -self # Add the two components @@ -417,13 +436,16 @@ def __mul__(self, other): from .algebra import product from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv * other_fast_rv) + try: return product(self, other) except Exception: return product(other, self) - else: - raise RVError("Could not compute the product") # If the random variable is multiplied by a constant, scale # the random variable if isinstance(other, (float, int)): @@ -441,6 +463,11 @@ def __rmul__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__rmul__(other_fast_rv)) return self.__mul__(other) def __truediv__(self, other): @@ -459,7 +486,12 @@ def __truediv__(self, other): from .algebra import product from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv / other_fast_rv) + gX = [[1 / x, 1 / x], [-oo, 0, oo]] random_variable = transform(other, gX) return product(self, random_variable) @@ -480,6 +512,12 @@ def __rtruediv__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__rtruediv__(other_fast_rv)) + ## Invert the random variable from .transform import transform From 52844a29bfc6a86011cccf33a6444ec73eb7b187 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Mon, 30 Mar 2026 15:10:28 -0400 Subject: [PATCH 13/16] raise python error on divide by zero --- Makefile | 2 +- rust/src/algorithms/rv.rs | 29 +++++++++++++++++++++++++++++ test_applpy/unit/test_rv.py | 2 +- 3 files changed, 31 insertions(+), 2 deletions(-) diff --git a/Makefile b/Makefile index ab373bb..fb235f9 100644 --- a/Makefile +++ b/Makefile @@ -95,7 +95,7 @@ ipython: ## Runs iPython with --no-sync to ensure rust bindings are preserved .PHONY: check check: ## Run Ruff lint checks. - uv run ruff check applpy test_applpy + uv run --no-sync ruff check applpy test_applpy .PHONY: tidy tidy: ## Run Ruff autoformatter. diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index f6538f3..543e959 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -84,6 +84,14 @@ impl Div for RandomVariable { type Output = Result; fn div(self, rhs: Self) -> Self::Output { + if rhs.support.iter().any(|value| match value { + Number::Float(x) => *x == 0.0, + Number::Integer(x) => *x == 0, + Number::Rational(x) => *x.numer() == 0, + }) { + return Err("cannot divide by a random variable with zero in its support".to_string()); + } + let min_support = rhs .support .first() @@ -738,6 +746,27 @@ mod tests { assert!(matches!(lhs.domain_type, DomainType::Discrete)); } + #[test] + fn div_returns_error_when_rhs_support_contains_zero() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(0), Number::Integer(1)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = lhs.clone(); + + let err = (lhs / rhs).unwrap_err(); + + assert_eq!( + err, + "cannot divide by a random variable with zero in its support" + ); + } + #[test] fn sub_returns_difference_distribution() { let lhs = RandomVariable { diff --git a/test_applpy/unit/test_rv.py b/test_applpy/unit/test_rv.py index 07fa139..42a59b2 100644 --- a/test_applpy/unit/test_rv.py +++ b/test_applpy/unit/test_rv.py @@ -151,7 +151,7 @@ def test_operator_overloads_for_scalars_and_rvs(): with pytest.raises(NotImplementedError): abs(rv) - with pytest.raises(TypeError): + with pytest.raises(ValueError): discrete / discrete with pytest.raises(RVError, match="integer value"): From 873492cc7ce68d0d9ce69dceb4b99686e3273c91 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Mon, 30 Mar 2026 15:25:22 -0400 Subject: [PATCH 14/16] refactor sorting into a shared function --- rust/src/algorithms/algebra.rs | 28 +++------------------------- rust/src/algorithms/mod.rs | 1 + rust/src/algorithms/shared.rs | 25 +++++++++++++++++++++++++ rust/src/algorithms/transform.rs | 15 +++------------ 4 files changed, 32 insertions(+), 37 deletions(-) create mode 100644 rust/src/algorithms/shared.rs diff --git a/rust/src/algorithms/algebra.rs b/rust/src/algorithms/algebra.rs index e2edf94..18cae92 100644 --- a/rust/src/algorithms/algebra.rs +++ b/rust/src/algorithms/algebra.rs @@ -2,6 +2,7 @@ use crate::algorithms::number::Number; use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +use crate::algorithms::shared; /// Computes the product of two independent discrete random variables /// @@ -86,19 +87,8 @@ pub fn product_discrete( } } - // Sort the multiplied support and function values - let mut raw_product_pairs: Vec<_> = raw_product_support - .into_iter() - .zip(raw_product_function) - .collect(); - raw_product_pairs.sort_by(|a, b| { - let first_value = a.0.to_f64(); - let second_value = b.0.to_f64(); - first_value.total_cmp(&second_value) - }); - let (sorted_support, sorted_function): (Vec, Vec) = - raw_product_pairs.into_iter().unzip(); + shared::sort_by_support(raw_product_support, raw_product_function)?; // De-duplicate the support. If a value appears multiple times in the // support, combine the probabilities @@ -211,20 +201,8 @@ pub fn convolution_discrete( } } - // Sorts the results by the support values - let mut raw_conv_pairs: Vec<_> = raw_conv_support - .into_iter() - .zip(raw_conv_function) - .collect(); - - raw_conv_pairs.sort_by(|a, b| { - let first_value = a.0.to_f64(); - let second_value = b.0.to_f64(); - first_value.total_cmp(&second_value) - }); - let (sorted_support, sorted_function): (Vec, Vec) = - raw_conv_pairs.into_iter().unzip(); + shared::sort_by_support(raw_conv_support, raw_conv_function)?; // Remove redundant elements from the support let mut conv_support = Vec::new(); diff --git a/rust/src/algorithms/mod.rs b/rust/src/algorithms/mod.rs index cffed21..dece185 100644 --- a/rust/src/algorithms/mod.rs +++ b/rust/src/algorithms/mod.rs @@ -4,4 +4,5 @@ pub mod moments; pub mod number; pub mod order_stat; pub mod rv; +pub mod shared; pub mod transform; diff --git a/rust/src/algorithms/shared.rs b/rust/src/algorithms/shared.rs new file mode 100644 index 0000000..00b4efc --- /dev/null +++ b/rust/src/algorithms/shared.rs @@ -0,0 +1,25 @@ +#![allow(dead_code)] + +use crate::algorithms::number::Number; + +/// Sorts the support of a random variable, while keeping the function aligned +/// with the support +pub fn sort_by_support( + support: Vec, + function: Vec, +) -> Result<(Vec, Vec), String> { + if support.len() != function.len() { + return Err("support and function must be the same length".to_string()); + } + + let mut zipped_pairs: Vec<_> = support.into_iter().zip(function).collect(); + + zipped_pairs.sort_by(|a, b| { + let first_value = a.0.to_f64(); + let second_value = b.0.to_f64(); + first_value.total_cmp(&second_value) + }); + + let (sorted_support, sorted_function) = zipped_pairs.into_iter().unzip(); + Ok((sorted_support, sorted_function)) +} diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index 391429c..d5a6c08 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -2,6 +2,7 @@ use crate::algorithms::number::Number; use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +use crate::algorithms::shared; /// Truncates a discrete random variable by cutting off a portion of the support /// and normalizing total probability of the distribution to 1. @@ -204,18 +205,8 @@ pub fn mixture_discrete( } } - let mut raw_mixture_pair: Vec<_> = raw_mixture_support - .into_iter() - .zip(raw_mixture_function) - .collect(); - - raw_mixture_pair.sort_by(|a, b| { - let first_value = a.0.to_f64(); - let second_value = b.0.to_f64(); - first_value.total_cmp(&second_value) - }); - - let (mixture_support, mixture_function) = raw_mixture_pair.into_iter().unzip(); + let (mixture_support, mixture_function) = + shared::sort_by_support(raw_mixture_support, raw_mixture_function)?; let mix_rv = RandomVariable { function: mixture_function, From ed2f38b80dcd3a51488f9e74458d557a25b4f961 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Mon, 30 Mar 2026 15:39:00 -0400 Subject: [PATCH 15/16] refactor shared sorting and deduped logic --- rust/src/algorithms/algebra.rs | 40 ++++---------------------------- rust/src/algorithms/shared.rs | 40 ++++++++++++++++++++++++++++++++ rust/src/algorithms/transform.rs | 21 ++--------------- 3 files changed, 46 insertions(+), 55 deletions(-) diff --git a/rust/src/algorithms/algebra.rs b/rust/src/algorithms/algebra.rs index 18cae92..7c2044c 100644 --- a/rust/src/algorithms/algebra.rs +++ b/rust/src/algorithms/algebra.rs @@ -90,25 +90,8 @@ pub fn product_discrete( let (sorted_support, sorted_function): (Vec, Vec) = shared::sort_by_support(raw_product_support, raw_product_function)?; - // De-duplicate the support. If a value appears multiple times in the - // support, combine the probabilities - let mut product_function = Vec::new(); - let mut product_support = Vec::new(); - for (&s, &probability) in sorted_support.iter().zip(sorted_function.iter()) { - let support_index = product_support - .iter() - .position(|&x: &Number| x.to_f64() == s.to_f64()); - - match support_index { - Some(index) => { - product_function[index] += probability; - } - None => { - product_function.push(probability); - product_support.push(s); - } - } - } + let (product_support, product_function) = + shared::deduplicate_support(sorted_support, sorted_function)?; let product_rv = RandomVariable { function: product_function, @@ -204,23 +187,8 @@ pub fn convolution_discrete( let (sorted_support, sorted_function): (Vec, Vec) = shared::sort_by_support(raw_conv_support, raw_conv_function)?; - // Remove redundant elements from the support - let mut conv_support = Vec::new(); - let mut conv_function = Vec::new(); - - for (&s, &f) in sorted_support.iter().zip(sorted_function.iter()) { - let support_index = conv_support.iter().position(|&x| x == s); - - match support_index { - Some(index) => { - conv_function[index] += f; - } - None => { - conv_support.push(s); - conv_function.push(f); - } - } - } + let (conv_support, conv_function) = + shared::deduplicate_support(sorted_support, sorted_function)?; let sum_rv = RandomVariable { function: conv_function, diff --git a/rust/src/algorithms/shared.rs b/rust/src/algorithms/shared.rs index 00b4efc..6cb96d1 100644 --- a/rust/src/algorithms/shared.rs +++ b/rust/src/algorithms/shared.rs @@ -12,6 +12,10 @@ pub fn sort_by_support( return Err("support and function must be the same length".to_string()); } + if support.is_empty() { + return Err("support and function cannot be empty".to_string()); + } + let mut zipped_pairs: Vec<_> = support.into_iter().zip(function).collect(); zipped_pairs.sort_by(|a, b| { @@ -23,3 +27,39 @@ pub fn sort_by_support( let (sorted_support, sorted_function) = zipped_pairs.into_iter().unzip(); Ok((sorted_support, sorted_function)) } + +/// De-duplicates the support by combining probabilities for values that already +/// appear in the support +pub fn deduplicate_support( + support: Vec, + function: Vec, +) -> Result<(Vec, Vec), String> { + if support.len() != function.len() { + return Err("support and function must be the same length".to_string()); + } + + if support.is_empty() { + return Err("support and function cannot be empty".to_string()); + } + + let mut deduped_support = Vec::new(); + let mut deduped_function = Vec::new(); + + for (&s, &probability) in support.iter().zip(function.iter()) { + let support_index = deduped_support + .iter() + .position(|&x: &Number| x.to_f64() == s.to_f64()); + + match support_index { + Some(index) => { + deduped_function[index] += probability; + } + None => { + deduped_support.push(s); + deduped_function.push(probability); + } + } + } + + Ok((deduped_support, deduped_function)) +} diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index d5a6c08..640e68d 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -370,25 +370,8 @@ pub fn transform_discrete( let (sorted_support, sorted_function): (Vec, Vec) = raw_transformed_pairs.into_iter().unzip(); - // De-duplicate the support. If a a vlue appears multiple times in the - // support, combine the probabilities - let mut transformed_function = Vec::new(); - let mut transformed_support = Vec::new(); - for (&s, &probability) in sorted_support.iter().zip(sorted_function.iter()) { - let support_index = transformed_support - .iter() - .position(|&x: &Number| x.to_f64() == s.to_f64()); - - match support_index { - Some(index) => { - transformed_function[index] += probability; - } - None => { - transformed_support.push(s); - transformed_function.push(probability); - } - } - } + let (transformed_support, transformed_function) = + shared::deduplicate_support(sorted_support, sorted_function)?; let transformed_rv = RandomVariable { function: transformed_function, From 3c6d23e490a4966c366a5c882e67795847abf635 Mon Sep 17 00:00:00 2001 From: Matt Robinson Date: Mon, 30 Mar 2026 15:40:28 -0400 Subject: [PATCH 16/16] tests for shared module --- rust/src/algorithms/shared.rs | 80 +++++++++++++++++++++++++++++++++++ 1 file changed, 80 insertions(+) diff --git a/rust/src/algorithms/shared.rs b/rust/src/algorithms/shared.rs index 6cb96d1..91d51f9 100644 --- a/rust/src/algorithms/shared.rs +++ b/rust/src/algorithms/shared.rs @@ -63,3 +63,83 @@ pub fn deduplicate_support( Ok((deduped_support, deduped_function)) } + +#[cfg(test)] +mod tests { + use super::*; + use num_rational::Rational64; + + #[test] + fn sort_by_support_rejects_empty_inputs() { + let result = sort_by_support(vec![], vec![]); + + assert!(matches!( + result, + Err(msg) if msg == "support and function cannot be empty" + )); + } + + #[test] + fn deduplicate_support_combines_duplicate_values() { + let support = vec![ + Number::Integer(1), + Number::Integer(1), + Number::Integer(2), + Number::Integer(2), + Number::Integer(3), + ]; + let function = vec![ + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(2, 10)), + Number::Rational(Rational64::new(1, 5)), + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(3, 10)), + ]; + + let (deduped_support, deduped_function) = deduplicate_support(support, function).unwrap(); + + assert_eq!( + deduped_support, + vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)] + ); + assert_eq!( + deduped_function, + vec![ + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(3, 10)), + ] + ); + } + + #[test] + fn deduplicate_support_preserves_first_seen_order() { + let support = vec![ + Number::Integer(2), + Number::Integer(1), + Number::Integer(2), + Number::Integer(3), + ]; + let function = vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ]; + + let (deduped_support, deduped_function) = deduplicate_support(support, function).unwrap(); + + assert_eq!( + deduped_support, + vec![Number::Integer(2), Number::Integer(1), Number::Integer(3)] + ); + assert_eq!( + deduped_function, + vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + } +}