Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ ipython: ## Runs iPython with --no-sync to ensure rust bindings are preserved

.PHONY: check
check: ## Run Ruff lint checks.
uv run ruff check applpy test_applpy
uv run --no-sync ruff check applpy test_applpy

.PHONY: tidy
tidy: ## Run Ruff autoformatter.
Expand Down
54 changes: 46 additions & 8 deletions applpy/rv.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,13 +330,16 @@ def __add__(self, other):
from .algebra import convolution
from .transform import transform

if "RV" in other.__class__.__name__:
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv + other_fast_rv)

try:
return convolution(self, other)
except Exception:
return convolution(other, self)
else:
raise RVError("Could not compute the convolution")
# If the random variable is added to a constant, shift
# the random variable
if isinstance(other, (float, int)):
Expand All @@ -357,6 +360,11 @@ def __radd__(self, other):
2. other: a constant or random variable
Output: 1. A new random variable
"""
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv.__radd__(other_fast_rv))
return self.__add__(other)

def __sub__(self, other):
Expand All @@ -375,7 +383,12 @@ def __sub__(self, other):
from .algebra import convolution
from .transform import transform

if "RV" in other.__class__.__name__:
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv - other_fast_rv)

gX = [[-x], [-oo, oo]]
random_variable = transform(other, gX)
return convolution(self, random_variable)
Expand All @@ -396,6 +409,12 @@ def __rsub__(self, other):
2. other: a constant or random variable
Output: 1. A new random variable
"""
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv.__rsub__(other_fast_rv))

# Perform an negative transformation of the random variable
neg_self = -self
# Add the two components
Expand All @@ -417,13 +436,16 @@ def __mul__(self, other):
from .algebra import product
from .transform import transform

if "RV" in other.__class__.__name__:
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv * other_fast_rv)

try:
return product(self, other)
except Exception:
return product(other, self)
else:
raise RVError("Could not compute the product")
# If the random variable is multiplied by a constant, scale
# the random variable
if isinstance(other, (float, int)):
Expand All @@ -441,6 +463,11 @@ def __rmul__(self, other):
2. other: a constant or random variable
Output: 1. A new random variable
"""
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv.__rmul__(other_fast_rv))
return self.__mul__(other)

def __truediv__(self, other):
Expand All @@ -459,7 +486,12 @@ def __truediv__(self, other):
from .algebra import product
from .transform import transform

if "RV" in other.__class__.__name__:
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv / other_fast_rv)

gX = [[1 / x, 1 / x], [-oo, 0, oo]]
random_variable = transform(other, gX)
return product(self, random_variable)
Expand All @@ -480,6 +512,12 @@ def __rtruediv__(self, other):
2. other: a constant or random variable
Output: 1. A new random variable
"""
if isinstance(other, RV):
if self.is_discrete() and other.is_discrete():
fast_rv = self.to_fast_rv()
other_fast_rv = other.to_fast_rv()
return RV.from_fast_rv(fast_rv.__rtruediv__(other_fast_rv))

## Invert the random variable
from .transform import transform

Expand Down
72 changes: 9 additions & 63 deletions rust/src/algorithms/algebra.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

use crate::algorithms::number::Number;
use crate::algorithms::rv::{DomainType, FunctionalForm, RandomVariable};
use crate::algorithms::shared;

/// Computes the product of two independent discrete random variables
///
Expand Down Expand Up @@ -86,39 +87,11 @@ pub fn product_discrete(
}
}

// Sort the multiplied support and function values
let mut raw_product_pairs: Vec<_> = raw_product_support
.into_iter()
.zip(raw_product_function)
.collect();
raw_product_pairs.sort_by(|a, b| {
let first_value = a.0.to_f64();
let second_value = b.0.to_f64();
first_value.total_cmp(&second_value)
});

let (sorted_support, sorted_function): (Vec<Number>, Vec<Number>) =
raw_product_pairs.into_iter().unzip();

// De-duplicate the support. If a value appears multiple times in the
// support, combine the probabilities
let mut product_function = Vec::new();
let mut product_support = Vec::new();
for (&s, &probability) in sorted_support.iter().zip(sorted_function.iter()) {
let support_index = product_support
.iter()
.position(|&x: &Number| x.to_f64() == s.to_f64());

match support_index {
Some(index) => {
product_function[index] += probability;
}
None => {
product_function.push(probability);
product_support.push(s);
}
}
}
shared::sort_by_support(raw_product_support, raw_product_function)?;

let (product_support, product_function) =
shared::deduplicate_support(sorted_support, sorted_function)?;

let product_rv = RandomVariable {
function: product_function,
Expand Down Expand Up @@ -211,38 +184,11 @@ pub fn convolution_discrete(
}
}

// Sorts the results by the support values
let mut raw_conv_pairs: Vec<_> = raw_conv_support
.into_iter()
.zip(raw_conv_function)
.collect();

raw_conv_pairs.sort_by(|a, b| {
let first_value = a.0.to_f64();
let second_value = b.0.to_f64();
first_value.total_cmp(&second_value)
});

let (sorted_support, sorted_function): (Vec<Number>, Vec<Number>) =
raw_conv_pairs.into_iter().unzip();

// Remove redundant elements from the support
let mut conv_support = Vec::new();
let mut conv_function = Vec::new();

for (&s, &f) in sorted_support.iter().zip(sorted_function.iter()) {
let support_index = conv_support.iter().position(|&x| x == s);

match support_index {
Some(index) => {
conv_function[index] += f;
}
None => {
conv_support.push(s);
conv_function.push(f);
}
}
}
shared::sort_by_support(raw_conv_support, raw_conv_function)?;

let (conv_support, conv_function) =
shared::deduplicate_support(sorted_support, sorted_function)?;

let sum_rv = RandomVariable {
function: conv_function,
Expand Down
1 change: 1 addition & 0 deletions rust/src/algorithms/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ pub mod moments;
pub mod number;
pub mod order_stat;
pub mod rv;
pub mod shared;
pub mod transform;
Loading
Loading