From a9b7e2e320575962b7a8a5b43f68f35403185015 Mon Sep 17 00:00:00 2001 From: ArizmendiWan <2311602492@qq.com> Date: Thu, 19 Mar 2026 17:44:21 -0500 Subject: [PATCH] Add KD-tree conflict detection benchmark slice Signed-off-by: ArizmendiWan <2311602492@qq.com> --- .gitignore | 1 + mapf/Cargo.toml | 6 + mapf/benches/conflict_detection.rs | 217 ++++++++++ mapf/examples/conflict_detection_report.rs | 227 ++++++++++ mapf/src/motion/environment.rs | 8 + mapf/src/negotiation/mod.rs | 473 +++++++++++++++++++-- notes/conflict_detection_benchmark.md | 45 ++ 7 files changed, 949 insertions(+), 28 deletions(-) create mode 100644 mapf/benches/conflict_detection.rs create mode 100644 mapf/examples/conflict_detection_report.rs create mode 100644 notes/conflict_detection_benchmark.md diff --git a/.gitignore b/.gitignore index 96ef6c0..61b3e2f 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ /target Cargo.lock +.cargo_home/ diff --git a/mapf/Cargo.toml b/mapf/Cargo.toml index 6920707..3084115 100644 --- a/mapf/Cargo.toml +++ b/mapf/Cargo.toml @@ -27,6 +27,12 @@ smallvec = "1.10" serde = { version="1.0", features = ["derive"] } serde_yaml = "0.9" slotmap = "1.0" +kdtree = "0.7" [dev-dependencies] approx = "0.5" +criterion = { version = "0.5", default-features = false, features = ["cargo_bench_support"] } + +[[bench]] +name = "conflict_detection" +harness = false diff --git a/mapf/benches/conflict_detection.rs b/mapf/benches/conflict_detection.rs new file mode 100644 index 0000000..8baa7a0 --- /dev/null +++ b/mapf/benches/conflict_detection.rs @@ -0,0 +1,217 @@ +use criterion::{black_box, criterion_group, criterion_main, BenchmarkId, Criterion}; +use mapf::{ + algorithm::path::{DecisionPoint, MetaTrajectory}, + domain::Cost, + graph::occupancy::Cell, + motion::CircularProfile, + motion::{r2::Positioned, se2::WaypointSE2, Trajectory}, + negotiation::{detect_conflicts_for_proposals, ConflictDetectionAlgorithm, Proposal}, + premade::StateSippSE2, +}; +use std::{collections::HashMap, f64::consts::PI}; + +#[derive(Clone, Copy)] +enum SyntheticScenarioFamily { + LowOverlap, + MediumOverlap, + HighOverlap, +} + +fn bench_conflict_detection(c: &mut Criterion) { + let mut group = c.benchmark_group("conflict_detection"); + + for family in [ + SyntheticScenarioFamily::LowOverlap, + SyntheticScenarioFamily::MediumOverlap, + SyntheticScenarioFamily::HighOverlap, + ] { + for (agents, waypoints) in [(16, 16), (32, 24), (64, 32)] { + let case = build_synthetic_proposals(agents, waypoints, family); + let label = format!("{}-a{}-w{}", family_name(family), agents, waypoints); + + group.bench_with_input( + BenchmarkId::new("baseline", &label), + &case, + |b, (proposals, profiles)| { + b.iter(|| { + black_box(detect_conflicts_for_proposals( + proposals, + profiles, + ConflictDetectionAlgorithm::Baseline, + )) + }); + }, + ); + + group.bench_with_input( + BenchmarkId::new("kdtree", &label), + &case, + |b, (proposals, profiles)| { + b.iter(|| { + black_box(detect_conflicts_for_proposals( + proposals, + profiles, + ConflictDetectionAlgorithm::KdTree, + )) + }); + }, + ); + } + } + + group.finish(); +} + +criterion_group!(benches, bench_conflict_detection); +criterion_main!(benches); + +fn family_name(family: SyntheticScenarioFamily) -> &'static str { + match family { + SyntheticScenarioFamily::LowOverlap => "low", + SyntheticScenarioFamily::MediumOverlap => "medium", + SyntheticScenarioFamily::HighOverlap => "high", + } +} + +fn build_synthetic_proposals( + agent_count: usize, + waypoint_count: usize, + family: SyntheticScenarioFamily, +) -> (HashMap, Vec) { + let mut proposals = HashMap::new(); + let profiles = vec![CircularProfile::new(0.35, 0.0, 0.0).unwrap(); agent_count]; + + for agent in 0..agent_count { + let control_points = control_points_for_family(agent, agent_count, family); + let start_time = start_time_for_family(agent, family); + let waypoints = sample_polyline( + &control_points, + waypoint_count, + start_time, + start_time + 20.0, + ); + let trajectory = Trajectory::from_iter(waypoints.iter().copied()).unwrap(); + + let states: Vec<_> = waypoints + .iter() + .copied() + .map(|waypoint| StateSippSE2::new(Cell::from_point(waypoint.point(), 1.0), waypoint)) + .collect(); + + let meta = MetaTrajectory { + trajectory, + decision_points: states + .iter() + .cloned() + .enumerate() + .map(|(index, state)| DecisionPoint { index, state }) + .collect(), + initial_state: states.first().cloned().unwrap(), + final_state: states.last().cloned().unwrap(), + }; + + proposals.insert( + agent, + Proposal { + meta, + cost: Cost(agent as f64), + }, + ); + } + + (proposals, profiles) +} + +fn control_points_for_family( + agent: usize, + agent_count: usize, + family: SyntheticScenarioFamily, +) -> Vec<(f64, f64)> { + match family { + SyntheticScenarioFamily::LowOverlap => { + let y = (agent as f64 - agent_count as f64 / 2.0) * 4.0; + vec![(-12.0, y), (12.0, y)] + } + SyntheticScenarioFamily::MediumOverlap => { + let lane = agent % 6; + let offset = lane as f64 - 2.5; + if agent % 2 == 0 { + let y = offset * 1.5; + vec![(-12.0, y), (-2.0, y), (2.0, y), (12.0, y)] + } else { + let x = offset * 1.5; + vec![(x, -12.0), (x, -2.0), (x, 2.0), (x, 12.0)] + } + } + SyntheticScenarioFamily::HighOverlap => { + let theta = 2.0 * PI * agent as f64 / agent_count as f64; + let (sin_t, cos_t) = theta.sin_cos(); + vec![ + (10.0 * cos_t, 10.0 * sin_t), + (2.0 * cos_t, 2.0 * sin_t), + (0.0, 0.0), + (-2.0 * cos_t, -2.0 * sin_t), + (-10.0 * cos_t, -10.0 * sin_t), + ] + } + } +} + +fn start_time_for_family(agent: usize, family: SyntheticScenarioFamily) -> f64 { + match family { + SyntheticScenarioFamily::LowOverlap => agent as f64 * 0.05, + SyntheticScenarioFamily::MediumOverlap => (agent % 4) as f64 * 0.15, + SyntheticScenarioFamily::HighOverlap => (agent % 8) as f64 * 0.05, + } +} + +fn sample_polyline( + control_points: &[(f64, f64)], + waypoint_count: usize, + start_time: f64, + finish_time: f64, +) -> Vec { + let mut lengths = Vec::new(); + let mut cumulative = Vec::new(); + let mut total_length = 0.0; + + for window in control_points.windows(2) { + let dx = window[1].0 - window[0].0; + let dy = window[1].1 - window[0].1; + let length = (dx * dx + dy * dy).sqrt(); + lengths.push(length); + cumulative.push(total_length); + total_length += length.max(1e-9); + } + + let mut waypoints = Vec::with_capacity(waypoint_count); + for index in 0..waypoint_count { + let progress = if waypoint_count == 1 { + 0.0 + } else { + index as f64 / (waypoint_count - 1) as f64 + }; + let target_length = total_length * progress; + + let mut segment = lengths.len() - 1; + for (candidate, start_length) in cumulative.iter().enumerate() { + let end_length = *start_length + lengths[candidate]; + if target_length <= end_length || candidate == lengths.len() - 1 { + segment = candidate; + break; + } + } + + let seg_start = control_points[segment]; + let seg_end = control_points[segment + 1]; + let seg_length = lengths[segment].max(1e-9); + let local_progress = (target_length - cumulative[segment]) / seg_length; + let x = seg_start.0 + (seg_end.0 - seg_start.0) * local_progress; + let y = seg_start.1 + (seg_end.1 - seg_start.1) * local_progress; + let yaw = (seg_end.1 - seg_start.1).atan2(seg_end.0 - seg_start.0); + let time = start_time + (finish_time - start_time) * progress; + waypoints.push(WaypointSE2::new_f64(time, x, y, yaw)); + } + + waypoints +} diff --git a/mapf/examples/conflict_detection_report.rs b/mapf/examples/conflict_detection_report.rs new file mode 100644 index 0000000..f229216 --- /dev/null +++ b/mapf/examples/conflict_detection_report.rs @@ -0,0 +1,227 @@ +use mapf::{ + algorithm::path::{DecisionPoint, MetaTrajectory}, + domain::Cost, + graph::occupancy::Cell, + motion::CircularProfile, + motion::{r2::Positioned, se2::WaypointSE2, Trajectory}, + negotiation::{ + detect_conflicts_for_proposals, ConflictDetectionAlgorithm, ConflictDetectionReport, + Proposal, + }, + premade::StateSippSE2, +}; +use std::{collections::HashMap, f64::consts::PI, time::Instant}; + +#[derive(Clone, Copy)] +enum SyntheticScenarioFamily { + LowOverlap, + MediumOverlap, + HighOverlap, +} + +fn main() { + println!( + "family,agents,waypoints,baseline_ms,kdtree_ms,baseline_pairs,kdtree_pairs,bbox_pairs_baseline,bbox_pairs_kdtree,conflicts" + ); + + for family in [ + SyntheticScenarioFamily::LowOverlap, + SyntheticScenarioFamily::MediumOverlap, + SyntheticScenarioFamily::HighOverlap, + ] { + for (agents, waypoints) in [(16, 16), (32, 24), (64, 32)] { + let (proposals, profiles) = build_synthetic_proposals(agents, waypoints, family); + + let started = Instant::now(); + let baseline = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::Baseline, + ); + let baseline_elapsed = started.elapsed(); + + let started = Instant::now(); + let kdtree = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::KdTree, + ); + let kdtree_elapsed = started.elapsed(); + + assert_eq!(conflict_signature(&baseline), conflict_signature(&kdtree)); + + println!( + "{},{},{},{:.3},{:.3},{},{},{},{},{}", + family_name(family), + agents, + waypoints, + baseline_elapsed.as_secs_f64() * 1000.0, + kdtree_elapsed.as_secs_f64() * 1000.0, + baseline.stats.pair_enumerations, + kdtree.stats.pair_enumerations, + baseline.stats.bbox_candidate_pairs, + kdtree.stats.bbox_candidate_pairs, + baseline.stats.conflicts_found + ); + } + } +} + +fn family_name(family: SyntheticScenarioFamily) -> &'static str { + match family { + SyntheticScenarioFamily::LowOverlap => "low", + SyntheticScenarioFamily::MediumOverlap => "medium", + SyntheticScenarioFamily::HighOverlap => "high", + } +} + +fn conflict_signature(report: &ConflictDetectionReport) -> Vec { + report + .conflicts + .iter() + .map(|conflict| format!("{conflict:?}")) + .collect() +} + +fn build_synthetic_proposals( + agent_count: usize, + waypoint_count: usize, + family: SyntheticScenarioFamily, +) -> (HashMap, Vec) { + let mut proposals = HashMap::new(); + let profiles = vec![CircularProfile::new(0.35, 0.0, 0.0).unwrap(); agent_count]; + + for agent in 0..agent_count { + let control_points = control_points_for_family(agent, agent_count, family); + let start_time = start_time_for_family(agent, family); + let waypoints = sample_polyline( + &control_points, + waypoint_count, + start_time, + start_time + 20.0, + ); + let trajectory = Trajectory::from_iter(waypoints.iter().copied()).unwrap(); + + let states: Vec<_> = waypoints + .iter() + .copied() + .map(|waypoint| StateSippSE2::new(Cell::from_point(waypoint.point(), 1.0), waypoint)) + .collect(); + + let meta = MetaTrajectory { + trajectory, + decision_points: states + .iter() + .cloned() + .enumerate() + .map(|(index, state)| DecisionPoint { index, state }) + .collect(), + initial_state: states.first().cloned().unwrap(), + final_state: states.last().cloned().unwrap(), + }; + + proposals.insert( + agent, + Proposal { + meta, + cost: Cost(agent as f64), + }, + ); + } + + (proposals, profiles) +} + +fn control_points_for_family( + agent: usize, + agent_count: usize, + family: SyntheticScenarioFamily, +) -> Vec<(f64, f64)> { + match family { + SyntheticScenarioFamily::LowOverlap => { + let y = (agent as f64 - agent_count as f64 / 2.0) * 4.0; + vec![(-12.0, y), (12.0, y)] + } + SyntheticScenarioFamily::MediumOverlap => { + let lane = agent % 6; + let offset = lane as f64 - 2.5; + if agent % 2 == 0 { + let y = offset * 1.5; + vec![(-12.0, y), (-2.0, y), (2.0, y), (12.0, y)] + } else { + let x = offset * 1.5; + vec![(x, -12.0), (x, -2.0), (x, 2.0), (x, 12.0)] + } + } + SyntheticScenarioFamily::HighOverlap => { + let theta = 2.0 * PI * agent as f64 / agent_count as f64; + let (sin_t, cos_t) = theta.sin_cos(); + vec![ + (10.0 * cos_t, 10.0 * sin_t), + (2.0 * cos_t, 2.0 * sin_t), + (0.0, 0.0), + (-2.0 * cos_t, -2.0 * sin_t), + (-10.0 * cos_t, -10.0 * sin_t), + ] + } + } +} + +fn start_time_for_family(agent: usize, family: SyntheticScenarioFamily) -> f64 { + match family { + SyntheticScenarioFamily::LowOverlap => agent as f64 * 0.05, + SyntheticScenarioFamily::MediumOverlap => (agent % 4) as f64 * 0.15, + SyntheticScenarioFamily::HighOverlap => (agent % 8) as f64 * 0.05, + } +} + +fn sample_polyline( + control_points: &[(f64, f64)], + waypoint_count: usize, + start_time: f64, + finish_time: f64, +) -> Vec { + let mut lengths = Vec::new(); + let mut cumulative = Vec::new(); + let mut total_length = 0.0; + + for window in control_points.windows(2) { + let dx = window[1].0 - window[0].0; + let dy = window[1].1 - window[0].1; + let length = (dx * dx + dy * dy).sqrt(); + lengths.push(length); + cumulative.push(total_length); + total_length += length.max(1e-9); + } + + let mut waypoints = Vec::with_capacity(waypoint_count); + for index in 0..waypoint_count { + let progress = if waypoint_count == 1 { + 0.0 + } else { + index as f64 / (waypoint_count - 1) as f64 + }; + let target_length = total_length * progress; + + let mut segment = lengths.len() - 1; + for (candidate, start_length) in cumulative.iter().enumerate() { + let end_length = *start_length + lengths[candidate]; + if target_length <= end_length || candidate == lengths.len() - 1 { + segment = candidate; + break; + } + } + + let seg_start = control_points[segment]; + let seg_end = control_points[segment + 1]; + let seg_length = lengths[segment].max(1e-9); + let local_progress = (target_length - cumulative[segment]) / seg_length; + let x = seg_start.0 + (seg_end.0 - seg_start.0) * local_progress; + let y = seg_start.1 + (seg_end.1 - seg_start.1) * local_progress; + let yaw = (seg_end.1 - seg_start.1).atan2(seg_end.0 - seg_start.0); + let time = start_time + (finish_time - start_time) * progress; + waypoints.push(WaypointSE2::new_f64(time, x, y, yaw)); + } + + waypoints +} diff --git a/mapf/src/motion/environment.rs b/mapf/src/motion/environment.rs index 1f15c57..c4714f1 100644 --- a/mapf/src/motion/environment.rs +++ b/mapf/src/motion/environment.rs @@ -622,4 +622,12 @@ impl BoundingBox { max: self.max + Vector2::from_element(r), } } + + pub fn center(&self) -> Point { + Point::from((self.min + self.max) / 2.0) + } + + pub fn circumscribed_radius(&self) -> f64 { + (self.max - self.center().coords).norm() + } } diff --git a/mapf/src/negotiation/mod.rs b/mapf/src/negotiation/mod.rs index 49f7f78..4e69e92 100644 --- a/mapf/src/negotiation/mod.rs +++ b/mapf/src/negotiation/mod.rs @@ -38,6 +38,7 @@ use crate::{ premade::{SippSE2, StateSippSE2}, util::triangular_for, }; +use kdtree::{distance::squared_euclidean, KdTree}; use std::{ cmp::Reverse, collections::{BinaryHeap, HashMap, HashSet}, @@ -714,6 +715,18 @@ impl std::fmt::Debug for Conflict { } } +impl Conflict { + pub fn time(&self) -> TimePoint { + self.time + } + + pub fn agent_pair(&self) -> [usize; 2] { + let a = self.segments[0].agent; + let b = self.segments[1].agent; + [a.min(b), a.max(b)] + } +} + pub type SippDecisionRange = DecisionRange>; pub type DecisionRangePair = (SippDecisionRange, SippDecisionRange); @@ -723,41 +736,215 @@ pub struct Segment { range: SippDecisionRange, } +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum ConflictDetectionAlgorithm { + Baseline, + KdTree, +} + +#[derive(Debug, Default, Clone, Copy, PartialEq, Eq)] +pub struct ConflictDetectionStats { + pub trajectory_count: usize, + pub pair_enumerations: usize, + pub bbox_candidate_pairs: usize, + pub conflicts_found: usize, +} + +#[derive(Debug, Clone)] +pub struct ConflictDetectionReport { + pub conflicts: Vec, + pub stats: ConflictDetectionStats, +} + +#[derive(Clone, Copy)] +struct ProposalEnvelope { + agent: usize, + bbox: BoundingBox, + center: [f64; 2], + radius: f64, +} + +pub fn detect_conflicts_for_proposals( + proposals: &HashMap, + profiles: &[CircularProfile], + algorithm: ConflictDetectionAlgorithm, +) -> ConflictDetectionReport { + let ordered = collect_ordered_proposals(proposals, profiles); + match algorithm { + ConflictDetectionAlgorithm::Baseline => detect_conflicts_baseline(&ordered), + ConflictDetectionAlgorithm::KdTree => detect_conflicts_kdtree(&ordered), + } +} + fn reasses_conflicts( proposals: &HashMap, profiles: &Vec, ) -> Vec { + detect_conflicts_for_proposals(proposals, profiles, ConflictDetectionAlgorithm::Baseline) + .conflicts +} + +type OrderedProposalRef<'a> = (usize, &'a Proposal, &'a CircularProfile); + +fn collect_ordered_proposals<'a>( + proposals: &'a HashMap, + profiles: &'a [CircularProfile], +) -> Vec> { + let mut ordered: Vec<_> = proposals + .iter() + .map(|(agent, proposal)| (*agent, proposal, profiles.get(*agent).unwrap())) + .collect(); + ordered.sort_unstable_by_key(|(agent, _, _)| *agent); + ordered +} + +fn conflict_for_pair( + agent_a: usize, + mt_a: &MetaTrajectory>, + profile_a: &CircularProfile, + agent_b: usize, + mt_b: &MetaTrajectory>, + profile_b: &CircularProfile, +) -> Option { + let (range_a, range_b) = find_first_conflict(mt_a, profile_a, mt_b, profile_b)?; + Some(Conflict { + time: TimePoint::min( + mt_a.decision_start_time(&range_a), + mt_b.decision_start_time(&range_b), + ), + segments: [ + Segment { + agent: agent_a, + range: range_a, + }, + Segment { + agent: agent_b, + range: range_b, + }, + ], + }) +} + +fn detect_conflicts_baseline(ordered: &[OrderedProposalRef<'_>]) -> ConflictDetectionReport { + let mut stats = ConflictDetectionStats { + trajectory_count: ordered.len(), + ..Default::default() + }; let mut conflicts = Vec::new(); - triangular_for( - proposals.iter().map(|(i, p)| (i, &p.meta)), - |(i_a, mt_a), (i_b, mt_b)| { - let profile_a = profiles.get(**i_a).unwrap(); - let profile_b = profiles.get(*i_b).unwrap(); - let (range_a, range_b) = match find_first_conflict(mt_a, profile_a, mt_b, profile_b) { - Some(r) => r, - None => return, - }; - conflicts.push(Conflict { - time: TimePoint::min( - mt_a.decision_start_time(&range_a), - mt_b.decision_start_time(&range_b), - ), - segments: [ - Segment { - agent: **i_a, - range: range_a, - }, - Segment { - agent: *i_b, - range: range_b, - }, - ], - }); - }, - ); + for i in 0..ordered.len() { + let (agent_a, proposal_a, profile_a) = ordered[i]; + for (agent_b, proposal_b, profile_b) in ordered[i + 1..].iter().copied() { + stats.pair_enumerations += 1; + + let bbox_a = BoundingBox::for_trajectory(profile_a, &proposal_a.meta.trajectory); + let bbox_b = BoundingBox::for_trajectory(profile_b, &proposal_b.meta.trajectory); + if !bbox_a.overlaps(Some(bbox_b)) { + continue; + } + + stats.bbox_candidate_pairs += 1; + if let Some(conflict) = conflict_for_pair( + agent_a, + &proposal_a.meta, + profile_a, + agent_b, + &proposal_b.meta, + profile_b, + ) { + conflicts.push(conflict); + } + } + } + + conflicts.sort_by_key(|conflict| { + let [a, b] = conflict.agent_pair(); + (conflict.time().nanos_since_zero, a, b) + }); + stats.conflicts_found = conflicts.len(); + ConflictDetectionReport { conflicts, stats } +} + +fn detect_conflicts_kdtree(ordered: &[OrderedProposalRef<'_>]) -> ConflictDetectionReport { + let mut stats = ConflictDetectionStats { + trajectory_count: ordered.len(), + ..Default::default() + }; + let mut conflicts = Vec::new(); + + if ordered.len() < 2 { + return ConflictDetectionReport { conflicts, stats }; + } + + let envelopes: Vec<_> = ordered + .iter() + .map(|(agent, proposal, profile)| { + let bbox = BoundingBox::for_trajectory(profile, &proposal.meta.trajectory); + let center = bbox.center(); + ProposalEnvelope { + agent: *agent, + bbox, + center: [center.x, center.y], + radius: bbox.circumscribed_radius(), + } + }) + .collect(); + + let max_radius = envelopes.iter().fold(0.0_f64, |r, env| r.max(env.radius)); + let mut tree = KdTree::new(2); + for (index, envelope) in envelopes.iter().enumerate() { + tree.add(envelope.center, index).unwrap(); + } + + let mut visited_pairs = HashSet::new(); + for (index, envelope_a) in envelopes.iter().enumerate() { + let query_radius = envelope_a.radius + max_radius; + let neighbors = tree + .within(&envelope_a.center, query_radius.powi(2), &squared_euclidean) + .unwrap(); + + for neighbor in neighbors { + let other_index = *neighbor.1; + if other_index <= index { + continue; + } + + let envelope_b = envelopes[other_index]; + let pair = ( + envelope_a.agent.min(envelope_b.agent), + envelope_a.agent.max(envelope_b.agent), + ); + if !visited_pairs.insert(pair) { + continue; + } + + stats.pair_enumerations += 1; + if !envelope_a.bbox.overlaps(Some(envelope_b.bbox)) { + continue; + } + + stats.bbox_candidate_pairs += 1; + let (agent_a, proposal_a, profile_a) = ordered[index]; + let (agent_b, proposal_b, profile_b) = ordered[other_index]; + if let Some(conflict) = conflict_for_pair( + agent_a, + &proposal_a.meta, + profile_a, + agent_b, + &proposal_b.meta, + profile_b, + ) { + conflicts.push(conflict); + } + } + } - conflicts + conflicts.sort_by_key(|conflict| { + let [a, b] = conflict.agent_pair(); + (conflict.time().nanos_since_zero, a, b) + }); + stats.conflicts_found = conflicts.len(); + ConflictDetectionReport { conflicts, stats } } fn organize_negotiations( @@ -1204,3 +1391,233 @@ fn find_spillover_conflict( None } + +#[cfg(test)] +mod tests { + use super::*; + use crate::{ + algorithm::path::DecisionPoint, + motion::{r2::Positioned, se2::WaypointSE2, Trajectory}, + }; + use std::f64::consts::PI; + + #[derive(Clone, Copy)] + enum SyntheticScenarioFamily { + LowOverlap, + MediumOverlap, + HighOverlap, + } + + fn build_synthetic_proposals( + agent_count: usize, + waypoint_count: usize, + family: SyntheticScenarioFamily, + ) -> (HashMap, Vec) { + assert!(agent_count > 1); + assert!(waypoint_count >= 2); + + let mut proposals = HashMap::new(); + let profiles = vec![CircularProfile::new(0.35, 0.0, 0.0).unwrap(); agent_count]; + + for agent in 0..agent_count { + let control_points = control_points_for_family(agent, agent_count, family); + let start_time = start_time_for_family(agent, family); + let waypoints = sample_polyline( + &control_points, + waypoint_count, + start_time, + start_time + 20.0, + ); + let trajectory = Trajectory::from_iter(waypoints.iter().copied()).unwrap(); + + let states: Vec<_> = waypoints + .iter() + .copied() + .map(|waypoint| { + StateSippSE2::new(Cell::from_point(waypoint.point(), 1.0), waypoint) + }) + .collect(); + + let meta = MetaTrajectory { + trajectory, + decision_points: states + .iter() + .cloned() + .enumerate() + .map(|(index, state)| DecisionPoint { index, state }) + .collect(), + initial_state: states.first().cloned().unwrap(), + final_state: states.last().cloned().unwrap(), + }; + + proposals.insert( + agent, + Proposal { + meta, + cost: Cost(agent as f64), + }, + ); + } + + (proposals, profiles) + } + + fn control_points_for_family( + agent: usize, + agent_count: usize, + family: SyntheticScenarioFamily, + ) -> Vec<(f64, f64)> { + match family { + SyntheticScenarioFamily::LowOverlap => { + let y = (agent as f64 - agent_count as f64 / 2.0) * 4.0; + vec![(-12.0, y), (12.0, y)] + } + SyntheticScenarioFamily::MediumOverlap => { + let lane = agent % 6; + let offset = lane as f64 - 2.5; + if agent % 2 == 0 { + let y = offset * 1.5; + vec![(-12.0, y), (-2.0, y), (2.0, y), (12.0, y)] + } else { + let x = offset * 1.5; + vec![(x, -12.0), (x, -2.0), (x, 2.0), (x, 12.0)] + } + } + SyntheticScenarioFamily::HighOverlap => { + let theta = 2.0 * PI * agent as f64 / agent_count as f64; + let (sin_t, cos_t) = theta.sin_cos(); + vec![ + (10.0 * cos_t, 10.0 * sin_t), + (2.0 * cos_t, 2.0 * sin_t), + (0.0, 0.0), + (-2.0 * cos_t, -2.0 * sin_t), + (-10.0 * cos_t, -10.0 * sin_t), + ] + } + } + } + + fn start_time_for_family(agent: usize, family: SyntheticScenarioFamily) -> f64 { + match family { + SyntheticScenarioFamily::LowOverlap => agent as f64 * 0.05, + SyntheticScenarioFamily::MediumOverlap => (agent % 4) as f64 * 0.15, + SyntheticScenarioFamily::HighOverlap => (agent % 8) as f64 * 0.05, + } + } + + fn sample_polyline( + control_points: &[(f64, f64)], + waypoint_count: usize, + start_time: f64, + finish_time: f64, + ) -> Vec { + let mut lengths = Vec::new(); + let mut cumulative = Vec::new(); + let mut total_length = 0.0; + + for window in control_points.windows(2) { + let dx = window[1].0 - window[0].0; + let dy = window[1].1 - window[0].1; + let length = (dx * dx + dy * dy).sqrt(); + lengths.push(length); + cumulative.push(total_length); + total_length += length.max(1e-9); + } + + let mut waypoints = Vec::with_capacity(waypoint_count); + for index in 0..waypoint_count { + let progress = if waypoint_count == 1 { + 0.0 + } else { + index as f64 / (waypoint_count - 1) as f64 + }; + let target_length = total_length * progress; + + let mut segment = lengths.len() - 1; + for (candidate, start_length) in cumulative.iter().enumerate() { + let end_length = *start_length + lengths[candidate]; + if target_length <= end_length || candidate == lengths.len() - 1 { + segment = candidate; + break; + } + } + + let seg_start = control_points[segment]; + let seg_end = control_points[segment + 1]; + let seg_length = lengths[segment].max(1e-9); + let local_progress = (target_length - cumulative[segment]) / seg_length; + let x = seg_start.0 + (seg_end.0 - seg_start.0) * local_progress; + let y = seg_start.1 + (seg_end.1 - seg_start.1) * local_progress; + let yaw = (seg_end.1 - seg_start.1).atan2(seg_end.0 - seg_start.0); + let time = start_time + (finish_time - start_time) * progress; + waypoints.push(WaypointSE2::new_f64(time, x, y, yaw)); + } + + waypoints + } + + fn conflict_signature(report: &ConflictDetectionReport) -> Vec { + report + .conflicts + .iter() + .map(|conflict| format!("{conflict:?}")) + .collect() + } + + #[test] + fn kdtree_matches_baseline_for_low_overlap() { + let (proposals, profiles) = + build_synthetic_proposals(24, 16, SyntheticScenarioFamily::LowOverlap); + let baseline = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::Baseline, + ); + let kdtree = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::KdTree, + ); + + assert_eq!(conflict_signature(&baseline), conflict_signature(&kdtree)); + assert!(kdtree.stats.pair_enumerations < baseline.stats.pair_enumerations); + } + + #[test] + fn kdtree_matches_baseline_for_medium_overlap() { + let (proposals, profiles) = + build_synthetic_proposals(32, 24, SyntheticScenarioFamily::MediumOverlap); + let baseline = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::Baseline, + ); + let kdtree = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::KdTree, + ); + + assert_eq!(conflict_signature(&baseline), conflict_signature(&kdtree)); + assert!(kdtree.stats.bbox_candidate_pairs <= baseline.stats.bbox_candidate_pairs); + } + + #[test] + fn kdtree_matches_baseline_for_high_overlap() { + let (proposals, profiles) = + build_synthetic_proposals(20, 24, SyntheticScenarioFamily::HighOverlap); + let baseline = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::Baseline, + ); + let kdtree = detect_conflicts_for_proposals( + &proposals, + &profiles, + ConflictDetectionAlgorithm::KdTree, + ); + + assert_eq!(conflict_signature(&baseline), conflict_signature(&kdtree)); + assert_eq!(baseline.stats.conflicts_found, kdtree.stats.conflicts_found); + } +} diff --git a/notes/conflict_detection_benchmark.md b/notes/conflict_detection_benchmark.md new file mode 100644 index 0000000..3924f70 --- /dev/null +++ b/notes/conflict_detection_benchmark.md @@ -0,0 +1,45 @@ +# Conflict Detection First Slice + +This first slice keeps the production negotiation path on the baseline algorithm and adds an experimental KD-tree broad phase for comparison. + +## What changed + +- `mapf::negotiation::detect_conflicts_for_proposals(...)` can now run: + - `ConflictDetectionAlgorithm::Baseline` + - `ConflictDetectionAlgorithm::KdTree` +- The KD-tree path prunes proposal pairs using trajectory bounding boxes before calling the exact `find_first_conflict(...)` logic. +- Existing negotiation behavior still uses the baseline detector. + +## Why this slice is useful + +- It isolates the performance-sensitive conflict-detection pass without changing planner semantics. +- It gives a reproducible benchmark path for a mentor conversation. +- It proves whether a KD-tree-style broad phase is promising before attempting deeper integration. + +## Commands + +Build and test: + +```bash +CARGO_HOME=$PWD/.cargo_home cargo test -p mapf +``` + +Print a baseline vs KD-tree report: + +```bash +CARGO_HOME=$PWD/.cargo_home cargo run --release -p mapf --example conflict_detection_report +``` + +Run criterion benchmarks: + +```bash +CARGO_HOME=$PWD/.cargo_home cargo bench -p mapf --bench conflict_detection +``` + +## Interpretation + +- `pair_enumerations`: how many proposal pairs the algorithm examined. +- `bbox_candidate_pairs`: how many pairs survived broad-phase pruning and needed exact checking. +- `conflicts`: how many exact conflicts were found. + +The KD-tree path is expected to reduce `pair_enumerations` and usually reduce `bbox_candidate_pairs`, while returning the same conflict set as the baseline.