diff --git a/Makefile b/Makefile index ab373bb..fb235f9 100644 --- a/Makefile +++ b/Makefile @@ -95,7 +95,7 @@ ipython: ## Runs iPython with --no-sync to ensure rust bindings are preserved .PHONY: check check: ## Run Ruff lint checks. - uv run ruff check applpy test_applpy + uv run --no-sync ruff check applpy test_applpy .PHONY: tidy tidy: ## Run Ruff autoformatter. diff --git a/applpy/rv.py b/applpy/rv.py index 3b4549c..1613a1a 100644 --- a/applpy/rv.py +++ b/applpy/rv.py @@ -330,13 +330,16 @@ def __add__(self, other): from .algebra import convolution from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv + other_fast_rv) + try: return convolution(self, other) except Exception: return convolution(other, self) - else: - raise RVError("Could not compute the convolution") # If the random variable is added to a constant, shift # the random variable if isinstance(other, (float, int)): @@ -357,6 +360,11 @@ def __radd__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__radd__(other_fast_rv)) return self.__add__(other) def __sub__(self, other): @@ -375,7 +383,12 @@ def __sub__(self, other): from .algebra import convolution from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv - other_fast_rv) + gX = [[-x], [-oo, oo]] random_variable = transform(other, gX) return convolution(self, random_variable) @@ -396,6 +409,12 @@ def __rsub__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__rsub__(other_fast_rv)) + # Perform an negative transformation of the random variable neg_self = -self # Add the two components @@ -417,13 +436,16 @@ def __mul__(self, other): from .algebra import product from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv * other_fast_rv) + try: return product(self, other) except Exception: return product(other, self) - else: - raise RVError("Could not compute the product") # If the random variable is multiplied by a constant, scale # the random variable if isinstance(other, (float, int)): @@ -441,6 +463,11 @@ def __rmul__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__rmul__(other_fast_rv)) return self.__mul__(other) def __truediv__(self, other): @@ -459,7 +486,12 @@ def __truediv__(self, other): from .algebra import product from .transform import transform - if "RV" in other.__class__.__name__: + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv / other_fast_rv) + gX = [[1 / x, 1 / x], [-oo, 0, oo]] random_variable = transform(other, gX) return product(self, random_variable) @@ -480,6 +512,12 @@ def __rtruediv__(self, other): 2. other: a constant or random variable Output: 1. A new random variable """ + if isinstance(other, RV): + if self.is_discrete() and other.is_discrete(): + fast_rv = self.to_fast_rv() + other_fast_rv = other.to_fast_rv() + return RV.from_fast_rv(fast_rv.__rtruediv__(other_fast_rv)) + ## Invert the random variable from .transform import transform diff --git a/rust/src/algorithms/algebra.rs b/rust/src/algorithms/algebra.rs index e2edf94..7c2044c 100644 --- a/rust/src/algorithms/algebra.rs +++ b/rust/src/algorithms/algebra.rs @@ -2,6 +2,7 @@ use crate::algorithms::number::Number; use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +use crate::algorithms::shared; /// Computes the product of two independent discrete random variables /// @@ -86,39 +87,11 @@ pub fn product_discrete( } } - // Sort the multiplied support and function values - let mut raw_product_pairs: Vec<_> = raw_product_support - .into_iter() - .zip(raw_product_function) - .collect(); - raw_product_pairs.sort_by(|a, b| { - let first_value = a.0.to_f64(); - let second_value = b.0.to_f64(); - first_value.total_cmp(&second_value) - }); - let (sorted_support, sorted_function): (Vec, Vec) = - raw_product_pairs.into_iter().unzip(); - - // De-duplicate the support. If a value appears multiple times in the - // support, combine the probabilities - let mut product_function = Vec::new(); - let mut product_support = Vec::new(); - for (&s, &probability) in sorted_support.iter().zip(sorted_function.iter()) { - let support_index = product_support - .iter() - .position(|&x: &Number| x.to_f64() == s.to_f64()); - - match support_index { - Some(index) => { - product_function[index] += probability; - } - None => { - product_function.push(probability); - product_support.push(s); - } - } - } + shared::sort_by_support(raw_product_support, raw_product_function)?; + + let (product_support, product_function) = + shared::deduplicate_support(sorted_support, sorted_function)?; let product_rv = RandomVariable { function: product_function, @@ -211,38 +184,11 @@ pub fn convolution_discrete( } } - // Sorts the results by the support values - let mut raw_conv_pairs: Vec<_> = raw_conv_support - .into_iter() - .zip(raw_conv_function) - .collect(); - - raw_conv_pairs.sort_by(|a, b| { - let first_value = a.0.to_f64(); - let second_value = b.0.to_f64(); - first_value.total_cmp(&second_value) - }); - let (sorted_support, sorted_function): (Vec, Vec) = - raw_conv_pairs.into_iter().unzip(); - - // Remove redundant elements from the support - let mut conv_support = Vec::new(); - let mut conv_function = Vec::new(); - - for (&s, &f) in sorted_support.iter().zip(sorted_function.iter()) { - let support_index = conv_support.iter().position(|&x| x == s); - - match support_index { - Some(index) => { - conv_function[index] += f; - } - None => { - conv_support.push(s); - conv_function.push(f); - } - } - } + shared::sort_by_support(raw_conv_support, raw_conv_function)?; + + let (conv_support, conv_function) = + shared::deduplicate_support(sorted_support, sorted_function)?; let sum_rv = RandomVariable { function: conv_function, diff --git a/rust/src/algorithms/mod.rs b/rust/src/algorithms/mod.rs index cffed21..dece185 100644 --- a/rust/src/algorithms/mod.rs +++ b/rust/src/algorithms/mod.rs @@ -4,4 +4,5 @@ pub mod moments; pub mod number; pub mod order_stat; pub mod rv; +pub mod shared; pub mod transform; diff --git a/rust/src/algorithms/rv.rs b/rust/src/algorithms/rv.rs index f5a1f31..543e959 100644 --- a/rust/src/algorithms/rv.rs +++ b/rust/src/algorithms/rv.rs @@ -1,7 +1,7 @@ #![allow(dead_code)] use std::fmt; -use std::ops::{Add, AddAssign, Mul}; +use std::ops::{Add, AddAssign, Div, DivAssign, Mul, MulAssign, Sub, SubAssign}; use num_rational::Rational64; use num_traits::cast::ToPrimitive; @@ -10,6 +10,7 @@ use crate::algorithms::algebra; use crate::algorithms::conversion; use crate::algorithms::moments; use crate::algorithms::number::Number; +use crate::algorithms::transform; #[derive(Debug, Clone, PartialEq)] pub enum FunctionalForm { @@ -79,6 +80,169 @@ impl AddAssign for RandomVariable { } } +impl Div for RandomVariable { + type Output = Result; + + fn div(self, rhs: Self) -> Self::Output { + if rhs.support.iter().any(|value| match value { + Number::Float(x) => *x == 0.0, + Number::Integer(x) => *x == 0, + Number::Rational(x) => *x.numer() == 0, + }) { + return Err("cannot divide by a random variable with zero in its support".to_string()); + } + + let min_support = rhs + .support + .first() + .expect("failed to extract the first number"); + let max_support = rhs + .support + .last() + .expect("failed to extract the last number"); + let transformation = transform::Transformation { + mapping: |x| Number::Integer(1) / x, + min_support: *min_support, + max_support: *max_support, + }; + let inverse_rhs = transform::transform_discrete(&rhs, &[transformation])?; + let div_rv = algebra::product_discrete(&self, &inverse_rhs)?; + Ok(div_rv) + } +} + +impl DivAssign for RandomVariable { + fn div_assign(&mut self, rhs: Self) { + let div_rv = self + .clone() + .div(rhs) + .expect("failed to divide the random variables"); + *self = div_rv; + } +} + +/// Computes the difference of two independent discrete random variables. +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use num_rational::Rational64; +/// +/// let rv1 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(1), Number::Integer(2)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let rv2 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(2), Number::Integer(3)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let difference = (rv1 - rv2).unwrap(); +/// +/// assert_eq!( +/// difference.support, +/// vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] +/// ); +/// assert_eq!( +/// difference.function, +/// vec![ +/// Number::Rational(Rational64::new(1, 4)), +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 4)), +/// ] +/// ); +/// assert!(difference.verify_pdf(None).unwrap()); +/// ``` +impl Sub for RandomVariable { + type Output = Result; + + fn sub(self, rhs: Self) -> Self::Output { + let min_support = rhs + .support + .first() + .expect("failed to extract the first number"); + let max_support = rhs + .support + .last() + .expect("failed to extract the last number"); + let transformation = transform::Transformation { + mapping: |x| x * Number::Integer(-1), + min_support: *min_support, + max_support: *max_support, + }; + let negative_rhs = transform::transform_discrete(&rhs, &[transformation])?; + let sub_rv = algebra::convolution_discrete(&self, &negative_rhs)?; + Ok(sub_rv) + } +} + +/// Updates a random variable in place with the difference of two +/// independent discrete random variables. +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use num_rational::Rational64; +/// +/// let mut rv1 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(1), Number::Integer(2)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let rv2 = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 2)), +/// ], +/// support: vec![Number::Integer(2), Number::Integer(3)], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// rv1 -= rv2; +/// +/// assert_eq!( +/// rv1.support, +/// vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] +/// ); +/// assert_eq!( +/// rv1.function, +/// vec![ +/// Number::Rational(Rational64::new(1, 4)), +/// Number::Rational(Rational64::new(1, 2)), +/// Number::Rational(Rational64::new(1, 4)), +/// ] +/// ); +/// assert!(rv1.verify_pdf(None).unwrap()); +/// ``` +impl SubAssign for RandomVariable { + fn sub_assign(&mut self, rhs: Self) { + let sub_rv = self + .clone() + .sub(rhs) + .expect("failed to subtract the random variables"); + *self = sub_rv.clone(); + } +} + impl Mul for RandomVariable { type Output = Result; @@ -88,6 +252,16 @@ impl Mul for RandomVariable { } } +impl MulAssign for RandomVariable { + fn mul_assign(&mut self, rhs: Self) { + let product_rv = self + .clone() + .mul(rhs) + .expect("failed to multiply the random variables"); + *self = product_rv; + } +} + impl RandomVariable { pub fn verify_pdf(&self, tolerance: Option) -> Result { if self.functional_form != FunctionalForm::Pdf { @@ -402,6 +576,365 @@ mod tests { use super::*; use num_rational::Rational64; + #[test] + fn add_returns_sum_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs + rhs).unwrap(); + + assert_eq!( + result.support, + vec![Number::Integer(3), Number::Integer(4), Number::Integer(5)] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + + #[test] + fn add_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs += rhs; + + assert_eq!( + lhs.support, + vec![Number::Integer(3), Number::Integer(4), Number::Integer(5)] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + + #[test] + fn div_returns_quotient_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(4)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![ + Number::Rational(Rational64::new(2, 1)), + Number::Rational(Rational64::new(4, 1)), + ], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs / rhs).unwrap(); + + assert_eq!( + result.support, + vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 1)), + Number::Rational(Rational64::new(2, 1)), + ] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + + #[test] + fn div_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(4)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![ + Number::Rational(Rational64::new(2, 1)), + Number::Rational(Rational64::new(4, 1)), + ], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs /= rhs; + + assert_eq!( + lhs.support, + vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 1)), + Number::Rational(Rational64::new(2, 1)), + ] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + + #[test] + fn div_returns_error_when_rhs_support_contains_zero() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(0), Number::Integer(1)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = lhs.clone(); + + let err = (lhs / rhs).unwrap_err(); + + assert_eq!( + err, + "cannot divide by a random variable with zero in its support" + ); + } + + #[test] + fn sub_returns_difference_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs - rhs).unwrap(); + + assert_eq!( + result.support, + vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + + #[test] + fn sub_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs -= rhs; + + assert_eq!( + lhs.support, + vec![Number::Integer(-2), Number::Integer(-1), Number::Integer(0)] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + + #[test] + fn mul_returns_product_distribution() { + let lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + let result = (lhs * rhs).unwrap(); + + assert_eq!( + result.support, + vec![ + Number::Integer(2), + Number::Integer(3), + Number::Integer(4), + Number::Integer(6), + ] + ); + assert_eq!( + result.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(result.functional_form, FunctionalForm::Pdf)); + assert!(matches!(result.domain_type, DomainType::Discrete)); + } + + #[test] + fn mul_assign_updates_random_variable_in_place() { + let mut lhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(1), Number::Integer(2)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + let rhs = RandomVariable { + function: vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 2)), + ], + support: vec![Number::Integer(2), Number::Integer(3)], + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + + lhs *= rhs; + + assert_eq!( + lhs.support, + vec![ + Number::Integer(2), + Number::Integer(3), + Number::Integer(4), + Number::Integer(6), + ] + ); + assert_eq!( + lhs.function, + vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + assert!(matches!(lhs.functional_form, FunctionalForm::Pdf)); + assert!(matches!(lhs.domain_type, DomainType::Discrete)); + } + #[test] fn verify_pdf_returns_err_for_non_pdf_functional_form() { let rv = RandomVariable { diff --git a/rust/src/algorithms/shared.rs b/rust/src/algorithms/shared.rs new file mode 100644 index 0000000..91d51f9 --- /dev/null +++ b/rust/src/algorithms/shared.rs @@ -0,0 +1,145 @@ +#![allow(dead_code)] + +use crate::algorithms::number::Number; + +/// Sorts the support of a random variable, while keeping the function aligned +/// with the support +pub fn sort_by_support( + support: Vec, + function: Vec, +) -> Result<(Vec, Vec), String> { + if support.len() != function.len() { + return Err("support and function must be the same length".to_string()); + } + + if support.is_empty() { + return Err("support and function cannot be empty".to_string()); + } + + let mut zipped_pairs: Vec<_> = support.into_iter().zip(function).collect(); + + zipped_pairs.sort_by(|a, b| { + let first_value = a.0.to_f64(); + let second_value = b.0.to_f64(); + first_value.total_cmp(&second_value) + }); + + let (sorted_support, sorted_function) = zipped_pairs.into_iter().unzip(); + Ok((sorted_support, sorted_function)) +} + +/// De-duplicates the support by combining probabilities for values that already +/// appear in the support +pub fn deduplicate_support( + support: Vec, + function: Vec, +) -> Result<(Vec, Vec), String> { + if support.len() != function.len() { + return Err("support and function must be the same length".to_string()); + } + + if support.is_empty() { + return Err("support and function cannot be empty".to_string()); + } + + let mut deduped_support = Vec::new(); + let mut deduped_function = Vec::new(); + + for (&s, &probability) in support.iter().zip(function.iter()) { + let support_index = deduped_support + .iter() + .position(|&x: &Number| x.to_f64() == s.to_f64()); + + match support_index { + Some(index) => { + deduped_function[index] += probability; + } + None => { + deduped_support.push(s); + deduped_function.push(probability); + } + } + } + + Ok((deduped_support, deduped_function)) +} + +#[cfg(test)] +mod tests { + use super::*; + use num_rational::Rational64; + + #[test] + fn sort_by_support_rejects_empty_inputs() { + let result = sort_by_support(vec![], vec![]); + + assert!(matches!( + result, + Err(msg) if msg == "support and function cannot be empty" + )); + } + + #[test] + fn deduplicate_support_combines_duplicate_values() { + let support = vec![ + Number::Integer(1), + Number::Integer(1), + Number::Integer(2), + Number::Integer(2), + Number::Integer(3), + ]; + let function = vec![ + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(2, 10)), + Number::Rational(Rational64::new(1, 5)), + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(3, 10)), + ]; + + let (deduped_support, deduped_function) = deduplicate_support(support, function).unwrap(); + + assert_eq!( + deduped_support, + vec![Number::Integer(1), Number::Integer(2), Number::Integer(3)] + ); + assert_eq!( + deduped_function, + vec![ + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(3, 10)), + ] + ); + } + + #[test] + fn deduplicate_support_preserves_first_seen_order() { + let support = vec![ + Number::Integer(2), + Number::Integer(1), + Number::Integer(2), + Number::Integer(3), + ]; + let function = vec![ + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ]; + + let (deduped_support, deduped_function) = deduplicate_support(support, function).unwrap(); + + assert_eq!( + deduped_support, + vec![Number::Integer(2), Number::Integer(1), Number::Integer(3)] + ); + assert_eq!( + deduped_function, + vec![ + Number::Rational(Rational64::new(1, 2)), + Number::Rational(Rational64::new(1, 4)), + Number::Rational(Rational64::new(1, 4)), + ] + ); + } +} diff --git a/rust/src/algorithms/transform.rs b/rust/src/algorithms/transform.rs index a65a844..640e68d 100644 --- a/rust/src/algorithms/transform.rs +++ b/rust/src/algorithms/transform.rs @@ -2,6 +2,7 @@ use crate::algorithms::number::Number; use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +use crate::algorithms::shared; /// Truncates a discrete random variable by cutting off a portion of the support /// and normalizing total probability of the distribution to 1. @@ -204,26 +205,182 @@ pub fn mixture_discrete( } } - let mut raw_mixture_pair: Vec<_> = raw_mixture_support - .into_iter() - .zip(raw_mixture_function) - .collect(); + let (mixture_support, mixture_function) = + shared::sort_by_support(raw_mixture_support, raw_mixture_function)?; - raw_mixture_pair.sort_by(|a, b| { + let mix_rv = RandomVariable { + function: mixture_function, + support: mixture_support, + functional_form: FunctionalForm::Pdf, + domain_type: DomainType::Discrete, + }; + Ok(mix_rv) +} + +pub struct Transformation { + pub mapping: fn(Number) -> Number, + pub min_support: Number, + pub max_support: Number, +} + +/// Computes a transformation of a discrete random variable +/// +/// # Arugments +/// * `random_variable`- the random variable to transform +/// * `transformation` - the transformation to apply to the random variable +/// +/// # Returns +/// * `transformed_rv` - the transformed random variable +/// +/// # Examples +/// ``` +/// use applpy_rust::algorithms::number::Number; +/// use applpy_rust::algorithms::rv::{DomainType, FunctionalForm, RandomVariable}; +/// use applpy_rust::algorithms::transform::{transform_discrete, Transformation}; +/// use num_rational::Rational64; +/// +/// let rv = RandomVariable { +/// function: vec![ +/// Number::Rational(Rational64::new(1, 10)), +/// Number::Rational(Rational64::new(2, 10)), +/// Number::Rational(Rational64::new(3, 10)), +/// Number::Rational(Rational64::new(4, 10)), +/// ], +/// support: vec![ +/// Number::Integer(1), +/// Number::Integer(2), +/// Number::Integer(3), +/// Number::Integer(4), +/// ], +/// functional_form: FunctionalForm::Pdf, +/// domain_type: DomainType::Discrete, +/// }; +/// +/// let transformed = transform_discrete( +/// &rv, +/// &[ +/// Transformation { +/// mapping: |x| x * Number::Integer(2), +/// min_support: Number::Integer(1), +/// max_support: Number::Integer(3), +/// }, +/// Transformation { +/// mapping: |x| x + Number::Integer(10), +/// min_support: Number::Integer(3), +/// max_support: Number::Integer(5), +/// }, +/// ], +/// ) +/// .unwrap(); +/// +/// assert_eq!( +/// transformed.support, +/// vec![ +/// Number::Integer(2), +/// Number::Integer(4), +/// Number::Integer(6), +/// Number::Integer(14), +/// ] +/// ); +/// assert_eq!( +/// transformed.function, +/// vec![ +/// Number::Rational(Rational64::new(1, 10)), +/// Number::Rational(Rational64::new(2, 10)), +/// Number::Rational(Rational64::new(3, 10)), +/// Number::Rational(Rational64::new(4, 10)), +/// ] +/// ); +/// ``` +pub fn transform_discrete( + random_variable: &RandomVariable, + transformations: &[Transformation], +) -> Result { + let pdf_random_variable = random_variable.to_pdf()?; + let support = pdf_random_variable.support; + let function = pdf_random_variable.function; + + if transformations.is_empty() { + return Err("at least one transformation is required".to_string()); + } + + // Validate that each transformation range is increasing. + for transformation in transformations { + if transformation.min_support >= transformation.max_support { + return Err( + "the max range of the transformation must exceeed the min range".to_string(), + ); + } + } + + // Validate that the transformations are adjacent. + for window in transformations.windows(2) { + let current_transformation = &window[0]; + let next_transformation = &window[1]; + + if current_transformation.max_support != next_transformation.min_support { + return Err("the transformation ranges must be adjacent".to_string()); + } + } + + // Validate that the transformations cover the support + let lowest_support = support.first().expect("unable to extract lowest support"); + let lowest_transform = transformations + .first() + .expect("unable to extract lowest transform") + .min_support; + if *lowest_support < lowest_transform { + return Err( + "the minimum transformation support is higher than the minimum rv support".to_string(), + ); + } + + let highest_support = support.last().expect("unable to extract highest support"); + let highest_transform = transformations + .last() + .expect("unable to extract highest transform") + .max_support; + if *highest_support > highest_transform { + return Err( + "the maxium transformation support is lower than the maximum rv support".to_string(), + ); + } + + // Compute the transformed support and preserve the original probabilities. + let mut raw_transformed_pairs = Vec::new(); + for (&s, &probability) in support.iter().zip(function.iter()) { + for transformation in transformations { + let mapping = &transformation.mapping; + if s >= transformation.min_support && s <= transformation.max_support { + let raw_transformed_support_value = mapping(s); + raw_transformed_pairs.push((raw_transformed_support_value, probability)); + // Break to avoid inserting multiple transformed entries for + // any given support value + break; + } + } + } + + raw_transformed_pairs.sort_by(|a, b| { let first_value = a.0.to_f64(); let second_value = b.0.to_f64(); first_value.total_cmp(&second_value) }); - let (mixture_support, mixture_function) = raw_mixture_pair.into_iter().unzip(); + let (sorted_support, sorted_function): (Vec, Vec) = + raw_transformed_pairs.into_iter().unzip(); - let mix_rv = RandomVariable { - function: mixture_function, - support: mixture_support, + let (transformed_support, transformed_function) = + shared::deduplicate_support(sorted_support, sorted_function)?; + + let transformed_rv = RandomVariable { + function: transformed_function, + support: transformed_support, functional_form: FunctionalForm::Pdf, domain_type: DomainType::Discrete, }; - Ok(mix_rv) + + Ok(transformed_rv) } #[cfg(test)] @@ -355,4 +512,137 @@ mod tests { Err(msg) if msg == "the mix weights must sum to one" )); } + + #[test] + fn transform_discrete_uses_first_range_for_shared_boundary() { + let rv = sample_discrete_rv(); + + let transformed = transform_discrete( + &rv, + &[ + Transformation { + mapping: |x| x * Number::Integer(2), + min_support: Number::Integer(1), + max_support: Number::Integer(3), + }, + Transformation { + mapping: |x| x + Number::Integer(10), + min_support: Number::Integer(3), + max_support: Number::Integer(5), + }, + ], + ) + .unwrap(); + + assert_eq!( + transformed.support, + vec![ + Number::Integer(2), + Number::Integer(4), + Number::Integer(6), + Number::Integer(14), + ] + ); + assert_eq!( + transformed.function, + vec![ + Number::Rational(Rational64::new(1, 10)), + Number::Rational(Rational64::new(2, 10)), + Number::Rational(Rational64::new(3, 10)), + Number::Rational(Rational64::new(4, 10)), + ] + ); + assert!(matches!(transformed.functional_form, FunctionalForm::Pdf)); + assert!(matches!(transformed.domain_type, DomainType::Discrete)); + } + + #[test] + fn transform_discrete_returns_error_for_empty_transformations() { + let rv = sample_discrete_rv(); + let result = transform_discrete(&rv, &[]); + + assert!(matches!( + result, + Err(msg) if msg == "at least one transformation is required" + )); + } + + #[test] + fn transform_discrete_returns_error_for_single_invalid_range() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[Transformation { + mapping: |x| x, + min_support: Number::Integer(1), + max_support: Number::Integer(1), + }], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the max range of the transformation must exceeed the min range" + )); + } + + #[test] + fn transform_discrete_returns_error_when_ranges_are_not_adjacent() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[ + Transformation { + mapping: |x| x, + min_support: Number::Integer(1), + max_support: Number::Integer(2), + }, + Transformation { + mapping: |x| x, + min_support: Number::Integer(3), + max_support: Number::Integer(5), + }, + ], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the transformation ranges must be adjacent" + )); + } + + #[test] + fn transform_discrete_returns_error_when_transformations_do_not_cover_min_support() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[Transformation { + mapping: |x| x, + min_support: Number::Integer(2), + max_support: Number::Integer(5), + }], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the minimum transformation support is higher than the minimum rv support" + )); + } + + #[test] + fn transform_discrete_returns_error_when_transformations_do_not_cover_max_support() { + let rv = sample_discrete_rv(); + let result = transform_discrete( + &rv, + &[Transformation { + mapping: |x| x, + min_support: Number::Integer(1), + max_support: Number::Integer(3), + }], + ); + + assert!(matches!( + result, + Err(msg) if msg == "the maxium transformation support is lower than the maximum rv support" + )); + } } diff --git a/rust/src/python/api.rs b/rust/src/python/api.rs index 77b050d..19ed581 100644 --- a/rust/src/python/api.rs +++ b/rust/src/python/api.rs @@ -291,6 +291,54 @@ impl FastRV { } } + pub fn __radd__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let sum_rv = lhs_rv + self_rv; + + match sum_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __sub__(&self, rhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let rhs_rv = rhs.inner.clone(); + + let difference_rv = self_rv - rhs_rv; + + match difference_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __rsub__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let difference_rv = lhs_rv - self_rv; + + match difference_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + pub fn __mul__(&self, rhs: FastRV) -> PyResult { let self_rv = self.inner.clone(); let rhs_rv = rhs.inner.clone(); @@ -307,6 +355,54 @@ impl FastRV { } } + pub fn __rmul__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let product_rv = lhs_rv * self_rv; + + match product_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __truediv__(&self, rhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let rhs_rv = rhs.inner.clone(); + + let quotient_rv = self_rv / rhs_rv; + + match quotient_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + + pub fn __rtruediv__(&self, lhs: FastRV) -> PyResult { + let self_rv = self.inner.clone(); + let lhs_rv = lhs.inner.clone(); + + let quotient_rv = lhs_rv / self_rv; + + match quotient_rv { + Ok(rv) => { + let fast_rv = + FastRV::new(rv.function, rv.support, rv.functional_form, rv.domain_type); + Ok(fast_rv) + } + Err(s) => Err(PyErr::new::(s)), + } + } + #[getter] pub fn function(&self) -> Vec { self.inner.function.clone() diff --git a/test_applpy/unit/test_rv.py b/test_applpy/unit/test_rv.py index 07fa139..42a59b2 100644 --- a/test_applpy/unit/test_rv.py +++ b/test_applpy/unit/test_rv.py @@ -151,7 +151,7 @@ def test_operator_overloads_for_scalars_and_rvs(): with pytest.raises(NotImplementedError): abs(rv) - with pytest.raises(TypeError): + with pytest.raises(ValueError): discrete / discrete with pytest.raises(RVError, match="integer value"):