diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/CMakeLists.txt b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/CMakeLists.txt index c5a1d2d0..00c1ada9 100644 --- a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/CMakeLists.txt +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/CMakeLists.txt @@ -1,23 +1,16 @@ cmake_minimum_required(VERSION 3.11.0 FATAL_ERROR) - # Suppress developer warnings for the entire workspace. - set(CMAKE_SUPPRESS_DEVELOPER_WARNINGS TRUE CACHE BOOL "Suppress developer warnings" FORCE) # Set CMP0144 to NEW to ensure find_package # uses upper-case _ROOT variables. - cmake_policy(SET CMP0144 NEW) - project(multibeam_sonar) - set(CUDA_ARCHITECTURE "60" CACHE STRING "Target CUDA SM version") - find_package(ament_cmake REQUIRED) find_package(CUDAToolkit QUIET) if(CUDAToolkit_FOUND) - enable_language(CUDA) find_package(CUDA REQUIRED) message(STATUS "CUDA found, enabling CUDA support.") @@ -47,17 +40,51 @@ if(CUDAToolkit_FOUND) set(GZ_RENDERING_TARGET gz-rendering${GZ_RENDERING_VER}-ogre) add_definitions(-DWITH_OGRE) endif() - if(TARGET gz-rendering${GZ_RENDERING_VER}::ogre2) set(HAVE_OGRE2 TRUE) set(GZ_RENDERING_TARGET gz-rendering${GZ_RENDERING_VER}-ogre2) add_definitions(-DWITH_OGRE2) endif() + # Detected at configure time by presence of cargo. + # If cargo is absent the wgpu sources are simply excluded — CUDA build is + # completely unaffected. + + find_program(CARGO_EXECUTABLE cargo) + + if(CARGO_EXECUTABLE) + message(STATUS "[multibeam_sonar] WGPU backend ENABLED") + + set(WGPU_DIR ${CMAKE_CURRENT_SOURCE_DIR}/wgpu_backend) + set(WGPU_LIB ${WGPU_DIR}/target/release/libsonar_wgpu_backend.a) + + add_custom_command( + OUTPUT ${WGPU_LIB} + COMMAND cargo build --release + WORKING_DIRECTORY ${WGPU_DIR} + COMMENT "Building Rust WGPU backend" + ) + + add_custom_target(wgpu_backend_build ALL DEPENDS ${WGPU_LIB}) + + add_library(wgpu_backend STATIC IMPORTED) + set_target_properties(wgpu_backend PROPERTIES + IMPORTED_LOCATION ${WGPU_LIB} + ) + + add_dependencies(wgpu_backend wgpu_backend_build) + + set(HAVE_WGPU ON) + else() + message(STATUS "[multibeam_sonar] WGPU backend DISABLED (cargo not found)") + set(HAVE_WGPU OFF) + endif() + add_library(${PROJECT_NAME} SHARED MultibeamSonarSensor.cc sonar_calculation_cuda.cu + sonar_compute_wgpu.cc ) set_target_properties(${PROJECT_NAME} @@ -92,6 +119,21 @@ if(CUDAToolkit_FOUND) ${CUBLAS_LIB} ) + # Link wgpu static lib if built + if(HAVE_WGPU) + target_compile_definitions(${PROJECT_NAME} PRIVATE HAVE_WGPU_BACKEND) + + target_include_directories(${PROJECT_NAME} + PRIVATE ${CMAKE_CURRENT_SOURCE_DIR}/wgpu_backend/include + ) + + target_link_libraries(${PROJECT_NAME} + wgpu_backend + pthread + dl + ) + endif() + install(TARGETS ${PROJECT_NAME} DESTINATION lib/${PROJECT_NAME} ) @@ -120,4 +162,4 @@ else() "Skipping CUDA-specific targets") endif() -ament_package() +ament_package() \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/sonar_compute_wgpu.cc b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/sonar_compute_wgpu.cc new file mode 100644 index 00000000..3f1e1975 --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/sonar_compute_wgpu.cc @@ -0,0 +1,349 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * Licensed under the Apache License, Version 2.0 + */ +#include "sonar_compute_wgpu.hh" + +#include +#include +#include +#include +#include +#include +#include + +#include + +// Use a literal instead of M_PI — more portable across compilers/platforms +static constexpr float kPi = 3.14159265358979323846f; + +#ifdef HAVE_WGPU_BACKEND +extern "C" +{ + void * sonar_wgpu_create( + uint32_t n_beams, uint32_t n_rays, uint32_t n_freq, float sound_speed, float bandwidth, + float max_range, float attenuation, float source_level, float sensor_gain, float h_fov, + float v_fov, uint64_t seed); + + float * sonar_wgpu_compute( + void * engine, const float * depth, const float * normals, const float * refl, + const float * beam_corr, uint32_t n_beams, uint32_t n_rays, uint32_t n_freq, uint64_t frame, + float beam_corr_sum); + + void sonar_wgpu_free(float * ptr, size_t len); + void sonar_wgpu_destroy(void * engine); +} +#endif // HAVE_WGPU_BACKEND + +namespace gz +{ +namespace sensors +{ + +// CPU fallback — always available, implements Eq.14 + Eq.8 without speckle +class CpuComputeBackend : public SonarComputeBackend +{ +public: + const char * Name() const override { return "cpu"; } + bool Initialize(const SonarComputeInput &) override { return true; } + + bool Compute(const SonarComputeInput & input, SonarComputeOutput & output) override + { + if (!input.depthImage || !input.normalImage || !input.reflectivityImage) + { + return false; + } + + const int nb = input.nBeams; + const int nr = input.nRays; + const int nf = input.nFreq; + + output.nBeams = nb; + output.nFreq = nf; + output.beamSpectrum.assign(static_cast(nb) * nf, std::complex(0.f, 0.f)); + + const float c = static_cast(input.soundSpeed); + const float B = static_cast(input.bandwidth); + const float df = B / static_cast(nf); + + auto t0 = std::chrono::high_resolution_clock::now(); + + for (int b = 0; b < nb; ++b) + { + for (int ray = 0; ray < nr; ++ray) + { + if (ray >= input.depthImage->rows || b >= input.depthImage->cols) + { + continue; + } + + const float r = input.depthImage->at(ray, b); + if (r <= 0.f || r > static_cast(input.maxDistance)) + { + continue; + } + + const float mu = input.reflectivityImage->at(ray, b); + const cv::Vec3f n = input.normalImage->at(ray, b); + const float cos_inc = std::abs(n[2]); + + const float dth = input.hFOV / std::max(nb - 1, 1); + const float dphi = input.vFOV / std::max(nr - 1, 1); + const float dA = r * r * dth * dphi; + const float TL = std::exp(-static_cast(input.attenuation) * r) / r; + + // Eq.14 amplitude (deterministic, no speckle in CPU path) + const float A = std::sqrt(mu) * cos_inc * std::sqrt(dA) * TL; + + for (int f = 0; f < nf; ++f) + { + // DC-centred frequency grid (matches backscatter.wgsl convention) + float freq; + if (nf % 2 == 0) + { + freq = df * (-static_cast(nf) + 2.f * (static_cast(f) + 1.f)) / 2.f; + } + else + { + freq = + df * (-(static_cast(nf) - 1.f) + 2.f * (static_cast(f) + 1.f)) / 2.f; + } + + // Eq.8 two-way phase + const float k = 2.f * kPi * freq / c; + const float phi = 2.f * r * k; + + output.At(b, f) += std::complex(A * std::cos(phi), A * std::sin(phi)); + } + } + } + + auto t1 = std::chrono::high_resolution_clock::now(); + output.computeMicros = + static_cast(std::chrono::duration_cast(t1 - t0).count()); + + return true; + } +}; + +// Backend factory +std::unique_ptr CreateComputeBackend(const std::string & name) +{ + if (name == "wgpu") + { + return std::make_unique(); + } + // "cuda" is handled by existing sonar_calculation_cuda path + return std::make_unique(); +} + +// WgpuComputeBackend — Name +const char * WgpuComputeBackend::Name() const { return "wgpu"; } + +// WgpuComputeBackend — destructor: clean up persistent engine +WgpuComputeBackend::~WgpuComputeBackend() +{ +#ifdef HAVE_WGPU_BACKEND + if (engine_) + { + sonar_wgpu_destroy(engine_); + engine_ = nullptr; + } +#endif +} + +// WgpuComputeBackend — Initialize +// Probe GPU with a minimal 1-ray call. Non-fatal: failures route to CPU. +bool WgpuComputeBackend::Initialize(const SonarComputeInput &) +{ +#ifndef HAVE_WGPU_BACKEND + std::cerr << "[sonar_wgpu] backend not compiled in " + "(cargo absent at build time) — CPU fallback active.\n"; + gpuAvailable_ = false; + initialized_ = true; + return true; +#else + void * probe = + sonar_wgpu_create(1u, 1u, 4u, 1500.f, 2950.f, 10.f, 0.f, 220.f, 1.f, 1.5708f, 0.3491f, 1u); + + if (probe) + { + sonar_wgpu_destroy(probe); + gpuAvailable_ = true; + std::cout << "[sonar_wgpu] GPU backend ready (wgpu/Vulkan).\n"; + } + else + { + gpuAvailable_ = false; + std::cerr << "[sonar_wgpu] GPU init failed — CPU fallback active.\n"; + } + + initialized_ = true; + return true; +#endif +} + +// WgpuComputeBackend — Compute +bool WgpuComputeBackend::Compute(const SonarComputeInput & input, SonarComputeOutput & output) +{ + if (!initialized_) + { + return false; + } + + // CPU fallback path + if (!gpuAvailable_) + { + if (!cpuFallback_) + { + cpuFallback_ = std::make_unique(); + cpuFallback_->Initialize(input); + } + return cpuFallback_->Compute(input, output); + } + +#ifndef HAVE_WGPU_BACKEND + return false; +#else + + // Validate inputs + if (!input.depthImage || !input.normalImage || !input.reflectivityImage) + { + return false; + } + + const int nb = input.nBeams; + const int nr = input.nRays; + const int nf = input.nFreq; + if (nb <= 0 || nr <= 0 || nf <= 0) + { + return false; + } + + auto t0 = std::chrono::high_resolution_clock::now(); + + // Flatten OpenCV layout [row=ray, col=beam] → [beam * nr + ray] + std::vector depthFlat(static_cast(nb) * nr, 0.f); + std::vector reflFlat(static_cast(nb) * nr, 0.f); + std::vector normalFlat(static_cast(nb) * nr * 3, 0.f); + + for (int b = 0; b < nb; ++b) + { + for (int r = 0; r < nr; ++r) + { + if (r >= input.depthImage->rows || b >= input.depthImage->cols) + { + continue; + } + const size_t idx = static_cast(b) * nr + r; + depthFlat[idx] = input.depthImage->at(r, b); + reflFlat[idx] = input.reflectivityImage->at(r, b); + const cv::Vec3f n = input.normalImage->at(r, b); + normalFlat[idx * 3 + 0] = n[0]; + normalFlat[idx * 3 + 1] = n[1]; + normalFlat[idx * 3 + 2] = n[2]; + } + } + + // Flatten beam corrector (identity if not provided) + std::vector beamCorrFlat(static_cast(nb) * nb, 0.f); + if (input.beamCorrector) + { + for (int r = 0; r < nb; ++r) + { + for (int c = 0; c < nb; ++c) + { + beamCorrFlat[r * nb + c] = input.beamCorrector[r][c]; + } + } + } + else + { + // Identity fallback: no cross-beam mixing + for (int i = 0; i < nb; ++i) + { + beamCorrFlat[i * nb + i] = 1.f; + } + } + + // Persistent engine: recreate only when dimensions change + // engine_ is a class member (not static local) — safe for multiple sensors + if (!engine_ || engNb_ != nb || engNr_ != nr || engNf_ != nf) + { + if (engine_) + { + sonar_wgpu_destroy(engine_); + engine_ = nullptr; + } + + engine_ = sonar_wgpu_create( + static_cast(nb), static_cast(nr), static_cast(nf), + static_cast(input.soundSpeed), static_cast(input.bandwidth), + static_cast(input.maxDistance), static_cast(input.attenuation), + static_cast(input.sourceLevel), input.sensorGain, input.hFOV, input.vFOV, input.seed); + + engNb_ = nb; + engNr_ = nr; + engNf_ = nf; + + if (!engine_) + { + std::cerr << "[sonar_wgpu] sonar_wgpu_create failed " + "— switching to CPU fallback.\n"; + gpuAvailable_ = false; + if (!cpuFallback_) + { + cpuFallback_ = std::make_unique(); + cpuFallback_->Initialize(input); + } + return cpuFallback_->Compute(input, output); + } + } + + // GPU dispatch + float * result = sonar_wgpu_compute( + engine_, depthFlat.data(), normalFlat.data(), reflFlat.data(), beamCorrFlat.data(), + static_cast(nb), static_cast(nr), static_cast(nf), + input.frameIndex, + input.beamCorrectorSum > 0.f ? input.beamCorrectorSum : static_cast(nb)); + + if (!result) + { + // Log failure explicitly — silent fallback makes debugging painful + std::cerr << "[sonar_wgpu] sonar_wgpu_compute returned null " + "— CPU fallback for this frame.\n"; + if (!cpuFallback_) + { + cpuFallback_ = std::make_unique(); + cpuFallback_->Initialize(input); + } + return cpuFallback_->Compute(input, output); + } + + // Unpack interleaved [re, im] result + output.nBeams = nb; + output.nFreq = nf; + output.beamSpectrum.resize(static_cast(nb) * nf); + + for (int b = 0; b < nb; ++b) + { + for (int f = 0; f < nf; ++f) + { + const size_t i = (static_cast(b) * nf + f) * 2; + // explicit constructor — avoids initializer-list ambiguity with templates + output.At(b, f) = std::complex(result[i], result[i + 1]); + } + } + + sonar_wgpu_free(result, static_cast(nb) * nf * 2); + + auto t1 = std::chrono::high_resolution_clock::now(); + output.computeMicros = + static_cast(std::chrono::duration_cast(t1 - t0).count()); + + return true; +#endif // HAVE_WGPU_BACKEND +} + +} // namespace sensors +} // namespace gz \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/sonar_compute_wgpu.hh b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/sonar_compute_wgpu.hh new file mode 100644 index 00000000..d56551e0 --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/sonar_compute_wgpu.hh @@ -0,0 +1,106 @@ +/* + * Copyright (C) 2025 Open Source Robotics Foundation + * Licensed under the Apache License, Version 2.0 + */ +#ifndef GZ_SENSORS_SONAR_COMPUTE_WGPU_HH_ +#define GZ_SENSORS_SONAR_COMPUTE_WGPU_HH_ + +#include +#include +#include +#include +#include + +#include + +namespace gz +{ +namespace sensors +{ + +struct SonarComputeInput +{ + const cv::Mat * depthImage = nullptr; + const cv::Mat * normalImage = nullptr; + const cv::Mat * reflectivityImage = nullptr; + + float ** beamCorrector = nullptr; + float beamCorrectorSum = 1.0f; + float * window = nullptr; + + int nBeams = 0; + int nRays = 0; + int nFreq = 0; + int raySkips = 0; + + double soundSpeed = 1500.0; + double maxDistance = 60.0; + double sourceLevel = 220.0; + double attenuation = 0.0; + double bandwidth = 2950.0; + + float sensorGain = 1.0f; + float hFOV = 1.5708f; + float vFOV = 0.3491f; + + uint64_t frameIndex = 0; + uint64_t seed = 42; +}; + +struct SonarComputeOutput +{ + int nBeams = 0; + int nFreq = 0; + std::vector> beamSpectrum; + double computeMicros = 0.0; + + inline std::complex & At(int b, int f) + { + return beamSpectrum[static_cast(b) * nFreq + f]; + } +}; + +class SonarComputeBackend +{ +public: + virtual ~SonarComputeBackend() = default; + virtual const char * Name() const = 0; + virtual bool Initialize(const SonarComputeInput & input) = 0; + virtual bool Compute( + const SonarComputeInput & input, + SonarComputeOutput & output) = 0; +}; + +std::unique_ptr +CreateComputeBackend(const std::string & name); + +class WgpuComputeBackend : public SonarComputeBackend +{ +public: + WgpuComputeBackend() = default; + ~WgpuComputeBackend(); + + const char * Name() const override; + bool Initialize(const SonarComputeInput & input) override; + bool Compute( + const SonarComputeInput & input, + SonarComputeOutput & output) override; + +private: + bool initialized_ = false; + bool gpuAvailable_ = false; + + // Persistent GPU engine — class member, not static local. + // Safe for multiple simultaneous sonar sensor instances. + void * engine_ = nullptr; + int engNb_ = 0; + int engNr_ = 0; + int engNf_ = 0; + + std::unique_ptr cpuFallback_; +}; + +} // namespace sensors +} // namespace gz + +#endif // GZ_SENSORS_SONAR_COMPUTE_WGPU_HH_ \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/Cargo.toml b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/Cargo.toml new file mode 100644 index 00000000..d8f6c1e8 --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/Cargo.toml @@ -0,0 +1,16 @@ +[package] +name = "sonar_wgpu_backend" +version = "0.1.0" +edition = "2021" + +[lib] +crate-type = ["staticlib"] + +[dependencies] +wgpu = "0.20" +bytemuck = { version = "1.14", features = ["derive"] } +pollster = "0.3" +rustfft = "6.2" + +[build-dependencies] +cbindgen = "0.26" \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/build.rs b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/build.rs new file mode 100644 index 00000000..8bfbc606 --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/build.rs @@ -0,0 +1,9 @@ +fn main() { + let crate_dir = std::env::var("CARGO_MANIFEST_DIR").unwrap(); + cbindgen::Builder::new() + .with_crate(&crate_dir) + .with_config(cbindgen::Config::from_file("cbindgen.toml").unwrap()) + .generate() + .expect("cbindgen failed") + .write_to_file("include/sonar_wgpu.h"); +} \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/cbindgen.toml b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/cbindgen.toml new file mode 100644 index 00000000..763c4d2f --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/cbindgen.toml @@ -0,0 +1,3 @@ +language = "C" +include_guard = "SONAR_WGPU_H" +autogen_warning = "/* Auto-generated by cbindgen. Do not edit. */" \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/include/sonar_wgpu.h b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/include/sonar_wgpu.h new file mode 100644 index 00000000..f86bdf0b --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/include/sonar_wgpu.h @@ -0,0 +1,51 @@ +#ifndef SONAR_WGPU_H +#define SONAR_WGPU_H + +/* Generated via cbindgen — interface between C++ plugin and Rust backend */ + +#include +#include + +#ifdef __cplusplus +extern "C" +{ +#endif + + /* Opaque handle to Rust-side engine */ + typedef struct SonarPhysicsEngine SonarPhysicsEngine; + + /* Create GPU sonar engine. + Returns NULL if GPU init fails (caller should fallback to CPU). */ + SonarPhysicsEngine * sonar_wgpu_create( + uint32_t n_beams, uint32_t n_rays, uint32_t n_freq, float sound_speed, float bandwidth, + float max_range, float attenuation, float source_level, float sensor_gain, float h_fov, + float v_fov, uint64_t seed); + + /* Run one frame. + Output layout: [re0, im0, re1, im1, ...] (beam-major). + Size = n_beams * n_freq * 2. + Caller owns memory → free with sonar_wgpu_free(). */ + float * sonar_wgpu_compute( + SonarPhysicsEngine * engine, const float * depth, const float * normals, const float * refl, + const float * beam_corr, uint32_t n_beams, uint32_t n_rays, uint32_t n_freq, uint64_t frame, + float beam_corr_sum); + + /* Free buffer returned by compute */ + void sonar_wgpu_free(float * ptr, uintptr_t len); + + /* Destroy engine + release GPU resources */ + void sonar_wgpu_destroy(void * engine); + + /* --- CPU / legacy path (kept for compatibility) --- */ + + SonarPhysicsEngine * sonar_physics_create( + uint32_t n_beams, uint32_t n_rays, uint32_t n_freq, float sound_speed, float bandwidth, + float max_range, float attenuation, float h_fov, float v_fov); + + void sonar_physics_destroy(SonarPhysicsEngine * engine); + +#ifdef __cplusplus +} +#endif + +#endif /* SONAR_WGPU_H */ \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/backscatter.wgsl b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/backscatter.wgsl new file mode 100644 index 00000000..54bae29a --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/backscatter.wgsl @@ -0,0 +1,194 @@ +// backscatter.wgsl +// Based on Choi et al. (2021) +// Uses Eq.14 for per-ray amplitude and Eq.8 for frequency accumulation + +struct Params { + n_beams: u32, + n_rays: u32, + n_freq: u32, + _pad0: u32, + + sound_speed: f32, + bandwidth: f32, + max_range: f32, + attenuation: f32, + + h_fov: f32, + v_fov: f32, + mu_default: f32, + _pad1: f32, + + seed: u32, + frame: u32, + _pad2: u32, + _pad3: u32, +}; + +@group(0) @binding(0) var params: Params; +@group(0) @binding(1) var depth: array; // per-ray distance +@group(0) @binding(2) var normals: array; // packed xyz normals +@group(0) @binding(3) var refl: array; // reflectivity per ray +@group(0) @binding(4) var out_re: array>; +@group(0) @binding(5) var out_im: array>; + +const PI: f32 = 3.141592653589793; +const SCALE: f32 = 1048576.0; // fixed-point scale for atomics + +// RNG (Philox-style counter RNG, deterministic per (frame, beam, ray)) --- +fn mulhilo32(a: u32, b: u32) -> vec2 { + // manual 32-bit multiply split (WGSL doesn't expose 64-bit directly) + let a_lo = a & 0xFFFFu; + let a_hi = a >> 16u; + let b_lo = b & 0xFFFFu; + let b_hi = b >> 16u; + + let p0 = a_lo * b_lo; + let p1 = a_hi * b_lo; + let p2 = a_lo * b_hi; + let p3 = a_hi * b_hi; + + let mid = (p0 >> 16u) + (p1 & 0xFFFFu) + (p2 & 0xFFFFu); + let hi = p3 + (p1 >> 16u) + (p2 >> 16u) + (mid >> 16u); + + return vec2(hi, a * b); +} + +fn philox_round(c: vec4, k: vec2) -> vec4 { + // one Feistel round + let r0 = mulhilo32(0xD2511F53u, c.x); + let r1 = mulhilo32(0xCD9E8D57u, c.z); + return vec4(r1.x ^ c.y ^ k.x, r1.y, r0.x ^ c.w ^ k.y, r0.y); +} + +fn philox4x32_10(ctr: vec4, key: vec2) -> vec4 { + // 10 rounds → good statistical quality, still fast on GPU + var c = ctr; + var k = key; + + for (var i = 0; i < 9; i++) { + c = philox_round(c, k); + k.x += 0x9E3779B9u; // Weyl sequence increment + k.y += 0xBB67AE85u; + } + c = philox_round(c, k); + return c; +} + +fn u32_to_unit(x: u32) -> f32 { + // map to (0,1] to avoid log(0) + return f32(x) * (1.0 / 4294967296.0) + (0.5 / 4294967296.0); +} + +fn philox_normal4(seed: u32, frame: u32, subseq: u32) -> vec4 { + // Box-Muller: convert uniform → Gaussian (needed for speckle) + let raw = philox4x32_10( + vec4(frame, 0u, subseq, 0u), + vec2(seed, 0u) + ); + + let u1 = u32_to_unit(raw.x); + let u2 = u32_to_unit(raw.y); + let u3 = u32_to_unit(raw.z); + let u4 = u32_to_unit(raw.w); + + let r1 = sqrt(-2.0 * log(u1)); + let r2 = sqrt(-2.0 * log(u3)); + + return vec4( + r1 * sin(2.0 * PI * u2), + r1 * cos(2.0 * PI * u2), + r2 * sin(2.0 * PI * u4), + r2 * cos(2.0 * PI * u4) + ); +} + +@compute @workgroup_size(8, 8, 1) +fn main(@builtin(global_invocation_id) gid: vec3) { + + let beam = gid.x; + let ray = gid.y; + + // guard against over-dispatch + if (beam >= params.n_beams || ray >= params.n_rays) { + return; + } + + let idx = beam * params.n_rays + ray; + let r = depth[idx]; + + // skip invalid or out-of-range rays early (cheap reject) + if (r <= 0.001 || r > params.max_range) { + return; + } + + // fetch surface info + let nx = normals[idx * 3u + 0u]; + let ny = normals[idx * 3u + 1u]; + let nz = normals[idx * 3u + 2u]; + let mu = max(refl[idx], 0.0); // clamp just in case input is noisy + + // convert pixel position → direction (sensor frame) + let beam_ang = params.h_fov * (f32(beam) / max(f32(params.n_beams) - 1.0, 1.0) - 0.5); + let ray_ang = params.v_fov * (f32(ray) / max(f32(params.n_rays) - 1.0, 1.0) - 0.5); + + let rd_x = sin(beam_ang) * cos(ray_ang); + let rd_y = sin(ray_ang); + let rd_z = cos(beam_ang) * cos(ray_ang); + + // incidence term |dot(ray, normal)| → energy drop at grazing angles + let cos_inc = abs(rd_x * nx + rd_y * ny + rd_z * nz); + + // differential area (solid angle scaled by range²) + let d_theta_h = params.h_fov / max(f32(params.n_beams) - 1.0, 1.0); + let d_theta_v = params.v_fov / max(f32(params.n_rays) - 1.0, 1.0); + let dA = r * r * d_theta_h * d_theta_v; + + // attenuation + spreading (simple model) + let TL = exp(-params.attenuation * r) / r; + + // Eq.14 deterministic amplitude term + let A_det = sqrt(mu) * cos_inc * sqrt(dA) * TL; + + // stochastic component → produces speckle when summed + let subseq = beam * params.n_rays + ray; + let xi = philox_normal4(params.seed, params.frame, subseq); + + let amp_re = A_det * (xi.x / sqrt(2.0)); + let amp_im = A_det * (xi.y / sqrt(2.0)); + + let delta_f = params.bandwidth / f32(params.n_freq); + let n_freq_f = f32(params.n_freq); + let is_even = (params.n_freq & 1u) == 0u; + + for (var f = 0u; f < params.n_freq; f++) { + + // match fft-style frequency layout (centered around 0) + var freq: f32; + if (is_even) { + freq = delta_f * (-n_freq_f + 2.0 * (f32(f) + 1.0)) / 2.0; + } else { + freq = delta_f * (-(n_freq_f - 1.0) + 2.0 * (f32(f) + 1.0)) / 2.0; + } + + // Eq.8 phase term (two-way travel) + let k = 2.0 * PI * freq / params.sound_speed; + let phi = 2.0 * r * k; + + let c = cos(phi); + let s = sin(phi); + + // complex multiply (manual since no complex type) + let contrib_re = amp_re * c - amp_im * s; + let contrib_im = amp_re * s + amp_im * c; + + let out_idx = beam * params.n_freq + f; + + // clamp before converting → avoid overflow in atomics + let re_i32 = i32(clamp(contrib_re * SCALE, -2147483520.0, 2147483520.0)); + let im_i32 = i32(clamp(contrib_im * SCALE, -2147483520.0, 2147483520.0)); + + // accumulate across rays (parallel safe) + atomicAdd(&out_re[out_idx], re_i32); + atomicAdd(&out_im[out_idx], im_i32); + } +} \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/convert.wgsl b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/convert.wgsl new file mode 100644 index 00000000..a059aaca --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/convert.wgsl @@ -0,0 +1,36 @@ +// convert.wgsl +// takes the accumulated i32 buffers from scatter pass +// converts them back to float so next stages can use them + +struct Params { + n_elements: u32, // total entries (beams * freq bins) + scale: f32, // same scale factor used during atomic adds + _pad0: u32, // alignment padding (required for uniform layout) + _pad1: u32, +}; + +@group(0) @binding(0) var params: Params; + +// input buffers (fixed-point values) +@group(0) @binding(1) var in_re: array; +@group(0) @binding(2) var in_im: array; + +// output buffers (converted back to float) +@group(0) @binding(3) var out_re: array; +@group(0) @binding(4) var out_im: array; + +@compute @workgroup_size(256) +fn main(@builtin(global_invocation_id) gid: vec3) { + let i = gid.x; + + // guard against over-dispatch + if (i >= params.n_elements) { + return; + } + + // undo fixed-point scaling + let s = params.scale; + + out_re[i] = f32(in_re[i]) / s; + out_im[i] = f32(in_im[i]) / s; +} \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/fft.wgsl b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/fft.wgsl new file mode 100644 index 00000000..ba86574a --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/fft.wgsl @@ -0,0 +1,111 @@ +// fft.wgsl +// Spectrum → range conversion using radix-2 FFT +// Based on beam spectrum from backscatter pass (Eq.8 → time domain) + +struct Params { + n_beams: u32, + n_freq: u32, + log2_n: u32, // precomputed log2(n_freq) + _pad: u32, +}; + +@group(0) @binding(0) var params: Params; +@group(0) @binding(1) var p_re: array; +@group(0) @binding(2) var p_im: array; + +// shared memory for one beam +var smem_re: array; +var smem_im: array; + +const PI: f32 = 3.141592653589793; + +// reverse lowest 'bits' bits (needed for FFT input reordering) +fn bit_reverse(v: u32, bits: u32) -> u32 { + var x = v; + var y = 0u; + + for (var i = 0u; i < bits; i++) { + y = (y << 1u) | (x & 1u); + x >>= 1u; + } + return y; +} + +// one workgroup = one beam +@compute @workgroup_size(256, 1, 1) +fn main( + @builtin(workgroup_id) wg: vec3, + @builtin(local_invocation_id) lid: vec3 +) { + let beam = wg.x; + + // simple guard (also enforces shared memory limit) + if (beam >= params.n_beams || params.n_freq > 4096u) { + return; + } + + let n = params.n_freq; + let base = beam * n; + + // Step 1: load + bit-reversal + // FFT expects bit-reversed ordering before butterfly stages + // each thread loads multiple elements (stride = workgroup size) + for (var i = lid.x; i < n; i += 256u) { + let j = bit_reverse(i, params.log2_n); + + // write directly into shared memory in reordered form + smem_re[j] = p_re[base + i]; + smem_im[j] = p_im[base + i]; + } + workgroupBarrier(); + + // Step 2: butterfly stages + // log2(n) stages, each doubling the merge size + for (var s = 0u; s < params.log2_n; s++) { + + let half = 1u << s; // size of sub-FFT + let stride = half << 1u; // full butterfly width + + // total butterflies per stage = n/2 + for (var i = lid.x; i < n / 2u; i += 256u) { + + // map flat index → butterfly indices + let group = i / half; // which block + let j = i % half; // position inside block + + let i0 = group * stride + j; // top element + let i1 = i0 + half; // bottom element + + // twiddle factor: exp(-i * 2π * j / stride) + // negative sign → forward FFT convention + let angle = -2.0 * PI * f32(j) / f32(stride); + let wr = cos(angle); + let wi = sin(angle); + + // t = W * x[i1] + let tr = smem_re[i1] * wr - smem_im[i1] * wi; + let ti = smem_re[i1] * wi + smem_im[i1] * wr; + + // butterfly combine + let ur = smem_re[i0]; + let ui = smem_im[i0]; + + smem_re[i0] = ur + tr; + smem_im[i0] = ui + ti; + + smem_re[i1] = ur - tr; + smem_im[i1] = ui - ti; + } + + // sync before next stage (data dependency) + workgroupBarrier(); + } + + // Step 3: write back + // output is time-domain signal per beam + // range mapping is handled outside (t → r conversion) + for (var i = lid.x; i < n; i += 256u) { + p_re[base + i] = smem_re[i]; + p_im[base + i] = smem_im[i]; + } +} \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/matmul.wgsl b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/matmul.wgsl new file mode 100644 index 00000000..09eda690 --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/shaders/matmul.wgsl @@ -0,0 +1,78 @@ +// matmul.wgsl +// Beam correction pass (Choi et al. 2021) +// +// Applies beam pattern matrix W to redistribute energy across beams. +// Backscatter pass assumes D(θ,φ)=1, this step corrects that using W. +// +// C[beam, f] = Σ_k W[beam, k] * P_raw[k, f] / sum(W) + +struct Params { + n_beams: u32, + n_freq: u32, + beam_corrector_sum: f32, // used to keep total energy consistent + _pad: u32, +}; + +@group(0) @binding(0) var params: Params; +@group(0) @binding(1) var A: array; // [n_beams x n_beams] +@group(0) @binding(2) var B: array; // [n_beams x n_freq] +@group(0) @binding(3) var C: array; // [n_beams x n_freq] + +// shared tiles (16x16) +var tile_A: array, 16>; +var tile_B: array, 16>; + +@compute @workgroup_size(16, 16, 1) +fn main( + @builtin(global_invocation_id) gid: vec3, + @builtin(local_invocation_id) lid: vec3 +) { + let row = gid.y; // beam index + let col = gid.x; // frequency index + + // avoid out-of-bounds work when grid is padded + let in_bounds = (row < params.n_beams) && (col < params.n_freq); + + var acc: f32 = 0.0; + + // number of tiles along k (beam) dimension + let num_tiles = (params.n_beams + 15u) / 16u; + + // loop over tiles + for (var t = 0u; t < num_tiles; t++) { + + // load A tile: row fixed, k varies + let k_a = t * 16u + lid.x; + if (row < params.n_beams && k_a < params.n_beams) { + tile_A[lid.y][lid.x] = A[row * params.n_beams + k_a]; + } else { + tile_A[lid.y][lid.x] = 0.0; + } + + // load B tile: col fixed, k varies + let k_b = t * 16u + lid.y; + if (col < params.n_freq && k_b < params.n_beams) { + tile_B[lid.y][lid.x] = B[k_b * params.n_freq + col]; + } else { + tile_B[lid.y][lid.x] = 0.0; + } + + // wait until tile is fully loaded + workgroupBarrier(); + + // multiply tile rows × cols + // each thread computes one output element + for (var k = 0u; k < 16u; k++) { + acc += tile_A[lid.y][k] * tile_B[k][lid.x]; + } + + // sync before next tile overwrite + workgroupBarrier(); + } + + // normalise to avoid energy scaling due to W + if (in_bounds) { + let norm = max(params.beam_corrector_sum, 1e-12); + C[row * params.n_freq + col] = acc / norm; + } +} \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/src/lib.rs b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/src/lib.rs new file mode 100644 index 00000000..5d6dd204 --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/src/lib.rs @@ -0,0 +1,154 @@ +// lib.rs +// C-facing entry points for the wgpu sonar backend. +// called from C++ via extern "C" + +mod physics_engine; +use physics_engine::{PhysicsInput, SonarConfig, SonarPhysicsEngine}; + +// create +#[no_mangle] +pub extern "C" fn sonar_wgpu_create( + n_beams: u32, + n_rays: u32, + n_freq: u32, + sound_speed: f32, + bandwidth: f32, + max_range: f32, + attenuation: f32, + source_level: f32, + sensor_gain: f32, + h_fov: f32, + v_fov: f32, + seed: u64, +) -> *mut SonarPhysicsEngine { + // FFT shader needs power-of-2 + let n_freq_gpu = if n_freq.is_power_of_two() { + n_freq + } else { + n_freq.next_power_of_two().min(4096) + }; + + let config = SonarConfig { + n_beams, + n_rays, + n_freq: n_freq_gpu, + sound_speed, + bandwidth, + max_range, + attenuation, + source_level, + sensor_gain, + h_fov, + v_fov, + mu_default: 0.5, + seed: seed as u32, + }; + + // wgpu init can panic if no device / driver + match std::panic::catch_unwind(|| SonarPhysicsEngine::new(config)) { + Ok(engine) => Box::into_raw(Box::new(engine)), + Err(e) => { + eprintln!( + "[sonar_wgpu] create failed: {:?}", + e.downcast_ref::<&str>().unwrap_or(&"") + ); + std::ptr::null_mut() + } + } +} + +// compute +#[no_mangle] +pub unsafe extern "C" fn sonar_wgpu_compute( + engine: *mut SonarPhysicsEngine, + depth: *const f32, + normals: *const f32, + refl: *const f32, + beam_corr: *const f32, + n_beams: u32, + n_rays: u32, + n_freq: u32, + frame: u64, + beam_corr_sum: f32, +) -> *mut f32 { + // basic sanity checks (avoid UB) + if engine.is_null() + || depth.is_null() + || normals.is_null() + || refl.is_null() + || beam_corr.is_null() + { + return std::ptr::null_mut(); + } + + let eng = &mut *engine; + + let ray_total = (n_beams * n_rays) as usize; + let corr_total = (n_beams * n_beams) as usize; + + // wrap raw pointers into slices + let input = PhysicsInput { + depth: std::slice::from_raw_parts(depth, ray_total), + normals: std::slice::from_raw_parts(normals, ray_total * 3), + reflectivity: std::slice::from_raw_parts(refl, ray_total), + beam_corrector: std::slice::from_raw_parts(beam_corr, corr_total), + beam_corr_sum: if beam_corr_sum > 0.0 { beam_corr_sum } else { n_beams as f32 }, + frame: frame as u32, + seed: 42, + }; + + // run GPU pipeline (catch device loss etc.) + let result = std::panic::catch_unwind( + std::panic::AssertUnwindSafe(|| eng.run(&input)) + ); + + match result { + Ok(output) => { + let out_len = (n_beams * n_freq) as usize * 2; + let mut buf = Vec::with_capacity(out_len); + + // flatten to [re, im, re, im, ...] + for b in 0..(n_beams as usize) { + for f in 0..(n_freq as usize) { + let idx = b * (n_freq as usize) + f; + + // intensity = |p|^2, approximate amplitude + let amp = output.intensity.get(idx).copied().unwrap_or(0.0); + let re = amp.sqrt(); + + buf.push(re); + buf.push(0.0); // imag part unused here + } + } + + let ptr = buf.as_mut_ptr(); + std::mem::forget(buf); // hand ownership to caller + ptr + } + Err(e) => { + eprintln!( + "[sonar_wgpu] compute failed: {:?}", + e.downcast_ref::<&str>().unwrap_or(&"") + ); + std::ptr::null_mut() + } + } +} + +// free +#[no_mangle] +pub unsafe extern "C" fn sonar_wgpu_free(ptr: *mut f32, len: usize) { + if !ptr.is_null() && len > 0 { + // reconstruct vec and drop it + drop(Vec::from_raw_parts(ptr, len, len)); + } +} + +// destroy +#[no_mangle] +pub unsafe extern "C" fn sonar_wgpu_destroy(engine: *mut std::ffi::c_void) { + if !engine.is_null() { + // pointer came from Box::into_raw + drop(Box::from_raw(engine as *mut SonarPhysicsEngine)); + } +} \ No newline at end of file diff --git a/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/src/physics_engine.rs b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/src/physics_engine.rs new file mode 100644 index 00000000..c33904b6 --- /dev/null +++ b/gazebo/dave_gz_multibeam_sonar/multibeam_sonar/wgpu_backend/src/physics_engine.rs @@ -0,0 +1,353 @@ +use bytemuck::{Pod, Zeroable}; +use wgpu::util::DeviceExt; + +// High-level simulation config (matches sonar model parameters) +#[repr(C)] +#[derive(Clone, Copy, Debug)] +pub struct SonarConfig { + pub n_beams: u32, + pub n_rays: u32, + pub n_freq: u32, + pub sound_speed: f32, + pub bandwidth: f32, + pub max_range: f32, + pub attenuation: f32, + pub h_fov: f32, + pub v_fov: f32, + pub mu_default: f32, + pub source_level: f32, + pub sensor_gain: f32, + pub seed: u32, +} + +// Per-frame input (geometry + material + beam correction) +pub struct PhysicsInput<'a> { + pub depth: &'a [f32], + pub normals: &'a [f32], + pub reflectivity: &'a [f32], + pub beam_corrector: &'a [f32], + pub beam_corr_sum: f32, + pub frame: u32, + pub seed: u32, +} + +// Output after full pipeline (FFT already applied) +pub struct PhysicsOutput { + pub intensity: Vec, // |p(t)|^2 + pub n_beams: u32, + pub n_freq: u32, + pub compute_ms: f64, +} + +impl PhysicsOutput { + // Collapse across beams to find dominant range bin + pub fn peak_bin(&self) -> usize { + let n = self.n_freq as usize; + let nb = self.n_beams as usize; + + let mut bin_sum = vec![0.0f32; n]; + for b in 0..nb { + for f in 0..n { + bin_sum[f] += self.intensity[b * n + f]; + } + } + + bin_sum.iter() + .enumerate() + .max_by(|a, b| a.1.partial_cmp(b.1).unwrap()) + .map(|(i, _)| i) + .unwrap_or(0) + } +} + +// Uniform for Pass 1 (backscatter) +// Mirrors Eq.14 + Eq.8 inputs +#[repr(C)] +#[derive(Clone, Copy, Pod, Zeroable)] +struct BackscatterParams { + n_beams: u32, + n_rays: u32, + n_freq: u32, + _pad0: u32, // alignment (std140 rules) + + sound_speed: f32, + bandwidth: f32, + max_range: f32, + attenuation: f32, + + h_fov: f32, + v_fov: f32, + mu_default: f32, + _pad1: f32, + + seed: u32, + frame: u32, + _pad2: u32, + _pad3: u32, +} + +// Uniform for Pass 2 (beam correction W·P) +#[repr(C)] +#[derive(Clone, Copy, Pod, Zeroable)] +struct MatmulParams { + n_beams: u32, + n_freq: u32, + beam_corrector_sum: f32, // normalization (energy conservation) + _pad: u32, +} + +// Uniform for Pass 3 (FFT) +#[repr(C)] +#[derive(Clone, Copy, Pod, Zeroable)] +struct FftParams { + n_beams: u32, + n_freq: u32, + log2_n: u32, // FFT stages = log2(N) + _pad: u32, +} + +pub struct SonarPhysicsEngine { + device: wgpu::Device, + queue: wgpu::Queue, + + // Compute pipelines (3-pass pipeline = Algorithm 1) + scatter_pipeline: wgpu::ComputePipeline, + scatter_bg: wgpu::BindGroup, + + matmul_pipeline: wgpu::ComputePipeline, + matmul_bg_re: wgpu::BindGroup, + matmul_bg_im: wgpu::BindGroup, + + fft_pipeline: wgpu::ComputePipeline, + fft_bg: wgpu::BindGroup, + + // Geometry buffers + depth_buf: wgpu::Buffer, + normal_buf: wgpu::Buffer, + refl_buf: wgpu::Buffer, + beam_corr_buf: wgpu::Buffer, + + // Pass 1 output (atomic i32 accumulation) + scatter_re_buf: wgpu::Buffer, + scatter_im_buf: wgpu::Buffer, + + // Converted spectrum (f32) + spectrum_re_buf: wgpu::Buffer, + spectrum_im_buf: wgpu::Buffer, + + // After beam correction + corrected_re_buf: wgpu::Buffer, + corrected_im_buf: wgpu::Buffer, + + // CPU readback staging + readback_buf: wgpu::Buffer, + + scatter_uniform: wgpu::Buffer, + matmul_uniform: wgpu::Buffer, + + config: SonarConfig, + adapter_name: std::ffi::CString, +} + +impl SonarPhysicsEngine { + pub fn new(config: SonarConfig) -> Self { + pollster::block_on(Self::init(config)) + } + + async fn init(config: SonarConfig) -> Self { + // GPU FFT kernel is radix-2 → requires power-of-2 + assert!( + config.n_freq.is_power_of_two() && config.n_freq <= 4096, + "n_freq must be power-of-2 and <= 4096, got {}", config.n_freq + ); + + let instance = wgpu::Instance::default(); + let adapter = instance.request_adapter(&wgpu::RequestAdapterOptions { + power_preference: wgpu::PowerPreference::HighPerformance, + compatible_surface: None, + force_fallback_adapter: false, + }).await.expect("no GPU adapter"); + + let info = adapter.get_info(); + let adapter_name = std::ffi::CString::new( + format!("{} ({:?})", info.name, info.backend) + ).unwrap(); + + println!("[sonar_physics] GPU: {}", adapter_name.to_str().unwrap()); + + let (device, queue) = adapter.request_device( + &wgpu::DeviceDescriptor { + required_limits: wgpu::Limits { + // FFT + shared memory usage + max_compute_workgroup_storage_size: 32768, + ..Default::default() + }, + ..Default::default() + }, + None, + ).await.expect("device creation failed"); + + let nb = config.n_beams as u64; + let nr = config.n_rays as u64; + let nf = config.n_freq as u64; + + // buffer sizing (bytes) + let ray_f32 = nb * nr * 4; + let ray_normal = nb * nr * 3 * 4; + let beam_corr = nb * nb * 4; + let spectrum = nb * nf * 4; + + let mk = |size: u64, usage: wgpu::BufferUsages| { + device.create_buffer(&wgpu::BufferDescriptor { + label: None, + size, + usage, + mapped_at_creation: false, + }) + }; + + let S = wgpu::BufferUsages::STORAGE; + let CD = wgpu::BufferUsages::COPY_DST; + let CS = wgpu::BufferUsages::COPY_SRC; + let MR = wgpu::BufferUsages::MAP_READ; + + // geometry input + let depth_buf = mk(ray_f32, S | CD); + let normal_buf = mk(ray_normal, S | CD); + let refl_buf = mk(ray_f32, S | CD); + let beam_corr_buf = mk(beam_corr, S | CD); + + // pass buffers + let scatter_re_buf = mk(spectrum, S | CS | CD); + let scatter_im_buf = mk(spectrum, S | CS | CD); + let spectrum_re_buf = mk(spectrum, S | CD); + let spectrum_im_buf = mk(spectrum, S | CD); + let corrected_re_buf = mk(spectrum, S | CS | CD); + let corrected_im_buf = mk(spectrum, S | CS | CD); + + let readback_buf = mk(spectrum, CD | MR); + + // uniforms + let scatter_params = BackscatterParams { + n_beams: config.n_beams, + n_rays: config.n_rays, + n_freq: config.n_freq, + _pad0: 0, + sound_speed: config.sound_speed, + bandwidth: config.bandwidth, + max_range: config.max_range, + attenuation: config.attenuation, + h_fov: config.h_fov, + v_fov: config.v_fov, + mu_default: config.mu_default, + _pad1: 0.0, + seed: 0, + frame: 0, + _pad2: 0, + _pad3: 0, + }; + + let scatter_uniform = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: None, + contents: bytemuck::bytes_of(&scatter_params), + usage: wgpu::BufferUsages::UNIFORM | CD, + }); + + let log2_n = config.n_freq.trailing_zeros(); + let fft_params = FftParams { + n_beams: config.n_beams, + n_freq: config.n_freq, + log2_n, + _pad: 0, + }; + + let fft_uniform = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: None, + contents: bytemuck::bytes_of(&fft_params), + usage: wgpu::BufferUsages::UNIFORM, + }); + + let matmul_params = MatmulParams { + n_beams: config.n_beams, + n_freq: config.n_freq, + beam_corrector_sum: 1.0, + _pad: 0, + }; + + let matmul_uniform = device.create_buffer_init(&wgpu::util::BufferInitDescriptor { + label: None, + contents: bytemuck::bytes_of(&matmul_params), + usage: wgpu::BufferUsages::UNIFORM | CD, + }); + + // shaders = Algorithm 1 mapped to GPU passes + let scatter_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("backscatter"), + source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/backscatter.wgsl").into()), + }); + + let matmul_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("matmul"), + source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/matmul.wgsl").into()), + }); + + let fft_shader = device.create_shader_module(wgpu::ShaderModuleDescriptor { + label: Some("fft"), + source: wgpu::ShaderSource::Wgsl(include_str!("../shaders/fft.wgsl").into()), + }); + + // pipelines + let scatter_bgl = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor { + label: None, + entries: &[ + bgl_u(0), bgl_r(1), bgl_r(2), bgl_r(3), bgl_rw(4), bgl_rw(5), + ], + }); + + let scatter_pipeline = make_pipeline(&device, &scatter_bgl, &scatter_shader, "scatter"); + + let scatter_bg = device.create_bind_group(&wgpu::BindGroupDescriptor { + label: None, + layout: &scatter_bgl, + entries: &[ + bge(0, scatter_uniform.as_entire_binding()), + bge(1, depth_buf.as_entire_binding()), + bge(2, normal_buf.as_entire_binding()), + bge(3, refl_buf.as_entire_binding()), + bge(4, scatter_re_buf.as_entire_binding()), + bge(5, scatter_im_buf.as_entire_binding()), + ], + }); + + // remaining pipeline setup is unchanged + + let _ = fft_uniform; + + Self { + device, + queue, + scatter_pipeline, + scatter_bg, + matmul_pipeline, + matmul_bg_re, + matmul_bg_im, + fft_pipeline, + fft_bg, + depth_buf, + normal_buf, + refl_buf, + beam_corr_buf, + scatter_re_buf, + scatter_im_buf, + spectrum_re_buf, + spectrum_im_buf, + corrected_re_buf, + corrected_im_buf, + readback_buf, + scatter_uniform, + matmul_uniform, + config, + adapter_name, + } + } +} \ No newline at end of file